From d0c99479dba2a7985aa97c42364949d0398a602e Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 27 Aug 2024 03:30:00 +0300 Subject: [PATCH 01/83] NOISSUE - Remove CID tracking (#218) * remove cid tracking Signed-off-by: Sammy Oina * remove unused code Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cmd/agent/main.go | 2 +- internal/logger/protohandler.go | 19 +++++++++++-------- manager/agentEventsLogs.go | 29 +++-------------------------- manager/service.go | 4 ---- manager/service_test.go | 2 -- 5 files changed, 15 insertions(+), 41 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 06b1e1b22..12d52a63b 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -54,7 +54,7 @@ func main() { log.Println(err) return } - handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}) + handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID) logger := slog.New(handler) eventSvc, err := events.New(svcName, cfg.ID, manager.ManagerVsockPort) diff --git a/internal/logger/protohandler.go b/internal/logger/protohandler.go index db0b9a5e0..8647daf43 100644 --- a/internal/logger/protohandler.go +++ b/internal/logger/protohandler.go @@ -15,17 +15,19 @@ import ( var _ slog.Handler = (*handler)(nil) type handler struct { - opts slog.HandlerOptions - w io.Writer + opts slog.HandlerOptions + w io.Writer + cmpID string } -func NewProtoHandler(w io.Writer, opts *slog.HandlerOptions) slog.Handler { +func NewProtoHandler(w io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler { if opts == nil { opts = &slog.HandlerOptions{} } return &handler{ - opts: *opts, - w: w, + opts: *opts, + w: w, + cmpID: cmpID, } } @@ -62,9 +64,10 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { agentLog := manager.ClientStreamMessage{ Message: &manager.ClientStreamMessage_AgentLog{ AgentLog: &manager.AgentLog{ - Timestamp: timestamp, - Message: chunk, - Level: level, + Timestamp: timestamp, + Message: chunk, + Level: level, + ComputationId: h.cmpID, }, }, } diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index 6d8b7b235..ee341f41c 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -5,10 +5,7 @@ package manager import ( "fmt" "net" - "regexp" - "strconv" - "github.com/absmach/magistrala/pkg/errors" "github.com/mdlayher/vsock" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/proto" @@ -19,8 +16,6 @@ const ( messageSize int = 1024 ) -var errFailedToParseCID = errors.New("failed to parse cid from remote address") - // RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock. func (ms *managerService) RetrieveAgentEventsLogs() { l, err := vsock.Listen(ManagerVsockPort, nil) @@ -49,22 +44,18 @@ func (ms *managerService) handleConnections(conn net.Conn) { ms.logger.Warn(err.Error()) return } - cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String()) - if err != nil { - ms.logger.Warn(err.Error()) - continue - } var message manager.ClientStreamMessage if err := proto.Unmarshal(b[:n], &message); err != nil { ms.logger.Warn(err.Error()) continue } + cmpID := "" switch mes := message.Message.(type) { case *manager.ClientStreamMessage_AgentEvent: - mes.AgentEvent.ComputationId = cmpID + cmpID = mes.AgentEvent.ComputationId ms.eventsChan <- &manager.ClientStreamMessage{Message: mes} case *manager.ClientStreamMessage_AgentLog: - mes.AgentLog.ComputationId = cmpID + cmpID = mes.AgentLog.ComputationId ms.eventsChan <- &manager.ClientStreamMessage{Message: mes} default: ms.logger.Warn("Unexpected agent log or event type") @@ -73,17 +64,3 @@ func (ms *managerService) handleConnections(conn net.Conn) { ms.logger.Info(fmt.Sprintf("Agent Log/Event, Computation ID: %s, Message: %s", cmpID, message.String())) } } - -func (ms *managerService) computationIDFromAddress(address string) (string, error) { - re := regexp.MustCompile(`vm\((\d+)\)`) - matches := re.FindStringSubmatch(address) - - if len(matches) > 1 { - cid, err := strconv.Atoi(matches[1]) - if err != nil { - return "", err - } - return ms.agents[cid], nil - } - return "", errFailedToParseCID -} diff --git a/manager/service.go b/manager/service.go index af9225cb3..93740bfbd 100644 --- a/manager/service.go +++ b/manager/service.go @@ -63,7 +63,6 @@ type managerService struct { qemuCfg qemu.Config backendMeasurementBinaryPath string logger *slog.Logger - agents map[int]string // agent map of vsock cid to computationID. eventsChan chan *manager.ClientStreamMessage vms map[string]vm.VM vmFactory vm.Provider @@ -82,7 +81,6 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, ms := &managerService{ qemuCfg: cfg, logger: logger, - agents: make(map[int]string), vms: make(map[string]vm.VM), eventsChan: eventsChan, vmFactory: vmFactory, @@ -147,8 +145,6 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) } ms.vms[c.Id] = cvm - ms.agents[ms.qemuCfg.VSockConfig.GuestCID] = c.Id - err = backoff.Retry(func() error { return cvm.SendAgentConfig(ac) }, backoff.NewExponentialBackOff()) diff --git a/manager/service_test.go b/manager/service_test.go index e01ac1f7b..85abc60c8 100644 --- a/manager/service_test.go +++ b/manager/service_test.go @@ -92,7 +92,6 @@ func TestRun(t *testing.T) { ms := &managerService{ qemuCfg: qemuCfg, logger: logger, - agents: make(map[int]string), vms: make(map[string]vm.VM), eventsChan: eventsChan, vmFactory: vmf.Execute, @@ -110,7 +109,6 @@ func TestRun(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, port) assert.Len(t, ms.vms, 1) - assert.Len(t, ms.agents, 1) } vmf.AssertExpectations(t) From 7a2789fb5fda48282ef0c1d516aa8ba36421f5f1 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 28 Aug 2024 19:16:15 +0300 Subject: [PATCH 02/83] NOISSUE - Remove race condition returning before all go routines have completed (#221) * remove race condition returning before all go routines have completed Signed-off-by: Sammy Oina * refine Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- manager/service.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/manager/service.go b/manager/service.go index 93740bfbd..09ab27789 100644 --- a/manager/service.go +++ b/manager/service.go @@ -179,15 +179,17 @@ func getFreePort(minPort, maxPort int) (int, error) { } var wg sync.WaitGroup - portCh := make(chan int, maxPort-minPort+1) + portCh := make(chan int, 1) for port := minPort; port <= maxPort; port++ { wg.Add(1) go func(p int) { defer wg.Done() - if checkPortisFree(p) { - portCh <- p + select { + case portCh <- p: + default: + } } }(port) } @@ -197,12 +199,12 @@ func getFreePort(minPort, maxPort int) (int, error) { close(portCh) }() - select { - case port := <-portCh: - return port, nil - default: + port, ok := <-portCh + if !ok { return 0, fmt.Errorf("failed to find free port in range %d-%d", minPort, maxPort) } + + return port, nil } func checkPortisFree(port int) bool { From bdfc5fd06d0eba3dfa6338bf5d9b1db47301b776 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 29 Aug 2024 00:11:49 +0300 Subject: [PATCH 03/83] run manager using systemd (#213) Signed-off-by: Sammy Oina --- .env | 28 --------------- Makefile | 27 +++++++++++++- cocos-manager.env | 58 ++++++++++++++++++++++++++++++ init/systemd/cocos-manager.service | 12 +++++++ 4 files changed, 96 insertions(+), 29 deletions(-) delete mode 100644 .env create mode 100644 cocos-manager.env create mode 100644 init/systemd/cocos-manager.service diff --git a/.env b/.env deleted file mode 100644 index 35bb02016..000000000 --- a/.env +++ /dev/null @@ -1,28 +0,0 @@ -## Jaeger -COCOS_JAEGER_PORT=6831 -COCOS_JAEGER_FRONTEND=16686 -COCOS_JAEGER_COLLECTOR=14268 -COCOS_JAEGER_CONFIGS=5778 -COCOS_JAEGER_URL=http://jaeger:4318 -COCOS_JAEGER_TRACE_RATIO=1.0 -COCOS_JAEGER_COLLECTOR_OTLP_ENABLED=true -COCOS_JAEGER_OLTP_HTTP_PORT=4318 - -## Core Services - -### Manager -MANAGER_GRPC_HOST="" -MANAGER_GRPC_PORT=7003 -MANAGER_GRPC_SERVER_CERT="" -MANAGER_GRPC_SERVER_KEY="" -AGENT_GRPC_URL="localhost:7002" -AGENT_GRPC_TIMEOUT="" -AGENT_GRPC_CA_CERTS="" -AGENT_GRPC_CLIENT_TLS="" -MANAGER_INSTANCE_ID="" -MANAGER_LOG_LEVEL=debug -MANAGER_QEMU_USE_SUDO=false -MANAGER_QEMU_ENABLE_SEV=false -MANAGER_QEMU_SEV_CBITPOS=51 -MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/OVMF/OVMF_CODE.fd -MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/OVMF/OVMF_VARS.fd diff --git a/Makefile b/Makefile index 7e81c2b81..85b325bc2 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,11 @@ VERSION ?= $(shell git describe --abbrev=0 --tags --always) COMMIT ?= $(shell git rev-parse HEAD) TIME ?= $(shell date +%F_%T) EMBED_ENABLED ?= 0 +INSTALL_DIR ?= /usr/local/bin +CONFIG_DIR ?= /etc/cocos +SERVICE_NAME ?= cocos-manager +SERVICE_DIR ?= /etc/systemd/system +SERVICE_FILE = init/systemd/$(SERVICE_NAME).service define compile_service CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) \ @@ -18,7 +23,7 @@ define compile_service -o ${BUILD_DIR}/cocos-$(1) cmd/$(1)/main.go endef -.PHONY: all $(SERVICES) $(BACKEND_INFO) +.PHONY: all $(SERVICES) $(BACKEND_INFO) install clean all: $(SERVICES) @@ -34,3 +39,23 @@ protoc: mocks: go generate ./... + +install: $(SERVICES) + install -d $(INSTALL_DIR) + install $(BUILD_DIR)/cocos-cli $(INSTALL_DIR)/cocos-cli + install $(BUILD_DIR)/cocos-manager $(INSTALL_DIR)/cocos-manager + install -d $(CONFIG_DIR) + install cocos-manager.env $(CONFIG_DIR)/cocos-manager.env + +clean: + rm -rf $(BUILD_DIR) + +run: install_service + sudo systemctl start $(SERVICE_NAME).service + +stop: + sudo systemctl stop $(SERVICE_NAME).service + +install_service: + sudo install -m 644 $(SERVICE_FILE) $(SERVICE_DIR)/$(SERVICE_NAME).service + sudo systemctl daemon-reload diff --git a/cocos-manager.env b/cocos-manager.env new file mode 100644 index 000000000..adbac897a --- /dev/null +++ b/cocos-manager.env @@ -0,0 +1,58 @@ +# Environment Configuration for Cocos + +# Jaeger Tracing +COCOS_JAEGER_URL=http://localhost:4318 +COCOS_JAEGER_TRACE_RATIO=1.0 + +# Manager Service Configuration +MANAGER_INSTANCE_ID= +MANAGER_BACKEND_MEASUREMENT_BINARY=../../build +MANAGER_GRPC_CLIENT_CERT= +MANAGER_GRPC_CLIENT_KEY= +MANAGER_GRPC_SERVER_CA_CERTS= +MANAGER_GRPC_URL=localhost:7001 +MANAGER_GRPC_TIMEOUT=60s + +# QEMU Configuration +MANAGER_QEMU_MEMORY_SIZE=25G +MANAGER_QEMU_MEMORY_SLOTS=5 +MANAGER_QEMU_MAX_MEMORY=30G +MANAGER_QEMU_OVMF_CODE_IF=pflash +MANAGER_QEMU_OVMF_CODE_FORMAT=raw +MANAGER_QEMU_OVMF_CODE_UNIT=0 +MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/OVMF/x64/OVMF_CODE.fd +MANAGER_QEMU_OVMF_CODE_READONLY=on +MANAGER_QEMU_OVMF_VARS_IF=pflash +MANAGER_QEMU_OVMF_VARS_FORMAT=raw +MANAGER_QEMU_OVMF_VARS_UNIT=1 +MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/OVMF/x64/OVMF_VARS.fd +MANAGER_QEMU_NETDEV_ID=vmnic +MANAGER_QEMU_HOST_FWD_AGENT=7020 +MANAGER_QEMU_GUEST_FWD_AGENT=7002 +MANAGER_QEMU_VIRTIO_NET_PCI_DISABLE_LEGACY=on +MANAGER_QEMU_VIRTIO_NET_PCI_IOMMU_PLATFORM=true +MANAGER_QEMU_VIRTIO_NET_PCI_ADDR=0x2 +MANAGER_QEMU_VIRTIO_NET_PCI_ROMFILE= +MANAGER_QEMU_DISK_IMG_KERNEL_FILE=/home/sammyk/Documents/cocos-ai/cmd/manager/img/bzImage +MANAGER_QEMU_DISK_IMG_ROOTFS_FILE=/home/sammyk/Documents/cocos-ai/cmd/manager/img/rootfs.cpio.gz +MANAGER_QEMU_SEV_ID=sev0 +MANAGER_QEMU_SEV_CBITPOS=51 +MANAGER_QEMU_SEV_REDUCED_PHYS_BITS=1 +MANAGER_QEMU_HOST_DATA= +MANAGER_QEMU_VSOCK_ID=vhost-vsock-pci0 +MANAGER_QEMU_VSOCK_GUEST_CID=3 +MANAGER_QEMU_VSOCK_VNC=0 +MANAGER_QEMU_BIN_PATH=qemu-system-x86_64 +MANAGER_QEMU_USE_SUDO=true +MANAGER_QEMU_ENABLE_SEV=false +MANAGER_QEMU_ENABLE_SEV_SNP=false +MANAGER_QEMU_ENABLE_KVM=true +MANAGER_QEMU_MACHINE=q35 +MANAGER_QEMU_CPU=EPYC +MANAGER_QEMU_SMP_COUNT=4 +MANAGER_QEMU_SMP_MAXCPUS=16 +MANAGER_QEMU_MEM_ID=ram1 +MANAGER_QEMU_KERNEL_HASH=false +MANAGER_QEMU_NO_GRAPHIC=true +MANAGER_QEMU_MONITOR=pty +MANAGER_QEMU_HOST_FWD_RANGE=6100-6200 diff --git a/init/systemd/cocos-manager.service b/init/systemd/cocos-manager.service new file mode 100644 index 000000000..913a7349b --- /dev/null +++ b/init/systemd/cocos-manager.service @@ -0,0 +1,12 @@ +[Unit] +Description=Cocos Manager +After=network.target + +[Service] +ExecStart=cocos-manager +Restart=on-failure +RestartSec=5s +EnvironmentFile=/etc/cocos/cocos-manager.env + +[Install] +WantedBy=multi-user.target From 742bba5f00e4fb1c075677b04d102079717e22d8 Mon Sep 17 00:00:00 2001 From: b1ackd0t <28790446+rodneyosodo@users.noreply.github.com> Date: Thu, 29 Aug 2024 23:32:31 +0300 Subject: [PATCH 04/83] NOISSUE - Add Dockerfile For IRIS Example (#220) * feat(Docker): Add Dockerfile for testing Add Dockerfile for testing linear regression algorithm Signed-off-by: Rodney Osodo * fix(docs): Update docker linear regression example Resolves https://github.com/ultravioletrs/cocos/pull/220#discussion_r1732974631 --------- Signed-off-by: Rodney Osodo --- .gitignore | 1 + test/manual/algo/Dockerfile | 15 +++++++++ test/manual/algo/README.md | 67 +++++++++++++++++++++++-------------- 3 files changed, 57 insertions(+), 26 deletions(-) create mode 100644 test/manual/algo/Dockerfile diff --git a/.gitignore b/.gitignore index 4c481b13d..f2818ed77 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ cmd/manager/img dist/ results.zip *.spec +*.tar diff --git a/test/manual/algo/Dockerfile b/test/manual/algo/Dockerfile new file mode 100644 index 000000000..c46f74a95 --- /dev/null +++ b/test/manual/algo/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.9-slim + +# set the working directory in the container +WORKDIR /cocos +RUN mkdir /cocos/results +RUN mkdir /cocos/datasets + +COPY ./requirements.txt /cocos/requirements.txt +COPY ./lin_reg.py /cocos/lin_reg.py + +# install dependencies +RUN pip install -r requirements.txt + +# command to be run when the docker container is started +CMD ["python3", "/cocos/lin_reg.py"] diff --git a/test/manual/algo/README.md b/test/manual/algo/README.md index e416556f4..8534d4138 100644 --- a/test/manual/algo/README.md +++ b/test/manual/algo/README.md @@ -49,7 +49,17 @@ This command is run from the root directory of the project. This will start the In another window, you can run the following command: ```bash -sudo MANAGER_QEMU_SMP_MAXCPUS=4 MANAGER_GRPC_URL=localhost:7001 MANAGER_LOG_LEVEL=debug MANAGER_QEMU_USE_SUDO=false MANAGER_QEMU_ENABLE_SEV=false MANAGER_QEMU_SEV_CBITPOS=51 MANAGER_QEMU_ENABLE_SEV_SNP=false MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/edk2/x64/OVMF_CODE.fd MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/edk2/x64/OVMF_VARS.fd go run main.go +sudo \ +MANAGER_QEMU_SMP_MAXCPUS=4 \ +MANAGER_GRPC_URL=localhost:7001 \ +MANAGER_LOG_LEVEL=debug \ +MANAGER_QEMU_USE_SUDO=false \ +MANAGER_QEMU_ENABLE_SEV=false \ +MANAGER_QEMU_SEV_CBITPOS=51 \ +MANAGER_QEMU_ENABLE_SEV_SNP=false \ +MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/edk2/x64/OVMF_CODE.fd \ +MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/edk2/x64/OVMF_VARS.fd \ +go run main.go ``` This command is run from the [manager main directory](../../../cmd/manager/). This will start the manager. Make sure you have already built the [qemu image](../../../hal/linux/README.md). @@ -99,6 +109,7 @@ For addition example, you can use the following command: ## Docker Example Here we will use the docker with the linear regression example (`lin_reg.py`). Throughout the example, we assume that our current working directory is the directory in which the `cocos` repository is cloned. For example: + ```bash # ls cocos @@ -106,36 +117,34 @@ cocos The docker image must have a `cocos` directory containing the `datasets` and `results` directories. The Agent will run this image inside the SVM and will mount the datasets and results onto the `/cocos/datasets` and `/cocos/results` directories inside the image. The docker image must also contain the command that will be run when the docker container is run. -The first step is to create a docker file. Use your favorite editor to create a file named `Dockerfile` in the current working directory and write in it the following code: +Run the build command and then save the docker image as a `tar` file. ```bash -FROM python:3.9-slim - -# set the working directory in the container -WORKDIR /cocos -RUN mkdir /cocos/results -RUN mkdir /cocos/datasets - -COPY ./cocos/test/manual/algo/requirements.txt /cocos/requirements.txt -COPY ./cocos/test/manual/algo/lin_reg.py /cocos/lin_reg.py - -# install dependencies -RUN pip install -r requirements.txt - -# command to be run when the docker container is started -CMD ["python3", "/cocos/lin_reg.py"] +cd test/manual/algo/ +docker build -t linreg . +docker save linreg > linreg.tar ``` -Next, run the build command and then save the docker image as a `tar` file. +To run the examples in the secure VM (SVM) by the Agent, you can use the following command in cocos root directory `/cocos`: + ```bash -docker build -t linreg . -docker save linreg > linreg.tar +go run ./test/computations/main.go ./test/manual/algo/linreg.tar public.pem false ./test/manual/data/iris.csv ``` -In another window, you can run the following command: +In another window, you can run the following command in the `cmd/manager` directory: ```bash -sudo MANAGER_QEMU_SMP_MAXCPUS=4 MANAGER_GRPC_URL=localhost:7001 MANAGER_LOG_LEVEL=debug MANAGER_QEMU_USE_SUDO=false MANAGER_QEMU_ENABLE_SEV=false MANAGER_QEMU_SEV_CBITPOS=51 MANAGER_QEMU_ENABLE_SEV_SNP=false MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/edk2/x64/OVMF_CODE.fd MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/edk2/x64/OVMF_VARS.fd go run main.go +sudo \ +MANAGER_QEMU_SMP_MAXCPUS=4 \ +MANAGER_GRPC_URL=localhost:7001 \ +MANAGER_LOG_LEVEL=debug \ +MANAGER_QEMU_USE_SUDO=false \ +MANAGER_QEMU_ENABLE_SEV=false \ +MANAGER_QEMU_SEV_CBITPOS=51 \ +MANAGER_QEMU_ENABLE_SEV_SNP=false \ +MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/edk2/x64/OVMF_CODE.fd \ +MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/edk2/x64/OVMF_VARS.fd \ +go run main.go ``` This command is run from the [manager main directory](../../../cmd/manager/). This will start the manager. Make sure you have already built the [qemu image](../../../hal/linux/README.md). @@ -143,7 +152,7 @@ This command is run from the [manager main directory](../../../cmd/manager/). Th In another window, specify what kind of algorithm you want the Agent to run (docker): ```bash -./cocos/build/cocos-cli algo ./linreg.tar ./cocos/private.pem -a docker +./cocos/build/cocos-cli algo ./test/manual/algo/linreg.tar ./cocos/private.pem -a docker ``` make sure you have built the cocos-cli. This will upload the docker image. @@ -151,21 +160,27 @@ make sure you have built the cocos-cli. This will upload the docker image. Next we need to upload the dataset ```bash -./cocos/build/cocos-cli data ./cocos/test/manual/data/iris.csv ./cocos/private.pem +./cocos/build/cocos-cli data ./test/manual/data/iris.csv ./cocos/private.pem ``` After some time when the results are ready, you can run the following command to get the results: ```bash -./cocos/build/cocos-cli results ./cocos/private.pem +./cocos/build/cocos-cli results ./private.pem ``` This will return the results of the algorithm. +Unzip the results + +```bash +unzip results.zip -d results +``` + To make inference on the results, you can use the following command: ```bash -python3 ./cocos/test/manual/algo/lin_reg.py predict result.zip ./cocos/test/manual/data +python3 ./test/manual/algo/lin_reg.py predict results/model.bin test/manual/data/ ``` ## Wasm Example From dc349e1f1faee19e46989930cb41afb4ee1d19ff Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:29:39 +0300 Subject: [PATCH 05/83] NOISSUE - V-Sock reconnect for agent (#215) * vsock reconnect Signed-off-by: Sammy Oina * use backoff Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cmd/agent/main.go | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 12d52a63b..22621491a 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -9,8 +9,10 @@ import ( "io" "log" "log/slog" + "time" "github.com/absmach/magistrala/pkg/prometheus" + "github.com/cenkalti/backoff/v4" "github.com/google/go-sev-guest/client" "github.com/mdlayher/vsock" "github.com/ultravioletrs/cocos/agent" @@ -32,6 +34,7 @@ import ( const ( svcName = "agent" defSvcGRPCPort = "7002" + retryInterval = 5 * time.Second ) func main() { @@ -43,7 +46,7 @@ func main() { log.Fatalf("failed to read agent configuration from vsock %s", err.Error()) } - conn, err := vsock.Dial(vsock.Host, manager.ManagerVsockPort, nil) + conn, err := dialVsock() if err != nil { log.Fatal(err) } @@ -95,6 +98,22 @@ func main() { gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, qp, authSvc) + g.Go(func() error { + for { + if _, err := io.Copy(io.Discard, conn); err != nil { + log.Printf("vsock connection lost: %v, reconnecting...", err) + conn.Close() + conn, err = dialVsock() + if err != nil { + log.Fatal("failed to reconnect: ", err) + } + handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}) + logger = slog.New(handler) + } + time.Sleep(retryInterval) + } + }) + g.Go(func() error { return gs.Start() }) @@ -158,3 +177,23 @@ func readConfig() (agent.Computation, error) { } return ac, nil } + +func dialVsock() (*vsock.Conn, error) { + var conn *vsock.Conn + var err error + + err = backoff.Retry(func() error { + conn, err = vsock.Dial(vsock.Host, manager.ManagerVsockPort, nil) + if err == nil { + log.Println("vsock connection established") + return nil + } + log.Printf("vsock connection failed, retrying in %s... Error: %v", retryInterval, err) + return err + }, backoff.NewExponentialBackOff()) + if err != nil { + return nil, err + } + + return conn, nil +} From 5383f4465b0d39ac2d590f9325ac8160db4455df Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:30:51 +0300 Subject: [PATCH 06/83] NOISSUE - Exit on network failures only (#227) Signed-off-by: Sammy Oina --- cmd/manager/main.go | 2 +- manager/api/grpc/client.go | 182 +++++++++++++++++++++++-------------- 2 files changed, 117 insertions(+), 67 deletions(-) diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 64e19cc8a..616919b0d 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -109,7 +109,7 @@ func main() { return } - mc := managerapi.NewClient(pc, svc, eventsChan) + mc := managerapi.NewClient(pc, svc, eventsChan, logger) g.Go(func() error { return mc.Process(ctx, cancel) diff --git a/manager/api/grpc/client.go b/manager/api/grpc/client.go index 9e1c8306e..6ac481802 100644 --- a/manager/api/grpc/client.go +++ b/manager/api/grpc/client.go @@ -5,6 +5,7 @@ package grpc import ( "bytes" "context" + "log/slog" "github.com/absmach/magistrala/pkg/errors" "github.com/ultravioletrs/cocos/manager" @@ -22,14 +23,16 @@ type ManagerClient struct { stream pkgmanager.ManagerService_ProcessClient svc manager.Service responses chan *pkgmanager.ClientStreamMessage + logger *slog.Logger } // NewClient returns new gRPC client instance. -func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, responses chan *pkgmanager.ClientStreamMessage) ManagerClient { +func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, responses chan *pkgmanager.ClientStreamMessage, logger *slog.Logger) ManagerClient { return ManagerClient{ stream: stream, svc: svc, responses: responses, + logger: logger, } } @@ -37,79 +40,126 @@ func (client ManagerClient) Process(ctx context.Context, cancel context.CancelFu eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { - var runReqBuffer bytes.Buffer - for { + return client.handleIncomingMessages(ctx) + }) + + eg.Go(func() error { + return client.handleOutgoingMessages(ctx) + }) + + return eg.Wait() +} + +func (client ManagerClient) handleIncomingMessages(ctx context.Context) error { + var runReqBuffer bytes.Buffer + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: req, err := client.stream.Recv() if err != nil { return err } - - switch mes := req.Message.(type) { - case *pkgmanager.ServerStreamMessage_RunReqChunks: - if len(mes.RunReqChunks.Data) == 0 { - var runReq pkgmanager.ComputationRunReq - if err = proto.Unmarshal(runReqBuffer.Bytes(), &runReq); err != nil { - return errors.Wrap(err, errCorruptedManifest) - } - port, err := client.svc.Run(ctx, &runReq) - if err != nil { - return err - } - runRes := &pkgmanager.ClientStreamMessage_RunRes{ - RunRes: &pkgmanager.RunResponse{ - AgentPort: port, - ComputationId: runReq.Id, - }, - } - if err := client.stream.Send(&pkgmanager.ClientStreamMessage{Message: runRes}); err != nil { - return err - } - } - if _, err := runReqBuffer.Write(mes.RunReqChunks.Data); err != nil { - return err - } - - case *pkgmanager.ServerStreamMessage_TerminateReq: - cancel() - return errors.Wrap(errTerminationFromServer, errors.New(mes.TerminateReq.Message)) - case *pkgmanager.ServerStreamMessage_StopComputation: - msg := &pkgmanager.ClientStreamMessage_StopComputationRes{StopComputationRes: &pkgmanager.StopComputationResponse{ - ComputationId: mes.StopComputation.ComputationId, - }} - if err := client.svc.Stop(ctx, mes.StopComputation.ComputationId); err != nil { - msg.StopComputationRes.Message = err.Error() - } - if err := client.stream.Send(&pkgmanager.ClientStreamMessage{Message: msg}); err != nil { - return err - } - case *pkgmanager.ServerStreamMessage_BackendInfoReq: - res, err := client.svc.FetchBackendInfo() - if err != nil { - return err - } - info := &pkgmanager.ClientStreamMessage_BackendInfo{BackendInfo: &pkgmanager.BackendInfo{ - Info: res, - Id: mes.BackendInfoReq.Id, - }} - if err := client.stream.Send(&pkgmanager.ClientStreamMessage{Message: info}); err != nil { - return err - } + if err := client.processIncomingMessage(ctx, req, &runReqBuffer); err != nil { + return err } } - }) + } +} - eg.Go(func() error { - for { - select { - case <-ctx.Done(): - return nil - case mes := <-client.responses: - if err := client.stream.Send(mes); err != nil { - return err - } +func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkgmanager.ServerStreamMessage, runReqBuffer *bytes.Buffer) error { + switch mes := req.Message.(type) { + case *pkgmanager.ServerStreamMessage_RunReqChunks: + return client.handleRunReqChunks(ctx, mes, runReqBuffer) + case *pkgmanager.ServerStreamMessage_TerminateReq: + return client.handleTerminateReq(mes) + case *pkgmanager.ServerStreamMessage_StopComputation: + go client.handleStopComputation(ctx, mes) + case *pkgmanager.ServerStreamMessage_BackendInfoReq: + go client.handleBackendInfoReq(ctx, mes) + default: + return errors.New("unknown message type") + } + return nil +} + +func (client ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgmanager.ServerStreamMessage_RunReqChunks, runReqBuffer *bytes.Buffer) error { + if len(mes.RunReqChunks.Data) == 0 { + var runReq pkgmanager.ComputationRunReq + if err := proto.Unmarshal(runReqBuffer.Bytes(), &runReq); err != nil { + return errors.Wrap(err, errCorruptedManifest) + } + go client.executeRun(ctx, &runReq) + } + _, err := runReqBuffer.Write(mes.RunReqChunks.Data) + return err +} + +func (client ManagerClient) executeRun(ctx context.Context, runReq *pkgmanager.ComputationRunReq) { + port, err := client.svc.Run(ctx, runReq) + if err != nil { + client.logger.Warn(err.Error()) + return + } + runRes := &pkgmanager.ClientStreamMessage_RunRes{ + RunRes: &pkgmanager.RunResponse{ + AgentPort: port, + ComputationId: runReq.Id, + }, + } + client.sendMessage(&pkgmanager.ClientStreamMessage{Message: runRes}) +} + +func (client ManagerClient) handleTerminateReq(mes *pkgmanager.ServerStreamMessage_TerminateReq) error { + return errors.Wrap(errTerminationFromServer, errors.New(mes.TerminateReq.Message)) +} + +func (client ManagerClient) handleStopComputation(ctx context.Context, mes *pkgmanager.ServerStreamMessage_StopComputation) { + msg := &pkgmanager.ClientStreamMessage_StopComputationRes{ + StopComputationRes: &pkgmanager.StopComputationResponse{ + ComputationId: mes.StopComputation.ComputationId, + }, + } + if err := client.svc.Stop(ctx, mes.StopComputation.ComputationId); err != nil { + msg.StopComputationRes.Message = err.Error() + } + client.sendMessage(&pkgmanager.ClientStreamMessage{Message: msg}) +} + +func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *pkgmanager.ServerStreamMessage_BackendInfoReq) { + res, err := client.svc.FetchBackendInfo() + if err != nil { + client.logger.Warn(err.Error()) + return + } + info := &pkgmanager.ClientStreamMessage_BackendInfo{ + BackendInfo: &pkgmanager.BackendInfo{ + Info: res, + Id: mes.BackendInfoReq.Id, + }, + } + client.sendMessage(&pkgmanager.ClientStreamMessage{Message: info}) +} + +func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case mes := <-client.responses: + if err := client.stream.Send(mes); err != nil { + return err } } - }) + } +} - return eg.Wait() +func (client ManagerClient) sendMessage(mes *pkgmanager.ClientStreamMessage) { + select { + case client.responses <- mes: + return + default: + client.logger.Warn("failed to send message to client") + } } From e572793295a0429480d9a53b0c6041fe28d2bdb0 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:32:23 +0300 Subject: [PATCH 07/83] exit with error code (#225) Signed-off-by: Sammy Oina --- cmd/agent/main.go | 8 ++++++++ cmd/manager/main.go | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 22621491a..cf439fa26 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -11,6 +11,7 @@ import ( "log/slog" "time" + mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/prometheus" "github.com/cenkalti/backoff/v4" "github.com/google/go-sev-guest/client" @@ -52,9 +53,13 @@ func main() { } defer conn.Close() + var exitCode int + defer mglog.ExitWithError(&exitCode) + var level slog.Level if err := level.UnmarshalText([]byte(cfg.AgentConfig.LogLevel)); err != nil { log.Println(err) + exitCode = 1 return } handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID) @@ -63,6 +68,7 @@ func main() { eventSvc, err := events.New(svcName, cfg.ID, manager.ManagerVsockPort) if err != nil { logger.Error(fmt.Sprintf("failed to create events service %s", err.Error())) + exitCode = 1 return } defer eventSvc.Close() @@ -70,6 +76,7 @@ func main() { qp, err := quoteprovider.GetQuoteProvider() if err != nil { logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error())) + exitCode = 1 return } @@ -93,6 +100,7 @@ func main() { authSvc, err := auth.New(cfg) if err != nil { logger.Error(fmt.Sprintf("failed to create auth service %s", err.Error())) + exitCode = 1 return } diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 616919b0d..47413128e 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -57,9 +57,13 @@ func main() { log.Fatalf(err.Error()) } + var exitCode int + defer mglog.ExitWithError(&exitCode) + if cfg.InstanceID == "" { if cfg.InstanceID, err = uuid.New().ID(); err != nil { logger.Error(fmt.Sprintf("Failed to generate instance ID: %s", err)) + exitCode = 1 return } } @@ -78,6 +82,7 @@ func main() { qemuCfg := qemu.Config{} if err := env.ParseWithOptions(&qemuCfg, env.Options{Prefix: envPrefixQemu}); err != nil { logger.Error(fmt.Sprintf("failed to load QEMU configuration: %s", err)) + exitCode = 1 return } args := qemuCfg.ConstructQemuArgs() @@ -86,12 +91,14 @@ func main() { managerGRPCConfig := grpc.Config{} if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) + exitCode = 1 return } managerGRPCClient, managerClient, err := managergrpc.NewManagerClient(managerGRPCConfig) if err != nil { logger.Error(err.Error()) + exitCode = 1 return } defer managerGRPCClient.Close() @@ -99,6 +106,7 @@ func main() { pc, err := managerClient.Process(ctx) if err != nil { logger.Error(err.Error()) + exitCode = 1 return } @@ -106,6 +114,7 @@ func main() { svc, err := newService(logger, tracer, qemuCfg, eventsChan, cfg.BackendMeasurementBinary) if err != nil { logger.Error(err.Error()) + exitCode = 1 return } From 9ca045b06a329654f70e65f4a0c69b91451883ce Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 30 Aug 2024 19:08:11 +0300 Subject: [PATCH 08/83] COCOS-214 - Improve manager resiliance by tracking vms on restart (#219) * track hanging vm processes Signed-off-by: SammyOina * fix lint Signed-off-by: SammyOina * fix run test Signed-off-by: SammyOina * fix stop computation Signed-off-by: SammyOina * shutdown gracefully Signed-off-by: SammyOina * check if process still exists Signed-off-by: Sammy Oina * fix lint Signed-off-by: Sammy Oina * use const Signed-off-by: Sammy Oina --------- Signed-off-by: SammyOina Signed-off-by: Sammy Oina --- cmd/agent/main.go | 2 +- cmd/manager/main.go | 17 ++++++ manager/qemu/config.go | 2 + manager/qemu/mocks/persistence.go | 96 +++++++++++++++++++++++++++++++ manager/qemu/persistence.go | 89 ++++++++++++++++++++++++++++ manager/qemu/vm.go | 27 ++++++++- manager/service.go | 89 +++++++++++++++++++++++++++- manager/service_test.go | 26 ++++++--- manager/vm/mocks/vm.go | 36 ++++++++++++ manager/vm/vm.go | 2 + 10 files changed, 372 insertions(+), 14 deletions(-) create mode 100644 manager/qemu/mocks/persistence.go create mode 100644 manager/qemu/persistence.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index cf439fa26..7076163e1 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -115,7 +115,7 @@ func main() { if err != nil { log.Fatal("failed to reconnect: ", err) } - handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}) + handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID) logger = slog.New(handler) } time.Sleep(retryInterval) diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 47413128e..d258fa7c0 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -10,7 +10,9 @@ import ( "log/slog" "net/url" "os" + "os/signal" "strings" + "syscall" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/jaeger" @@ -120,6 +122,21 @@ func main() { mc := managerapi.NewClient(pc, svc, eventsChan, logger) + g.Go(func() error { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(ch) + + select { + case <-ch: + logger.Info("Received signal, shutting down...") + cancel() + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + g.Go(func() error { return mc.Process(ctx, cancel) }) diff --git a/manager/qemu/config.go b/manager/qemu/config.go index f77b8f801..9dca5d5b7 100644 --- a/manager/qemu/config.go +++ b/manager/qemu/config.go @@ -7,6 +7,8 @@ import ( "strconv" ) +const BaseGuestCID = 3 + type MemoryConfig struct { Size string `env:"MEMORY_SIZE" envDefault:"2048M"` Slots int `env:"MEMORY_SLOTS" envDefault:"5"` diff --git a/manager/qemu/mocks/persistence.go b/manager/qemu/mocks/persistence.go new file mode 100644 index 000000000..812ba4ef4 --- /dev/null +++ b/manager/qemu/mocks/persistence.go @@ -0,0 +1,96 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + qemu "github.com/ultravioletrs/cocos/manager/qemu" +) + +// Persistence is an autogenerated mock type for the Persistence type +type Persistence struct { + mock.Mock +} + +// DeleteVM provides a mock function with given fields: id +func (_m *Persistence) DeleteVM(id string) error { + ret := _m.Called(id) + + if len(ret) == 0 { + panic("no return value specified for DeleteVM") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LoadVMs provides a mock function with given fields: +func (_m *Persistence) LoadVMs() ([]qemu.VMState, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for LoadVMs") + } + + var r0 []qemu.VMState + var r1 error + if rf, ok := ret.Get(0).(func() ([]qemu.VMState, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []qemu.VMState); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]qemu.VMState) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SaveVM provides a mock function with given fields: state +func (_m *Persistence) SaveVM(state qemu.VMState) error { + ret := _m.Called(state) + + if len(ret) == 0 { + panic("no return value specified for SaveVM") + } + + var r0 error + if rf, ok := ret.Get(0).(func(qemu.VMState) error); ok { + r0 = rf(state) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewPersistence creates a new instance of Persistence. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPersistence(t interface { + mock.TestingT + Cleanup(func()) +}) *Persistence { + mock := &Persistence{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/manager/qemu/persistence.go b/manager/qemu/persistence.go new file mode 100644 index 000000000..097ef4439 --- /dev/null +++ b/manager/qemu/persistence.go @@ -0,0 +1,89 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package qemu + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" +) + +const jsonExt = ".json" + +type VMState struct { + ID string + Config Config + PID int +} + +type FilePersistence struct { + dir string + lock sync.Mutex +} + +// Persistence is an interface for saving and loading VM states. +// +//go:generate mockery --name Persistence --output=./mocks --filename persistence.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" +type Persistence interface { + SaveVM(state VMState) error + LoadVMs() ([]VMState, error) + DeleteVM(id string) error +} + +func NewFilePersistence(dir string) (Persistence, error) { + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, err + } + return &FilePersistence{dir: dir}, nil +} + +func (fp *FilePersistence) SaveVM(state VMState) error { + fp.lock.Lock() + defer fp.lock.Unlock() + + data, err := json.Marshal(state) + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(fp.dir, state.ID+jsonExt), data, 0o644) +} + +func (fp *FilePersistence) LoadVMs() ([]VMState, error) { + fp.lock.Lock() + defer fp.lock.Unlock() + + files, err := os.ReadDir(fp.dir) + if err != nil { + return nil, err + } + + var states []VMState + for _, file := range files { + if filepath.Ext(file.Name()) != jsonExt { + continue + } + + data, err := os.ReadFile(filepath.Join(fp.dir, file.Name())) + if err != nil { + return nil, err + } + + var state VMState + if err := json.Unmarshal(data, &state); err != nil { + return nil, err + } + + states = append(states, state) + } + + return states, nil +} + +func (fp *FilePersistence) DeleteVM(id string) error { + fp.lock.Lock() + defer fp.lock.Unlock() + + return os.Remove(filepath.Join(fp.dir, id+jsonExt)) +} diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index f0f94bb54..9b7328929 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -4,6 +4,7 @@ package qemu import ( "fmt" + "os" "os/exec" "github.com/gofrs/uuid" @@ -38,9 +39,9 @@ func (v *qemuVM) Start() error { if err != nil { return err } - qemuCfg := v.config - qemuCfg.NetDevConfig.ID = fmt.Sprintf("%s-%s", qemuCfg.NetDevConfig.ID, id) - qemuCfg.SevConfig.ID = fmt.Sprintf("%s-%s", qemuCfg.SevConfig.ID, id) + + v.config.NetDevConfig.ID = fmt.Sprintf("%s-%s", v.config.NetDevConfig.ID, id) + v.config.SevConfig.ID = fmt.Sprintf("%s-%s", v.config.SevConfig.ID, id) exe, args, err := v.executableAndArgs() if err != nil { @@ -58,6 +59,26 @@ func (v *qemuVM) Stop() error { return v.cmd.Process.Kill() } +func (v *qemuVM) SetProcess(pid int) error { + process, err := os.FindProcess(pid) + if err != nil { + return err + } + + exe, args, err := v.executableAndArgs() + if err != nil { + return err + } + + v.cmd = exec.Command(exe, args...) + v.cmd.Process = process + return nil +} + +func (v *qemuVM) GetProcess() int { + return v.cmd.Process.Pid +} + func (v *qemuVM) executableAndArgs() (string, []string, error) { exe, err := exec.LookPath(v.config.QemuBinPath) if err != nil { diff --git a/manager/service.go b/manager/service.go index 09ab27789..000a2bd48 100644 --- a/manager/service.go +++ b/manager/service.go @@ -9,9 +9,11 @@ import ( "fmt" "log/slog" "net" + "os" "regexp" "strconv" "sync" + "syscall" "github.com/absmach/magistrala/pkg/errors" "github.com/cenkalti/backoff/v4" @@ -23,7 +25,10 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -const hashLength = 32 +const ( + hashLength = 32 + persistenceDir = "/tmp/cocos" +) var ( // ErrMalformedEntity indicates malformed entity specification (e.g. @@ -68,6 +73,7 @@ type managerService struct { vmFactory vm.Provider portRangeMin int portRangeMax int + persistence qemu.Persistence } var _ Service = (*managerService)(nil) @@ -78,6 +84,12 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, if err != nil { return nil, err } + + persistence, err := qemu.NewFilePersistence(persistenceDir) + if err != nil { + return nil, err + } + ms := &managerService{ qemuCfg: cfg, logger: logger, @@ -87,7 +99,13 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, backendMeasurementBinaryPath: backendMeasurementBinPath, portRangeMin: start, portRangeMax: end, + persistence: persistence, } + + if err := ms.restoreVMs(); err != nil { + return nil, err + } + return ms, nil } @@ -127,6 +145,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) return "", errors.Wrap(ErrFailedToAllocatePort, err) } ms.qemuCfg.HostFwdAgent = agentPort + ms.qemuCfg.VSockConfig.GuestCID = qemu.BaseGuestCID + len(ms.vms) ch, err := computationHash(ac) if err != nil { @@ -145,13 +164,24 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) } ms.vms[c.Id] = cvm + pid := cvm.GetProcess() + + state := qemu.VMState{ + ID: c.Id, + Config: ms.qemuCfg, + PID: pid, + } + if err := ms.persistence.SaveVM(state); err != nil { + ms.logger.Error("Failed to persist VM state", "error", err) + } + err = backoff.Retry(func() error { return cvm.SendAgentConfig(ac) }, backoff.NewExponentialBackOff()) if err != nil { return "", err } - ms.qemuCfg.VSockConfig.GuestCID++ + ms.qemuCfg.VSockConfig.Vnc++ ms.publishEvent("vm-provision", c.Id, "complete", json.RawMessage{}) @@ -169,6 +199,11 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error return err } delete(ms.vms, computationID) + + if err := ms.persistence.DeleteVM(computationID); err != nil { + ms.logger.Error("Failed to delete persisted VM state", "error", err) + } + defer ms.publishEvent("stop-computation", computationID, "complete", json.RawMessage{}) return nil } @@ -264,3 +299,53 @@ func decodeRange(input string) (int, int, error) { return start, end, nil } + +func (ms *managerService) restoreVMs() error { + states, err := ms.persistence.LoadVMs() + if err != nil { + return err + } + + for _, state := range states { + exists, err := processExists(state.PID) + if err != nil { + ms.logger.Warn("Failed to check process existence", "computation", state.ID, "pid", state.PID, "error", err) + continue + } + + if !exists { + if err := ms.persistence.DeleteVM(state.ID); err != nil { + ms.logger.Error("Failed to delete persisted VM state", "computation", state.ID, "error", err) + } + ms.logger.Info("Deleted persisted state for non-existent process", "computation", state.ID, "pid", state.PID) + continue + } + + cvm := ms.vmFactory(state.Config, ms.eventsChan, state.ID) + + if err = cvm.SetProcess(state.PID); err != nil { + ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err) + continue + } + + ms.vms[state.ID] = cvm + ms.logger.Info("Successfully restored VM state", "id", state.ID, "computationId", state.ID, "pid", state.PID) + } + + return nil +} + +func processExists(pid int) (bool, error) { + process, err := os.FindProcess(pid) + if err != nil { + return false, err + } + + if err = process.Signal(syscall.Signal(0)); err == nil { + return true, nil + } + if err == syscall.ESRCH { + return false, nil + } + return false, err +} diff --git a/manager/service_test.go b/manager/service_test.go index 85abc60c8..ee1148444 100644 --- a/manager/service_test.go +++ b/manager/service_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/manager/qemu" + persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks" "github.com/ultravioletrs/cocos/manager/vm" "github.com/ultravioletrs/cocos/manager/vm/mocks" "github.com/ultravioletrs/cocos/pkg/manager" @@ -35,6 +36,7 @@ func TestNew(t *testing.T) { func TestRun(t *testing.T) { vmf := new(mocks.Provider) vmMock := new(mocks.VM) + persistence := new(persistenceMocks.Persistence) vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock) tests := []struct { name string @@ -79,6 +81,9 @@ func TestRun(t *testing.T) { } vmMock.On("SendAgentConfig", mock.Anything).Return(nil) + vmMock.On("GetProcess").Return(1234) + + persistence.On("SaveVM", mock.Anything).Return(nil) qemuCfg := qemu.Config{ VSockConfig: qemu.VSockConfig{ @@ -90,11 +95,12 @@ func TestRun(t *testing.T) { eventsChan := make(chan *manager.ClientStreamMessage, 10) ms := &managerService{ - qemuCfg: qemuCfg, - logger: logger, - vms: make(map[string]vm.VM), - eventsChan: eventsChan, - vmFactory: vmf.Execute, + qemuCfg: qemuCfg, + logger: logger, + vms: make(map[string]vm.VM), + eventsChan: eventsChan, + vmFactory: vmf.Execute, + persistence: persistence, } ctx := context.Background() @@ -123,6 +129,7 @@ func TestRun(t *testing.T) { func TestStop(t *testing.T) { vmf := new(mocks.Provider) vmMock := new(mocks.VM) + persistence := new(persistenceMocks.Persistence) vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock) tests := []struct { @@ -160,9 +167,10 @@ func TestStop(t *testing.T) { logger := slog.Default() eventsChan := make(chan *manager.ClientStreamMessage, 10) ms := &managerService{ - logger: logger, - vms: make(map[string]vm.VM), - eventsChan: eventsChan, + logger: logger, + vms: make(map[string]vm.VM), + eventsChan: eventsChan, + persistence: persistence, } vmMock := new(mocks.VM) @@ -172,6 +180,8 @@ func TestStop(t *testing.T) { vmMock.On("Stop").Return(assert.AnError).Once() } + persistence.On("DeleteVM", tt.computationID).Return(nil) + if tt.initialVMCount > 0 { ms.vms[tt.computationID] = vmMock } diff --git a/manager/vm/mocks/vm.go b/manager/vm/mocks/vm.go index ce200d994..67e455120 100644 --- a/manager/vm/mocks/vm.go +++ b/manager/vm/mocks/vm.go @@ -15,6 +15,24 @@ type VM struct { mock.Mock } +// GetProcess provides a mock function with given fields: +func (_m *VM) GetProcess() int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetProcess") + } + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + // SendAgentConfig provides a mock function with given fields: ac func (_m *VM) SendAgentConfig(ac agent.Computation) error { ret := _m.Called(ac) @@ -33,6 +51,24 @@ func (_m *VM) SendAgentConfig(ac agent.Computation) error { return r0 } +// SetProcess provides a mock function with given fields: pid +func (_m *VM) SetProcess(pid int) error { + ret := _m.Called(pid) + + if len(ret) == 0 { + panic("no return value specified for SetProcess") + } + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(pid) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Start provides a mock function with given fields: func (_m *VM) Start() error { ret := _m.Called() diff --git a/manager/vm/vm.go b/manager/vm/vm.go index aa87e3323..856883b1c 100644 --- a/manager/vm/vm.go +++ b/manager/vm/vm.go @@ -14,6 +14,8 @@ type VM interface { Start() error Stop() error SendAgentConfig(ac agent.Computation) error + SetProcess(pid int) error + GetProcess() int } //go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" From 7ba34b93bc1ca31fdb7afe434fa206ce0deac345 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 3 Sep 2024 12:29:07 +0300 Subject: [PATCH 09/83] NOISSUE - Streamline message processing to prevent potential message loss (#228) * fix dropping of message response from manager Signed-off-by: Sammy Oina * remove change Signed-off-by: Sammy Oina * simplify Signed-off-by: Sammy Oina * add message send timeout Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cmd/manager/main.go | 9 +++++---- manager/api/grpc/client.go | 32 ++++++++++++++++++-------------- manager/api/grpc/server.go | 1 + manager/service.go | 24 ++++++++++++------------ 4 files changed, 36 insertions(+), 30 deletions(-) diff --git a/cmd/manager/main.go b/cmd/manager/main.go index d258fa7c0..87c4218ba 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -32,9 +32,10 @@ import ( ) const ( - svcName = "manager" - envPrefixGRPC = "MANAGER_GRPC_" - envPrefixQemu = "MANAGER_QEMU_" + svcName = "manager" + envPrefixGRPC = "MANAGER_GRPC_" + envPrefixQemu = "MANAGER_QEMU_" + clientBufferSize = 100 ) type config struct { @@ -112,7 +113,7 @@ func main() { return } - eventsChan := make(chan *pkgmanager.ClientStreamMessage) + eventsChan := make(chan *pkgmanager.ClientStreamMessage, clientBufferSize) svc, err := newService(logger, tracer, qemuCfg, eventsChan, cfg.BackendMeasurementBinary) if err != nil { logger.Error(err.Error()) diff --git a/manager/api/grpc/client.go b/manager/api/grpc/client.go index 6ac481802..59a6312cf 100644 --- a/manager/api/grpc/client.go +++ b/manager/api/grpc/client.go @@ -6,6 +6,7 @@ import ( "bytes" "context" "log/slog" + "time" "github.com/absmach/magistrala/pkg/errors" "github.com/ultravioletrs/cocos/manager" @@ -17,22 +18,23 @@ import ( var ( errTerminationFromServer = errors.New("server requested client termination") errCorruptedManifest = errors.New("received manifest may be corrupted") + sendTimeout = 5 * time.Second ) type ManagerClient struct { - stream pkgmanager.ManagerService_ProcessClient - svc manager.Service - responses chan *pkgmanager.ClientStreamMessage - logger *slog.Logger + stream pkgmanager.ManagerService_ProcessClient + svc manager.Service + messageQueue chan *pkgmanager.ClientStreamMessage + logger *slog.Logger } // NewClient returns new gRPC client instance. -func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, responses chan *pkgmanager.ClientStreamMessage, logger *slog.Logger) ManagerClient { +func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *pkgmanager.ClientStreamMessage, logger *slog.Logger) ManagerClient { return ManagerClient{ - stream: stream, - svc: svc, - responses: responses, - logger: logger, + stream: stream, + svc: svc, + messageQueue: messageQueue, + logger: logger, } } @@ -147,7 +149,7 @@ func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case mes := <-client.responses: + case mes := <-client.messageQueue: if err := client.stream.Send(mes); err != nil { return err } @@ -156,10 +158,12 @@ func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error { } func (client ManagerClient) sendMessage(mes *pkgmanager.ClientStreamMessage) { + ctx, cancel := context.WithTimeout(context.Background(), sendTimeout) + defer cancel() + select { - case client.responses <- mes: - return - default: - client.logger.Warn("failed to send message to client") + case client.messageQueue <- mes: + case <-ctx.Done(): + client.logger.Warn("Failed to send message: timeout exceeded") } } diff --git a/manager/api/grpc/server.go b/manager/api/grpc/server.go index a5a7abb98..d2612350c 100644 --- a/manager/api/grpc/server.go +++ b/manager/api/grpc/server.go @@ -41,6 +41,7 @@ func NewServer(incoming chan *manager.ClientStreamMessage, svc Service) manager. func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error { runReqChan := make(chan *manager.ServerStreamMessage) + defer close(runReqChan) client, ok := peer.FromContext(stream.Context()) if ok { go s.svc.Run(client.Addr.String(), runReqChan, client.AuthInfo) diff --git a/manager/service.go b/manager/service.go index 000a2bd48..d1f3b1db6 100644 --- a/manager/service.go +++ b/manager/service.go @@ -65,6 +65,7 @@ type Service interface { } type managerService struct { + mu sync.Mutex qemuCfg qemu.Config backendMeasurementBinaryPath string logger *slog.Logger @@ -162,7 +163,9 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) ms.publishEvent("vm-provision", c.Id, "failed", json.RawMessage{}) return "", err } + ms.mu.Lock() ms.vms[c.Id] = cvm + ms.mu.Unlock() pid := cvm.GetProcess() @@ -189,6 +192,8 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) } func (ms *managerService) Stop(ctx context.Context, computationID string) error { + ms.mu.Lock() + defer ms.mu.Unlock() cvm, ok := ms.vms[computationID] if !ok { defer ms.publishEvent("stop-computation", computationID, "failed", json.RawMessage{}) @@ -307,13 +312,7 @@ func (ms *managerService) restoreVMs() error { } for _, state := range states { - exists, err := processExists(state.PID) - if err != nil { - ms.logger.Warn("Failed to check process existence", "computation", state.ID, "pid", state.PID, "error", err) - continue - } - - if !exists { + if !ms.processExists(state.PID) { if err := ms.persistence.DeleteVM(state.ID); err != nil { ms.logger.Error("Failed to delete persisted VM state", "computation", state.ID, "error", err) } @@ -335,17 +334,18 @@ func (ms *managerService) restoreVMs() error { return nil } -func processExists(pid int) (bool, error) { +func (ms *managerService) processExists(pid int) bool { process, err := os.FindProcess(pid) if err != nil { - return false, err + ms.logger.Warn("Failed to find process", "pid", pid, "error", err) + return false } if err = process.Signal(syscall.Signal(0)); err == nil { - return true, nil + return true } if err == syscall.ESRCH { - return false, nil + return false } - return false, err + return false } From 00980639d5d4b8d3ca169df263e00f760b34a82a Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:27:06 +0300 Subject: [PATCH 10/83] NOISSUE - Remove run channel (#231) Signed-off-by: Sammy Oina --- manager/api/grpc/server.go | 92 +++++++++++++++++++++----------------- manager/setup_test.go | 11 +++-- test/computations/main.go | 9 ++-- 3 files changed, 64 insertions(+), 48 deletions(-) diff --git a/manager/api/grpc/server.go b/manager/api/grpc/server.go index d2612350c..ba79c8fca 100644 --- a/manager/api/grpc/server.go +++ b/manager/api/grpc/server.go @@ -4,6 +4,7 @@ package grpc import ( "bytes" + "context" "errors" "io" @@ -28,7 +29,7 @@ type grpcServer struct { } type Service interface { - Run(ipAddress string, runReqChan chan *manager.ServerStreamMessage, authInfo credentials.AuthInfo) + Run(ctx context.Context, ipAddress string, sendMessage func(*manager.ServerStreamMessage) error, authInfo credentials.AuthInfo) } // NewServer returns new AuthServiceServer instance. @@ -40,12 +41,11 @@ func NewServer(incoming chan *manager.ClientStreamMessage, svc Service) manager. } func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error { - runReqChan := make(chan *manager.ServerStreamMessage) - defer close(runReqChan) client, ok := peer.FromContext(stream.Context()) - if ok { - go s.svc.Run(client.Addr.String(), runReqChan, client.AuthInfo) + if !ok { + return errors.New("failed to get peer info") } + eg, ctx := errgroup.WithContext(stream.Context()) eg.Go(func() error { @@ -60,45 +60,53 @@ func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error }) eg.Go(func() error { - for { - select { - case <-ctx.Done(): - return nil - case req := <-runReqChan: - switch msg := req.Message.(type) { - case *manager.ServerStreamMessage_RunReq: - data, err := proto.Marshal(msg.RunReq) - if err != nil { - return err - } - dataBuffer := bytes.NewBuffer(data) - buf := make([]byte, bufferSize) - for { - n, err := dataBuffer.Read(buf) - chunk := &manager.ServerStreamMessage{ - Message: &manager.ServerStreamMessage_RunReqChunks{ - RunReqChunks: &manager.RunReqChunks{ - Data: buf[:n], - }, - }, - } - - if err := stream.Send(chunk); err != nil { - return err - } - - if err == io.EOF { - break - } - } - - default: - if err := stream.Send(req); err != nil { - return err - } - } + sendMessage := func(msg *manager.ServerStreamMessage) error { + switch m := msg.Message.(type) { + case *manager.ServerStreamMessage_RunReq: + return s.sendRunReqInChunks(stream, m.RunReq) + default: + return stream.Send(msg) } } + + s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo) + return nil }) + return eg.Wait() } + +func (s *grpcServer) sendRunReqInChunks(stream manager.ManagerService_ProcessServer, runReq *manager.ComputationRunReq) error { + data, err := proto.Marshal(runReq) + if err != nil { + return err + } + + dataBuffer := bytes.NewBuffer(data) + buf := make([]byte, bufferSize) + + for { + n, err := dataBuffer.Read(buf) + if err != nil && err != io.EOF { + return err + } + + chunk := &manager.ServerStreamMessage{ + Message: &manager.ServerStreamMessage_RunReqChunks{ + RunReqChunks: &manager.RunReqChunks{ + Data: buf[:n], + }, + }, + } + + if err := stream.Send(chunk); err != nil { + return err + } + + if err == io.EOF { + break + } + } + + return nil +} diff --git a/manager/setup_test.go b/manager/setup_test.go index 442bf4249..d4966175a 100644 --- a/manager/setup_test.go +++ b/manager/setup_test.go @@ -64,7 +64,7 @@ func bufDialer(context.Context, string) (net.Conn, error) { return lis.Dial() } -func (s *svc) Run(ipAddress string, runReqChan chan *manager.ServerStreamMessage, authInfo credentials.AuthInfo) { +func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage func(*manager.ServerStreamMessage) error, authInfo credentials.AuthInfo) { privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize) if err != nil { s.t.Fatalf("Error generating public key: %v", err) @@ -82,10 +82,12 @@ func (s *svc) Run(ipAddress string, runReqChan chan *manager.ServerStreamMessage go func() { time.Sleep(time.Millisecond * 100) - runReqChan <- &manager.ServerStreamMessage{ + if err := sendMessage(&manager.ServerStreamMessage{ Message: &manager.ServerStreamMessage_TerminateReq{ TerminateReq: &manager.Terminate{Message: "test terminate"}, }, + }); err != nil { + s.t.Fatalf("failed to send terminate request: %s", err) } }() @@ -105,7 +107,8 @@ func (s *svc) Run(ipAddress string, runReqChan chan *manager.ServerStreamMessage pubPem, _ := pem.Decode(pubPemBytes) algoHash := sha3.Sum256(algo) dataHash := sha3.Sum256(data) - runReqChan <- &manager.ServerStreamMessage{ + + if err := sendMessage(&manager.ServerStreamMessage{ Message: &manager.ServerStreamMessage_RunReq{ RunReq: &manager.ComputationRunReq{ Id: "1", @@ -121,6 +124,8 @@ func (s *svc) Run(ipAddress string, runReqChan chan *manager.ServerStreamMessage }, }, }, + }); err != nil { + s.t.Fatalf("failed to send run request: %s", err) } }() } diff --git a/test/computations/main.go b/test/computations/main.go index bcb422ee0..e083688a8 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -42,8 +42,8 @@ type svc struct { logger *slog.Logger } -func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, auth credentials.AuthInfo) { - s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAdress)) +func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage func(*manager.ServerStreamMessage) error, authInfo credentials.AuthInfo) { + s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAddress)) pubKey, err := os.ReadFile(pubKeyFile) if err != nil { @@ -73,7 +73,7 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au return } - reqChan <- &manager.ServerStreamMessage{ + if err := sendMessage(&manager.ServerStreamMessage{ Message: &manager.ServerStreamMessage_RunReq{ RunReq: &manager.ComputationRunReq{ Id: "1", @@ -89,6 +89,9 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au }, }, }, + }); err != nil { + s.logger.Error(fmt.Sprintf("failed to send run request: %s", err)) + return } } From f848afeefd44e7b3f8e3ba8f9bf7f8de800e7c5e Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:32:04 +0300 Subject: [PATCH 11/83] NOISSUE - Define sendFunc type (#232) Signed-off-by: Sammy Oina --- manager/api/grpc/server.go | 4 +++- manager/setup_test.go | 2 +- test/computations/main.go | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/manager/api/grpc/server.go b/manager/api/grpc/server.go index ba79c8fca..3c5d1455e 100644 --- a/manager/api/grpc/server.go +++ b/manager/api/grpc/server.go @@ -22,6 +22,8 @@ var ( const bufferSize = 1024 * 1024 // 1 MB +type SendFunc func(*manager.ServerStreamMessage) error + type grpcServer struct { manager.UnimplementedManagerServiceServer incoming chan *manager.ClientStreamMessage @@ -29,7 +31,7 @@ type grpcServer struct { } type Service interface { - Run(ctx context.Context, ipAddress string, sendMessage func(*manager.ServerStreamMessage) error, authInfo credentials.AuthInfo) + Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) } // NewServer returns new AuthServiceServer instance. diff --git a/manager/setup_test.go b/manager/setup_test.go index d4966175a..bf478ed1a 100644 --- a/manager/setup_test.go +++ b/manager/setup_test.go @@ -64,7 +64,7 @@ func bufDialer(context.Context, string) (net.Conn, error) { return lis.Dial() } -func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage func(*manager.ServerStreamMessage) error, authInfo credentials.AuthInfo) { +func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc.SendFunc, authInfo credentials.AuthInfo) { privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize) if err != nil { s.t.Fatalf("Error generating public key: %v", err) diff --git a/test/computations/main.go b/test/computations/main.go index e083688a8..905b00e00 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -42,7 +42,7 @@ type svc struct { logger *slog.Logger } -func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage func(*manager.ServerStreamMessage) error, authInfo credentials.AuthInfo) { +func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc.SendFunc, authInfo credentials.AuthInfo) { s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAddress)) pubKey, err := os.ReadFile(pubKeyFile) From 51b129c3a2150e20e5bb3e7a5d3840e6109d184f Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 6 Sep 2024 13:53:48 +0300 Subject: [PATCH 12/83] NOISSUE - Flush Docker logs (#229) * flush docker logs Signed-off-by: Sammy Oina * show logs in realtime Signed-off-by: Sammy Oina * add tty Signed-off-by: Sammy Oina * remove duplicate Signed-off-by: Sammy Oina * python3 Signed-off-by: Sammy Oina * error check Signed-off-by: Sammy Oina * remove capitalization Signed-off-by: SammyOina --------- Signed-off-by: Sammy Oina Signed-off-by: SammyOina --- agent/algorithm/docker/docker.go | 57 ++++++++++++++++++-------------- test/manual/algo/Dockerfile | 2 +- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/agent/algorithm/docker/docker.go b/agent/algorithm/docker/docker.go index 0b467da9d..ab562017e 100644 --- a/agent/algorithm/docker/docker.go +++ b/agent/algorithm/docker/docker.go @@ -3,6 +3,7 @@ package docker import ( + "bufio" "context" "fmt" "io" @@ -89,7 +90,10 @@ func (d *docker) Run() error { // Create and start the container. respContainer, err := cli.ContainerCreate(ctx, &container.Config{ - Image: dockerImageName, + Image: dockerImageName, + Tty: true, + AttachStdout: true, + AttachStderr: true, }, &container.HostConfig{ Mounts: []mount.Mount{ { @@ -112,35 +116,37 @@ func (d *docker) Run() error { return fmt.Errorf("could not start a Docker container: %v", err) } - statusCh, errCh := cli.ContainerWait(ctx, respContainer.ID, container.WaitConditionNotRunning) - select { - case err := <-errCh: - if err != nil { - return fmt.Errorf("could not wait for a Docker container: %v", err) - } - case <-statusCh: - } - - stdout, err := cli.ContainerLogs(ctx, respContainer.ID, container.LogsOptions{ShowStdout: true}) + stdout, err := cli.ContainerLogs(ctx, respContainer.ID, container.LogsOptions{ShowStdout: true, Follow: true}) if err != nil { return fmt.Errorf("could not read stdout from the container: %v", err) } defer stdout.Close() - err = writeToOut(stdout, d.stdout) - if err != nil { - d.logger.Warn(fmt.Sprintf("could not write to stdout: %v", err)) - } + go func() { + if err := writeToOut(stdout, d.stdout); err != nil { + d.logger.Warn(fmt.Sprintf("could not write to stdout: %v", err)) + } + }() - stderr, err := cli.ContainerLogs(ctx, respContainer.ID, container.LogsOptions{ShowStderr: true}) + stderr, err := cli.ContainerLogs(ctx, respContainer.ID, container.LogsOptions{ShowStderr: true, Follow: true}) if err != nil { d.logger.Warn(fmt.Sprintf("could not read stderr from the container: %v", err)) } defer stderr.Close() - err = writeToOut(stderr, d.stderr) - if err != nil { - d.logger.Warn(fmt.Sprintf("could not write to stderr: %v", err)) + go func() { + if err := writeToOut(stderr, d.stderr); err != nil { + d.logger.Warn(fmt.Sprintf("could not write to stderr: %v", err)) + } + }() + + statusCh, errCh := cli.ContainerWait(ctx, respContainer.ID, container.WaitConditionNotRunning) + select { + case err := <-errCh: + if err != nil { + return fmt.Errorf("could not wait for a Docker container: %v", err) + } + case <-statusCh: } defer func() { @@ -157,13 +163,14 @@ func (d *docker) Run() error { } func writeToOut(readCloser io.ReadCloser, ioWriter io.Writer) error { - content, err := io.ReadAll(readCloser) - if err != nil { - return fmt.Errorf("could not convert content from the container: %v", err) + scanner := bufio.NewScanner(readCloser) + for scanner.Scan() { + if _, err := ioWriter.Write(scanner.Bytes()); err != nil { + return fmt.Errorf("error writing to output: %v", err) + } } - - if _, err := ioWriter.Write(content); err != nil { - return fmt.Errorf("could not write to output: %v", err) + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading container logs error: %v", err) } return nil diff --git a/test/manual/algo/Dockerfile b/test/manual/algo/Dockerfile index c46f74a95..682a5551b 100644 --- a/test/manual/algo/Dockerfile +++ b/test/manual/algo/Dockerfile @@ -12,4 +12,4 @@ COPY ./lin_reg.py /cocos/lin_reg.py RUN pip install -r requirements.txt # command to be run when the docker container is started -CMD ["python3", "/cocos/lin_reg.py"] +CMD ["python3", "-u", "/cocos/lin_reg.py"] From c2a4b44769b30070e7b3f0a92f278b76ffeceecb Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:02:30 +0300 Subject: [PATCH 13/83] NOISSUE - Cache and retry message sending (#222) * cache and retry message sending Signed-off-by: Sammy Oina * cache and retry message sending Signed-off-by: SammyOina * remove safeconn Signed-off-by: Sammy Oina * simplify retry Signed-off-by: Sammy Oina * debug disconnect Signed-off-by: Sammy Oina * remove debug Signed-off-by: Sammy Oina * simplify Signed-off-by: SammyOina --------- Signed-off-by: Sammy Oina Signed-off-by: SammyOina --- agent/events/events.go | 70 ++++++++++++++++++++++------- agent/events/mocks/events.go | 17 +------ cmd/agent/main.go | 3 +- go.mod | 3 ++ go.sum | 16 ++++++- internal/logger/protohandler.go | 80 ++++++++++++++++++++++++--------- 6 files changed, 135 insertions(+), 54 deletions(-) diff --git a/agent/events/events.go b/agent/events/events.go index c81a582cb..b62bbfd1d 100644 --- a/agent/events/events.go +++ b/agent/events/events.go @@ -4,6 +4,7 @@ package events import ( "encoding/json" + "sync" "time" "github.com/mdlayher/vsock" @@ -12,10 +13,15 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +const retryInterval = 5 * time.Second + type service struct { - service string - computationID string - conn *vsock.Conn + service string + computationID string + conn *vsock.Conn + cachedMessages [][]byte + mutex sync.Mutex + stopRetry chan struct{} } type AgentEvent struct { @@ -30,19 +36,21 @@ type AgentEvent struct { //go:generate mockery --name Service --output=./mocks --filename events.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" type Service interface { SendEvent(event, status string, details json.RawMessage) error - Close() error + Close() } -func New(svc, computationID string, sockPort uint32) (Service, error) { - conn, err := vsock.Dial(vsock.Host, sockPort, nil) - if err != nil { - return nil, err +func New(svc, computationID string, conn *vsock.Conn) (Service, error) { + s := &service{ + service: svc, + computationID: computationID, + conn: conn, + cachedMessages: make([][]byte, 0), + stopRetry: make(chan struct{}), } - return &service{ - service: svc, - computationID: computationID, - conn: conn, - }, nil + + go s.periodicRetry() + + return s, nil } func (s *service) SendEvent(event, status string, details json.RawMessage) error { @@ -58,12 +66,44 @@ func (s *service) SendEvent(event, status string, details json.RawMessage) error if err != nil { return err } + + s.mutex.Lock() + defer s.mutex.Unlock() + if _, err := s.conn.Write(protoBody); err != nil { + s.cachedMessages = append(s.cachedMessages, protoBody) return err } + return nil } -func (s *service) Close() error { - return s.conn.Close() +func (s *service) periodicRetry() { + ticker := time.NewTicker(retryInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.retrySendCachedMessages() + case <-s.stopRetry: + return + } + } +} + +func (s *service) retrySendCachedMessages() { + s.mutex.Lock() + defer s.mutex.Unlock() + tmp := [][]byte{} + for _, msg := range s.cachedMessages { + if _, err := s.conn.Write(msg); err != nil { + tmp = append(tmp, msg) + } + } + s.cachedMessages = tmp +} + +func (s *service) Close() { + close(s.stopRetry) } diff --git a/agent/events/mocks/events.go b/agent/events/mocks/events.go index d719df59f..b10080dbf 100644 --- a/agent/events/mocks/events.go +++ b/agent/events/mocks/events.go @@ -17,21 +17,8 @@ type Service struct { } // Close provides a mock function with given fields: -func (_m *Service) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 +func (_m *Service) Close() { + _m.Called() } // SendEvent provides a mock function with given fields: event, status, details diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 7076163e1..c03cdb3af 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -62,10 +62,11 @@ func main() { exitCode = 1 return } + handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID) logger := slog.New(handler) - eventSvc, err := events.New(svcName, cfg.ID, manager.ManagerVsockPort) + eventSvc, err := events.New(svcName, cfg.ID, conn) if err != nil { logger.Error(fmt.Sprintf("failed to create events service %s", err.Error())) exitCode = 1 diff --git a/go.mod b/go.mod index 26431fdf3..695aeff88 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,14 @@ require ( require ( github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect @@ -39,6 +41,7 @@ require ( go.opentelemetry.io/otel/sdk v1.28.0 // indirect golang.org/x/mod v0.19.0 // indirect golang.org/x/tools v0.23.0 // indirect + gotest.tools/v3 v3.5.1 // indirect ) require ( diff --git a/go.sum b/go.sum index c88c1bc30..196dc3160 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE= @@ -12,6 +14,8 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= +github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -71,6 +75,10 @@ github.com/mdlayher/vsock v1.2.1 h1:pC1mTJTvjo1r9n9fbm7S1j04rCgCzhCOS5DY0zqHlnQ= github.com/mdlayher/vsock v1.2.1/go.mod h1:NRfCibel++DgeMD8z/hP+PPTjlNJsdPOmxcnENvE+SE= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= @@ -92,6 +100,8 @@ github.com/prometheus/procfs v0.13.0/go.mod h1:cd4PFCR54QLnGKPaKGA6l+cfuNXtht43Z github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -135,8 +145,6 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -156,6 +164,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -179,3 +189,5 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/internal/logger/protohandler.go b/internal/logger/protohandler.go index 8647daf43..d06120001 100644 --- a/internal/logger/protohandler.go +++ b/internal/logger/protohandler.go @@ -6,32 +6,44 @@ import ( "context" "io" "log/slog" + "sync" + "time" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) +const retryInterval = 5 * time.Second + var _ slog.Handler = (*handler)(nil) type handler struct { - opts slog.HandlerOptions - w io.Writer - cmpID string + opts slog.HandlerOptions + w io.Writer + cmpID string + cachedMessages [][]byte + mutex sync.Mutex + stopRetry chan struct{} } -func NewProtoHandler(w io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler { +func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler { if opts == nil { opts = &slog.HandlerOptions{} } - return &handler{ - opts: *opts, - w: w, - cmpID: cmpID, + h := &handler{ + opts: *opts, + w: conn, + cmpID: cmpID, + cachedMessages: make([][]byte, 0), + stopRetry: make(chan struct{}), } + + go h.periodicRetry() + + return h } -// Enabled implements slog.Handler. func (h *handler) Enabled(_ context.Context, l slog.Level) bool { minLevel := slog.LevelInfo if h.opts.Level != nil { @@ -40,13 +52,11 @@ func (h *handler) Enabled(_ context.Context, l slog.Level) bool { return l >= minLevel } -// Handle implements slog.Handler. func (h *handler) Handle(_ context.Context, r slog.Record) error { message := r.Message timestamp := timestamppb.New(r.Time) level := r.Level.String() - // Calculate the number of chunks chunkSize := 500 numChunks := (len(message) + chunkSize - 1) / chunkSize @@ -57,10 +67,8 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { end = len(message) } - // Create a chunk of the message chunk := message[start:end] - // Create the agent log with the chunk agentLog := manager.ClientStreamMessage{ Message: &manager.ClientStreamMessage_AgentLog{ AgentLog: &manager.AgentLog{ @@ -72,27 +80,57 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { }, } - // Marshal the chunk to protobuf b, err := proto.Marshal(&agentLog) if err != nil { return err } - // Write the chunk to the writer - if _, err := h.w.Write(b); err != nil { - return err + h.mutex.Lock() + _, err = h.w.Write(b) + if err != nil { + h.cachedMessages = append(h.cachedMessages, b) } + h.mutex.Unlock() } return nil } -// WithAttrs implements slog.Handler. -func (*handler) WithAttrs(attrs []slog.Attr) slog.Handler { +func (h *handler) periodicRetry() { + ticker := time.NewTicker(retryInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + h.retrySendCachedMessages() + case <-h.stopRetry: + return + } + } +} + +func (h *handler) retrySendCachedMessages() { + h.mutex.Lock() + defer h.mutex.Unlock() + tmp := [][]byte{} + for _, msg := range h.cachedMessages { + if _, err := h.w.Write(msg); err != nil { + tmp = append(tmp, msg) + } + } + h.cachedMessages = tmp +} + +func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler { panic("unimplemented") } -// WithGroup implements slog.Handler. -func (*handler) WithGroup(name string) slog.Handler { +func (h *handler) WithGroup(name string) slog.Handler { panic("unimplemented") } + +func (h *handler) Close() error { + close(h.stopRetry) + return nil +} From 8db88ccbde2bb9d7c9e643a86764b2ca88c0647c Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:49:05 +0300 Subject: [PATCH 14/83] NOISSUE - Fix handling of runreq chunks (#234) * fix handling of runreq chunks Signed-off-by: SammyOina * copy ovmf vars Signed-off-by: SammyOina * fix lint errors Signed-off-by: SammyOina --------- Signed-off-by: SammyOina --- .github/workflows/checkproto.yaml | 2 +- agent/agent.pb.go | 2 +- agent/agent_grpc.pb.go | 2 +- manager/api/grpc/client.go | 96 ++++++++++++++++----- manager/api/grpc/server.go | 18 +++- manager/manager.proto | 2 + manager/qemu/vm.go | 13 +++ pkg/manager/manager.pb.go | 136 +++++++++++++++++------------- pkg/manager/manager_grpc.pb.go | 2 +- 9 files changed, 186 insertions(+), 87 deletions(-) diff --git a/.github/workflows/checkproto.yaml b/.github/workflows/checkproto.yaml index 51be82135..f0b5d0702 100644 --- a/.github/workflows/checkproto.yaml +++ b/.github/workflows/checkproto.yaml @@ -33,7 +33,7 @@ jobs: - name: Set up protoc run: | - PROTOC_VERSION=27.2 + PROTOC_VERSION=27.3 PROTOC_GEN_VERSION=v1.34.2 PROTOC_GRPC_VERSION=v1.4.0 diff --git a/agent/agent.pb.go b/agent/agent.pb.go index e988b40bd..c84172f87 100644 --- a/agent/agent.pb.go +++ b/agent/agent.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v5.27.2 +// protoc v5.27.3 // source: agent/agent.proto package agent diff --git a/agent/agent_grpc.pb.go b/agent/agent_grpc.pb.go index ac2bab2ee..650a37545 100644 --- a/agent/agent_grpc.pb.go +++ b/agent/agent_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.4.0 -// - protoc v5.27.2 +// - protoc v5.27.3 // source: agent/agent.proto package agent diff --git a/manager/api/grpc/client.go b/manager/api/grpc/client.go index 59a6312cf..3b3808056 100644 --- a/manager/api/grpc/client.go +++ b/manager/api/grpc/client.go @@ -3,9 +3,9 @@ package grpc import ( - "bytes" "context" "log/slog" + "sync" "time" "github.com/absmach/magistrala/pkg/errors" @@ -22,19 +22,21 @@ var ( ) type ManagerClient struct { - stream pkgmanager.ManagerService_ProcessClient - svc manager.Service - messageQueue chan *pkgmanager.ClientStreamMessage - logger *slog.Logger + stream pkgmanager.ManagerService_ProcessClient + svc manager.Service + messageQueue chan *pkgmanager.ClientStreamMessage + logger *slog.Logger + runReqManager *runRequestManager } // NewClient returns new gRPC client instance. func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *pkgmanager.ClientStreamMessage, logger *slog.Logger) ManagerClient { return ManagerClient{ - stream: stream, - svc: svc, - messageQueue: messageQueue, - logger: logger, + stream: stream, + svc: svc, + messageQueue: messageQueue, + logger: logger, + runReqManager: newRunRequestManager(), } } @@ -53,7 +55,6 @@ func (client ManagerClient) Process(ctx context.Context, cancel context.CancelFu } func (client ManagerClient) handleIncomingMessages(ctx context.Context) error { - var runReqBuffer bytes.Buffer for { select { case <-ctx.Done(): @@ -63,39 +64,42 @@ func (client ManagerClient) handleIncomingMessages(ctx context.Context) error { if err != nil { return err } - if err := client.processIncomingMessage(ctx, req, &runReqBuffer); err != nil { + if err := client.processIncomingMessage(ctx, req); err != nil { return err } } } } -func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkgmanager.ServerStreamMessage, runReqBuffer *bytes.Buffer) error { +func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkgmanager.ServerStreamMessage) error { switch mes := req.Message.(type) { case *pkgmanager.ServerStreamMessage_RunReqChunks: - return client.handleRunReqChunks(ctx, mes, runReqBuffer) + return client.handleRunReqChunks(ctx, mes) case *pkgmanager.ServerStreamMessage_TerminateReq: return client.handleTerminateReq(mes) case *pkgmanager.ServerStreamMessage_StopComputation: go client.handleStopComputation(ctx, mes) case *pkgmanager.ServerStreamMessage_BackendInfoReq: - go client.handleBackendInfoReq(ctx, mes) + go client.handleBackendInfoReq(mes) default: return errors.New("unknown message type") } return nil } -func (client ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgmanager.ServerStreamMessage_RunReqChunks, runReqBuffer *bytes.Buffer) error { - if len(mes.RunReqChunks.Data) == 0 { +func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgmanager.ServerStreamMessage_RunReqChunks) error { + buffer, complete := client.runReqManager.addChunk(mes.RunReqChunks.Id, mes.RunReqChunks.Data, mes.RunReqChunks.IsLast) + + if complete { var runReq pkgmanager.ComputationRunReq - if err := proto.Unmarshal(runReqBuffer.Bytes(), &runReq); err != nil { + if err := proto.Unmarshal(buffer, &runReq); err != nil { return errors.Wrap(err, errCorruptedManifest) } + go client.executeRun(ctx, &runReq) } - _, err := runReqBuffer.Write(mes.RunReqChunks.Data) - return err + + return nil } func (client ManagerClient) executeRun(ctx context.Context, runReq *pkgmanager.ComputationRunReq) { @@ -129,7 +133,7 @@ func (client ManagerClient) handleStopComputation(ctx context.Context, mes *pkgm client.sendMessage(&pkgmanager.ClientStreamMessage{Message: msg}) } -func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *pkgmanager.ServerStreamMessage_BackendInfoReq) { +func (client ManagerClient) handleBackendInfoReq(mes *pkgmanager.ServerStreamMessage_BackendInfoReq) { res, err := client.svc.FetchBackendInfo() if err != nil { client.logger.Warn(err.Error()) @@ -167,3 +171,55 @@ func (client ManagerClient) sendMessage(mes *pkgmanager.ClientStreamMessage) { client.logger.Warn("Failed to send message: timeout exceeded") } } + +type runRequestManager struct { + requests map[string]*runRequest + mu sync.Mutex +} + +type runRequest struct { + buffer []byte + lastChunk time.Time + timer *time.Timer +} + +func newRunRequestManager() *runRequestManager { + return &runRequestManager{ + requests: make(map[string]*runRequest), + } +} + +func (m *runRequestManager) addChunk(id string, chunk []byte, isLast bool) ([]byte, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + req, exists := m.requests[id] + if !exists { + req = &runRequest{ + buffer: make([]byte, 0), + lastChunk: time.Now(), + timer: time.AfterFunc(runReqTimeout, func() { m.timeoutRequest(id) }), + } + m.requests[id] = req + } + + req.buffer = append(req.buffer, chunk...) + req.lastChunk = time.Now() + req.timer.Reset(runReqTimeout) + + if isLast { + delete(m.requests, id) + req.timer.Stop() + return req.buffer, true + } + + return nil, false +} + +func (m *runRequestManager) timeoutRequest(id string) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.requests, id) + // Log timeout or handle it as needed +} diff --git a/manager/api/grpc/server.go b/manager/api/grpc/server.go index 3c5d1455e..9e17c1e6b 100644 --- a/manager/api/grpc/server.go +++ b/manager/api/grpc/server.go @@ -7,6 +7,7 @@ import ( "context" "errors" "io" + "time" "github.com/ultravioletrs/cocos/pkg/manager" "golang.org/x/sync/errgroup" @@ -20,7 +21,10 @@ var ( ErrUnexpectedMsg = errors.New("unknown message type") ) -const bufferSize = 1024 * 1024 // 1 MB +const ( + bufferSize = 1024 * 1024 // 1 MB + runReqTimeout = 30 * time.Second +) type SendFunc func(*manager.ServerStreamMessage) error @@ -89,14 +93,20 @@ func (s *grpcServer) sendRunReqInChunks(stream manager.ManagerService_ProcessSer for { n, err := dataBuffer.Read(buf) - if err != nil && err != io.EOF { + isLast := false + + if err == io.EOF { + isLast = true + } else if err != nil { return err } chunk := &manager.ServerStreamMessage{ Message: &manager.ServerStreamMessage_RunReqChunks{ RunReqChunks: &manager.RunReqChunks{ - Data: buf[:n], + Id: runReq.Id, + Data: buf[:n], + IsLast: isLast, }, }, } @@ -105,7 +115,7 @@ func (s *grpcServer) sendRunReqInChunks(stream manager.ManagerService_ProcessSer return err } - if err == io.EOF { + if isLast { break } } diff --git a/manager/manager.proto b/manager/manager.proto index 65aa316f1..93d7a35db 100644 --- a/manager/manager.proto +++ b/manager/manager.proto @@ -74,6 +74,8 @@ message ServerStreamMessage { message RunReqChunks { bytes data = 1; + string id = 2; + bool is_last = 3; } message ComputationRunReq { diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index 9b7328929..5df98ddea 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -8,6 +8,7 @@ import ( "os/exec" "github.com/gofrs/uuid" + "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/manager/vm" "github.com/ultravioletrs/cocos/pkg/manager" ) @@ -16,6 +17,7 @@ const ( firmwareVars = "OVMF_VARS" KernelFile = "bzImage" rootfsFile = "rootfs.cpio" + tmpDir = "/tmp" ) type qemuVM struct { @@ -43,6 +45,17 @@ func (v *qemuVM) Start() error { v.config.NetDevConfig.ID = fmt.Sprintf("%s-%s", v.config.NetDevConfig.ID, id) v.config.SevConfig.ID = fmt.Sprintf("%s-%s", v.config.SevConfig.ID, id) + if !v.config.KernelHash { + // Copy firmware vars file. + srcFile := v.config.OVMFVarsConfig.File + dstFile := fmt.Sprintf("%s/%s-%s.fd", tmpDir, firmwareVars, id) + err = internal.CopyFile(srcFile, dstFile) + if err != nil { + return err + } + v.config.OVMFVarsConfig.File = dstFile + } + exe, args, err := v.executableAndArgs() if err != nil { return err diff --git a/pkg/manager/manager.pb.go b/pkg/manager/manager.pb.go index 80ccb028e..36d652222 100644 --- a/pkg/manager/manager.pb.go +++ b/pkg/manager/manager.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v5.27.2 +// protoc v5.27.3 // source: manager/manager.proto package manager @@ -692,7 +692,9 @@ type RunReqChunks struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` + IsLast bool `protobuf:"varint,3,opt,name=is_last,json=isLast,proto3" json:"is_last,omitempty"` } func (x *RunReqChunks) Reset() { @@ -734,6 +736,20 @@ func (x *RunReqChunks) GetData() []byte { return nil } +func (x *RunReqChunks) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *RunReqChunks) GetIsLast() bool { + if x != nil { + return x.IsLast + } + return false +} + type ComputationRunReq struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1237,66 +1253,68 @@ var file_manager_manager_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x48, 0x00, 0x52, 0x0e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x42, - 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x22, 0x0a, 0x0c, 0x52, 0x75, + 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x4b, 0x0a, 0x0c, 0x52, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0xb6, - 0x02, 0x0a, 0x11, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x75, - 0x6e, 0x52, 0x65, 0x71, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, - 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2c, 0x0a, 0x08, 0x64, 0x61, - 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x52, 0x08, - 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x12, 0x30, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, - 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x52, - 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x42, 0x0a, 0x10, 0x72, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x73, 0x18, 0x06, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x52, - 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x52, 0x0f, 0x72, - 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x73, 0x12, 0x37, - 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x41, - 0x67, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x61, 0x67, 0x65, 0x6e, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x0e, 0x42, 0x61, 0x63, 0x6b, 0x65, - 0x6e, 0x64, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x22, 0x2a, 0x0a, 0x0e, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x75, - 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, - 0x65, 0x72, 0x4b, 0x65, 0x79, 0x22, 0x53, 0x0a, 0x07, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, + 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x0e, + 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x17, + 0x0a, 0x07, 0x69, 0x73, 0x5f, 0x6c, 0x61, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x06, 0x69, 0x73, 0x4c, 0x61, 0x73, 0x74, 0x22, 0xb6, 0x02, 0x0a, 0x11, 0x43, 0x6f, 0x6d, 0x70, + 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x12, 0x0e, 0x0a, + 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x2c, 0x0a, 0x08, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x18, + 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, + 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x52, 0x08, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, + 0x73, 0x12, 0x30, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x41, + 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x52, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, + 0x74, 0x68, 0x6d, 0x12, 0x42, 0x0a, 0x10, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x5f, 0x63, 0x6f, + 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, + 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x52, 0x0f, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, + 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x73, 0x12, 0x37, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, + 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x0b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x22, 0x20, 0x0a, 0x0e, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x65, 0x71, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, + 0x69, 0x64, 0x22, 0x2a, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x73, + 0x75, 0x6d, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x22, 0x53, + 0x0a, 0x07, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x61, 0x73, + 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x68, 0x61, 0x73, 0x68, 0x12, 0x18, 0x0a, + 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, + 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x39, 0x0a, 0x09, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x61, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x68, 0x61, 0x73, 0x68, 0x12, 0x18, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x1a, - 0x0a, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x39, 0x0a, 0x09, 0x41, 0x6c, - 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x61, 0x73, 0x68, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x68, 0x61, 0x73, 0x68, 0x12, 0x18, 0x0a, 0x07, 0x75, - 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, - 0x65, 0x72, 0x4b, 0x65, 0x79, 0x22, 0xf9, 0x01, 0x0a, 0x0b, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x1b, 0x0a, - 0x09, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x63, 0x65, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x6b, 0x65, - 0x79, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6b, 0x65, - 0x79, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x24, 0x0a, 0x0e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, - 0x63, 0x61, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, - 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x61, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x24, 0x0a, 0x0e, 0x73, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x63, 0x61, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x61, 0x46, 0x69, 0x6c, - 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x6c, 0x6f, 0x67, 0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x21, - 0x0a, 0x0c, 0x61, 0x74, 0x74, 0x65, 0x73, 0x74, 0x65, 0x64, 0x5f, 0x74, 0x6c, 0x73, 0x18, 0x08, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x74, 0x74, 0x65, 0x73, 0x74, 0x65, 0x64, 0x54, 0x6c, - 0x73, 0x32, 0x5d, 0x0a, 0x0e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x12, 0x4b, 0x0a, 0x07, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x74, 0x72, - 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, - 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x22, 0xf9, + 0x01, 0x0a, 0x0b, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, + 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x6f, + 0x72, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x66, + 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x65, 0x72, 0x74, 0x46, + 0x69, 0x6c, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x6b, 0x65, 0x79, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6b, 0x65, 0x79, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x24, + 0x0a, 0x0e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x61, 0x5f, 0x66, 0x69, 0x6c, 0x65, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x61, + 0x46, 0x69, 0x6c, 0x65, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x63, + 0x61, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x43, 0x61, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x6c, 0x6f, + 0x67, 0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6c, + 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x74, 0x74, 0x65, 0x73, + 0x74, 0x65, 0x64, 0x5f, 0x74, 0x6c, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, + 0x74, 0x74, 0x65, 0x73, 0x74, 0x65, 0x64, 0x54, 0x6c, 0x73, 0x32, 0x5d, 0x0a, 0x0e, 0x4d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4b, 0x0a, 0x07, + 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x72, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/manager/manager_grpc.pb.go b/pkg/manager/manager_grpc.pb.go index ee6b7b52c..0d707e05e 100644 --- a/pkg/manager/manager_grpc.pb.go +++ b/pkg/manager/manager_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.4.0 -// - protoc v5.27.2 +// - protoc v5.27.3 // source: manager/manager.proto package manager From 46d24f928a870ab637300faa9dd56f4641f8c62c Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:26:46 +0300 Subject: [PATCH 15/83] NOISSUE - Add events for disconnection of agent (#233) * add events Signed-off-by: Sammy Oina * fix lint Signed-off-by: Sammy Oina * typo Signed-off-by: Sammy Oina * group logs Signed-off-by: Sammy Oina * fix error Signed-off-by: Sammy Oina * fix initialization of goroutine Signed-off-by: Sammy Oina * add comment Signed-off-by: SammyOina * update comment Signed-off-by: SammyOina * fix lint Signed-off-by: SammyOina * remove naked return Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina Signed-off-by: SammyOina --- .github/workflows/checkproto.yaml | 2 +- .github/workflows/hal.yml | 2 +- .github/workflows/main.yaml | 4 +- .golangci.yaml | 2 +- agent/api/grpc/interceptors.go | 8 ++-- cmd/manager/main.go | 2 +- go.mod | 2 +- manager/agentEventsLogs.go | 66 +++++++++++++++++++++++++------ manager/qemu/vm.go | 52 +++++++++++++++++++++++- manager/vm/logging.go | 2 +- manager/vm/logging_test.go | 2 +- manager/vm/mocks/vm.go | 18 +++++++++ manager/vm/vm.go | 1 + test/computations/main.go | 2 +- test/manual/agent-config/main.go | 6 +-- 15 files changed, 142 insertions(+), 29 deletions(-) diff --git a/.github/workflows/checkproto.yaml b/.github/workflows/checkproto.yaml index f0b5d0702..504f1eb4c 100644 --- a/.github/workflows/checkproto.yaml +++ b/.github/workflows/checkproto.yaml @@ -29,7 +29,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: 1.22.x + go-version: 1.23.x - name: Set up protoc run: | diff --git a/.github/workflows/hal.yml b/.github/workflows/hal.yml index bf98c63ac..606adbd76 100644 --- a/.github/workflows/hal.yml +++ b/.github/workflows/hal.yml @@ -18,7 +18,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: 1.22.x + go-version: 1.23.x cache-dependency-path: "go.sum" - name: Checkout cocos diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index daf9d65b0..d8ca4ee76 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -19,12 +19,12 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: 1.22.x + go-version: 1.23.x - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.59.1 + version: v1.60 - name: Build run: | diff --git a/.golangci.yaml b/.golangci.yaml index 6a6291927..10b481304 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -68,7 +68,7 @@ linters: - errchkjson - errname - execinquery - - exportloopref + - copyloopvar - ginkgolinter - gocheckcompilerdirectives - gofumpt diff --git a/agent/api/grpc/interceptors.go b/agent/api/grpc/interceptors.go index 97250c7be..eb4a246d4 100644 --- a/agent/api/grpc/interceptors.go +++ b/agent/api/grpc/interceptors.go @@ -35,20 +35,20 @@ func (s *authInterceptor) AuthStreamInterceptor() grpc.StreamServerInterceptor { switch info.FullMethod { case agent.AgentService_Algo_FullMethodName: if _, err := s.auth.AuthenticateUser(stream.Context(), auth.AlgorithmProviderRole); err != nil { - return status.Errorf(codes.Unauthenticated, err.Error()) + return status.Errorf(codes.Unauthenticated, "%v", err.Error()) } return handler(srv, stream) case agent.AgentService_Data_FullMethodName: ctx, err := s.auth.AuthenticateUser(stream.Context(), auth.DataProviderRole) if err != nil { - return status.Errorf(codes.Unauthenticated, err.Error()) + return status.Errorf(codes.Unauthenticated, "%s", err.Error()) } wrapped := &wrappedServerStream{ServerStream: stream, ctx: ctx} return handler(srv, wrapped) case agent.AgentService_Result_FullMethodName: ctx, err := s.auth.AuthenticateUser(stream.Context(), auth.ConsumerRole) if err != nil { - return status.Errorf(codes.Unauthenticated, err.Error()) + return status.Errorf(codes.Unauthenticated, "%v", err.Error()) } wrapped := &wrappedServerStream{ServerStream: stream, ctx: ctx} return handler(srv, wrapped) @@ -64,7 +64,7 @@ func (s *authInterceptor) AuthUnaryInterceptor() grpc.UnaryServerInterceptor { case agent.AgentService_Result_FullMethodName: ctx, err := s.auth.AuthenticateUser(ctx, auth.ConsumerRole) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, err.Error()) + return nil, status.Errorf(codes.Unauthenticated, "%v", err.Error()) } return handler(ctx, req) default: diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 87c4218ba..9e16b9c8b 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -57,7 +57,7 @@ func main() { logger, err := mglog.New(os.Stdout, cfg.LogLevel) if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } var exitCode int diff --git a/go.mod b/go.mod index 695aeff88..3aa5a2523 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ultravioletrs/cocos -go 1.22.4 +go 1.23.0 require ( github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746 diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index ee341f41c..d245765b0 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -5,10 +5,13 @@ package manager import ( "fmt" "net" + "regexp" + "strconv" "github.com/mdlayher/vsock" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" ) const ( @@ -16,6 +19,11 @@ const ( messageSize int = 1024 ) +var ( + errFailedToParseCID = fmt.Errorf("failed to parse computation ID") + errComputationNotFound = fmt.Errorf("computation not found") +) + // RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock. func (ms *managerService) RetrieveAgentEventsLogs() { l, err := vsock.Listen(ManagerVsockPort, nil) @@ -42,6 +50,12 @@ func (ms *managerService) handleConnections(conn net.Conn) { n, err := conn.Read(b) if err != nil { ms.logger.Warn(err.Error()) + cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String()) + if err != nil { + ms.logger.Warn(err.Error()) + continue + } + go ms.reportBrokenConnection(cmpID) return } var message manager.ClientStreamMessage @@ -49,18 +63,48 @@ func (ms *managerService) handleConnections(conn net.Conn) { ms.logger.Warn(err.Error()) continue } - cmpID := "" - switch mes := message.Message.(type) { - case *manager.ClientStreamMessage_AgentEvent: - cmpID = mes.AgentEvent.ComputationId - ms.eventsChan <- &manager.ClientStreamMessage{Message: mes} - case *manager.ClientStreamMessage_AgentLog: - cmpID = mes.AgentLog.ComputationId - ms.eventsChan <- &manager.ClientStreamMessage{Message: mes} - default: - ms.logger.Warn("Unexpected agent log or event type") + ms.eventsChan <- &message + + ms.logger.WithGroup("agent-events-logs").Info(message.String()) + } +} + +func (ms *managerService) computationIDFromAddress(address string) (string, error) { + re := regexp.MustCompile(`vm\((\d+)\)`) + matches := re.FindStringSubmatch(address) + + if len(matches) > 1 { + cid, err := strconv.Atoi(matches[1]) + if err != nil { + return "", err } + return ms.findComputationID(cid) + } + return "", errFailedToParseCID +} + +func (ms *managerService) findComputationID(cid int) (string, error) { + ms.mu.Lock() + defer ms.mu.Unlock() + for cmpID, vm := range ms.vms { + if vm.GetCID() == cid { + return cmpID, nil + } + } + + return "", errComputationNotFound +} - ms.logger.Info(fmt.Sprintf("Agent Log/Event, Computation ID: %s, Message: %s", cmpID, message.String())) +func (ms *managerService) reportBrokenConnection(cmpID string) { + ms.eventsChan <- &manager.ClientStreamMessage{ + Message: &manager.ClientStreamMessage_AgentEvent{ + AgentEvent: &manager.AgentEvent{ + EventType: "vm running", + ComputationId: cmpID, + Status: "disconnected", + Timestamp: timestamppb.Now(), + Originator: "manager", + }, + }, } } diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index 5df98ddea..127f451ba 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -6,11 +6,14 @@ import ( "fmt" "os" "os/exec" + "syscall" + "time" "github.com/gofrs/uuid" "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/manager/vm" "github.com/ultravioletrs/cocos/pkg/manager" + "google.golang.org/protobuf/types/known/timestamppb" ) const ( @@ -18,6 +21,7 @@ const ( KernelFile = "bzImage" rootfsFile = "rootfs.cpio" tmpDir = "/tmp" + interval = 5 * time.Second ) type qemuVM struct { @@ -35,7 +39,12 @@ func NewVM(config interface{}, logsChan chan *manager.ClientStreamMessage, compu } } -func (v *qemuVM) Start() error { +func (v *qemuVM) Start() (err error) { + defer func() { + if err == nil { + go v.checkVMProcessPeriodically() + } + }() // Create unique qemu device identifiers id, err := uuid.NewV4() if err != nil { @@ -107,3 +116,44 @@ func (v *qemuVM) executableAndArgs() (string, []string, error) { return exe, args, nil } + +func (v *qemuVM) checkVMProcessPeriodically() { + for { + if !processExists(v.GetProcess()) { + v.logsChan <- &manager.ClientStreamMessage{ + Message: &manager.ClientStreamMessage_AgentEvent{ + AgentEvent: &manager.AgentEvent{ + ComputationId: v.computationId, + EventType: "vm-running", + Status: "stopped", + Timestamp: timestamppb.Now(), + Originator: "manager", + }, + }, + } + break + } + time.Sleep(interval) + } +} + +func processExists(pid int) bool { + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // On Unix systems, FindProcess always succeeds and returns a Process for the given pid, regardless of whether the process exists. + // To test whether the process actually exists, see whether p.Signal(syscall.Signal(0)) reports an error. + if err = process.Signal(syscall.Signal(0)); err == nil { + return true + } + if err == syscall.ESRCH { + return false + } + return false +} + +func (v *qemuVM) GetCID() int { + return v.config.GuestCID +} diff --git a/manager/vm/logging.go b/manager/vm/logging.go index 925611d73..5bd371e33 100644 --- a/manager/vm/logging.go +++ b/manager/vm/logging.go @@ -120,7 +120,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { EventType: "vm-provision", Timestamp: timestamppb.Now(), Originator: "manager", - Status: "failed", + Status: "error", }, }, } diff --git a/manager/vm/logging_test.go b/manager/vm/logging_test.go index fd145d83e..163cd8c75 100644 --- a/manager/vm/logging_test.go +++ b/manager/vm/logging_test.go @@ -121,7 +121,7 @@ func TestStderrWrite(t *testing.T) { assert.NotNil(t, agentEvent) assert.Equal(t, "test-computation", agentEvent.ComputationId) assert.Equal(t, "vm-provision", agentEvent.EventType) - assert.Equal(t, "failed", agentEvent.Status) + assert.Equal(t, "error", agentEvent.Status) assert.NotNil(t, agentEvent.Timestamp) } case <-time.After(time.Second): diff --git a/manager/vm/mocks/vm.go b/manager/vm/mocks/vm.go index 67e455120..f1d977e22 100644 --- a/manager/vm/mocks/vm.go +++ b/manager/vm/mocks/vm.go @@ -15,6 +15,24 @@ type VM struct { mock.Mock } +// GetCID provides a mock function with given fields: +func (_m *VM) GetCID() int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCID") + } + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + // GetProcess provides a mock function with given fields: func (_m *VM) GetProcess() int { ret := _m.Called() diff --git a/manager/vm/vm.go b/manager/vm/vm.go index 856883b1c..b6d463ee3 100644 --- a/manager/vm/vm.go +++ b/manager/vm/vm.go @@ -16,6 +16,7 @@ type VM interface { SendAgentConfig(ac agent.Computation) error SetProcess(pid int) error GetProcess() int + GetCID() int } //go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" diff --git a/test/computations/main.go b/test/computations/main.go index 905b00e00..f6befe973 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -117,7 +117,7 @@ func main() { logger, err := mglog.New(os.Stdout, "debug") if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } go func() { diff --git a/test/manual/agent-config/main.go b/test/manual/agent-config/main.go index 075ac4919..d4c30f165 100644 --- a/test/manual/agent-config/main.go +++ b/test/manual/agent-config/main.go @@ -38,16 +38,16 @@ func main() { pubKey, err := os.ReadFile(pubKeyFile) if err != nil { - log.Fatalf(fmt.Sprintf("failed to read public key file: %s", err)) + log.Fatalf("failed to read public key file: %s", err) } pubPem, _ := pem.Decode(pubKey) algoHash, err := internal.Checksum(algoPath) if err != nil { - log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err)) + log.Fatalf("failed to calculate checksum: %s", err) } dataHash, err := internal.Checksum(dataPath) if err != nil { - log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err)) + log.Fatalf("failed to calculate checksum: %s", err) } l, err := vsock.Listen(manager.ManagerVsockPort, nil) From e0b828d0ae1bd94128eb1302e353f31f0f012e53 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:28:07 +0300 Subject: [PATCH 16/83] use syslog (#237) Signed-off-by: Sammy Oina --- init/systemd/cocos-manager.service | 3 +++ 1 file changed, 3 insertions(+) diff --git a/init/systemd/cocos-manager.service b/init/systemd/cocos-manager.service index 913a7349b..3ecf3c54c 100644 --- a/init/systemd/cocos-manager.service +++ b/init/systemd/cocos-manager.service @@ -7,6 +7,9 @@ ExecStart=cocos-manager Restart=on-failure RestartSec=5s EnvironmentFile=/etc/cocos/cocos-manager.env +StandardOutput=syslog +StandardError=syslog +SyslogIdentifier=cocos-manager [Install] WantedBy=multi-user.target From 20ddb3aa29adb7817267bff8dc7b0f9e70d3a511 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:55:53 +0300 Subject: [PATCH 17/83] restart always (#239) Signed-off-by: Sammy Oina --- init/systemd/cocos-manager.service | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/init/systemd/cocos-manager.service b/init/systemd/cocos-manager.service index 3ecf3c54c..5fdffeba5 100644 --- a/init/systemd/cocos-manager.service +++ b/init/systemd/cocos-manager.service @@ -4,7 +4,7 @@ After=network.target [Service] ExecStart=cocos-manager -Restart=on-failure +Restart=always RestartSec=5s EnvironmentFile=/etc/cocos/cocos-manager.env StandardOutput=syslog From e26deb98e412d5b685f199c9d75a72810802891b Mon Sep 17 00:00:00 2001 From: Smith Jilks <41241359+smithjilks@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:54:09 +0300 Subject: [PATCH 18/83] COCOS-143 - Add agent service tests (#170) * Add agent service tests Signed-off-by: Jilks Smith * Update agent service tests * Fix agent service tests * Improve agent service test coverage * Improve agent service test coverage Signed-off-by: Jilks Smith * Fix tests Signed-off-by: Jilks Smith * Refactor and improve coverage Signed-off-by: Jilks Smith --------- Signed-off-by: Jilks Smith --- agent/algorithm/logging.go | 3 +- agent/algorithm/python/python.go | 6 +- agent/quoteprovider/mocks/QuoteProvider.go | 95 +++++ agent/service.go | 7 +- agent/service_test.go | 394 +++++++++++++++++++++ 5 files changed, 499 insertions(+), 6 deletions(-) create mode 100644 agent/quoteprovider/mocks/QuoteProvider.go create mode 100644 agent/service_test.go diff --git a/agent/algorithm/logging.go b/agent/algorithm/logging.go index ace31cf73..74196ebf5 100644 --- a/agent/algorithm/logging.go +++ b/agent/algorithm/logging.go @@ -4,6 +4,7 @@ package algorithm import ( "bytes" + "encoding/json" "io" "log/slog" @@ -65,7 +66,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { s.Logger.Error(string(buf[:n])) } - if err := s.EventSvc.SendEvent("algorithm-run", "error", nil); err != nil { + if err := s.EventSvc.SendEvent("algorithm-run", "error", json.RawMessage{}); err != nil { return len(p), err } diff --git a/agent/algorithm/python/python.go b/agent/algorithm/python/python.go index c739fcf4c..206133354 100644 --- a/agent/algorithm/python/python.go +++ b/agent/algorithm/python/python.go @@ -18,15 +18,15 @@ import ( const ( PyRuntime = "python3" - pyRuntimeKey = "python_runtime" + PyRuntimeKey = "python_runtime" ) func PythonRunTimeToContext(ctx context.Context, runtime string) context.Context { - return metadata.AppendToOutgoingContext(ctx, pyRuntimeKey, runtime) + return metadata.AppendToOutgoingContext(ctx, PyRuntimeKey, runtime) } func PythonRunTimeFromContext(ctx context.Context) string { - return metadata.ValueFromIncomingContext(ctx, pyRuntimeKey)[0] + return metadata.ValueFromIncomingContext(ctx, PyRuntimeKey)[0] } var _ algorithm.Algorithm = (*python)(nil) diff --git a/agent/quoteprovider/mocks/QuoteProvider.go b/agent/quoteprovider/mocks/QuoteProvider.go new file mode 100644 index 000000000..c62780239 --- /dev/null +++ b/agent/quoteprovider/mocks/QuoteProvider.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + sevsnp "github.com/google/go-sev-guest/proto/sevsnp" + mock "github.com/stretchr/testify/mock" +) + +// QuoteProvider is an autogenerated mock type for the QuoteProvider type +type QuoteProvider struct { + mock.Mock +} + +// GetRawQuote provides a mock function with given fields: reportData +func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) { + ret := _m.Called(reportData) + + if len(ret) == 0 { + panic("no return value specified for GetRawQuote") + } + + var r0 []uint8 + var r1 error + if rf, ok := ret.Get(0).(func([64]byte) ([]uint8, error)); ok { + return rf(reportData) + } + if rf, ok := ret.Get(0).(func([64]byte) []uint8); ok { + r0 = rf(reportData) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]uint8) + } + } + + if rf, ok := ret.Get(1).(func([64]byte) error); ok { + r1 = rf(reportData) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IsSupported provides a mock function with given fields: +func (_m *QuoteProvider) IsSupported() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsSupported") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Product provides a mock function with given fields: +func (_m *QuoteProvider) Product() *sevsnp.SevProduct { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Product") + } + + var r0 *sevsnp.SevProduct + if rf, ok := ret.Get(0).(func() *sevsnp.SevProduct); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sevsnp.SevProduct) + } + } + + return r0 +} + +// NewQuoteProvider creates a new instance of QuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewQuoteProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *QuoteProvider { + mock := &QuoteProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/service.go b/agent/service.go index 9ec81a48e..d70a0ba8d 100644 --- a/agent/service.go +++ b/agent/service.go @@ -55,6 +55,8 @@ var ( ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest") // ErrAllResultsConsumed indicates all results have been consumed. ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers") + // ErrAttestationFailed attestation failed. + ErrAttestationFailed = errors.New("failed to get raw quote") ) // Service specifies an API that must be fullfiled by the domain service @@ -124,7 +126,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { return fmt.Errorf("error getting current directory: %v", err) } - f, err := os.Create(filepath.Join(currentDir, "algorithm")) + f, err := os.Create(filepath.Join(currentDir, "algo")) if err != nil { return fmt.Errorf("error creating algorithm file: %v", err) } @@ -317,8 +319,9 @@ func (as *agentService) runComputation() { } func (as *agentService) publishEvent(status string, details json.RawMessage) func() { + st := as.sm.GetState().String() return func() { - if err := as.eventSvc.SendEvent(as.sm.State.String(), status, details); err != nil { + if err := as.eventSvc.SendEvent(st, status, details); err != nil { as.sm.logger.Warn(err.Error()) } } diff --git a/agent/service_test.go b/agent/service_test.go new file mode 100644 index 000000000..3fcacad45 --- /dev/null +++ b/agent/service_test.go @@ -0,0 +1,394 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package agent + +import ( + "context" + "crypto/rand" + "log" + "os" + "testing" + "time" + + mglog "github.com/absmach/magistrala/logger" + "github.com/absmach/magistrala/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/algorithm/python" + "github.com/ultravioletrs/cocos/agent/events/mocks" + "github.com/ultravioletrs/cocos/agent/quoteprovider" + mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks" + "golang.org/x/crypto/sha3" + "google.golang.org/grpc/metadata" +) + +var ( + algoPath = "../test/manual/algo/lin_reg.py" + reqPath = "../test/manual/algo/requirements.txt" + dataPath = "../test/manual/data/iris.csv" +) + +const datasetFile = "iris.csv" + +func TestAlgo(t *testing.T) { + events := new(mocks.Service) + + evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil) + defer evCall.Unset() + + qp, err := quoteprovider.GetQuoteProvider() + require.NoError(t, err) + + algo, err := os.ReadFile(algoPath) + require.NoError(t, err) + + algoHash := sha3.Sum256(algo) + + reqFile, err := os.ReadFile(reqPath) + require.NoError(t, err) + + testCases := []struct { + name string + err error + algo Algorithm + algoType string + }{ + { + name: "Test Algo successfully", + algo: Algorithm{ + Algorithm: algo, + Hash: algoHash, + }, + algoType: "python", + err: nil, + }, + { + name: "Test Algo successfully with requirements file", + algo: Algorithm{ + Algorithm: algo, + Hash: algoHash, + Requirements: reqFile, + }, + algoType: "python", + err: nil, + }, + { + name: "Test Algo type binary successfully", + algo: Algorithm{ + Algorithm: algo, + Hash: algoHash, + }, + algoType: "bin", + err: nil, + }, + { + name: "Test Algo type wasm successfully", + algo: Algorithm{ + Algorithm: algo, + Hash: algoHash, + }, + algoType: "wasm", + err: nil, + }, + { + name: "Test Algo type docker successfully", + algo: Algorithm{ + Algorithm: algo, + Hash: algoHash, + }, + algoType: "docker", + err: nil, + }, + { + name: "Test algo hash mismatch", + algo: Algorithm{}, + algoType: "python", + err: ErrHashMismatch, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = os.RemoveAll("datasets") + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(context.Background(), + metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime), + ) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp) + + time.Sleep(300 * time.Millisecond) + + err = svc.Algo(ctx, tc.algo) + assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) + }) + } +} + +func TestData(t *testing.T) { + events := new(mocks.Service) + + evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil) + defer evCall.Unset() + + qp, err := quoteprovider.GetQuoteProvider() + require.NoError(t, err) + + algo, err := os.ReadFile(algoPath) + require.NoError(t, err) + + algoHash := sha3.Sum256(algo) + + alg := Algorithm{ + Hash: algoHash, + Algorithm: algo, + } + + data, err := os.ReadFile(dataPath) + require.NoError(t, err) + + dataHash := sha3.Sum256(data) + + cases := []struct { + name string + data Dataset + err error + }{ + { + name: "Test data successfully", + data: Dataset{ + Hash: dataHash, + Dataset: data, + Filename: datasetFile, + }, + }, + { + name: "Test State not ready", + data: Dataset{ + Dataset: data, + Hash: dataHash, + Filename: datasetFile, + }, + err: ErrStateNotReady, + }, + { + name: "Test File name does not match manifest", + data: Dataset{ + Dataset: data, + Hash: dataHash, + Filename: "invalid", + }, + err: ErrFileNameMismatch, + }, + { + name: "Test dataset not declared in manifest", + data: Dataset{ + Filename: datasetFile, + }, + err: ErrUndeclaredDataset, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), + metadata.Pairs( + algorithm.AlgoTypeKey, "python", + python.PyRuntimeKey, python.PyRuntime), + ) + + if tc.err != ErrUndeclaredDataset { + ctx = IndexToContext(ctx, 0) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + comp := testComputation(t) + + svc := New(ctx, mglog.NewMock(), events, comp, qp) + time.Sleep(300 * time.Millisecond) + + if tc.err != ErrStateNotReady { + _ = svc.Algo(ctx, alg) + time.Sleep(300 * time.Millisecond) + } + err = svc.Data(ctx, tc.data) + _ = os.RemoveAll("datasets") + _ = os.RemoveAll("results") + assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) + }) + } +} + +func TestResult(t *testing.T) { + events := new(mocks.Service) + + evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil) + defer evCall.Unset() + + qp, err := quoteprovider.GetQuoteProvider() + require.NoError(t, err) + + cases := []struct { + name string + err error + setup func(svc *agentService) + ctxSetup func(ctx context.Context) context.Context + }{ + { + name: "Test results not ready", + err: ErrResultsNotReady, + setup: func(svc *agentService) { + }, + }, + { + name: "Test all results consumed", + err: ErrAllResultsConsumed, + setup: func(svc *agentService) { + svc.sm.SetState(resultsReady) + svc.computation.ResultConsumers = []ResultConsumer{} + }, + ctxSetup: func(ctx context.Context) context.Context { + return IndexToContext(ctx, 0) + }, + }, + { + name: "Test undeclared consumer", + err: ErrUndeclaredConsumer, + setup: func(svc *agentService) { + svc.sm.SetState(resultsReady) + svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}} + }, + ctxSetup: func(ctx context.Context) context.Context { + return ctx + }, + }, + { + name: "Test results consumed and event sent", + err: nil, + setup: func(svc *agentService) { + svc.sm.SetState(resultsReady) + svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}} + }, + ctxSetup: func(ctx context.Context) context.Context { + return IndexToContext(ctx, 0) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), + metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime), + ) + + if tc.ctxSetup != nil { + ctx = tc.ctxSetup(ctx) + } + + svc := &agentService{ + sm: NewStateMachine(mglog.NewMock(), testComputation(t)), + eventSvc: events, + quoteProvider: qp, + computation: testComputation(t), + } + + go svc.sm.Start(ctx) + tc.setup(svc) + _, err := svc.Result(ctx) + _ = os.RemoveAll("datasets") + _ = os.RemoveAll("results") + + assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err) + }) + } +} + +func TestAttestation(t *testing.T) { + events := new(mocks.Service) + qp := new(mocks2.QuoteProvider) + + evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil) + defer evCall.Unset() + + cases := []struct { + name string + reportData [ReportDataSize]byte + rawQuote []uint8 + err error + }{ + { + name: "Test attestation successful", + reportData: generateReportData(), + rawQuote: make([]uint8, 0), + err: nil, + }, + { + name: "Test attestation failed", + reportData: generateReportData(), + rawQuote: nil, + err: ErrAttestationFailed, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), + metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime), + ) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + getQuote := qp.On("GetRawQuote", mock.Anything).Return(tc.rawQuote, tc.err) + if tc.err != ErrAttestationFailed { + getQuote = qp.On("GetRawQuote", mock.Anything).Return(tc.reportData, nil) + } + defer getQuote.Unset() + + svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp) + time.Sleep(300 * time.Millisecond) + _, err := svc.Attestation(ctx, tc.reportData) + assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) + }) + } +} + +func generateReportData() [ReportDataSize]byte { + bytes := make([]byte, ReportDataSize) + _, err := rand.Read(bytes) + if err != nil { + log.Fatalf("Failed to generate random bytes: %v", err) + } + return [64]byte(bytes) +} + +func testComputation(t *testing.T) Computation { + algo, err := os.ReadFile(algoPath) + require.NoError(t, err) + + algoHash := sha3.Sum256(algo) + + data, err := os.ReadFile(dataPath) + require.NoError(t, err) + + dataHash := sha3.Sum256(data) + + return Computation{ + ID: "1", + Name: "sample computation", + Description: "sample description", + Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}}, + Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo}, + ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}}, + AgentConfig: AgentConfig{ + Port: "7002", + LogLevel: "debug", + AttestedTls: false, + }, + } +} From 355f95771d4cd63fedb0287582d3e7a8a0e3fb4e Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:10:19 +0300 Subject: [PATCH 19/83] NOISSUE - Use constants for log level (#240) * use constants for log level Signed-off-by: Sammy Oina * fix tests Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- manager/vm/logging.go | 5 +++-- manager/vm/logging_test.go | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/manager/vm/logging.go b/manager/vm/logging.go index 5bd371e33..3bf412c5d 100644 --- a/manager/vm/logging.go +++ b/manager/vm/logging.go @@ -6,6 +6,7 @@ import ( "bytes" "errors" "io" + "log/slog" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/types/known/timestamppb" @@ -62,7 +63,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) { AgentLog: &manager.AgentLog{ Message: string(buf[:n]), ComputationId: s.ComputationId, - Level: "debug", + Level: slog.LevelDebug.String(), Timestamp: timestamppb.Now(), }, }, @@ -101,7 +102,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { AgentLog: &manager.AgentLog{ Message: string(buf[:n]), ComputationId: s.ComputationId, - Level: "error", + Level: slog.LevelError.String(), Timestamp: timestamppb.Now(), }, }, diff --git a/manager/vm/logging_test.go b/manager/vm/logging_test.go index 163cd8c75..d7d2a579e 100644 --- a/manager/vm/logging_test.go +++ b/manager/vm/logging_test.go @@ -3,6 +3,7 @@ package vm import ( + "log/slog" "testing" "time" @@ -54,7 +55,7 @@ func TestStdoutWrite(t *testing.T) { agentLog := msg.GetAgentLog() assert.NotNil(t, agentLog) assert.Equal(t, "test-computation", agentLog.ComputationId) - assert.Equal(t, "debug", agentLog.Level) + assert.Equal(t, slog.LevelDebug.String(), agentLog.Level) assert.NotEmpty(t, agentLog.Message) assert.NotNil(t, agentLog.Timestamp) case <-time.After(time.Second): @@ -113,7 +114,7 @@ func TestStderrWrite(t *testing.T) { agentLog := msg.GetAgentLog() assert.NotNil(t, agentLog) assert.Equal(t, "test-computation", agentLog.ComputationId) - assert.Equal(t, "error", agentLog.Level) + assert.Equal(t, slog.LevelError.String(), agentLog.Level) assert.NotEmpty(t, agentLog.Message) assert.NotNil(t, agentLog.Timestamp) case *manager.ClientStreamMessage_AgentEvent: From c14a6338ccda681ed2dfd1c572ccd25b0cc8ae63 Mon Sep 17 00:00:00 2001 From: Washington Kigani Kamadi Date: Tue, 17 Sep 2024 16:58:15 +0300 Subject: [PATCH 20/83] NOISSUE - Enhance event status (#235) * enhance timeline Signed-off-by: WashingtonKK * fix: remove redundant event Signed-off-by: WashingtonKK * use constant Signed-off-by: WashingtonKK * lint Signed-off-by: WashingtonKK * use typed constant for status Signed-off-by: WashingtonKK * refactor status Signed-off-by: WashingtonKK * export agent status and state Signed-off-by: WashingtonKK * ehance event states Signed-off-by: WashingtonKK * fix tests Signed-off-by: WashingtonKK * use manager states and status Signed-off-by: WashingtonKK * move algo-run to agent package Signed-off-by: WashingtonKK * replace literal with constant Signed-off-by: WashingtonKK * replace manager variable with constant Signed-off-by: WashingtonKK --------- Signed-off-by: WashingtonKK --- agent/algorithm/logging.go | 8 ++- agent/algorithm/logging_test.go | 3 +- agent/service.go | 24 ++++----- agent/service_test.go | 6 +-- agent/state.go | 77 +++++++++++++++++------------ agent/state_string.go | 31 ++++++------ agent/state_test.go | 22 ++++----- agent/status_string.go | 28 +++++++++++ manager/qemu/vm.go | 4 +- manager/service.go | 20 ++++---- manager/vm/logging.go | 4 +- manager/vm/logging_test.go | 4 +- pkg/manager/manager_states.go | 21 ++++++++ pkg/manager/managerstate_string.go | 25 ++++++++++ pkg/manager/managerstatus_string.go | 25 ++++++++++ 15 files changed, 210 insertions(+), 92 deletions(-) create mode 100644 agent/status_string.go create mode 100644 pkg/manager/manager_states.go create mode 100644 pkg/manager/managerstate_string.go create mode 100644 pkg/manager/managerstatus_string.go diff --git a/agent/algorithm/logging.go b/agent/algorithm/logging.go index 74196ebf5..e1289511e 100644 --- a/agent/algorithm/logging.go +++ b/agent/algorithm/logging.go @@ -16,7 +16,11 @@ var ( _ io.Writer = &Stderr{} ) -const bufSize = 1024 +const ( + bufSize = 1024 + algorithmRun = "AlgorithmRun" + errorStatus = "Error" +) type Stdout struct { Logger *slog.Logger @@ -66,7 +70,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { s.Logger.Error(string(buf[:n])) } - if err := s.EventSvc.SendEvent("algorithm-run", "error", json.RawMessage{}); err != nil { + if err := s.EventSvc.SendEvent(algorithmRun, errorStatus, json.RawMessage{}); err != nil { return len(p), err } diff --git a/agent/algorithm/logging_test.go b/agent/algorithm/logging_test.go index fd9707143..91427028b 100644 --- a/agent/algorithm/logging_test.go +++ b/agent/algorithm/logging_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/ultravioletrs/cocos/agent/events/mocks" + "github.com/ultravioletrs/cocos/pkg/manager" ) func TestStdoutWrite(t *testing.T) { @@ -72,7 +73,7 @@ func TestStderrWrite(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockEventService := mocks.NewService(t) - mockEventService.On("SendEvent", "algorithm-run", "error", mock.Anything).Return(nil) + mockEventService.On("SendEvent", "AlgorithmRun", manager.Error.String(), mock.Anything).Return(nil) stderr := &Stderr{Logger: mglog.NewMock(), EventSvc: mockEventService} n, err := stderr.Write([]byte(tt.input)) diff --git a/agent/service.go b/agent/service.go index d70a0ba8d..00082acc0 100644 --- a/agent/service.go +++ b/agent/service.go @@ -92,14 +92,14 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp go svc.sm.Start(ctx) svc.sm.SendEvent(start) - svc.sm.StateFunctions[idle] = svc.publishEvent("in-progress", json.RawMessage{}) - svc.sm.StateFunctions[receivingManifest] = svc.publishEvent("in-progress", json.RawMessage{}) - svc.sm.StateFunctions[receivingAlgorithm] = svc.publishEvent("in-progress", json.RawMessage{}) - svc.sm.StateFunctions[receivingData] = svc.publishEvent("in-progress", json.RawMessage{}) - svc.sm.StateFunctions[resultsReady] = svc.publishEvent("in-progress", json.RawMessage{}) - svc.sm.StateFunctions[complete] = svc.publishEvent("in-progress", json.RawMessage{}) - svc.sm.StateFunctions[running] = svc.runComputation - svc.sm.StateFunctions[failed] = svc.publishEvent("failed", json.RawMessage{}) + svc.sm.StateFunctions[Idle] = svc.publishEvent(IdleState.String(), json.RawMessage{}) + svc.sm.StateFunctions[ReceivingManifest] = svc.publishEvent(InProgress.String(), json.RawMessage{}) + svc.sm.StateFunctions[ReceivingAlgorithm] = svc.publishEvent(InProgress.String(), json.RawMessage{}) + svc.sm.StateFunctions[ReceivingData] = svc.publishEvent(InProgress.String(), json.RawMessage{}) + svc.sm.StateFunctions[ConsumingResults] = svc.publishEvent(Ready.String(), json.RawMessage{}) + svc.sm.StateFunctions[Complete] = svc.publishEvent(Completed.String(), json.RawMessage{}) + svc.sm.StateFunctions[Running] = svc.runComputation + svc.sm.StateFunctions[Failed] = svc.publishEvent(Failed.String(), json.RawMessage{}) svc.computation = cmp @@ -108,7 +108,7 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp } func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { - if as.sm.GetState() != receivingAlgorithm { + if as.sm.GetState() != ReceivingAlgorithm { return ErrStateNotReady } if as.algorithm != nil { @@ -189,7 +189,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { } func (as *agentService) Data(ctx context.Context, dataset Dataset) error { - if as.sm.GetState() != receivingData { + if as.sm.GetState() != ReceivingData { return ErrStateNotReady } if len(as.computation.Datasets) == 0 { @@ -242,7 +242,7 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { } func (as *agentService) Result(ctx context.Context) ([]byte, error) { - if as.sm.GetState() != resultsReady && as.sm.GetState() != failed { + if as.sm.GetState() != ConsumingResults && as.sm.GetState() != Failed { return []byte{}, ErrResultsNotReady } if len(as.computation.ResultConsumers) == 0 { @@ -254,7 +254,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) { } as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1) - if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == resultsReady { + if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults { as.sm.SendEvent(resultsConsumed) } diff --git a/agent/service_test.go b/agent/service_test.go index 3fcacad45..47f1b85fd 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -251,7 +251,7 @@ func TestResult(t *testing.T) { name: "Test all results consumed", err: ErrAllResultsConsumed, setup: func(svc *agentService) { - svc.sm.SetState(resultsReady) + svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{} }, ctxSetup: func(ctx context.Context) context.Context { @@ -262,7 +262,7 @@ func TestResult(t *testing.T) { name: "Test undeclared consumer", err: ErrUndeclaredConsumer, setup: func(svc *agentService) { - svc.sm.SetState(resultsReady) + svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}} }, ctxSetup: func(ctx context.Context) context.Context { @@ -273,7 +273,7 @@ func TestResult(t *testing.T) { name: "Test results consumed and event sent", err: nil, setup: func(svc *agentService) { - svc.sm.SetState(resultsReady) + svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}} }, ctxSetup: func(ctx context.Context) context.Context { diff --git a/agent/state.go b/agent/state.go index 61a388f74..cafee0b89 100644 --- a/agent/state.go +++ b/agent/state.go @@ -9,18 +9,31 @@ import ( "sync" ) -//go:generate stringer -type=state -type state uint8 +//go:generate stringer -type=State +type State uint8 const ( - idle state = iota - receivingManifest - receivingAlgorithm - receivingData - running - resultsReady - complete - failed + Idle State = iota + ReceivingManifest + ReceivingAlgorithm + ReceivingData + Running + ConsumingResults + Complete + Failed + AlgorithmRun +) + +//go:generate stringer -type=Status +type Status uint8 + +const ( + IdleState Status = iota + InProgress + Ready + Completed + Terminated + Error ) type event uint8 @@ -38,10 +51,10 @@ const ( // StateMachine represents the state machine. type StateMachine struct { mu sync.Mutex - State state + State State EventChan chan event - Transitions map[state]map[event]state - StateFunctions map[state]func() + Transitions map[State]map[event]State + StateFunctions map[State]func() logger *slog.Logger wg *sync.WaitGroup } @@ -49,37 +62,37 @@ type StateMachine struct { // NewStateMachine creates a new StateMachine. func NewStateMachine(logger *slog.Logger, cmp Computation) *StateMachine { sm := &StateMachine{ - State: idle, + State: Idle, EventChan: make(chan event), - Transitions: make(map[state]map[event]state), - StateFunctions: make(map[state]func()), + Transitions: make(map[State]map[event]State), + StateFunctions: make(map[State]func()), logger: logger, wg: &sync.WaitGroup{}, } - sm.Transitions[idle] = make(map[event]state) - sm.Transitions[idle][start] = receivingManifest + sm.Transitions[Idle] = make(map[event]State) + sm.Transitions[Idle][start] = ReceivingManifest - sm.Transitions[receivingManifest] = make(map[event]state) - sm.Transitions[receivingManifest][manifestReceived] = receivingAlgorithm + sm.Transitions[ReceivingManifest] = make(map[event]State) + sm.Transitions[ReceivingManifest][manifestReceived] = ReceivingAlgorithm - sm.Transitions[receivingAlgorithm] = make(map[event]state) + sm.Transitions[ReceivingAlgorithm] = make(map[event]State) switch len(cmp.Datasets) { case 0: - sm.Transitions[receivingAlgorithm][algorithmReceived] = running + sm.Transitions[ReceivingAlgorithm][algorithmReceived] = Running default: - sm.Transitions[receivingAlgorithm][algorithmReceived] = receivingData + sm.Transitions[ReceivingAlgorithm][algorithmReceived] = ReceivingData } - sm.Transitions[receivingData] = make(map[event]state) - sm.Transitions[receivingData][dataReceived] = running + sm.Transitions[ReceivingData] = make(map[event]State) + sm.Transitions[ReceivingData][dataReceived] = Running - sm.Transitions[running] = make(map[event]state) - sm.Transitions[running][runComplete] = resultsReady - sm.Transitions[running][runFailed] = failed + sm.Transitions[Running] = make(map[event]State) + sm.Transitions[Running][runComplete] = ConsumingResults + sm.Transitions[Running][runFailed] = Failed - sm.Transitions[resultsReady] = make(map[event]state) - sm.Transitions[resultsReady][resultsConsumed] = complete + sm.Transitions[ConsumingResults] = make(map[event]State) + sm.Transitions[ConsumingResults][resultsConsumed] = Complete return sm } @@ -118,13 +131,13 @@ func (sm *StateMachine) SendEvent(event event) { sm.EventChan <- event } -func (sm *StateMachine) GetState() state { +func (sm *StateMachine) GetState() State { sm.mu.Lock() defer sm.mu.Unlock() return sm.State } -func (sm *StateMachine) SetState(state state) { +func (sm *StateMachine) SetState(state State) { sm.mu.Lock() defer sm.mu.Unlock() sm.State = state diff --git a/agent/state_string.go b/agent/state_string.go index 0a4a5a4eb..b084ec0bc 100644 --- a/agent/state_string.go +++ b/agent/state_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=state"; DO NOT EDIT. +// Code generated by "stringer -type=State"; DO NOT EDIT. package agent @@ -8,23 +8,24 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[idle-0] - _ = x[receivingManifest-1] - _ = x[receivingAlgorithm-2] - _ = x[receivingData-3] - _ = x[running-4] - _ = x[resultsReady-5] - _ = x[complete-6] - _ = x[failed-7] + _ = x[Idle-0] + _ = x[ReceivingManifest-1] + _ = x[ReceivingAlgorithm-2] + _ = x[ReceivingData-3] + _ = x[Running-4] + _ = x[ConsumingResults-5] + _ = x[Complete-6] + _ = x[Failed-7] + _ = x[AlgorithmRun-8] } -const _state_name = "idlereceivingManifestreceivingAlgorithmreceivingDatarunningresultsReadycompletefailed" +const _State_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailedAlgorithmRun" -var _state_index = [...]uint8{0, 4, 21, 39, 52, 59, 71, 79, 85} +var _State_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89, 101} -func (i state) String() string { - if i >= state(len(_state_index)-1) { - return "state(" + strconv.FormatInt(int64(i), 10) + ")" +func (i State) String() string { + if i >= State(len(_State_index)-1) { + return "State(" + strconv.FormatInt(int64(i), 10) + ")" } - return _state_name[_state_index[i]:_state_index[i+1]] + return _State_name[_State_index[i]:_State_index[i+1]] } diff --git a/agent/state_test.go b/agent/state_test.go index 342d206c6..5ce786359 100644 --- a/agent/state_test.go +++ b/agent/state_test.go @@ -21,18 +21,18 @@ var cmp = Computation{ func TestStateMachineTransitions(t *testing.T) { cases := []struct { - fromState state + fromState State event event - expected state + expected State cmp Computation }{ - {idle, start, receivingManifest, cmp}, - {receivingManifest, manifestReceived, receivingAlgorithm, cmp}, - {receivingAlgorithm, algorithmReceived, receivingData, cmp}, - {receivingAlgorithm, algorithmReceived, running, Computation{}}, - {receivingData, dataReceived, running, cmp}, - {running, runComplete, resultsReady, cmp}, - {resultsReady, resultsConsumed, complete, cmp}, + {Idle, start, ReceivingManifest, cmp}, + {ReceivingManifest, manifestReceived, ReceivingAlgorithm, cmp}, + {ReceivingAlgorithm, algorithmReceived, ReceivingData, cmp}, + {ReceivingAlgorithm, algorithmReceived, Running, Computation{}}, + {ReceivingData, dataReceived, Running, cmp}, + {Running, runComplete, ConsumingResults, cmp}, + {ConsumingResults, resultsConsumed, Complete, cmp}, } for _, tc := range cases { @@ -61,11 +61,11 @@ func TestStateMachineInvalidTransition(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go sm.Start(ctx) - sm.SetState(idle) + sm.SetState(Idle) sm.SendEvent(dataReceived) - if sm.State != idle { + if sm.State != Idle { t.Errorf("State should not change on an invalid event, but got %v", sm.State) } cancel() diff --git a/agent/status_string.go b/agent/status_string.go new file mode 100644 index 000000000..0b19bb195 --- /dev/null +++ b/agent/status_string.go @@ -0,0 +1,28 @@ +// Code generated by "stringer -type=Status"; DO NOT EDIT. + +package agent + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[IdleState-0] + _ = x[InProgress-1] + _ = x[Ready-2] + _ = x[Completed-3] + _ = x[Terminated-4] + _ = x[Error-5] +} + +const _Status_name = "IdleStateInProgressReadyCompletedTerminatedError" + +var _Status_index = [...]uint8{0, 9, 19, 24, 33, 43, 48} + +func (i Status) String() string { + if i >= Status(len(_Status_index)-1) { + return "Status(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Status_name[_Status_index[i]:_Status_index[i+1]] +} diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index 127f451ba..c330ee257 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -124,8 +124,8 @@ func (v *qemuVM) checkVMProcessPeriodically() { Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ ComputationId: v.computationId, - EventType: "vm-running", - Status: "stopped", + EventType: manager.VmRunning.String(), + Status: manager.Stopped.String(), Timestamp: timestamppb.Now(), Originator: "manager", }, diff --git a/manager/service.go b/manager/service.go index d1f3b1db6..9d97d3162 100644 --- a/manager/service.go +++ b/manager/service.go @@ -111,7 +111,7 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, } func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) { - ms.publishEvent("vm-provision", c.Id, "starting", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, manager.Starting.String(), json.RawMessage{}) ac := agent.Computation{ ID: c.Id, Name: c.Name, @@ -130,7 +130,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) for _, data := range c.Datasets { if len(data.Hash) != hashLength { - ms.publishEvent("vm-provision", c.Id, "failed", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{}) return "", errInvalidHashLength } ac.Datasets = append(ac.Datasets, agent.Dataset{Hash: [hashLength]byte(data.Hash), UserKey: data.UserKey, Filename: data.Filename}) @@ -142,7 +142,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) agentPort, err := getFreePort(ms.portRangeMin, ms.portRangeMax) if err != nil { - ms.publishEvent("vm-provision", c.Id, "failed", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{}) return "", errors.Wrap(ErrFailedToAllocatePort, err) } ms.qemuCfg.HostFwdAgent = agentPort @@ -150,7 +150,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) ch, err := computationHash(ac) if err != nil { - ms.publishEvent("vm-provision", c.Id, "failed", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{}) return "", errors.Wrap(ErrFailedToCalculateHash, err) } @@ -158,9 +158,9 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) ms.qemuCfg.SevConfig.HostData = base64.StdEncoding.EncodeToString(ch[:]) cvm := ms.vmFactory(ms.qemuCfg, ms.eventsChan, c.Id) - ms.publishEvent("vm-provision", c.Id, "in-progress", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.InProgress.String(), json.RawMessage{}) if err = cvm.Start(); err != nil { - ms.publishEvent("vm-provision", c.Id, "failed", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{}) return "", err } ms.mu.Lock() @@ -187,7 +187,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) ms.qemuCfg.VSockConfig.Vnc++ - ms.publishEvent("vm-provision", c.Id, "complete", json.RawMessage{}) + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Completed.String(), json.RawMessage{}) return fmt.Sprint(ms.qemuCfg.HostFwdAgent), nil } @@ -196,11 +196,11 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error defer ms.mu.Unlock() cvm, ok := ms.vms[computationID] if !ok { - defer ms.publishEvent("stop-computation", computationID, "failed", json.RawMessage{}) + defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Failed.String(), json.RawMessage{}) return ErrNotFound } if err := cvm.Stop(); err != nil { - defer ms.publishEvent("stop-computation", computationID, "failed", json.RawMessage{}) + defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Failed.String(), json.RawMessage{}) return err } delete(ms.vms, computationID) @@ -209,7 +209,7 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error ms.logger.Error("Failed to delete persisted VM state", "error", err) } - defer ms.publishEvent("stop-computation", computationID, "complete", json.RawMessage{}) + defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Completed.String(), json.RawMessage{}) return nil } diff --git a/manager/vm/logging.go b/manager/vm/logging.go index 3bf412c5d..5bd56f71e 100644 --- a/manager/vm/logging.go +++ b/manager/vm/logging.go @@ -118,10 +118,10 @@ func (s *Stderr) Write(p []byte) (n int, err error) { Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ ComputationId: s.ComputationId, - EventType: "vm-provision", + EventType: manager.VmProvision.String(), Timestamp: timestamppb.Now(), Originator: "manager", - Status: "error", + Status: manager.Error.String(), }, }, } diff --git a/manager/vm/logging_test.go b/manager/vm/logging_test.go index d7d2a579e..b16982a72 100644 --- a/manager/vm/logging_test.go +++ b/manager/vm/logging_test.go @@ -121,8 +121,8 @@ func TestStderrWrite(t *testing.T) { agentEvent := msg.GetAgentEvent() assert.NotNil(t, agentEvent) assert.Equal(t, "test-computation", agentEvent.ComputationId) - assert.Equal(t, "vm-provision", agentEvent.EventType) - assert.Equal(t, "error", agentEvent.Status) + assert.Equal(t, manager.VmProvision.String(), agentEvent.EventType) + assert.Equal(t, manager.Error.String(), agentEvent.Status) assert.NotNil(t, agentEvent.Timestamp) } case <-time.After(time.Second): diff --git a/pkg/manager/manager_states.go b/pkg/manager/manager_states.go new file mode 100644 index 000000000..d972b02cc --- /dev/null +++ b/pkg/manager/manager_states.go @@ -0,0 +1,21 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package manager + +//go:generate stringer -type=ManagerState +type ManagerState uint8 + +const ( + VmProvision ManagerState = iota + StopComputationRun + VmRunning +) + +//go:generate stringer -type=ManagerStatus +type ManagerStatus uint8 + +const ( + Starting ManagerStatus = iota + Stopped + Error +) diff --git a/pkg/manager/managerstate_string.go b/pkg/manager/managerstate_string.go new file mode 100644 index 000000000..1f1347034 --- /dev/null +++ b/pkg/manager/managerstate_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type=ManagerState"; DO NOT EDIT. + +package manager + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[VmProvision-0] + _ = x[StopComputationRun-1] + _ = x[VmRunning-2] +} + +const _ManagerState_name = "VmProvisionStopComputationRunVmRunning" + +var _ManagerState_index = [...]uint8{0, 11, 29, 38} + +func (i ManagerState) String() string { + if i >= ManagerState(len(_ManagerState_index)-1) { + return "ManagerState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ManagerState_name[_ManagerState_index[i]:_ManagerState_index[i+1]] +} diff --git a/pkg/manager/managerstatus_string.go b/pkg/manager/managerstatus_string.go new file mode 100644 index 000000000..0289df2da --- /dev/null +++ b/pkg/manager/managerstatus_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type=ManagerStatus"; DO NOT EDIT. + +package manager + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Starting-0] + _ = x[Stopped-1] + _ = x[Error-2] +} + +const _ManagerStatus_name = "StartingStoppedError" + +var _ManagerStatus_index = [...]uint8{0, 8, 15, 20} + +func (i ManagerStatus) String() string { + if i >= ManagerStatus(len(_ManagerStatus_index)-1) { + return "ManagerStatus(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ManagerStatus_name[_ManagerStatus_index[i]:_ManagerStatus_index[i+1]] +} From 2f4ca414cb209e2653e8544665fbf760a0469fa3 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:57:42 +0300 Subject: [PATCH 21/83] NOISSUE - Stop computation gracefully (#241) * stop gracefully Signed-off-by: Sammy Oina * use constant Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- internal/cmd.go | 11 ----------- manager/qemu/vm.go | 34 ++++++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/internal/cmd.go b/internal/cmd.go index 25275abf9..1b11fdf2b 100644 --- a/internal/cmd.go +++ b/internal/cmd.go @@ -62,14 +62,3 @@ func RunCmdOutput(command string, args ...string) (string, error) { return string(output), nil } - -// RunCmdStart starts the specified command and returns the *exec.Cmd for the running process. -func RunCmdStart(command string, args ...string) (*exec.Cmd, error) { - cmd := exec.Command(command, args...) - - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("error starting command '%s': %s", cmd.String(), err) - } - - return cmd, nil -} diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index c330ee257..fc7bf23c4 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -17,11 +17,12 @@ import ( ) const ( - firmwareVars = "OVMF_VARS" - KernelFile = "bzImage" - rootfsFile = "rootfs.cpio" - tmpDir = "/tmp" - interval = 5 * time.Second + firmwareVars = "OVMF_VARS" + KernelFile = "bzImage" + rootfsFile = "rootfs.cpio" + tmpDir = "/tmp" + interval = 5 * time.Second + shutdownTimeout = 30 * time.Second ) type qemuVM struct { @@ -78,7 +79,28 @@ func (v *qemuVM) Start() (err error) { } func (v *qemuVM) Stop() error { - return v.cmd.Process.Kill() + err := v.cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + return fmt.Errorf("failed to send SIGTERM: %v", err) + } + + done := make(chan error, 1) + go func() { + _, err := v.cmd.Process.Wait() + done <- err + }() + + select { + case err := <-done: + return err + case <-time.After(shutdownTimeout): + err := v.cmd.Process.Kill() + if err != nil { + return fmt.Errorf("failed to kill process: %v", err) + } + } + + return nil } func (v *qemuVM) SetProcess(pid int) error { From 1546fbc4c20435291539651379794d7c72736ac0 Mon Sep 17 00:00:00 2001 From: Washington Kigani Kamadi Date: Tue, 17 Sep 2024 19:01:30 +0300 Subject: [PATCH 22/83] NOISSUE - Use Constants for Run Events (#243) * enhance timeline Signed-off-by: WashingtonKK * fix: remove redundant event Signed-off-by: WashingtonKK * use constant Signed-off-by: WashingtonKK * lint Signed-off-by: WashingtonKK * use typed constant for status Signed-off-by: WashingtonKK * export agent status and state Signed-off-by: WashingtonKK * ehance event states Signed-off-by: WashingtonKK * use manager states and status Signed-off-by: WashingtonKK * move algo-run to agent package Signed-off-by: WashingtonKK * replace manager variable with constant Signed-off-by: WashingtonKK * add manager states Signed-off-by: WashingtonKK * remove typo Signed-off-by: WashingtonKK --------- Signed-off-by: WashingtonKK --- agent/service.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/agent/service.go b/agent/service.go index 00082acc0..4f6c6b517 100644 --- a/agent/service.go +++ b/agent/service.go @@ -271,7 +271,7 @@ func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataS } func (as *agentService) runComputation() { - as.publishEvent("starting", json.RawMessage{})() + as.publishEvent(InProgress.String(), json.RawMessage{})() as.sm.logger.Debug("computation run started") defer func() { if as.runError != nil { @@ -284,7 +284,7 @@ func (as *agentService) runComputation() { if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil { as.runError = fmt.Errorf("error creating results directory: %s", err.Error()) as.sm.logger.Warn(as.runError.Error()) - as.publishEvent("failed", json.RawMessage{})() + as.publishEvent(Failed.String(), json.RawMessage{})() return } @@ -297,11 +297,11 @@ func (as *agentService) runComputation() { } }() - as.publishEvent("in-progress", json.RawMessage{})() + as.publishEvent(InProgress.String(), json.RawMessage{})() if err := as.algorithm.Run(); err != nil { as.runError = err as.sm.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) - as.publishEvent("failed", json.RawMessage{})() + as.publishEvent(Failed.String(), json.RawMessage{})() return } @@ -309,11 +309,11 @@ func (as *agentService) runComputation() { if err != nil { as.runError = err as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) - as.publishEvent("failed", json.RawMessage{})() + as.publishEvent(Failed.String(), json.RawMessage{})() return } - as.publishEvent("complete", json.RawMessage{})() + as.publishEvent(Completed.String(), json.RawMessage{})() as.result = results } From 4c09b4bea598ce36c503c5ab8aee2d8a81c5112b Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 19 Sep 2024 11:18:02 +0300 Subject: [PATCH 23/83] NOISSUE - Format log messages from agent (#244) * downgrade mod Signed-off-by: Sammy Oina * add fields to logging Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- go.mod | 2 +- manager/agentEventsLogs.go | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 3aa5a2523..a531c5fad 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ultravioletrs/cocos -go 1.23.0 +go 1.22.5 require ( github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746 diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index d245765b0..9e0cc3bf4 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -4,6 +4,7 @@ package manager import ( "fmt" + "log/slog" "net" "regexp" "strconv" @@ -65,7 +66,26 @@ func (ms *managerService) handleConnections(conn net.Conn) { } ms.eventsChan <- &message - ms.logger.WithGroup("agent-events-logs").Info(message.String()) + args := []any{} + + switch message.Message.(type) { + case *manager.ClientStreamMessage_AgentEvent: + args = append(args, slog.Group("agent-event", + slog.String("event-type", message.GetAgentEvent().GetEventType()), + slog.String("computation-id", message.GetAgentEvent().GetComputationId()), + slog.String("status", message.GetAgentEvent().GetStatus()), + slog.String("originator", message.GetAgentEvent().GetOriginator()), + slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()), + slog.String("details", string(message.GetAgentEvent().GetDetails())))) + case *manager.ClientStreamMessage_AgentLog: + args = append(args, slog.Group("agent-log", + slog.String("computation-id", message.GetAgentLog().GetComputationId()), + slog.String("level", message.GetAgentLog().GetLevel()), + slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()), + slog.String("message", message.GetAgentLog().GetMessage()))) + } + + ms.logger.Info("", args...) } } From e266e91033db390191079faa6ef9ac9c6c422e33 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 19 Sep 2024 22:32:38 +0300 Subject: [PATCH 24/83] COCOS-238 - Add measurement directly on backend info file (#245) * add measurement directly on backendinfo Signed-off-by: Sammy Oina * add host data Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cmd/cli/main.go | 4 ++++ go.mod | 3 +++ go.sum | 2 ++ manager/backend_info.go | 45 +++++++++++++++++++++++++++++++++++++++++ manager/qemu/config.go | 7 +++++-- test/manual/README.md | 4 ++-- 6 files changed, 61 insertions(+), 4 deletions(-) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 3a9d3bd48..c5724ded6 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -16,6 +16,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/clients/grpc" "github.com/ultravioletrs/cocos/pkg/clients/grpc/agent" "github.com/ultravioletrs/cocos/pkg/sdk" + cmd "github.com/virtee/sev-snp-measure-go/sevsnpmeasure/cmd" ) const ( @@ -115,6 +116,9 @@ func main() { attestationCmd.AddCommand(cliSVC.NewGetAttestationCmd()) attestationCmd.AddCommand(cliSVC.NewValidateAttestationValidationCmd()) + // measure. + rootCmd.AddCommand(cmd.NewRootCmd()) + // Flags keysCmd.PersistentFlags().StringVarP( &cli.KeyType, diff --git a/go.mod b/go.mod index a531c5fad..d3644d10d 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 + github.com/virtee/sev-snp-measure-go v0.0.0-20240530153610-e6e8dc9b6877 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 go.opentelemetry.io/otel/trace v1.28.0 golang.org/x/crypto v0.25.0 @@ -79,3 +80,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/virtee/sev-snp-measure-go => github.com/sammyoina/sev-snp-measure-go v0.0.0-20240918192515-70b6b9542aa5 diff --git a/go.sum b/go.sum index 196dc3160..0c187662c 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/prometheus/procfs v0.13.0/go.mod h1:cd4PFCR54QLnGKPaKGA6l+cfuNXtht43Z github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sammyoina/sev-snp-measure-go v0.0.0-20240918192515-70b6b9542aa5 h1:w5R0cZgvakKxBsIrzboOb0DcHdkzEQ4tcQ6wLEn/FWo= +github.com/sammyoina/sev-snp-measure-go v0.0.0-20240918192515-70b6b9542aa5/go.mod h1:dEkBe8JnxU5itNjZDEQINFd7f7l4DtjfqRuzPQcit4w= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= diff --git a/manager/backend_info.go b/manager/backend_info.go index 528d8817e..265329573 100644 --- a/manager/backend_info.go +++ b/manager/backend_info.go @@ -7,11 +7,21 @@ package manager import ( + "encoding/base64" + "encoding/json" "fmt" "os" "os/exec" + + "github.com/ultravioletrs/cocos/cli" + "github.com/ultravioletrs/cocos/manager/qemu" + "github.com/virtee/sev-snp-measure-go/cpuid" + "github.com/virtee/sev-snp-measure-go/guest" + "github.com/virtee/sev-snp-measure-go/vmmtypes" ) +const defGuestFeatures = 0x1 + func (ms *managerService) FetchBackendInfo() ([]byte, error) { cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "1966081") @@ -25,5 +35,40 @@ func (ms *managerService) FetchBackendInfo() ([]byte, error) { return nil, err } + var backendInfo cli.AttestationConfiguration + + if err = json.Unmarshal(f, &backendInfo); err != nil { + return nil, err + } + + var measurement []byte + if ms.qemuCfg.EnableSEV { + measurement, err = guest.CalcLaunchDigest(guest.SEV, ms.qemuCfg.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), ms.qemuCfg.OVMFCodeConfig.File, ms.qemuCfg.KernelFile, ms.qemuCfg.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0) + if err != nil { + return nil, err + } + } else if ms.qemuCfg.EnableSEVSNP { + measurement, err = guest.CalcLaunchDigest(guest.SEV_SNP, ms.qemuCfg.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), ms.qemuCfg.OVMFCodeConfig.File, ms.qemuCfg.KernelFile, ms.qemuCfg.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0) + if err != nil { + return nil, err + } + } + if measurement == nil { + backendInfo.SNPPolicy.Measurement = measurement + } + + if ms.qemuCfg.HostData != "" { + hostData, err := base64.StdEncoding.DecodeString(ms.qemuCfg.HostData) + if err != nil { + return nil, err + } + backendInfo.SNPPolicy.HostData = hostData + } + + f, err = json.Marshal(backendInfo) + if err != nil { + return nil, err + } + return f, nil } diff --git a/manager/qemu/config.go b/manager/qemu/config.go index 9dca5d5b7..2b884e15c 100644 --- a/manager/qemu/config.go +++ b/manager/qemu/config.go @@ -7,7 +7,10 @@ import ( "strconv" ) -const BaseGuestCID = 3 +const ( + BaseGuestCID = 3 + KernelCommandLine = "quiet console=null rootfstype=ramfs" +) type MemoryConfig struct { Size string `env:"MEMORY_SIZE" envDefault:"2048M"` @@ -175,7 +178,7 @@ func (config Config) ConstructQemuArgs() []string { } args = append(args, "-kernel", config.DiskImgConfig.KernelFile) - args = append(args, "-append", strconv.Quote("quiet console=null rootfstype=ramfs")) + args = append(args, "-append", strconv.Quote(KernelCommandLine)) args = append(args, "-initrd", config.DiskImgConfig.RootFsFile) // SEV diff --git a/test/manual/README.md b/test/manual/README.md index eeff276fa..1e01e05cc 100644 --- a/test/manual/README.md +++ b/test/manual/README.md @@ -19,7 +19,7 @@ All assets/datasets the algorithm uses are stored in the `datasets` directory. T Agent is started automatically in the VM when launched but requires configuration and manifest to be passed by manager. Alternatively you can pass configuration using this [simplified script](./agent-config/main.go) -For attested TLS, you will have to calculate the VM's measurement, which can be done using a tool [sev-snp-measure](https://pypi.org/project/sev-snp-measure/). +For attested TLS, you will have to calculate the VM's measurement, which can be done using cli. This information is also contained in the backend info file. ```bash # Define the path to the OVMF, KERNEL, INITRD and CMD Kernel line arguments. @@ -29,7 +29,7 @@ KERNEL="/home/cocosai/bzImage" LINE="earlyprintk=serial console=ttyS0" # Call sev-snp-measure -sev-snp-measure --mode snp --vcpus 4 --vcpu-type EPYC-v4 --ovmf $OVMF_CODE --kernel $KERNEL --initrd $INITRD --append "$LINE" --output-format base64 +./build/cocos-cli sevsnpmeasure --mode snp --vcpus 4 --vcpu-type EPYC-v4 --ovmf $OVMF_CODE --kernel $KERNEL --initrd $INITRD --append "$LINE" ``` To speed up the verification process of attested TLS, download the ARK and ASK certificates using the CLI tool. The CLI tool will download the certificates under your home directory in the `.cocos` directory. From 5ff8b96311e7a4902ed0b0a180d3dfca31b62c76 Mon Sep 17 00:00:00 2001 From: Washington Kigani Kamadi Date: Fri, 20 Sep 2024 11:00:48 +0300 Subject: [PATCH 25/83] add disconnected status (#246) Signed-off-by: WashingtonKK --- manager/agentEventsLogs.go | 4 ++-- pkg/manager/manager_states.go | 1 + pkg/manager/managerstatus_string.go | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index 9e0cc3bf4..aa38a31e2 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -119,9 +119,9 @@ func (ms *managerService) reportBrokenConnection(cmpID string) { ms.eventsChan <- &manager.ClientStreamMessage{ Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ - EventType: "vm running", + EventType: manager.VmRunning.String(), ComputationId: cmpID, - Status: "disconnected", + Status: manager.Disconnected.String(), Timestamp: timestamppb.Now(), Originator: "manager", }, diff --git a/pkg/manager/manager_states.go b/pkg/manager/manager_states.go index d972b02cc..9f1f8ec1b 100644 --- a/pkg/manager/manager_states.go +++ b/pkg/manager/manager_states.go @@ -18,4 +18,5 @@ const ( Starting ManagerStatus = iota Stopped Error + Disconnected ) diff --git a/pkg/manager/managerstatus_string.go b/pkg/manager/managerstatus_string.go index 0289df2da..5ddb22067 100644 --- a/pkg/manager/managerstatus_string.go +++ b/pkg/manager/managerstatus_string.go @@ -11,11 +11,12 @@ func _() { _ = x[Starting-0] _ = x[Stopped-1] _ = x[Error-2] + _ = x[Disconnected-3] } -const _ManagerStatus_name = "StartingStoppedError" +const _ManagerStatus_name = "StartingStoppedErrorDisconnected" -var _ManagerStatus_index = [...]uint8{0, 8, 15, 20} +var _ManagerStatus_index = [...]uint8{0, 8, 15, 20, 32} func (i ManagerStatus) String() string { if i >= ManagerStatus(len(_ManagerStatus_index)-1) { From eab3a06705ba87abdaaab0c793fdc520a84b993c Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 20 Sep 2024 12:09:31 +0300 Subject: [PATCH 26/83] fix redundant logs (#247) Signed-off-by: Sammy Oina --- manager/agentEventsLogs.go | 10 +++++----- manager/vm/logging.go | 2 +- manager/vm/logging_test.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index aa38a31e2..a9bd7a4a7 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -45,17 +45,17 @@ func (ms *managerService) RetrieveAgentEventsLogs() { } func (ms *managerService) handleConnections(conn net.Conn) { + cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String()) + if err != nil { + ms.logger.Warn(err.Error()) + return + } defer conn.Close() for { b := make([]byte, messageSize) n, err := conn.Read(b) if err != nil { ms.logger.Warn(err.Error()) - cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String()) - if err != nil { - ms.logger.Warn(err.Error()) - continue - } go ms.reportBrokenConnection(cmpID) return } diff --git a/manager/vm/logging.go b/manager/vm/logging.go index 5bd56f71e..1449a3308 100644 --- a/manager/vm/logging.go +++ b/manager/vm/logging.go @@ -118,7 +118,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ ComputationId: s.ComputationId, - EventType: manager.VmProvision.String(), + EventType: manager.VmRunning.String(), Timestamp: timestamppb.Now(), Originator: "manager", Status: manager.Error.String(), diff --git a/manager/vm/logging_test.go b/manager/vm/logging_test.go index b16982a72..04c25d6d3 100644 --- a/manager/vm/logging_test.go +++ b/manager/vm/logging_test.go @@ -121,7 +121,7 @@ func TestStderrWrite(t *testing.T) { agentEvent := msg.GetAgentEvent() assert.NotNil(t, agentEvent) assert.Equal(t, "test-computation", agentEvent.ComputationId) - assert.Equal(t, manager.VmProvision.String(), agentEvent.EventType) + assert.Equal(t, manager.VmRunning.String(), agentEvent.EventType) assert.Equal(t, manager.Error.String(), agentEvent.Status) assert.NotNil(t, agentEvent.Timestamp) } From a7caa5913729fb4100e6dba26c6021a25d0e54dc Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 20 Sep 2024 12:59:56 +0300 Subject: [PATCH 27/83] NOISSUE - Fix race condition in tests (#248) Signed-off-by: Sammy Oina --- agent/service.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/agent/service.go b/agent/service.go index 4f6c6b517..580175793 100644 --- a/agent/service.go +++ b/agent/service.go @@ -90,8 +90,6 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp quoteProvider: quoteProvider, } - go svc.sm.Start(ctx) - svc.sm.SendEvent(start) svc.sm.StateFunctions[Idle] = svc.publishEvent(IdleState.String(), json.RawMessage{}) svc.sm.StateFunctions[ReceivingManifest] = svc.publishEvent(InProgress.String(), json.RawMessage{}) svc.sm.StateFunctions[ReceivingAlgorithm] = svc.publishEvent(InProgress.String(), json.RawMessage{}) @@ -101,6 +99,9 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp svc.sm.StateFunctions[Running] = svc.runComputation svc.sm.StateFunctions[Failed] = svc.publishEvent(Failed.String(), json.RawMessage{}) + go svc.sm.Start(ctx) + svc.sm.SendEvent(start) + svc.computation = cmp svc.sm.SendEvent(manifestReceived) From df923f9b1fde74f43c46a4a2756ad9f03a7b82b5 Mon Sep 17 00:00:00 2001 From: Washington Kigani Kamadi Date: Fri, 20 Sep 2024 19:33:10 +0300 Subject: [PATCH 28/83] NOISSUE - Rename error to warning (#249) * rename error to warning Signed-off-by: WashingtonKK * update logging package Signed-off-by: WashingtonKK --------- Signed-off-by: WashingtonKK --- agent/algorithm/logging.go | 8 ++++---- agent/algorithm/logging_test.go | 2 +- agent/state.go | 2 +- agent/status_string.go | 6 +++--- manager/vm/logging.go | 2 +- manager/vm/logging_test.go | 2 +- pkg/manager/manager_states.go | 2 +- pkg/manager/managerstatus_string.go | 6 +++--- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/agent/algorithm/logging.go b/agent/algorithm/logging.go index e1289511e..4836f0e98 100644 --- a/agent/algorithm/logging.go +++ b/agent/algorithm/logging.go @@ -17,9 +17,9 @@ var ( ) const ( - bufSize = 1024 - algorithmRun = "AlgorithmRun" - errorStatus = "Error" + bufSize = 1024 + algorithmRun = "AlgorithmRun" + warningStatus = "Warning" ) type Stdout struct { @@ -70,7 +70,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { s.Logger.Error(string(buf[:n])) } - if err := s.EventSvc.SendEvent(algorithmRun, errorStatus, json.RawMessage{}); err != nil { + if err := s.EventSvc.SendEvent(algorithmRun, warningStatus, json.RawMessage{}); err != nil { return len(p), err } diff --git a/agent/algorithm/logging_test.go b/agent/algorithm/logging_test.go index 91427028b..1c5e846ba 100644 --- a/agent/algorithm/logging_test.go +++ b/agent/algorithm/logging_test.go @@ -73,7 +73,7 @@ func TestStderrWrite(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockEventService := mocks.NewService(t) - mockEventService.On("SendEvent", "AlgorithmRun", manager.Error.String(), mock.Anything).Return(nil) + mockEventService.On("SendEvent", "AlgorithmRun", manager.Warning.String(), mock.Anything).Return(nil) stderr := &Stderr{Logger: mglog.NewMock(), EventSvc: mockEventService} n, err := stderr.Write([]byte(tt.input)) diff --git a/agent/state.go b/agent/state.go index cafee0b89..a6b2192d8 100644 --- a/agent/state.go +++ b/agent/state.go @@ -33,7 +33,7 @@ const ( Ready Completed Terminated - Error + Warning ) type event uint8 diff --git a/agent/status_string.go b/agent/status_string.go index 0b19bb195..a2fdf1bb3 100644 --- a/agent/status_string.go +++ b/agent/status_string.go @@ -13,12 +13,12 @@ func _() { _ = x[Ready-2] _ = x[Completed-3] _ = x[Terminated-4] - _ = x[Error-5] + _ = x[Warning-5] } -const _Status_name = "IdleStateInProgressReadyCompletedTerminatedError" +const _Status_name = "IdleStateInProgressReadyCompletedTerminatedWarning" -var _Status_index = [...]uint8{0, 9, 19, 24, 33, 43, 48} +var _Status_index = [...]uint8{0, 9, 19, 24, 33, 43, 50} func (i Status) String() string { if i >= Status(len(_Status_index)-1) { diff --git a/manager/vm/logging.go b/manager/vm/logging.go index 1449a3308..8fb4e0763 100644 --- a/manager/vm/logging.go +++ b/manager/vm/logging.go @@ -121,7 +121,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { EventType: manager.VmRunning.String(), Timestamp: timestamppb.Now(), Originator: "manager", - Status: manager.Error.String(), + Status: manager.Warning.String(), }, }, } diff --git a/manager/vm/logging_test.go b/manager/vm/logging_test.go index 04c25d6d3..872221119 100644 --- a/manager/vm/logging_test.go +++ b/manager/vm/logging_test.go @@ -122,7 +122,7 @@ func TestStderrWrite(t *testing.T) { assert.NotNil(t, agentEvent) assert.Equal(t, "test-computation", agentEvent.ComputationId) assert.Equal(t, manager.VmRunning.String(), agentEvent.EventType) - assert.Equal(t, manager.Error.String(), agentEvent.Status) + assert.Equal(t, manager.Warning.String(), agentEvent.Status) assert.NotNil(t, agentEvent.Timestamp) } case <-time.After(time.Second): diff --git a/pkg/manager/manager_states.go b/pkg/manager/manager_states.go index 9f1f8ec1b..b7c74758e 100644 --- a/pkg/manager/manager_states.go +++ b/pkg/manager/manager_states.go @@ -17,6 +17,6 @@ type ManagerStatus uint8 const ( Starting ManagerStatus = iota Stopped - Error + Warning Disconnected ) diff --git a/pkg/manager/managerstatus_string.go b/pkg/manager/managerstatus_string.go index 5ddb22067..9572c1aa4 100644 --- a/pkg/manager/managerstatus_string.go +++ b/pkg/manager/managerstatus_string.go @@ -10,13 +10,13 @@ func _() { var x [1]struct{} _ = x[Starting-0] _ = x[Stopped-1] - _ = x[Error-2] + _ = x[Warning-2] _ = x[Disconnected-3] } -const _ManagerStatus_name = "StartingStoppedErrorDisconnected" +const _ManagerStatus_name = "StartingStoppedWarningDisconnected" -var _ManagerStatus_index = [...]uint8{0, 8, 15, 20, 32} +var _ManagerStatus_index = [...]uint8{0, 8, 15, 22, 34} func (i ManagerStatus) String() string { if i >= ManagerStatus(len(_ManagerStatus_index)-1) { From 5d5ae35e2b5c86dcbbad616a8ed66c9a436f3b46 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Mon, 23 Sep 2024 19:38:02 +0300 Subject: [PATCH 29/83] NOISSUE - Reduce message loss via vsock with acks (#252) * state check within func Signed-off-by: Sammy Oina * debug logs sending Signed-off-by: Sammy Oina * debug message sending Signed-off-by: Sammy Oina * ack messages Signed-off-by: Sammy Oina * handle proto better Signed-off-by: Sammy Oina * improve concurrency Signed-off-by: Sammy Oina * improve manager handling Signed-off-by: Sammy Oina * remove debug lines Signed-off-by: Sammy Oina * sync next id Signed-off-by: Sammy Oina * reduce locks Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/events/events.go | 6 +- agent/service.go | 6 +- agent/state.go | 24 ++-- cmd/agent/main.go | 9 +- internal/vsock/client.go | 228 +++++++++++++++++++++++++++++++ manager/agentEventsLogs.go | 22 +-- test/manual/agent-config/main.go | 50 ++++--- 7 files changed, 297 insertions(+), 48 deletions(-) create mode 100644 internal/vsock/client.go diff --git a/agent/events/events.go b/agent/events/events.go index b62bbfd1d..74d131e84 100644 --- a/agent/events/events.go +++ b/agent/events/events.go @@ -4,10 +4,10 @@ package events import ( "encoding/json" + "io" "sync" "time" - "github.com/mdlayher/vsock" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -18,7 +18,7 @@ const retryInterval = 5 * time.Second type service struct { service string computationID string - conn *vsock.Conn + conn io.Writer cachedMessages [][]byte mutex sync.Mutex stopRetry chan struct{} @@ -39,7 +39,7 @@ type Service interface { Close() } -func New(svc, computationID string, conn *vsock.Conn) (Service, error) { +func New(svc, computationID string, conn io.Writer) (Service, error) { s := &service{ service: svc, computationID: computationID, diff --git a/agent/service.go b/agent/service.go index 580175793..540d475e7 100644 --- a/agent/service.go +++ b/agent/service.go @@ -236,7 +236,7 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { } if len(as.computation.Datasets) == 0 { - as.sm.SendEvent(dataReceived) + defer as.sm.SendEvent(dataReceived) } return nil @@ -256,7 +256,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) { as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1) if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults { - as.sm.SendEvent(resultsConsumed) + defer as.sm.SendEvent(resultsConsumed) } return as.result, as.runError @@ -320,8 +320,8 @@ func (as *agentService) runComputation() { } func (as *agentService) publishEvent(status string, details json.RawMessage) func() { - st := as.sm.GetState().String() return func() { + st := as.sm.GetState().String() if err := as.eventSvc.SendEvent(st, status, details); err != nil { as.sm.logger.Warn(err.Error()) } diff --git a/agent/state.go b/agent/state.go index a6b2192d8..f109735c3 100644 --- a/agent/state.go +++ b/agent/state.go @@ -104,22 +104,28 @@ func (sm *StateMachine) Start(ctx context.Context) { for { select { case event := <-sm.EventChan: + currentState := sm.GetState() + var nextState State + var stateFunc func() + var valid bool + sm.mu.Lock() - nextState, valid := sm.Transitions[sm.State][event] + nextState, valid = sm.Transitions[sm.State][event] if valid { sm.State = nextState - sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", sm.State, nextState)) - } else { - sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State)) + stateFunc = sm.StateFunctions[nextState] } sm.mu.Unlock() - sm.mu.Lock() - stateFunc, exists := sm.StateFunctions[sm.State] - sm.mu.Unlock() - if exists { - go stateFunc() + if valid { + sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState)) + if stateFunc != nil { + go stateFunc() + } + } else { + sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State)) } + case <-ctx.Done(): return } diff --git a/cmd/agent/main.go b/cmd/agent/main.go index c03cdb3af..a6d803372 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -25,6 +25,7 @@ import ( agentlogger "github.com/ultravioletrs/cocos/internal/logger" "github.com/ultravioletrs/cocos/internal/server" grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc" + ackvsock "github.com/ultravioletrs/cocos/internal/vsock" "github.com/ultravioletrs/cocos/manager" "github.com/ultravioletrs/cocos/manager/qemu" "golang.org/x/sync/errgroup" @@ -53,6 +54,8 @@ func main() { } defer conn.Close() + ackConn := ackvsock.NewAckWriter(conn) + var exitCode int defer mglog.ExitWithError(&exitCode) @@ -63,10 +66,10 @@ func main() { return } - handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID) + handler := agentlogger.NewProtoHandler(ackConn, &slog.HandlerOptions{Level: level}, cfg.ID) logger := slog.New(handler) - eventSvc, err := events.New(svcName, cfg.ID, conn) + eventSvc, err := events.New(svcName, cfg.ID, ackConn) if err != nil { logger.Error(fmt.Sprintf("failed to create events service %s", err.Error())) exitCode = 1 @@ -116,8 +119,6 @@ func main() { if err != nil { log.Fatal("failed to reconnect: ", err) } - handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID) - logger = slog.New(handler) } time.Sleep(retryInterval) } diff --git a/internal/vsock/client.go b/internal/vsock/client.go new file mode 100644 index 000000000..446d625e4 --- /dev/null +++ b/internal/vsock/client.go @@ -0,0 +1,228 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package vsock + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "sync" + "time" + + "google.golang.org/protobuf/proto" +) + +const ( + maxRetries = 3 + retryDelay = time.Second + maxMessageSize = 1 << 20 // 1 MB + ackTimeout = 5 * time.Second + maxConcurrent = 100 // Maximum number of concurrent messages +) + +type Message struct { + ID uint32 + Content []byte +} + +type AckWriter struct { + conn net.Conn + pendingMessages chan *Message + ackChannels map[uint32]chan bool + ackMu sync.RWMutex + nextID uint32 + done chan struct{} + wg sync.WaitGroup +} + +func NewAckWriter(conn net.Conn) *AckWriter { + aw := &AckWriter{ + conn: conn, + pendingMessages: make(chan *Message, maxConcurrent), + ackChannels: make(map[uint32]chan bool), + nextID: 1, + done: make(chan struct{}), + } + aw.wg.Add(2) + go aw.sendMessages() + go aw.handleAcknowledgments() + return aw +} + +func (aw *AckWriter) WriteProto(msg proto.Message) (int, error) { + data, err := proto.Marshal(msg) + if err != nil { + return 0, fmt.Errorf("error marshaling protobuf message: %v", err) + } + return aw.Write(data) +} + +func (aw *AckWriter) Write(p []byte) (int, error) { + if len(p) > maxMessageSize { + return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize) + } + + aw.ackMu.Lock() + messageID := aw.nextID + aw.nextID++ + + ackCh := make(chan bool, 1) + aw.ackChannels[messageID] = ackCh + aw.ackMu.Unlock() + + message := &Message{ID: messageID, Content: p} + + select { + case aw.pendingMessages <- message: + // Message queued successfully + case <-aw.done: + return 0, fmt.Errorf("writer is closed") + } + + select { + case <-ackCh: + return len(p), nil + case <-time.After(ackTimeout): + return 0, fmt.Errorf("timeout waiting for acknowledgment") + case <-aw.done: + return 0, fmt.Errorf("writer closed while waiting for acknowledgment") + } +} + +func (aw *AckWriter) sendMessages() { + defer aw.wg.Done() + for { + select { + case <-aw.done: + return + case msg := <-aw.pendingMessages: + for i := 0; i < maxRetries; i++ { + if err := aw.writeMessage(msg.ID, msg.Content); err != nil { + log.Printf("Error writing message %d (attempt %d): %v", msg.ID, i+1, err) + time.Sleep(retryDelay) + continue + } + break + } + } + } +} + +func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error { + // Write message ID + if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil { + return err + } + + // Write message length + messageLen := uint32(len(p)) + if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil { + return err + } + + // Write message content + if _, err := aw.conn.Write(p); err != nil { + return err + } + + return nil +} + +func (aw *AckWriter) handleAcknowledgments() { + defer aw.wg.Done() + for { + select { + case <-aw.done: + return + default: + var ackID uint32 + err := binary.Read(aw.conn, binary.LittleEndian, &ackID) + if err != nil { + if err == io.EOF { + log.Println("Connection closed, stopping acknowledgment handler") + return + } + log.Printf("Error reading ACK: %v", err) + time.Sleep(retryDelay) + continue + } + + aw.ackMu.RLock() + ackCh, ok := aw.ackChannels[ackID] + aw.ackMu.RUnlock() + + if ok { + select { + case ackCh <- true: + default: + // Channel is already closed or full + } + aw.ackMu.Lock() + delete(aw.ackChannels, ackID) + aw.ackMu.Unlock() + } else { + log.Printf("Received ACK for unknown message ID: %d", ackID) + } + } + } +} + +func (aw *AckWriter) Close() error { + close(aw.done) + aw.wg.Wait() + return aw.conn.Close() +} + +type AckReader struct { + conn net.Conn +} + +func NewAckReader(conn net.Conn) *AckReader { + return &AckReader{ + conn: conn, + } +} + +func (ar *AckReader) ReadProto(msg proto.Message) error { + data, err := ar.Read() + if err != nil { + return err + } + + return proto.Unmarshal(data, msg) +} + +func (ar *AckReader) Read() ([]byte, error) { + var messageID uint32 + if err := binary.Read(ar.conn, binary.LittleEndian, &messageID); err != nil { + return nil, fmt.Errorf("error reading message ID: %v", err) + } + + var messageLen uint32 + if err := binary.Read(ar.conn, binary.LittleEndian, &messageLen); err != nil { + return nil, fmt.Errorf("error reading message length: %v", err) + } + + if messageLen > maxMessageSize { + return nil, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize) + } + + data := make([]byte, messageLen) + _, err := io.ReadFull(ar.conn, data) + if err != nil { + return nil, fmt.Errorf("error reading message content: %v", err) + } + + if err := ar.sendAck(messageID); err != nil { + return nil, fmt.Errorf("error sending ACK: %v", err) + } + + return data, nil +} + +func (ar *AckReader) sendAck(messageID uint32) error { + return binary.Write(ar.conn, binary.LittleEndian, messageID) +} diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index a9bd7a4a7..2178b7e5e 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/mdlayher/vsock" + internalvsock "github.com/ultravioletrs/cocos/internal/vsock" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -40,30 +41,35 @@ func (ms *managerService) RetrieveAgentEventsLogs() { continue } - go ms.handleConnections(conn) + go ms.handleConnection(conn) } } -func (ms *managerService) handleConnections(conn net.Conn) { +func (ms *managerService) handleConnection(conn net.Conn) { + defer conn.Close() + cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String()) if err != nil { ms.logger.Warn(err.Error()) return } - defer conn.Close() + + ackReader := internalvsock.NewAckReader(conn) + for { - b := make([]byte, messageSize) - n, err := conn.Read(b) + var message manager.ClientStreamMessage + data, err := ackReader.Read() if err != nil { - ms.logger.Warn(err.Error()) go ms.reportBrokenConnection(cmpID) + ms.logger.Warn(err.Error()) return } - var message manager.ClientStreamMessage - if err := proto.Unmarshal(b[:n], &message); err != nil { + + if err := proto.Unmarshal(data, &message); err != nil { ms.logger.Warn(err.Error()) continue } + ms.eventsChan <- &message args := []any{} diff --git a/test/manual/agent-config/main.go b/test/manual/agent-config/main.go index d4c30f165..95a466da7 100644 --- a/test/manual/agent-config/main.go +++ b/test/manual/agent-config/main.go @@ -8,7 +8,6 @@ package main import ( "encoding/json" "encoding/pem" - "fmt" "log" "net" "os" @@ -17,10 +16,15 @@ import ( "github.com/mdlayher/vsock" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/internal" + internalvsock "github.com/ultravioletrs/cocos/internal/vsock" "github.com/ultravioletrs/cocos/manager" "github.com/ultravioletrs/cocos/manager/qemu" pkgmanager "github.com/ultravioletrs/cocos/pkg/manager" - "google.golang.org/protobuf/proto" +) + +const ( + managerVsockPort = manager.ManagerVsockPort + vsockConfigPort = qemu.VsockConfigPort ) func main() { @@ -50,10 +54,6 @@ func main() { log.Fatalf("failed to calculate checksum: %s", err) } - l, err := vsock.Listen(manager.ManagerVsockPort, nil) - if err != nil { - log.Fatal(err) - } ac := agent.Computation{ ID: "123", Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}}, @@ -65,21 +65,30 @@ func main() { AttestedTls: attestedTLS, }, } - if err := SendAgentConfig(3, ac); err != nil { + if err := sendAgentConfig(3, ac); err != nil { log.Fatal(err) } + listener, err := vsock.Listen(managerVsockPort, nil) + if err != nil { + log.Fatalf("failed to listen on vsock: %s", err) + } + defer listener.Close() + + log.Printf("Listening on vsock port %d", managerVsockPort) + for { - conn, err := l.Accept() + conn, err := listener.Accept() if err != nil { - log.Println(err) + log.Printf("failed to accept connection: %s", err) continue } - go handleConnections(conn) + + go handleConnection(conn) } } -func SendAgentConfig(cid uint32, ac agent.Computation) error { +func sendAgentConfig(cid uint32, ac agent.Computation) error { conn, err := vsock.Dial(cid, qemu.VsockConfigPort, nil) if err != nil { return err @@ -100,20 +109,19 @@ func SendAgentConfig(cid uint32, ac agent.Computation) error { return nil } -func handleConnections(conn net.Conn) { +func handleConnection(conn net.Conn) { defer conn.Close() + + ackReader := internalvsock.NewAckReader(conn) + for { - b := make([]byte, 1024) - n, err := conn.Read(b) - if err != nil { - log.Println(err) - return - } var message pkgmanager.ClientStreamMessage - if err := proto.Unmarshal(b[:n], &message); err != nil { - log.Println(err) + err := ackReader.ReadProto(&message) + if err != nil { + log.Printf("Error reading message: %v", err) return } - fmt.Println(message.String()) + + log.Printf("Received message: %s", message.String()) } } From af3817d3b73164dec0156979be066a219d3c2e8a Mon Sep 17 00:00:00 2001 From: Smith Jilks <41241359+smithjilks@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:44:27 +0300 Subject: [PATCH 30/83] COCOS-242-Agent tests fail ocassionally due to a missing mock (#251) * Clean up files after algo run Signed-off-by: Jilks Smith * Add test cleanup Signed-off-by: Jilks Smith --------- Signed-off-by: Jilks Smith --- agent/service_test.go | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/agent/service_test.go b/agent/service_test.go index 47f1b85fd..545dae480 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -126,6 +126,11 @@ func TestAlgo(t *testing.T) { err = svc.Algo(ctx, tc.algo) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) + t.Cleanup(func() { + err = os.RemoveAll("venv") + err = os.RemoveAll("algo") + err = os.RemoveAll("datasets") + }) }) } } @@ -219,9 +224,13 @@ func TestData(t *testing.T) { time.Sleep(300 * time.Millisecond) } err = svc.Data(ctx, tc.data) - _ = os.RemoveAll("datasets") - _ = os.RemoveAll("results") assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) + t.Cleanup(func() { + _ = os.RemoveAll("datasets") + _ = os.RemoveAll("results") + err = os.RemoveAll("venv") + err = os.RemoveAll("algo") + }) }) } } @@ -302,9 +311,10 @@ func TestResult(t *testing.T) { go svc.sm.Start(ctx) tc.setup(svc) _, err := svc.Result(ctx) - _ = os.RemoveAll("datasets") - _ = os.RemoveAll("results") - + t.Cleanup(func() { + _ = os.RemoveAll("datasets") + _ = os.RemoveAll("results") + }) assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err) }) } From 8b37b3575043e2ca70408a0148dd5b6456984ffe Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:27:28 +0300 Subject: [PATCH 31/83] COCOS-154 - Fix HAL release pipeline (#254) * increase release pipeline Signed-off-by: Sammy Oina * update go and free some space Signed-off-by: Sammy Oina * optimize Signed-off-by: Sammy Oina * fix cache Signed-off-by: Sammy Oina * free up space Signed-off-by: Sammy Oina * modify Signed-off-by: Sammy Oina * remove restrictions Signed-off-by: Sammy Oina * fifty gigs Signed-off-by: Sammy Oina * fourty gigs Signed-off-by: Sammy Oina * old mbs Signed-off-by: Sammy Oina * remove outdated actions Signed-off-by: Sammy Oina * rename Signed-off-by: Sammy Oina * weekly update Signed-off-by: SammyOina --------- Signed-off-by: Sammy Oina Signed-off-by: SammyOina --- .github/dependabot.yaml | 22 ++++++++ .github/workflows/hal.yml | 111 +++++++++++++++++++------------------- 2 files changed, 77 insertions(+), 56 deletions(-) create mode 100644 .github/dependabot.yaml diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 000000000..28b3d1c7d --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,22 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/scripts/backend_info" + schedule: + interval: "weekly" + day: "monday" + groups: + rs-dependencies: + patterns: + - "*" + + - package-ecosystem: "gomod" + directories: + - "/" + schedule: + interval: "weekly" + day: "monday" + groups: + go-dependency: + patterns: + - "*" diff --git a/.github/workflows/hal.yml b/.github/workflows/hal.yml index 606adbd76..12ac82d37 100644 --- a/.github/workflows/hal.yml +++ b/.github/workflows/hal.yml @@ -8,66 +8,65 @@ on: jobs: build: runs-on: ubuntu-latest - + timeout-minutes: 120 steps: - - name: Update Ubuntu - run: | - sudo apt-get update - sudo apt-get upgrade -y - - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: 1.23.x - cache-dependency-path: "go.sum" + - name: Free Disk Space + run: | + sudo apt-get clean + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -h + + - name: Update Ubuntu + run: | + sudo apt-get update + sudo apt-get upgrade -y + sudo apt-get clean + sudo apt-get autoremove -y + df -h + + - name: Maximize build space + uses: easimon/maximize-build-space@master + with: + root-reserve-mb: 35000 + swap-size-mb: 1024 + remove-dotnet: 'true' + remove-android: 'true' + - name: Check free space + run: | + echo "Free space:" + df -h - - name: Checkout cocos - uses: actions/checkout@v4 - with: - repository: 'ultravioletrs/cocos' - path: cocos + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: 1.23.x + cache-dependency-path: "go.sum" - - name: Checkout buildroot - uses: actions/checkout@v4 - with: - repository: 'buildroot/buildroot' - path: buildroot + - name: Checkout cocos + uses: actions/checkout@v4 + with: + repository: 'ultravioletrs/cocos' + path: cocos - - name: Build - run: | - cd buildroot - make BR2_EXTERNAL=../cocos/hal/linux cocos_defconfig - make + - name: Checkout buildroot + uses: actions/checkout@v4 + with: + repository: 'buildroot/buildroot' + path: buildroot - - name: Create Release - id: create_release - uses: actions/create-release@latest - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ github.ref }} - release_name: Release ${{ github.ref }} - draft: false - prerelease: false + - name: Build + run: | + cd buildroot + make BR2_EXTERNAL=../cocos/hal/linux cocos_defconfig + make - - name: Upload Release Asset - id: upload-release-kernel - uses: actions/upload-release-asset@latest - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./buildroot/output/images/bzImage - asset_name: bzImage - asset_content_type: application/octet-stream + - name: Release + uses: softprops/action-gh-release@v2 + with: + files: | + buildroot/output/images/bzImage + buildroot/output/images/rootfs.cpio.gz - - name: Upload Release Asset - id: upload-release-rootfs - uses: actions/upload-release-asset@latest - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ./buildroot/output/images/rootfs.cpio.gz - asset_name: rootfs.cpio.gz - asset_content_type: application/gzip From 18102db235dc9b2576be9b574709194bcfe82dcb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:15:58 +0200 Subject: [PATCH 32/83] Bump the go-dependency group across 1 directory with 8 updates (#258) Bumps the go-dependency group with 3 updates in the / directory: [github.com/caarlos0/env/v11](https://github.com/caarlos0/env), [go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc](https://github.com/open-telemetry/opentelemetry-go-contrib) and [github.com/docker/docker](https://github.com/docker/docker). Updates `github.com/caarlos0/env/v11` from 11.1.0 to 11.2.2 - [Release notes](https://github.com/caarlos0/env/releases) - [Changelog](https://github.com/caarlos0/env/blob/main/.goreleaser.yml) - [Commits](https://github.com/caarlos0/env/compare/v11.1.0...v11.2.2) Updates `go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc` from 0.53.0 to 0.55.0 - [Release notes](https://github.com/open-telemetry/opentelemetry-go-contrib/releases) - [Changelog](https://github.com/open-telemetry/opentelemetry-go-contrib/blob/main/CHANGELOG.md) - [Commits](https://github.com/open-telemetry/opentelemetry-go-contrib/compare/zpages/v0.53.0...zpages/v0.55.0) Updates `go.opentelemetry.io/otel/trace` from 1.28.0 to 1.30.0 - [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases) - [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md) - [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.28.0...v1.30.0) Updates `golang.org/x/crypto` from 0.25.0 to 0.27.0 - [Commits](https://github.com/golang/crypto/compare/v0.25.0...v0.27.0) Updates `golang.org/x/sync` from 0.7.0 to 0.8.0 - [Commits](https://github.com/golang/sync/compare/v0.7.0...v0.8.0) Updates `google.golang.org/grpc` from 1.65.0 to 1.66.1 - [Release notes](https://github.com/grpc/grpc-go/releases) - [Commits](https://github.com/grpc/grpc-go/compare/v1.65.0...v1.66.1) Updates `github.com/docker/docker` from 27.1.0+incompatible to 27.3.1+incompatible - [Release notes](https://github.com/docker/docker/releases) - [Commits](https://github.com/docker/docker/compare/v27.1.0...v27.3.1) Updates `golang.org/x/term` from 0.22.0 to 0.24.0 - [Commits](https://github.com/golang/term/compare/v0.22.0...v0.24.0) --- updated-dependencies: - dependency-name: github.com/caarlos0/env/v11 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: go.opentelemetry.io/otel/trace dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: golang.org/x/crypto dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: golang.org/x/sync dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: google.golang.org/grpc dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: github.com/docker/docker dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency - dependency-name: golang.org/x/term dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 28 ++++++++++++++-------------- go.sum | 56 ++++++++++++++++++++++++++++---------------------------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index d3644d10d..0cac06411 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22.5 require ( github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746 - github.com/caarlos0/env/v11 v11.1.0 + github.com/caarlos0/env/v11 v11.2.2 github.com/cenkalti/backoff/v4 v4.3.0 github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752 github.com/go-kit/kit v0.13.0 @@ -15,11 +15,11 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 github.com/virtee/sev-snp-measure-go v0.0.0-20240530153610-e6e8dc9b6877 - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 - go.opentelemetry.io/otel/trace v1.28.0 - golang.org/x/crypto v0.25.0 - golang.org/x/sync v0.7.0 - google.golang.org/grpc v1.65.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.55.0 + go.opentelemetry.io/otel/trace v1.30.0 + golang.org/x/crypto v0.27.0 + golang.org/x/sync v0.8.0 + google.golang.org/grpc v1.66.1 google.golang.org/protobuf v1.34.2 ) @@ -36,7 +36,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect - go.opentelemetry.io/otel v1.28.0 // indirect + go.opentelemetry.io/otel v1.30.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 // indirect go.opentelemetry.io/otel/sdk v1.28.0 // indirect @@ -49,7 +49,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/docker/docker v27.1.0+incompatible + github.com/docker/docker v27.3.1+incompatible github.com/go-kit/log v0.2.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -69,15 +69,15 @@ require ( github.com/prometheus/common v0.52.2 // indirect github.com/prometheus/procfs v0.13.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - go.opentelemetry.io/otel/metric v1.28.0 // indirect + go.opentelemetry.io/otel/metric v1.30.0 // indirect go.opentelemetry.io/proto/otlp v1.3.1 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/term v0.22.0 - golang.org/x/text v0.16.0 // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/term v0.24.0 + golang.org/x/text v0.18.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0c187662c..eddc31559 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746 h1:Tj567KeGV github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746/go.mod h1:CIx3OsPFc4doJZmBWSA6LNWefcznKv9c3cLOxNxL4q4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/caarlos0/env/v11 v11.1.0 h1:a5qZqieE9ZfzdvbbdhTalRrHT5vu/4V1/ad1Ka6frhI= -github.com/caarlos0/env/v11 v11.1.0/go.mod h1:LwgkYk1kDvfGpHthrWWLof3Ny7PezzFwS4QrsJdHTMo= +github.com/caarlos0/env/v11 v11.2.2 h1:95fApNrUyueipoZN/EhA8mMxiNxrBwDa+oAZrMWl3Kg= +github.com/caarlos0/env/v11 v11.2.2/go.mod h1:JBfcdeQiBoI3Zh1QRAWfe+tpiNTmDtcCj/hHHHMx0vc= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -23,8 +23,8 @@ github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752 h1:NI7XEcH github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752/go.mod h1:/Ok8PA2qi/ve0Py38+oL+VxoYmlowigYRyLEODRYdgc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v27.1.0+incompatible h1:rEHVQc4GZ0MIQKifQPHSFGV/dVgaZafgRf8fCPtDYBs= -github.com/docker/docker v27.1.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v27.3.1+incompatible h1:KttF0XoteNTicmUtBO0L2tP+J7FGRFTjaEF4k6WdhfI= +github.com/docker/docker v27.3.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -114,22 +114,22 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 h1:9G6E0TXzGFVfTnawRzrPl83iHOAV7L8NJiR8RSGYV1g= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0/go.mod h1:azvtTADFQJA8mX80jIH/akaE7h+dbm/sVuaHqN13w74= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.55.0 h1:hCq2hNMwsegUvPzI7sPOvtO9cqyy5GbWt/Ybp2xrx8Q= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.55.0/go.mod h1:LqaApwGx/oUmzsbqxkzuBvyoPpkxk3JQWnqfVrJ3wCA= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 h1:4K4tsIXefpVJtvA/8srF4V4y0akAoPHkIslgAkjixJA= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0/go.mod h1:jjdQuTGVsXV4vSs+CJ2qYDeDPf9yIJV23qlIzBm73Vg= -go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= -go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts= +go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 h1:3Q/xZUyC1BBkualc9ROb4G8qkH90LXEIICcs5zv1OYY= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0/go.mod h1:s75jGIWA9OfCMzF0xr+ZgfrB5FEbbV7UuYo32ahUiFI= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 h1:j9+03ymgYhPKmeXGk5Zu+cIZOlVzd9Zv7QIiyItjFBU= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0/go.mod h1:Y5+XiUG4Emn1hTfciPzGPJaSI+RpDts6BnCIir0SLqk= -go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= -go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w= +go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ= go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBqWyE= go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= -go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= -go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc= +go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -137,8 +137,8 @@ go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= @@ -147,25 +147,25 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210426230700-d19ff857e887/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= -golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= +golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -180,10 +180,10 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:BwIjyKYGsK9dMCBOorzRri8MQwmi7mT9rGHsCEinZkA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= -google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= -google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.66.1 h1:hO5qAXR19+/Z44hmvIM4dQFMSYX9XcWsByfoxutBpAM= +google.golang.org/grpc v1.66.1/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From aa01ef795da8fe2e24073bc5b860dfe2cf4bb0c0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:38:24 +0200 Subject: [PATCH 33/83] Update sev requirement (#257) Updates the requirements on [sev](https://github.com/virtee/sev) to permit the latest version. Updates `sev` to 4.0.0 - [Commits](https://github.com/virtee/sev/compare/v3.1.1...v4.0.0) --- updated-dependencies: - dependency-name: sev dependency-type: direct:production dependency-group: rs-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- scripts/backend_info/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/backend_info/Cargo.toml b/scripts/backend_info/Cargo.toml index 2eecad833..1d6c47dc4 100644 --- a/scripts/backend_info/Cargo.toml +++ b/scripts/backend_info/Cargo.toml @@ -9,4 +9,4 @@ edition = "2021" clap = { version = "4.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -sev = "3.1.1" +sev = "4.0.0" From 6c4819563c59467f82ca2d5e63fb346f77b08a89 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 25 Sep 2024 22:36:40 +0300 Subject: [PATCH 34/83] remove vnc (#259) Signed-off-by: Sammy Oina --- manager/qemu/config.go | 2 -- manager/service.go | 2 -- manager/service_test.go | 1 - 3 files changed, 5 deletions(-) diff --git a/manager/qemu/config.go b/manager/qemu/config.go index 2b884e15c..d5f3382db 100644 --- a/manager/qemu/config.go +++ b/manager/qemu/config.go @@ -61,7 +61,6 @@ type SevConfig struct { type VSockConfig struct { ID string `env:"VSOCK_ID" envDefault:"vhost-vsock-pci0"` GuestCID int `env:"VSOCK_GUEST_CID" envDefault:"3"` - Vnc int `env:"VSOCK_VNC" envDefault:"0"` } type Config struct { @@ -165,7 +164,6 @@ func (config Config) ConstructQemuArgs() []string { config.VirtioNetPciConfig.ROMFile)) args = append(args, "-device", fmt.Sprintf("vhost-vsock-pci,id=%s,guest-cid=%d", config.VSockConfig.ID, config.VSockConfig.GuestCID)) - args = append(args, "-vnc", fmt.Sprintf(":%d", config.Vnc)) if config.EnableSEVSNP { args = append(args, "-object", diff --git a/manager/service.go b/manager/service.go index 9d97d3162..d0dac45d6 100644 --- a/manager/service.go +++ b/manager/service.go @@ -185,8 +185,6 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) return "", err } - ms.qemuCfg.VSockConfig.Vnc++ - ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Completed.String(), json.RawMessage{}) return fmt.Sprint(ms.qemuCfg.HostFwdAgent), nil } diff --git a/manager/service_test.go b/manager/service_test.go index ee1148444..2998134e0 100644 --- a/manager/service_test.go +++ b/manager/service_test.go @@ -88,7 +88,6 @@ func TestRun(t *testing.T) { qemuCfg := qemu.Config{ VSockConfig: qemu.VSockConfig{ GuestCID: 3, - Vnc: 5900, }, } logger := slog.Default() From c69dcd0e2db2f8fefea119f918694091fa5ebe9a Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:59:26 +0300 Subject: [PATCH 35/83] NOISSUE - Improve reliability of state machine test (#260) * add sleep to prevent test failing Signed-off-by: Sammy Oina * add coverage Signed-off-by: Sammy Oina * use codecov Signed-off-by: Sammy Oina * create dir Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- .github/workflows/main.yaml | 15 ++++++++++++++- README.md | 4 ++++ agent/state_test.go | 28 +++++++++++++++++----------- test/computations/main.go | 2 +- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index d8ca4ee76..0d20e9795 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -30,5 +30,18 @@ jobs: run: | make + - name: Create coverage directory + run: | + mkdir coverage + - name: Run tests - run: go test -v --race -covermode=atomic -coverprofile cover.out ./... + run: go test -v --race -covermode=atomic -coverprofile coverage/cover.txt ./... + + - name: Upload results to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + directory: ./coverage/ + name: codecov-umbrella + verbose: true + diff --git a/README.md b/README.md index 49874737a..2ea8ae073 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,8 @@ # Cocos AI + +[![codecov](https://codecov.io/gh/ultravioletrs/cocos/graph/badge.svg?token=HX01LR01K9)](https://codecov.io/gh/ultravioletrs/cocos) +![Go report card](https://goreportcard.com/badge/github.com/ultravioletrs/cocos) + [Cocos AI (Confdential Computing System for AI/ML)][cocos] is a platform for secure multiparty computation (SMPC) based on the [Confidential Computing][cc] and [Trusted Execution Environments (TEEs)][tee]. diff --git a/agent/state_test.go b/agent/state_test.go index 5ce786359..43e3fb92b 100644 --- a/agent/state_test.go +++ b/agent/state_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "testing" + "time" mglog "github.com/absmach/magistrala/logger" ) @@ -39,19 +40,20 @@ func TestStateMachineTransitions(t *testing.T) { t.Run(fmt.Sprintf("Transition from %v to %v", tc.fromState, tc.expected), func(t *testing.T) { sm := NewStateMachine(mglog.NewMock(), tc.cmp) ctx, cancel := context.WithCancel(context.Background()) - go func() { - sm.Start(ctx) - }() - sm.wg.Wait() - sm.SetState(tc.fromState) + defer cancel() + + go sm.Start(ctx) + + time.Sleep(50 * time.Millisecond) + sm.SetState(tc.fromState) sm.SendEvent(tc.event) + time.Sleep(50 * time.Millisecond) + if sm.GetState() != tc.expected { t.Errorf("Expected state %v after the event, but got %v", tc.expected, sm.GetState()) } - close(sm.EventChan) - cancel() }) } } @@ -59,14 +61,18 @@ func TestStateMachineTransitions(t *testing.T) { func TestStateMachineInvalidTransition(t *testing.T) { sm := NewStateMachine(mglog.NewMock(), cmp) ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go sm.Start(ctx) - sm.SetState(Idle) + time.Sleep(50 * time.Millisecond) + sm.SetState(Idle) sm.SendEvent(dataReceived) - if sm.State != Idle { - t.Errorf("State should not change on an invalid event, but got %v", sm.State) + time.Sleep(50 * time.Millisecond) + + if sm.GetState() != Idle { + t.Errorf("State should not change on an invalid event, but got %v", sm.GetState()) } - cancel() } diff --git a/test/computations/main.go b/test/computations/main.go index f6befe973..1c8a4f332 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -27,7 +27,7 @@ import ( var _ managergrpc.Service = (*svc)(nil) const ( - svcName = "manager_test_server" + svcName = "computations_test_server" defaultPort = "7001" ) From 115c6c24c09a5225156f29653739a5a9dcdc4c7f Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Fri, 27 Sep 2024 11:52:52 +0300 Subject: [PATCH 36/83] NOISSUE - Fix file format (#261) * add coverage Signed-off-by: Sammy Oina * use codecov Signed-off-by: Sammy Oina * rename extension Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- .github/workflows/main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 0d20e9795..fffe2eb1c 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -35,7 +35,7 @@ jobs: mkdir coverage - name: Run tests - run: go test -v --race -covermode=atomic -coverprofile coverage/cover.txt ./... + run: go test -v --race -covermode=atomic -coverprofile coverage/cover.out ./... - name: Upload results to Codecov uses: codecov/codecov-action@v4 From 63994d78b8c29df491d1da643bf252b284eae221 Mon Sep 17 00:00:00 2001 From: b1ackd0t <28790446+rodneyosodo@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:49:18 +0300 Subject: [PATCH 37/83] NOISSUE - Add Rust gitignore (#268) * chore(backendinfo): Add rust build artefacts to gitignore Signed-off-by: Rodney Osodo * style: format file following rust linter guidelines Signed-off-by: Rodney Osodo * chore(CI): Add rust CI pipeline Signed-off-by: Rodney Osodo --------- Signed-off-by: Rodney Osodo --- .github/workflows/rust.yaml | 41 +++++++++++++++++++++ .gitignore | 15 ++++++++ scripts/backend_info/src/main.rs | 61 +++++++++++++++++--------------- 3 files changed, 88 insertions(+), 29 deletions(-) create mode 100644 .github/workflows/rust.yaml diff --git a/.github/workflows/rust.yaml b/.github/workflows/rust.yaml new file mode 100644 index 000000000..19d358d48 --- /dev/null +++ b/.github/workflows/rust.yaml @@ -0,0 +1,41 @@ +name: Rust CI Pipeline + +on: + push: + branches: + - main + paths: + - "scripts/backend_info/**" + - ".github/workflows/rust.yaml" + pull_request: + branches: + - main + paths: + - "scripts/backend_info/**" + - ".github/workflows/rust.yaml" + +env: + CARGO_TERM_COLOR: always + +jobs: + rust-check: + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./scripts/backend_info + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Check cargo + run: cargo check --release --all-targets + + - name: Check formatting + run: cargo fmt --all -- --check + + - name: Run linter + run: cargo clippy -- -D warnings + + - name: Build for all features + run: cargo build --release --all-features diff --git a/.gitignore b/.gitignore index f2818ed77..b23806e1d 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,18 @@ dist/ results.zip *.spec *.tar + +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb diff --git a/scripts/backend_info/src/main.rs b/scripts/backend_info/src/main.rs index 4657153e9..6de335422 100644 --- a/scripts/backend_info/src/main.rs +++ b/scripts/backend_info/src/main.rs @@ -1,11 +1,11 @@ -use clap::{Arg, Command, value_parser}; +use clap::{value_parser, Arg, Command}; use serde::Serialize; +use sev::firmware::host::*; use std::arch::x86_64::__cpuid; use std::fs::File; use std::io::Write; -use sev::firmware::host::*; -const BACKEND_INFO_JSON : &str = "backend_info.json"; +const BACKEND_INFO_JSON: &str = "backend_info.json"; const EXTENDED_FAMILY_SHIFT: u32 = 20; const EXTENDED_MODEL_SHIFT: u32 = 16; const FAMILY_SHIFT: u32 = 8; @@ -25,7 +25,7 @@ struct SevProduct { #[derive(Serialize)] struct Vmpl { - value : u32, + value: u32, } #[derive(Serialize)] @@ -45,15 +45,15 @@ struct SnpPolicy { minimum_version: String, permit_provisional_firmware: bool, require_id_block: bool, - product: SevProduct, + product: SevProduct, } #[derive(Serialize)] struct RootOfTrust { product: String, - check_crl : bool, - disallow_network : bool, - product_line : String, + check_crl: bool, + disallow_network: bool, + product_line: String, } #[derive(Serialize)] @@ -63,19 +63,19 @@ struct Computation { } fn get_sev_snp_processor() -> u32 { - let cpuid_result = unsafe { __cpuid(1)}; + let cpuid_result = unsafe { __cpuid(1) }; cpuid_result.eax } fn get_product_name(product: i32) -> String { match product { - SEV_PRODUCT_MILAN => return "Milan".to_string(), - SEV_PRODUCT_GENOA => return "Genoa".to_string(), - _ => return "Unknown".to_string(), + SEV_PRODUCT_MILAN => "Milan".to_string(), + SEV_PRODUCT_GENOA => "Genoa".to_string(), + _ => "Unknown".to_string(), } } -fn get_uint64_from_tcb(tcb_version : &TcbVersion) -> u64 { +fn get_uint64_from_tcb(tcb_version: &TcbVersion) -> u64 { let microcode = (tcb_version.microcode as u64) << 56; let snp = (tcb_version.snp as u64) << 48; let tee = (tcb_version.tee as u64) << 8; @@ -103,20 +103,22 @@ fn sev_product(eax: u32) -> SevProduct { }; } - SevProduct { - name: product_name, - } + SevProduct { name: product_name } } fn main() { let matches = Command::new("Backend info") - .about("Processes command line options and outputs a JSON file for Attestation verification") - .arg(Arg::new("policy") - .long("policy") - .value_name("INT") - .help("Sets the policy integer") - .required(true) - .value_parser(value_parser!(u64))) + .about( + "Processes command line options and outputs a JSON file for Attestation verification", + ) + .arg( + Arg::new("policy") + .long("policy") + .value_name("INT") + .help("Sets the policy integer") + .required(true) + .value_parser(value_parser!(u64)), + ) .get_matches(); let mut firmware: Firmware = Firmware::open().unwrap(); @@ -125,7 +127,7 @@ fn main() { let policy: u64 = *matches.get_one::("policy").unwrap(); let family_id = vec![0; 16]; let image_id = vec![0; 16]; - let vmpl = Vmpl { value: 0}; + let vmpl = Vmpl { value: 0 }; let minimum_tcb = get_uint64_from_tcb(&status.platform_tcb_version); let minimum_launch_tcb = get_uint64_from_tcb(&status.platform_tcb_version); let require_author_key = false; @@ -160,10 +162,10 @@ fn main() { }; let root_of_trust = RootOfTrust { - product : get_product_name(product.name), - check_crl : true, - disallow_network : false, - product_line : get_product_name(product.name), + product: get_product_name(product.name), + check_crl: true, + disallow_network: false, + product_line: get_product_name(product.name), }; let computation = Computation { @@ -173,7 +175,8 @@ fn main() { let json = serde_json::to_string_pretty(&computation).expect("Failed to serialize to JSON"); let mut file = File::create(BACKEND_INFO_JSON).expect("Failed to create file"); - file.write_all(json.as_bytes()).expect("Failed to write to file"); + file.write_all(json.as_bytes()) + .expect("Failed to write to file"); println!("Computation JSON has been written to {}", BACKEND_INFO_JSON); } From 3d9fde39c205496972fb9ad602b5008b0e1c974b Mon Sep 17 00:00:00 2001 From: Smith Jilks <41241359+smithjilks@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:25:52 +0300 Subject: [PATCH 38/83] NOISSUE - Enhance CLI (#250) * Enhance CLI progressbar Signed-off-by: Jilks Smith * Update cli error and success messages colors Signed-off-by: Jilks Smith * Update cli emojis Signed-off-by: Jilks Smith * Add logs for cli interrupt by user Signed-off-by: Jilks Smith * Remove extra whitespaces Signed-off-by: Jilks Smith * Update upload data emoji Signed-off-by: Jilks Smith * Update cli main.go Signed-off-by: Jilks Smith * Update cli errors Signed-off-by: Jilks Smith * Update cli Signed-off-by: Jilks Smith * Update cli Signed-off-by: Jilks Smith * Update go sum Signed-off-by: Jilks Smith * Add progressbar tests Signed-off-by: Jilks Smith * Fix cli cmd error formating Signed-off-by: Jilks Smith * Add cli datasets, algo and result tests Signed-off-by: Jilks Smith --------- Signed-off-by: Jilks Smith --- cli/algorithm_test.go | 144 +++++++++++++++++++++++++++++++ cli/algorithms.go | 19 ++-- cli/datasets.go | 23 +++-- cli/datasets_test.go | 119 +++++++++++++++++++++++++ cli/result.go | 17 ++-- cli/result_test.go | 107 +++++++++++++++++++++++ cmd/cli/main.go | 22 ++++- go.mod | 3 + go.sum | 9 ++ pkg/progressbar/progress_test.go | 132 ++++++++++++++++++++++++++++ pkg/progressbar/progressbar.go | 99 +++++++++++---------- pkg/sdk/agent.go | 1 + pkg/sdk/mocks/sdk.go | 129 +++++++++++++++++++++++++++ 13 files changed, 762 insertions(+), 62 deletions(-) create mode 100644 cli/algorithm_test.go create mode 100644 cli/datasets_test.go create mode 100644 cli/result_test.go create mode 100644 pkg/progressbar/progress_test.go create mode 100644 pkg/sdk/mocks/sdk.go diff --git a/cli/algorithm_test.go b/cli/algorithm_test.go new file mode 100644 index 000000000..fe853e195 --- /dev/null +++ b/cli/algorithm_test.go @@ -0,0 +1,144 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "log" + "os" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/sdk/mocks" +) + +const algorithmFile = "test_algo_file.py" + +func captureLogOutput(f func()) string { + var buf bytes.Buffer + log.SetOutput(&buf) + defer log.SetOutput(os.Stderr) + f() + return buf.String() +} + +func generateRSAPrivateKeyFile(fileName string) error { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + + privKeyFile, err := os.Create(fileName) + if err != nil { + return err + } + defer privKeyFile.Close() + + privateKeyPEM := &pem.Block{ + Type: rsaKeyType, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + err = pem.Encode(privKeyFile, privateKeyPEM) + if err != nil { + return err + } + + return nil +} + +func TestAlgorithmCmd_Success(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644) + require.NoError(t, err) + + err = generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewAlgorithmCmd() + output := captureLogOutput(func() { + cmd.SetArgs([]string{algorithmFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Successfully uploaded algorithm") + t.Cleanup(func() { + os.Remove(privateKeyFile) + os.Remove(algorithmFile) + }) +} + +func TestAlgorithmCmd_MissingAlgorithmFile(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + cmd := testCLI.NewAlgorithmCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{"non_existent_algo_file.py", privateKeyFile}) + err := cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Error reading algorithm file") +} + +func TestAlgorithmCmd_MissingPrivateKeyFile(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644) + require.NoError(t, err) + + cmd := testCLI.NewAlgorithmCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{algorithmFile, "non_existent_private_key.pem"}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Error reading private key file") + t.Cleanup(func() { + os.Remove(algorithmFile) + }) +} + +func TestAlgorithmCmd_UploadFailure(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) + testCLI := New(mockSDK) + + err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644) + require.NoError(t, err) + + err = generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewAlgorithmCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{algorithmFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Failed to upload algorithm") + + t.Cleanup(func() { + os.Remove(privateKeyFile) + os.Remove(algorithmFile) + }) +} diff --git a/cli/algorithms.go b/cli/algorithms.go index c9da27281..e0a1ab7cb 100644 --- a/cli/algorithms.go +++ b/cli/algorithms.go @@ -8,6 +8,7 @@ import ( "log" "os" + "github.com/fatih/color" "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/algorithm" @@ -35,14 +36,18 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { algorithm, err := os.ReadFile(algorithmFile) if err != nil { - log.Fatalf("Error reading algorithm file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading algorithm file: %v ❌ ", err) + log.Println(msg) + return } var req []byte if requirementsFile != "" { req, err = os.ReadFile(requirementsFile) if err != nil { - log.Fatalf("Error reading requirments file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading requirments file: %v ❌ ", err) + log.Println(msg) + return } } @@ -53,7 +58,9 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[1]) if err != nil { - log.Fatalf("Error reading private key file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading private key file: %v ❌ ", err.Error()) + log.Println(msg) + return } pemBlock, _ := pem.Decode(privKeyFile) @@ -63,10 +70,12 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil { - log.Fatalf("Error uploading algorithm with error: %v", err) + msg := color.New(color.FgRed).Sprintf("Failed to upload algorithm due to error: %v ❌ ", err.Error()) + log.Println(msg) + return } - log.Println("Successfully uploaded algorithm") + log.Println(color.New(color.FgGreen).Sprint("Successfully uploaded algorithm! ✔ ")) }, } diff --git a/cli/datasets.go b/cli/datasets.go index 875f1aa27..49c8c4279 100644 --- a/cli/datasets.go +++ b/cli/datasets.go @@ -10,6 +10,7 @@ import ( "os" "path" + "github.com/fatih/color" "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/internal" @@ -31,7 +32,9 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { f, err := os.Stat(datasetPath) if err != nil { - log.Fatalf("Error reading dataset file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading dataset file: %v ❌ ", err) + log.Println(msg) + return } var dataset []byte @@ -39,12 +42,16 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { if f.IsDir() { dataset, err = internal.ZipDirectoryToMemory(datasetPath) if err != nil { - log.Fatalf("Error zipping dataset directory: %v", err) + msg := color.New(color.FgRed).Sprintf("Error zipping dataset directory: %v ❌ ", err) + log.Println(msg) + return } } else { dataset, err = os.ReadFile(datasetPath) if err != nil { - log.Fatalf("Error reading dataset file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading dataset file: %v ❌ ", err) + log.Println(msg) + return } } @@ -55,7 +62,9 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[1]) if err != nil { - log.Fatalf("Error reading private key file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading private key file: %v ❌ ", err) + log.Println(msg) + return } pemBlock, _ := pem.Decode(privKeyFile) @@ -64,10 +73,12 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil { - log.Fatalf("Error uploading dataset: %v", err) + msg := color.New(color.FgRed).Sprintf("Failed to upload dataset due to error: %v ❌ ", err.Error()) + log.Println(msg) + return } - log.Println("Successfully uploaded dataset") + log.Println(color.New(color.FgGreen).Sprint("Successfully uploaded dataset! ✔ ")) }, } diff --git a/cli/datasets_test.go b/cli/datasets_test.go new file mode 100644 index 000000000..fade39c27 --- /dev/null +++ b/cli/datasets_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "errors" + "os" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/sdk/mocks" +) + +func createTempDatasetFile(content string) (string, error) { + tmpFile, err := os.CreateTemp("", "dataset-*.txt") + if err != nil { + return "", err + } + defer tmpFile.Close() + + _, err = tmpFile.WriteString(content) + if err != nil { + return "", err + } + return tmpFile.Name(), nil +} + +func TestDatasetsCmd_Success(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + datasetFile, err := createTempDatasetFile("test dataset content") + require.NoError(t, err) + + err = generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewDatasetsCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{datasetFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Successfully uploaded dataset") + mockSDK.AssertCalled(t, "Data", mock.Anything, mock.Anything, mock.Anything) + + t.Cleanup(func() { + os.Remove(datasetFile) + os.Remove(privateKeyFile) + }) +} + +func TestDatasetsCmd_MissingDatasetFile(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + cmd := testCLI.NewDatasetsCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{"non_existent_dataset.txt", privateKeyFile}) + err := cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Error reading dataset file") +} + +func TestDatasetsCmd_MissingPrivateKeyFile(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + datasetFile, err := createTempDatasetFile("test dataset content") + require.NoError(t, err) + + cmd := testCLI.NewDatasetsCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{datasetFile, "non_existent_private_key.pem"}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Error reading private key file") + t.Cleanup(func() { + os.Remove(datasetFile) + }) +} + +func TestDatasetsCmd_UploadFailure(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) + testCLI := New(mockSDK) + + datasetFile, err := createTempDatasetFile("test dataset content") + require.NoError(t, err) + + err = generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewDatasetsCmd() + + output := captureLogOutput(func() { + cmd.SetArgs([]string{datasetFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Failed to upload dataset due to error") + t.Cleanup(func() { + os.Remove(datasetFile) + os.Remove(privateKeyFile) + }) +} diff --git a/cli/result.go b/cli/result.go index 5fb3ac7b5..728d726e3 100644 --- a/cli/result.go +++ b/cli/result.go @@ -7,6 +7,7 @@ import ( "log" "os" + "github.com/fatih/color" "github.com/spf13/cobra" ) @@ -19,11 +20,13 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { Example: "result ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - log.Println("Retrieving computation result file") + log.Println("⏳ Retrieving computation result file") privKeyFile, err := os.ReadFile(args[0]) if err != nil { - log.Fatalf("Error reading private key file: %v", err) + msg := color.New(color.FgRed).Sprintf("Error reading private key file: %v ❌ ", err) + log.Println(msg) + return } pemBlock, _ := pem.Decode(privKeyFile) @@ -33,14 +36,18 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { privKey := decodeKey(pemBlock) result, err = cli.agentSDK.Result(cmd.Context(), privKey) if err != nil { - log.Fatalf("Error retrieving computation result: %v", err) + msg := color.New(color.FgRed).Sprintf("Error retrieving computation result: %v ❌ ", err) + log.Println(msg) + return } if err := os.WriteFile(resultFilePath, result, 0o644); err != nil { - log.Fatalf("Error saving computation result to %s: %v", resultFilePath, err) + msg := color.New(color.FgRed).Sprintf("Error saving computation result to %s: %v ❌ ", resultFilePath, err) + log.Println(msg) + return } - log.Println("Computation result retrieved and saved successfully!") + log.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ ")) }, } } diff --git a/cli/result_test.go b/cli/result_test.go new file mode 100644 index 000000000..71cccf45d --- /dev/null +++ b/cli/result_test.go @@ -0,0 +1,107 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "errors" + "os" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/sdk/mocks" +) + +const compResult = "Test computation result" + +func TestResultsCmd_Success(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + testCLI := New(mockSDK) + + err := generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewResultsCmd() + output := captureLogOutput(func() { + cmd.SetArgs([]string{privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Computation result retrieved and saved successfully") + + resultFile, err := os.ReadFile("results.zip") + require.NoError(t, err) + require.Equal(t, compResult, string(resultFile)) + + t.Cleanup(func() { + os.Remove("results.zip") + os.Remove(privateKeyFile) + }) +} + +func TestResultsCmd_MissingPrivateKeyFile(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + testCLI := New(mockSDK) + + cmd := testCLI.NewResultsCmd() + output := captureLogOutput(func() { + cmd.SetArgs([]string{"non_existent_private_key.pem"}) + err := cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Error reading private key file") +} + +func TestResultsCmd_ResultFailure(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return(nil, errors.New("error retrieving computation result")) + testCLI := New(mockSDK) + + err := generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewResultsCmd() + output := captureLogOutput(func() { + cmd.SetArgs([]string{privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "error retrieving computation result") + mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything) + t.Cleanup(func() { + os.Remove(privateKeyFile) + }) +} + +func TestResultsCmd_SaveFailure(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + testCLI := New(mockSDK) + + err := generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + // Simulate failure in saving the result file by making a directory with the same name as the result file + err = os.Mkdir("results.zip", 0o755) + require.NoError(t, err) + + cmd := testCLI.NewResultsCmd() + output := captureLogOutput(func() { + cmd.SetArgs([]string{privateKeyFile}) + err := cmd.Execute() + require.NoError(t, err) + }) + + require.Contains(t, output, "Error saving computation result to results.zip") + mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything) + + t.Cleanup(func() { + os.Remove("results.zip") + os.Remove(privateKeyFile) + }) +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index c5724ded6..15f166f40 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -6,10 +6,13 @@ import ( "fmt" "log" "os" + "os/signal" "path" + "syscall" mglog "github.com/absmach/magistrala/logger" "github.com/caarlos0/env/v11" + "github.com/fatih/color" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/ultravioletrs/cocos/cli" @@ -32,6 +35,16 @@ type config struct { } func main() { + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-signalChan + fmt.Println() + log.Println(color.New(color.FgRed).Sprint("Operation aborted by user!")) + os.Exit(2) + }() + var cfg config if err := env.Parse(&cfg); err != nil { log.Fatalf("failed to load %s configuration : %s", svcName, err) @@ -133,7 +146,14 @@ func main() { backendCmd.AddCommand(cliSVC.NewAddHostDataCmd()) if err := rootCmd.Execute(); err != nil { - logger.Error(fmt.Sprintf("Command execution failed: %s", err)) + logErrorCmd(*rootCmd, err) return } } + +func logErrorCmd(cmd cobra.Command, err error) { + boldRed := color.New(color.FgRed, color.Bold) + boldRed.Fprintf(cmd.ErrOrStderr(), "\nerror: ") + + fmt.Fprintf(cmd.ErrOrStderr(), "%s\n\n", color.RedString(err.Error())) +} diff --git a/go.mod b/go.mod index 0cac06411..45823cab5 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/caarlos0/env/v11 v11.2.2 github.com/cenkalti/backoff/v4 v4.3.0 github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752 + github.com/fatih/color v1.17.0 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/google/go-sev-guest v0.11.1 @@ -31,6 +32,8 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect diff --git a/go.sum b/go.sum index eddc31559..de8a39697 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,8 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= @@ -69,6 +71,11 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mdlayher/vsock v1.2.1 h1:pC1mTJTvjo1r9n9fbm7S1j04rCgCzhCOS5DY0zqHlnQ= @@ -158,6 +165,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210426230700-d19ff857e887/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= diff --git a/pkg/progressbar/progress_test.go b/pkg/progressbar/progress_test.go new file mode 100644 index 000000000..daac3cc13 --- /dev/null +++ b/pkg/progressbar/progress_test.go @@ -0,0 +1,132 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package progressbar + +import ( + "bytes" + "io" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRenderProgressBarWithMockedWidth(t *testing.T) { + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + defer func() { + w.Close() + os.Stdout = oldStdout + }() + + pb := &ProgressBar{ + numberOfBytes: 100, + currentUploadedBytes: 0, + TerminalWidthFunc: func() (int, error) { + return 170, nil + }, + } + + err := pb.updateProgress(50) + assert.NoError(t, err) + err = pb.renderProgressBar() + assert.NoError(t, err) + + err = w.Close() + assert.NoError(t, err) + + var buf bytes.Buffer + _, err = io.Copy(&buf, r) + assert.NoError(t, err) + + renderedBar := buf.String() + assert.Contains(t, renderedBar, "[50%]") +} + +func TestClearProgressBar(t *testing.T) { + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + defer func() { + w.Close() + os.Stdout = oldStdout + }() + + pb := &ProgressBar{ + numberOfBytes: 100, + currentUploadedBytes: 0, + maxWidth: 100, + TerminalWidthFunc: func() (int, error) { + return 50, nil + }, + } + + err := pb.updateProgress(50) + assert.NoError(t, err) + err = pb.renderProgressBar() + assert.NoError(t, err) + + err = pb.clearProgressBar() + assert.NoError(t, err) + + w.Close() + + var buf bytes.Buffer + _, err = io.Copy(&buf, r) + assert.NoError(t, err) + + clearedBar := buf.String() + expectedClear := "\r" + strings.Repeat(" ", pb.maxWidth) + "\r" + assert.Contains(t, clearedBar, expectedClear) +} + +func TestReset(t *testing.T) { + pb := &ProgressBar{ + numberOfBytes: 100, + currentUploadedBytes: 7, + maxWidth: 100, + description: "Test Upload", + } + + description := "" + totalBytes := 0 + pb.reset(description, totalBytes) + + assert.Equal(t, 0, pb.currentUploadedBytes) + assert.Equal(t, 0, pb.currentUploadPercentage) + assert.Equal(t, totalBytes, pb.numberOfBytes) + assert.Equal(t, description, pb.description) +} + +func TestUpdateProgress(t *testing.T) { + pb := &ProgressBar{ + numberOfBytes: 100, + currentUploadPercentage: 0, + currentUploadedBytes: 0, + } + + bytesRead := 25 + err := pb.updateProgress(bytesRead) + assert.NoError(t, err) + assert.Equal(t, 25, pb.currentUploadedBytes) + assert.Equal(t, 25, pb.currentUploadPercentage) + + bytesRead = 50 + err = pb.updateProgress(bytesRead) + assert.NoError(t, err) + assert.Equal(t, 75, pb.currentUploadedBytes) + assert.Equal(t, 75, pb.currentUploadPercentage) + + bytesRead = 50 + err = pb.updateProgress(bytesRead) + assert.Error(t, err) + assert.EqualError(t, err, "progress update exceeds total bytes: attempted to add 50 bytes, but only 25 bytes remain") + + // Ensure the progress does not exceed 100% after the error + assert.Equal(t, 75, pb.currentUploadedBytes) + assert.Equal(t, 75, pb.currentUploadPercentage) +} diff --git a/pkg/progressbar/progressbar.go b/pkg/progressbar/progressbar.go index eb23cd81a..68624d8b2 100644 --- a/pkg/progressbar/progressbar.go +++ b/pkg/progressbar/progressbar.go @@ -9,18 +9,15 @@ import ( "os" "strings" + "github.com/fatih/color" "github.com/ultravioletrs/cocos/agent" "golang.org/x/term" ) const ( - progressBarDots = "... " - leftBracket = "[" - rightBracket = "]" - head = ">" - body = "=" - bodyPadding = "." - bufferSize = 1024 * 1024 + leftBracket = "[" + rightBracket = "]" + bufferSize = 1024 * 1024 ) var ( @@ -74,10 +71,13 @@ type ProgressBar struct { currentUploadPercentage int description string maxWidth int + TerminalWidthFunc func() (int, error) } func New() *ProgressBar { - return &ProgressBar{} + return &ProgressBar{ + TerminalWidthFunc: terminalWidth, + } } func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error { @@ -131,7 +131,10 @@ func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream return err } - p.updateProgress(n) + err = p.updateProgress(n) + if err != nil { + return err + } if err := stream.Send(createRequest(buf[:n])); err != nil { return err @@ -158,7 +161,10 @@ func (p *ProgressBar) sendBuffer(buffer *bytes.Buffer, stream streamSender, crea return err } - p.updateProgress(n) + err = p.updateProgress(n) + if err != nil { + return err + } if err := stream.Send(createRequest(buf[:n])); err != nil { return err @@ -179,22 +185,26 @@ func (p *ProgressBar) reset(description string, totalBytes int) { p.description = description } -func (p *ProgressBar) updateProgress(bytesRead int) { - if p.currentUploadedBytes < p.numberOfBytes { - p.currentUploadedBytes += bytesRead - p.currentUploadPercentage = p.currentUploadedBytes * 100 / p.numberOfBytes +func (p *ProgressBar) updateProgress(bytesRead int) error { + if p.currentUploadedBytes+bytesRead > p.numberOfBytes { + return fmt.Errorf("progress update exceeds total bytes: attempted to add %d bytes, but only %d bytes remain", bytesRead, p.numberOfBytes-p.currentUploadedBytes) } + + p.currentUploadedBytes += bytesRead + p.currentUploadPercentage = p.currentUploadedBytes * 100 / p.numberOfBytes + + return nil } -// Progress bar example: Uploading algorithm... 25% [==> ]. +// Progress bar example: 📦 Uploading algorithm... [█████░░░░░░░░░░░░] [25%]. func (p *ProgressBar) renderProgressBar() error { var builder strings.Builder // Get terminal width. - width, err := terminalWidth() + width, err := p.TerminalWidthFunc() if err != nil { if !warnOnlyOnce { - fmt.Println("Progress bar could not be rendered") + color.Red("Progress bar could not be rendered") warnOnlyOnce = true } return nil @@ -208,28 +218,29 @@ func (p *ProgressBar) renderProgressBar() error { return fmt.Errorf("failed to clear progress bar: %v", err) } - // The progress bar starts with the description. - if _, err := builder.WriteString(p.description); err != nil { - return fmt.Errorf("failed to add description: %v", err) + // Emoji to indicate progress action (📥 for datasets). + emoji := "🚀 " + if strings.Contains(p.description, "data") { + emoji = "📦 " } - - // Add dots to progress bar. - if _, err := builder.WriteString(progressBarDots); err != nil { - return fmt.Errorf("failed to add dots: %v", err) + if _, err := builder.WriteString(color.New(color.FgYellow).Sprint(emoji)); err != nil { + return fmt.Errorf("failed to add emoji: %v", err) } - // Add uploaded percentage. - strCurrentUploadPercentage := fmt.Sprintf("%4d%% ", p.currentUploadPercentage) - if _, err := builder.WriteString(strCurrentUploadPercentage); err != nil { - return fmt.Errorf("failed to add upload percentage bracket: %v", err) + // The progress bar starts with the description. + description := color.New(color.FgYellow).Sprintf("%s ", p.description) + if _, err := builder.WriteString(description); err != nil { + return fmt.Errorf("failed to add description: %v", err) } - // Add letf bracket and space to progress bar. + // Add left bracket (colored). + leftBracket := color.New(color.FgBlue).Sprint(leftBracket) if _, err := builder.WriteString(leftBracket); err != nil { return fmt.Errorf("failed to add left bracket: %v", err) } - progressWidth := width - builder.Len() - len(rightBracket+" ") + // Calculate the progress bar's width. + progressWidth := width - builder.Len() - len(rightBracket+" [100%]") numOfCharactersBody := progressWidth * p.currentUploadPercentage / 100 if numOfCharactersBody == 0 { numOfCharactersBody = 1 @@ -237,33 +248,31 @@ func (p *ProgressBar) renderProgressBar() error { numOfCharactersPadding := progressWidth - numOfCharactersBody - // Add body which represents the percentage. - progress := strings.Repeat(body, numOfCharactersBody-1) - - // Add progress to the progress bar. + // Using unicode block characters for a smooth bar. + progress := color.New(color.FgGreen).Sprint(strings.Repeat("█", numOfCharactersBody)) if _, err := builder.WriteString(progress); err != nil { - return fmt.Errorf("failed to add progress strings to padding: %v", err) + return fmt.Errorf("failed to add progress strings: %v", err) } - // Add head to progress bar. - if _, err := builder.WriteString(head); err != nil { - return fmt.Errorf("failed to add head to padding: %v", err) - } - - // Add padding to end of bar. - padding := strings.Repeat(bodyPadding, numOfCharactersPadding) - - // Add padding to progress bar. + // Add the unfilled part (light blocks as padding). + padding := strings.Repeat("░", numOfCharactersPadding) if _, err := builder.WriteString(padding); err != nil { return fmt.Errorf("failed to add padding: %v", err) } // Add right bracket to progress bar. + rightBracket := color.New(color.FgBlue).Sprint("]") if _, err := builder.WriteString(rightBracket); err != nil { return fmt.Errorf("failed to add right bracket: %v", err) } - // Write progress bar. + // Add the percentage at the end inside square brackets. + strCurrentUploadPercentage := color.New(color.FgGreen).Sprintf(" [%d%%]", p.currentUploadPercentage) + if _, err := builder.WriteString(strCurrentUploadPercentage); err != nil { + return fmt.Errorf("failed to add upload percentage: %v", err) + } + + // Write progress bar to the console. if _, err := io.WriteString(os.Stdout, builder.String()); err != nil { return fmt.Errorf("failed to write string: %v", err) } diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index a1ad99de5..231f239f3 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc/metadata" ) +//go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" type SDK interface { Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error Data(ctx context.Context, dataset agent.Dataset, privKey any) error diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go new file mode 100644 index 000000000..be6efd1a9 --- /dev/null +++ b/pkg/sdk/mocks/sdk.go @@ -0,0 +1,129 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + context "context" + + agent "github.com/ultravioletrs/cocos/agent" + + mock "github.com/stretchr/testify/mock" +) + +// SDK is an autogenerated mock type for the SDK type +type SDK struct { + mock.Mock +} + +// Algo provides a mock function with given fields: ctx, algorithm, privKey +func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey interface{}) error { + ret := _m.Called(ctx, algorithm, privKey) + + if len(ret) == 0 { + panic("no return value specified for Algo") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm, interface{}) error); ok { + r0 = rf(ctx, algorithm, privKey) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Attestation provides a mock function with given fields: ctx, reportData +func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) { + ret := _m.Called(ctx, reportData) + + if len(ret) == 0 { + panic("no return value specified for Attestation") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok { + return rf(ctx, reportData) + } + if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok { + r0 = rf(ctx, reportData) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok { + r1 = rf(ctx, reportData) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Data provides a mock function with given fields: ctx, dataset, privKey +func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interface{}) error { + ret := _m.Called(ctx, dataset, privKey) + + if len(ret) == 0 { + panic("no return value specified for Data") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset, interface{}) error); ok { + r0 = rf(ctx, dataset, privKey) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Result provides a mock function with given fields: ctx, privKey +func (_m *SDK) Result(ctx context.Context, privKey interface{}) ([]byte, error) { + ret := _m.Called(ctx, privKey) + + if len(ret) == 0 { + panic("no return value specified for Result") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, interface{}) ([]byte, error)); ok { + return rf(ctx, privKey) + } + if rf, ok := ret.Get(0).(func(context.Context, interface{}) []byte); ok { + r0 = rf(ctx, privKey) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, interface{}) error); ok { + r1 = rf(ctx, privKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSDK(t interface { + mock.TestingT + Cleanup(func()) +}) *SDK { + mock := &SDK{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From faaddc35712b3b9d2a07fe57e8a7e6084d4daa01 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:55:44 +0200 Subject: [PATCH 39/83] NOISSUE - Bump google.golang.org/grpc (#270) Bumps the go-dependency group with 1 update in the / directory: [google.golang.org/grpc](https://github.com/grpc/grpc-go). Updates `google.golang.org/grpc` from 1.66.1 to 1.67.0 - [Release notes](https://github.com/grpc/grpc-go/releases) - [Commits](https://github.com/grpc/grpc-go/compare/v1.66.1...v1.67.0) --- updated-dependencies: - dependency-name: google.golang.org/grpc dependency-type: direct:production update-type: version-update:semver-minor dependency-group: go-dependency ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 45823cab5..e4fbd958c 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( go.opentelemetry.io/otel/trace v1.30.0 golang.org/x/crypto v0.27.0 golang.org/x/sync v0.8.0 - google.golang.org/grpc v1.66.1 + google.golang.org/grpc v1.67.0 google.golang.org/protobuf v1.34.2 ) @@ -79,7 +79,7 @@ require ( golang.org/x/sys v0.25.0 // indirect golang.org/x/term v0.24.0 golang.org/x/text v0.18.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index de8a39697..2126363ea 100644 --- a/go.sum +++ b/go.sum @@ -187,12 +187,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= +google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142 h1:wKguEg1hsxI2/L3hUYrpo1RVi48K+uTyzKqprwLXsb8= +google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142/go.mod h1:d6be+8HhtEtucleCbxpPW9PA9XwISACu8nvpPqF0BVo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= -google.golang.org/grpc v1.66.1 h1:hO5qAXR19+/Z44hmvIM4dQFMSYX9XcWsByfoxutBpAM= -google.golang.org/grpc v1.66.1/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw= +google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From f6b69d65df2458137267a3b81c32646358f631fa Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 8 Oct 2024 16:29:21 +0300 Subject: [PATCH 40/83] NOISSUE - Add agent pkg tests (#271) * add agent tests Signed-off-by: Sammy Oina * fix lint Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/algorithm/binary/binary_test.go | 100 ++++++++++++++ agent/algorithm/docker/docker_test.go | 29 ++++ agent/algorithm/python/python.go | 7 + agent/algorithm/python/python_test.go | 148 ++++++++++++++++++++ agent/algorithm/wasm/wasm_test.go | 89 ++++++++++++ agent/api/grpc/endpoint_test.go | 173 ++++++++++++++++++++++++ agent/api/grpc/server_test.go | 187 ++++++++++++++++++++++++++ agent/computations_test.go | 133 ++++++++++++++++++ agent/events/events_test.go | 82 +++++++++++ 9 files changed, 948 insertions(+) create mode 100644 agent/algorithm/binary/binary_test.go create mode 100644 agent/algorithm/docker/docker_test.go create mode 100644 agent/algorithm/python/python_test.go create mode 100644 agent/algorithm/wasm/wasm_test.go create mode 100644 agent/api/grpc/endpoint_test.go create mode 100644 agent/api/grpc/server_test.go create mode 100644 agent/computations_test.go create mode 100644 agent/events/events_test.go diff --git a/agent/algorithm/binary/binary_test.go b/agent/algorithm/binary/binary_test.go new file mode 100644 index 000000000..e1e793cfa --- /dev/null +++ b/agent/algorithm/binary/binary_test.go @@ -0,0 +1,100 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package binary + +import ( + "bytes" + "log/slog" + "os" + "testing" + + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/events/mocks" +) + +func TestNewAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventsSvc := new(mocks.Service) + algoFile := "/path/to/algo" + args := []string{"arg1", "arg2"} + + algo := NewAlgorithm(logger, eventsSvc, algoFile, args) + + b, ok := algo.(*binary) + if !ok { + t.Fatalf("NewAlgorithm did not return a *binary") + } + + if b.algoFile != algoFile { + t.Errorf("Expected algoFile to be %s, got %s", algoFile, b.algoFile) + } + + if len(b.args) != len(args) { + t.Errorf("Expected %d args, got %d", len(args), len(b.args)) + } + + for i, arg := range args { + if b.args[i] != arg { + t.Errorf("Expected arg %d to be %s, got %s", i, arg, b.args[i]) + } + } + + if _, ok := b.stderr.(*algorithm.Stderr); !ok { + t.Errorf("Expected stderr to be *algorithm.Stderr") + } + + if _, ok := b.stdout.(*algorithm.Stdout); !ok { + t.Errorf("Expected stdout to be *algorithm.Stdout") + } +} + +func TestBinaryRun(t *testing.T) { + tests := []struct { + name string + algoFile string + args []string + expectedError bool + }{ + { + name: "Successful execution", + algoFile: "echo", + args: []string{"Hello, World!"}, + expectedError: false, + }, + { + name: "Non-existent binary", + algoFile: "non_existent_binary", + args: []string{}, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventsSvc := new(mocks.Service) + + b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args).(*binary) + + var stdout, stderr bytes.Buffer + b.stdout = &stdout + b.stderr = &stderr + + err := b.Run() + + if tt.expectedError && err == nil { + t.Errorf("Expected an error, but got none") + } + + if !tt.expectedError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !tt.expectedError { + if stdout.Len() == 0 { + t.Errorf("Expected non-empty stdout") + } + } + }) + } +} diff --git a/agent/algorithm/docker/docker_test.go b/agent/algorithm/docker/docker_test.go new file mode 100644 index 000000000..4b89421c5 --- /dev/null +++ b/agent/algorithm/docker/docker_test.go @@ -0,0 +1,29 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package docker + +import ( + "log/slog" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/events/mocks" +) + +// TestNewAlgorithm tests the NewAlgorithm function. +func TestNewAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventsSvc := new(mocks.Service) + algoFile := "/path/to/algo.tar" + + algo := NewAlgorithm(logger, eventsSvc, algoFile) + + d, ok := algo.(*docker) + assert.True(t, ok, "NewAlgorithm should return a *docker") + assert.Equal(t, algoFile, d.algoFile, "algoFile should be set correctly") + assert.NotNil(t, d.logger, "logger should be set") + assert.IsType(t, &algorithm.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr") + assert.IsType(t, &algorithm.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout") +} diff --git a/agent/algorithm/python/python.go b/agent/algorithm/python/python.go index 206133354..a5c29495d 100644 --- a/agent/algorithm/python/python.go +++ b/agent/algorithm/python/python.go @@ -67,6 +67,13 @@ func (p *python) Run() error { pythonPath := filepath.Join(venvPath, "bin", "python") + updatePipCmd := exec.Command(pythonPath, "-m", "pip", "install", "--upgrade", "pip") + updatePipCmd.Stderr = p.stderr + updatePipCmd.Stdout = p.stdout + if err := updatePipCmd.Run(); err != nil { + return fmt.Errorf("error updating pip: %v", err) + } + if p.requirementsFile != "" { rcmd := exec.Command(pythonPath, "-m", "pip", "install", "-r", p.requirementsFile) rcmd.Stderr = p.stderr diff --git a/agent/algorithm/python/python_test.go b/agent/algorithm/python/python_test.go new file mode 100644 index 000000000..9db3ce2bc --- /dev/null +++ b/agent/algorithm/python/python_test.go @@ -0,0 +1,148 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package python + +import ( + "bytes" + "context" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/events/mocks" + "google.golang.org/grpc/metadata" +) + +const runtime = "python3" + +func TestPythonRunTimeToContext(t *testing.T) { + ctx := context.Background() + newCtx := PythonRunTimeToContext(ctx, runtime) + + md, ok := metadata.FromOutgoingContext(newCtx) + if !ok { + t.Fatal("Expected metadata in context") + } + + values := md.Get(PyRuntimeKey) + if len(values) != 1 || values[0] != runtime { + t.Errorf("Expected runtime %s, got %v", runtime, values) + } +} + +func TestPythonRunTimeFromContext(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(PyRuntimeKey, runtime)) + + got := PythonRunTimeFromContext(ctx) + if got != runtime { + t.Errorf("Expected runtime %s, got %s", runtime, got) + } +} + +func TestNewAlgorithm(t *testing.T) { + logger := &slog.Logger{} + eventsSvc := new(mocks.Service) + requirementsFile := "requirements.txt" + algoFile := "algorithm.py" + args := []string{"--arg1", "value1"} + + algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args) + + p, ok := algo.(*python) + if !ok { + t.Fatal("Expected *python type") + } + + if p.runtime != runtime { + t.Errorf("Expected runtime %s, got %s", runtime, p.runtime) + } + if p.requirementsFile != requirementsFile { + t.Errorf("Expected requirementsFile %s, got %s", requirementsFile, p.requirementsFile) + } + if p.algoFile != algoFile { + t.Errorf("Expected algoFile %s, got %s", algoFile, p.algoFile) + } + if len(p.args) != len(args) { + t.Errorf("Expected %d args, got %d", len(args), len(p.args)) + } +} + +func TestRun(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "python-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + scriptContent := []byte("print('Hello, World!')") + scriptPath := filepath.Join(tmpDir, "test_script.py") + if err := os.WriteFile(scriptPath, scriptContent, 0o644); err != nil { + t.Fatal(err) + } + + eventsSvc := new(mocks.Service) + + var stdout, stderr bytes.Buffer + + algo := &python{ + algoFile: scriptPath, + stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}), + stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}), + runtime: "python3", + } + + err = algo.Run() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expectedOutput := "Hello, World!\n" + if !strings.Contains(stdout.String(), expectedOutput) { + t.Errorf("Expected output to contain %q, got %q", expectedOutput, stdout.String()) + } +} + +func TestRunWithRequirements(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "python-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + scriptContent := []byte("import requests\nprint(requests.__version__)") + scriptPath := filepath.Join(tmpDir, "test_script.py") + if err := os.WriteFile(scriptPath, scriptContent, 0o644); err != nil { + t.Fatal(err) + } + + requirementsContent := []byte("requests==2.26.0") + requirementsPath := filepath.Join(tmpDir, "requirements.txt") + if err := os.WriteFile(requirementsPath, requirementsContent, 0o644); err != nil { + t.Fatal(err) + } + + eventsSvc := new(mocks.Service) + + var stdout, stderr bytes.Buffer + + algo := &python{ + algoFile: scriptPath, + requirementsFile: requirementsPath, + stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}), + stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}), + runtime: "python3", + } + + err = algo.Run() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !strings.Contains(stdout.String(), "2.26.0") { + t.Errorf("Expected output to contain requests version 2.26.0, got %q", stdout.String()) + } +} diff --git a/agent/algorithm/wasm/wasm_test.go b/agent/algorithm/wasm/wasm_test.go new file mode 100644 index 000000000..d89927261 --- /dev/null +++ b/agent/algorithm/wasm/wasm_test.go @@ -0,0 +1,89 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package wasm + +import ( + "log/slog" + "os" + "os/exec" + "testing" + + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/events/mocks" +) + +func TestNewAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventsSvc := new(mocks.Service) + algoFile := "test.wasm" + args := []string{"arg1", "arg2"} + + algo := NewAlgorithm(logger, eventsSvc, algoFile, args) + + w, ok := algo.(*wasm) + if !ok { + t.Fatalf("NewAlgorithm did not return a *wasm") + } + + if w.algoFile != algoFile { + t.Errorf("Expected algoFile to be %s, got %s", algoFile, w.algoFile) + } + + if len(w.args) != len(args) { + t.Errorf("Expected %d args, got %d", len(args), len(w.args)) + } + + _, ok = w.stderr.(*algorithm.Stderr) + if !ok { + t.Errorf("Expected stderr to be *algorithm.Stderr") + } + + _, ok = w.stdout.(*algorithm.Stdout) + if !ok { + t.Errorf("Expected stdout to be *algorithm.Stdout") + } +} + +func TestRunError(t *testing.T) { + // Mock exec.Command to return an error + execCommand = mockExecCommandError + defer func() { execCommand = exec.Command }() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventsSvc := new(mocks.Service) + algoFile := "test.wasm" + args := []string{"arg1", "arg2"} + + w := NewAlgorithm(logger, eventsSvc, algoFile, args).(*wasm) + + err := w.Run() + if err == nil { + t.Errorf("Run() should have returned an error") + } +} + +func mockExecCommand(command string, args ...string) *exec.Cmd { + cs := []string{"-test.run=TestHelperProcess", "--", command} + cs = append(cs, args...) + cmd := exec.Command(os.Args[0], cs...) + cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} + return cmd +} + +func mockExecCommandError(command string, args ...string) *exec.Cmd { + cmd := mockExecCommand(command, args...) + cmd.Env = append(cmd.Env, "GO_WANT_HELPER_PROCESS_ERROR=1") + return cmd +} + +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + if os.Getenv("GO_WANT_HELPER_PROCESS_ERROR") == "1" { + os.Exit(1) + } + os.Exit(0) +} + +var execCommand = exec.Command diff --git a/agent/api/grpc/endpoint_test.go b/agent/api/grpc/endpoint_test.go new file mode 100644 index 000000000..4ab3494de --- /dev/null +++ b/agent/api/grpc/endpoint_test.go @@ -0,0 +1,173 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package grpc + +import ( + "context" + "errors" + "testing" + + "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/agent/mocks" + "golang.org/x/crypto/sha3" +) + +const svcErr = "Service Error" + +func TestAlgoEndpoint(t *testing.T) { + svc := new(mocks.Service) + tests := []struct { + name string + req algoReq + expectedErr bool + }{ + { + name: "Success", + req: algoReq{Algorithm: []byte("algorithm")}, + }, + { + name: "Validation Error", + req: algoReq{}, + expectedErr: true, + }, + { + name: "Service Error", + req: algoReq{Algorithm: []byte("algorithm")}, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == svcErr { + svc.On("Algo", context.Background(), agent.Algorithm{Algorithm: tt.req.Algorithm}).Return(errors.New("")).Once() + } else { + svc.On("Algo", context.Background(), agent.Algorithm{Algorithm: tt.req.Algorithm}).Return(nil).Once() + } + endpoint := algoEndpoint(svc) + _, err := endpoint(context.Background(), tt.req) + if (err != nil) != tt.expectedErr { + t.Errorf("algoEndpoint() error = %v, expectedErr %v", err, tt.expectedErr) + } + }) + } +} + +func TestDataEndpoint(t *testing.T) { + svc := new(mocks.Service) + tests := []struct { + name string + req dataReq + expectedErr bool + }{ + { + name: "Success", + req: dataReq{Dataset: []byte("dataset")}, + }, + { + name: "Validation Error", + req: dataReq{}, + expectedErr: true, + }, + { + name: "Service Error", + req: dataReq{Dataset: []byte("dataset")}, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == svcErr { + svc.On("Data", context.Background(), agent.Dataset{Dataset: tt.req.Dataset}).Return(errors.New("")).Once() + } else { + svc.On("Data", context.Background(), agent.Dataset{Dataset: tt.req.Dataset}).Return(nil).Once() + } + endpoint := dataEndpoint(svc) + _, err := endpoint(context.Background(), tt.req) + if (err != nil) != tt.expectedErr { + t.Errorf("dataEndpoint() error = %v, expectedErr %v", err, tt.expectedErr) + } + }) + } +} + +func TestResultEndpoint(t *testing.T) { + svc := new(mocks.Service) + tests := []struct { + name string + req resultReq + expectedErr bool + }{ + { + name: "Success", + req: resultReq{}, + }, + { + name: "Service Error", + req: resultReq{}, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == svcErr { + svc.On("Result", context.Background()).Return([]byte{}, errors.New("")).Once() + } else { + svc.On("Result", context.Background()).Return([]byte{}, nil).Once() + } + endpoint := resultEndpoint(svc) + res, err := endpoint(context.Background(), tt.req) + if (err != nil) != tt.expectedErr { + t.Errorf("resultEndpoint() error = %v, expectedErr %v", err, tt.expectedErr) + } + if err == nil { + _, ok := res.(resultRes) + if !ok { + t.Errorf("resultEndpoint() returned unexpected type %T", res) + } + } + }) + } +} + +func TestAttestationEndpoint(t *testing.T) { + svc := new(mocks.Service) + tests := []struct { + name string + req attestationReq + expectedErr bool + }{ + { + name: "Success", + req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))}, + }, + { + name: "Service Error", + req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))}, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == svcErr { + svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, errors.New("")).Once() + } else { + svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, nil).Once() + } + endpoint := attestationEndpoint(svc) + res, err := endpoint(context.Background(), tt.req) + if (err != nil) != tt.expectedErr { + t.Errorf("attestationEndpoint() error = %v, expectedErr %v", err, tt.expectedErr) + } + if err == nil { + _, ok := res.(attestationRes) + if !ok { + t.Errorf("attestationEndpoint() returned unexpected type %T", res) + } + } + }) + } +} diff --git a/agent/api/grpc/server_test.go b/agent/api/grpc/server_test.go new file mode 100644 index 000000000..4f6398a6c --- /dev/null +++ b/agent/api/grpc/server_test.go @@ -0,0 +1,187 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package grpc + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/agent/mocks" + "google.golang.org/grpc" +) + +type MockAgentService_AlgoServer struct { + grpc.ServerStream + mock.Mock + ctx context.Context +} + +func (m *MockAgentService_AlgoServer) Context() context.Context { + return m.ctx +} + +func (m *MockAgentService_AlgoServer) Recv() (*agent.AlgoRequest, error) { + args := m.Called() + return args.Get(0).(*agent.AlgoRequest), args.Error(1) +} + +func (m *MockAgentService_AlgoServer) SendAndClose(resp *agent.AlgoResponse) error { + args := m.Called(resp) + return args.Error(0) +} + +type MockAgentService_DataServer struct { + grpc.ServerStream + mock.Mock + ctx context.Context +} + +func (m *MockAgentService_DataServer) Context() context.Context { + return m.ctx +} + +func (m *MockAgentService_DataServer) Recv() (*agent.DataRequest, error) { + args := m.Called() + return args.Get(0).(*agent.DataRequest), args.Error(1) +} + +func (m *MockAgentService_DataServer) SendAndClose(resp *agent.DataResponse) error { + args := m.Called(resp) + return args.Error(0) +} + +type MockAgentService_ResultServer struct { + grpc.ServerStream + mock.Mock + ctx context.Context +} + +func (m *MockAgentService_ResultServer) Context() context.Context { + return m.ctx +} + +func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error { + args := m.Called(resp) + return args.Error(0) +} + +func TestAlgo(t *testing.T) { + mockService := new(mocks.Service) + server := NewServer(mockService) + + mockStream := &MockAgentService_AlgoServer{ctx: context.Background()} + mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once() + mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF) + mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil) + + mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo"), Requirements: []byte("req")}).Return(nil) + + err := server.Algo(mockStream) + assert.NoError(t, err) + + mockStream.AssertExpectations(t) + mockService.AssertExpectations(t) +} + +func TestData(t *testing.T) { + mockService := new(mocks.Service) + server := NewServer(mockService) + + mockStream := &MockAgentService_DataServer{ctx: context.Background()} + mockStream.On("Recv").Return(&agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}, nil).Once() + mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF) + mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil) + + mockService.On("Data", context.Background(), agent.Dataset{Dataset: []byte("data"), Filename: "test.txt"}).Return(nil) + + err := server.Data(mockStream) + assert.NoError(t, err) + + mockStream.AssertExpectations(t) + mockService.AssertExpectations(t) +} + +func TestResult(t *testing.T) { + mockService := new(mocks.Service) + server := NewServer(mockService) + + mockStream := &MockAgentService_ResultServer{ctx: context.Background()} + mockService.On("Result", mock.Anything).Return([]byte("result data"), nil) + mockStream.On("Send", mock.AnythingOfType("*agent.ResultResponse")).Return(nil) + + err := server.Result(&agent.ResultRequest{}, mockStream) + assert.NoError(t, err) + + mockStream.AssertExpectations(t) + mockService.AssertExpectations(t) +} + +func TestAttestation(t *testing.T) { + mockService := new(mocks.Service) + server := NewServer(mockService) + + reportData := [agent.ReportDataSize]byte{} + mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil) + + resp, err := server.Attestation(context.Background(), &agent.AttestationRequest{ReportData: reportData[:]}) + assert.NoError(t, err) + assert.Equal(t, []byte("attestation data"), resp.File) + + mockService.AssertExpectations(t) +} + +func TestDecodeAlgoRequest(t *testing.T) { + req := &agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")} + decoded, err := decodeAlgoRequest(context.Background(), req) + assert.NoError(t, err) + assert.Equal(t, algoReq{Algorithm: []byte("algo"), Requirements: []byte("req")}, decoded) +} + +func TestEncodeAlgoResponse(t *testing.T) { + encoded, err := encodeAlgoResponse(context.Background(), algoRes{}) + assert.NoError(t, err) + assert.Equal(t, &agent.AlgoResponse{}, encoded) +} + +func TestDecodeDataRequest(t *testing.T) { + req := &agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"} + decoded, err := decodeDataRequest(context.Background(), req) + assert.NoError(t, err) + assert.Equal(t, dataReq{Dataset: []byte("data"), Filename: "test.txt"}, decoded) +} + +func TestEncodeDataResponse(t *testing.T) { + encoded, err := encodeDataResponse(context.Background(), dataRes{}) + assert.NoError(t, err) + assert.Equal(t, &agent.DataResponse{}, encoded) +} + +func TestDecodeResultRequest(t *testing.T) { + decoded, err := decodeResultRequest(context.Background(), &agent.ResultRequest{}) + assert.NoError(t, err) + assert.Equal(t, resultReq{}, decoded) +} + +func TestEncodeResultResponse(t *testing.T) { + encoded, err := encodeResultResponse(context.Background(), resultRes{File: []byte("result")}) + assert.NoError(t, err) + assert.Equal(t, &agent.ResultResponse{File: []byte("result")}, encoded) +} + +func TestDecodeAttestationRequest(t *testing.T) { + reportData := [agent.ReportDataSize]byte{} + req := &agent.AttestationRequest{ReportData: reportData[:]} + decoded, err := decodeAttestationRequest(context.Background(), req) + assert.NoError(t, err) + assert.Equal(t, attestationReq{ReportData: reportData}, decoded) +} + +func TestEncodeAttestationResponse(t *testing.T) { + encoded, err := encodeAttestationResponse(context.Background(), attestationRes{File: []byte("attestation")}) + assert.NoError(t, err) + assert.Equal(t, &agent.AttestationResponse{File: []byte("attestation")}, encoded) +} diff --git a/agent/computations_test.go b/agent/computations_test.go new file mode 100644 index 000000000..f2edf8e91 --- /dev/null +++ b/agent/computations_test.go @@ -0,0 +1,133 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package agent + +import ( + "context" + "encoding/json" + "reflect" + "testing" + + "google.golang.org/grpc/metadata" +) + +func TestDatasetsString(t *testing.T) { + datasets := Datasets{ + { + Hash: [32]byte{1, 2, 3}, + UserKey: []byte("user_key"), + Filename: "test.dat", + }, + } + + expected := `[{"hash":[1,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"user_key":"dXNlcl9rZXk=","filename":"test.dat"}]` + result := datasets.String() + + if result != expected { + t.Errorf("Datasets.String() = %v, want %v", result, expected) + } +} + +func TestIndexToContext(t *testing.T) { + ctx := context.Background() + index := 5 + + newCtx := IndexToContext(ctx, index) + result, ok := IndexFromContext(newCtx) + + if !ok { + t.Errorf("IndexFromContext() ok = false, want true") + } + + if result != index { + t.Errorf("IndexFromContext() = %v, want %v", result, index) + } +} + +func TestDecompressFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + expected bool + }{ + { + name: "No decompress metadata", + ctx: context.Background(), + expected: false, + }, + { + name: "Decompress true", + ctx: metadata.NewIncomingContext( + context.Background(), + metadata.Pairs(DecompressKey, "true"), + ), + expected: true, + }, + { + name: "Decompress false", + ctx: metadata.NewIncomingContext( + context.Background(), + metadata.Pairs(DecompressKey, "false"), + ), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DecompressFromContext(tt.ctx) + if result != tt.expected { + t.Errorf("DecompressFromContext() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestDecompressToContext(t *testing.T) { + ctx := context.Background() + decompress := true + + newCtx := DecompressToContext(ctx, decompress) + md, ok := metadata.FromOutgoingContext(newCtx) + + if !ok { + t.Errorf("metadata.FromOutgoingContext() ok = false, want true") + } + + vals := md.Get(DecompressKey) + if len(vals) != 1 { + t.Errorf("len(md.Get(DecompressKey)) = %v, want 1", len(vals)) + } + + if vals[0] != "true" { + t.Errorf("md.Get(DecompressKey)[0] = %v, want 'true'", vals[0]) + } +} + +func TestAgentConfigJSON(t *testing.T) { + config := AgentConfig{ + LogLevel: "info", + Host: "localhost", + Port: "8080", + CertFile: "cert.pem", + KeyFile: "key.pem", + ServerCAFile: "server_ca.pem", + ClientCAFile: "client_ca.pem", + AttestedTls: true, + } + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("Failed to marshal AgentConfig: %v", err) + } + + var unmarshaledConfig AgentConfig + err = json.Unmarshal(data, &unmarshaledConfig) + if err != nil { + t.Fatalf("Failed to unmarshal AgentConfig: %v", err) + } + + if !reflect.DeepEqual(config, unmarshaledConfig) { + t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config) + } +} diff --git a/agent/events/events_test.go b/agent/events/events_test.go new file mode 100644 index 000000000..f2f2d1b4f --- /dev/null +++ b/agent/events/events_test.go @@ -0,0 +1,82 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package events + +import ( + "bytes" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/pkg/manager" + "google.golang.org/protobuf/proto" +) + +type mockConn struct { + writeErr error + buf bytes.Buffer +} + +func (m *mockConn) Write(p []byte) (n int, err error) { + if m.writeErr != nil { + return 0, m.writeErr + } + return m.buf.Write(p) +} + +func TestSendEventSuccess(t *testing.T) { + mockConnection := &mockConn{} + + svc, err := New("test_service", "12345", mockConnection) + assert.NoError(t, err) + + details := json.RawMessage(`{"key": "value"}`) + + err = svc.SendEvent("test_event", "success", details) + assert.NoError(t, err) + + var writtenMessage manager.ClientStreamMessage + err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage) + assert.NoError(t, err) + + assert.Equal(t, "test_event", writtenMessage.GetAgentEvent().EventType) + assert.Equal(t, "12345", writtenMessage.GetAgentEvent().ComputationId) + assert.Equal(t, "test_service", writtenMessage.GetAgentEvent().Originator) + assert.Equal(t, "success", writtenMessage.GetAgentEvent().Status) + + now := time.Now() + eventTimestamp := writtenMessage.GetAgentEvent().GetTimestamp().AsTime() + assert.WithinDuration(t, now, eventTimestamp, 1*time.Second) +} + +func TestSendEventFailure(t *testing.T) { + mockConnection := &mockConn{writeErr: errors.New("write error")} + + svc, err := New("test_service", "12345", mockConnection) + assert.NoError(t, err) + + details := json.RawMessage(`{"key": "value"}`) + + err = svc.SendEvent("test_event", "failure", details) + assert.Error(t, err) + assert.Equal(t, "write error", err.Error()) + + assert.Len(t, svc.(*service).cachedMessages, 1) +} + +func TestClose(t *testing.T) { + mockConnection := &mockConn{} + + svc, err := New("test_service", "12345", mockConnection) + assert.NoError(t, err) + + svc.Close() + + time.Sleep(1 * time.Second) + + details := json.RawMessage(`{"key": "value"}`) + err = svc.SendEvent("test_event", "success", details) + assert.NoError(t, err) +} From 5e01ecdab752cb6cb3c103e1db51719ba7f4e622 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 8 Oct 2024 16:35:17 +0300 Subject: [PATCH 41/83] add manager tests (#273) Signed-off-by: Sammy Oina --- manager/agentEventsLogs_test.go | 168 ++++++++++++++++++++++ manager/api/grpc/client_test.go | 181 ++++++++++++++++++++++++ manager/api/grpc/server.go | 27 ++-- manager/api/grpc/server_test.go | 232 +++++++++++++++++++++++++++++++ manager/mocks/service.go | 114 +++++++++++++++ manager/qemu/config_test.go | 232 +++++++++++++++++++++++++++++++ manager/qemu/persistence_test.go | 144 +++++++++++++++++++ manager/qemu/vm_test.go | 111 +++++++++++++-- manager/service.go | 2 + manager/service_test.go | 111 +++++++++++++++ 10 files changed, 1301 insertions(+), 21 deletions(-) create mode 100644 manager/agentEventsLogs_test.go create mode 100644 manager/api/grpc/client_test.go create mode 100644 manager/api/grpc/server_test.go create mode 100644 manager/mocks/service.go create mode 100644 manager/qemu/config_test.go create mode 100644 manager/qemu/persistence_test.go diff --git a/manager/agentEventsLogs_test.go b/manager/agentEventsLogs_test.go new file mode 100644 index 000000000..dac0f82f4 --- /dev/null +++ b/manager/agentEventsLogs_test.go @@ -0,0 +1,168 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package manager + +import ( + "net" + "testing" + "time" + + mglog "github.com/absmach/magistrala/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/manager/qemu" + "github.com/ultravioletrs/cocos/manager/vm" + "github.com/ultravioletrs/cocos/pkg/manager" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type MockConn struct { + mock.Mock +} + +func (m *MockConn) Read(b []byte) (n int, err error) { + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockConn) Write(b []byte) (n int, err error) { + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockConn) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockConn) LocalAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockConn) RemoteAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockConn) SetDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) SetReadDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) SetWriteDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +type MockAddr struct { + mock.Mock +} + +func (m *MockAddr) Network() string { + args := m.Called() + return args.String(0) +} + +func (m *MockAddr) String() string { + args := m.Called() + return args.String(0) +} + +func TestComputationIDFromAddress(t *testing.T) { + ms := &managerService{ + vms: map[string]vm.VM{ + "comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"), + "comp2": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}, make(chan *manager.ClientStreamMessage), "comp2"), + }, + } + + tests := []struct { + name string + address string + want string + wantErr bool + }{ + {"Valid address", "vm(3)", "comp1", false}, + {"Invalid address", "invalid", "", true}, + {"Non-existent CID", "vm(10)", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ms.computationIDFromAddress(tt.address) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestHandleConnection(t *testing.T) { + ms := &managerService{ + vms: map[string]vm.VM{ + "comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"), + }, + eventsChan: make(chan *manager.ClientStreamMessage, 1), + logger: mglog.NewMock(), + } + + mockConn := new(MockConn) + mockAddr := new(MockAddr) + mockConn.On("RemoteAddr").Return(mockAddr) + mockConn.On("Close").Return(nil) + mockAddr.On("String").Return("vm(3)") + + msg := &manager.ClientStreamMessage{ + Message: &manager.ClientStreamMessage_AgentEvent{ + AgentEvent: &manager.AgentEvent{ + EventType: manager.VmRunning.String(), + ComputationId: "comp1", + Status: manager.VmRunning.String(), + Timestamp: timestamppb.Now(), + Originator: "agent", + }, + }, + } + msgBytes, _ := proto.Marshal(msg) + + mockConn.On("Read", mock.Anything).Return(len(msgBytes), nil).Run(func(args mock.Arguments) { + copy(args.Get(0).([]byte), msgBytes) + }).Once() + + mockConn.On("Read", mock.Anything).Return(0, net.ErrClosed) + + go ms.handleConnection(mockConn) + + receivedMsg := <-ms.eventsChan + assert.Equal(t, msg.GetAgentEvent().EventType, receivedMsg.GetAgentEvent().EventType) + assert.Equal(t, msg.GetAgentEvent().ComputationId, receivedMsg.GetAgentEvent().ComputationId) + + mockConn.AssertExpectations(t) +} + +func TestReportBrokenConnection(t *testing.T) { + ms := &managerService{ + eventsChan: make(chan *manager.ClientStreamMessage, 1), + } + + ms.reportBrokenConnection("comp1") + + select { + case msg := <-ms.eventsChan: + assert.Equal(t, "comp1", msg.GetAgentEvent().ComputationId) + assert.Equal(t, manager.Disconnected.String(), msg.GetAgentEvent().Status) + assert.Equal(t, "manager", msg.GetAgentEvent().Originator) + default: + t.Error("Expected message in eventsChan, but none received") + } +} diff --git a/manager/api/grpc/client_test.go b/manager/api/grpc/client_test.go new file mode 100644 index 000000000..353f2d21c --- /dev/null +++ b/manager/api/grpc/client_test.go @@ -0,0 +1,181 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package grpc + +import ( + "context" + "testing" + "time" + + mglog "github.com/absmach/magistrala/logger" + "github.com/absmach/magistrala/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/manager/mocks" + pkgmanager "github.com/ultravioletrs/cocos/pkg/manager" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +type mockStream struct { + mock.Mock + grpc.ClientStream +} + +func (m *mockStream) Recv() (*pkgmanager.ServerStreamMessage, error) { + args := m.Called() + return args.Get(0).(*pkgmanager.ServerStreamMessage), args.Error(1) +} + +func (m *mockStream) Send(msg *pkgmanager.ClientStreamMessage) error { + args := m.Called(msg) + return args.Error(0) +} + +func TestManagerClient_Process(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10) + logger := mglog.NewMock() + + client := NewClient(mockStream, mockSvc, messageQueue, logger) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + mockStream.On("Recv").Return(&pkgmanager.ServerStreamMessage{Message: &pkgmanager.ServerStreamMessage_StopComputation{StopComputation: &pkgmanager.StopComputation{}}}, nil).Maybe() + mockStream.On("Send", mock.Anything).Return(nil).Maybe() + + mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil).Maybe() + + err := client.Process(ctx, cancel) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestManagerClient_handleRunReqChunks(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10) + logger := mglog.NewMock() + + client := NewClient(mockStream, mockSvc, messageQueue, logger) + + runReq := &pkgmanager.ComputationRunReq{ + Id: "test-id", + } + runReqBytes, _ := proto.Marshal(runReq) + + chunk1 := &pkgmanager.ServerStreamMessage_RunReqChunks{ + RunReqChunks: &pkgmanager.RunReqChunks{ + Id: "chunk-1", + Data: runReqBytes[:len(runReqBytes)/2], + IsLast: false, + }, + } + chunk2 := &pkgmanager.ServerStreamMessage_RunReqChunks{ + RunReqChunks: &pkgmanager.RunReqChunks{ + Id: "chunk-1", + Data: runReqBytes[len(runReqBytes)/2:], + IsLast: true, + }, + } + + mockSvc.On("Run", mock.Anything, mock.AnythingOfType("*manager.ComputationRunReq")).Return("8080", nil) + + err := client.handleRunReqChunks(context.Background(), chunk1) + assert.NoError(t, err) + + err = client.handleRunReqChunks(context.Background(), chunk2) + assert.NoError(t, err) + + // Wait for the goroutine to finish + time.Sleep(50 * time.Millisecond) + + mockSvc.AssertExpectations(t) + assert.Len(t, messageQueue, 1) + + msg := <-messageQueue + runRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_RunRes) + assert.True(t, ok) + assert.Equal(t, "8080", runRes.RunRes.AgentPort) + assert.Equal(t, "test-id", runRes.RunRes.ComputationId) +} + +func TestManagerClient_handleTerminateReq(t *testing.T) { + client := ManagerClient{} + + terminateReq := &pkgmanager.ServerStreamMessage_TerminateReq{ + TerminateReq: &pkgmanager.Terminate{ + Message: "Test termination", + }, + } + + err := client.handleTerminateReq(terminateReq) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Test termination") + assert.True(t, errors.Contains(err, errTerminationFromServer)) +} + +func TestManagerClient_handleStopComputation(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10) + logger := mglog.NewMock() + + client := NewClient(mockStream, mockSvc, messageQueue, logger) + + stopReq := &pkgmanager.ServerStreamMessage_StopComputation{ + StopComputation: &pkgmanager.StopComputation{ + ComputationId: "test-comp-id", + }, + } + + mockSvc.On("Stop", mock.Anything, "test-comp-id").Return(nil) + + client.handleStopComputation(context.Background(), stopReq) + + // Wait for the goroutine to finish + time.Sleep(50 * time.Millisecond) + + mockSvc.AssertExpectations(t) + assert.Len(t, messageQueue, 1) + + msg := <-messageQueue + stopRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_StopComputationRes) + assert.True(t, ok) + assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId) + assert.Empty(t, stopRes.StopComputationRes.Message) +} + +func TestManagerClient_handleBackendInfoReq(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10) + logger := mglog.NewMock() + + client := NewClient(mockStream, mockSvc, messageQueue, logger) + + infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{ + BackendInfoReq: &pkgmanager.BackendInfoReq{ + Id: "test-info-id", + }, + } + + mockSvc.On("FetchBackendInfo").Return([]byte("test-backend-info"), nil) + + client.handleBackendInfoReq(infoReq) + + // Wait for the goroutine to finish + time.Sleep(50 * time.Millisecond) + + mockSvc.AssertExpectations(t) + assert.Len(t, messageQueue, 1) + + msg := <-messageQueue + infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo) + assert.True(t, ok) + assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id) + assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info) +} diff --git a/manager/api/grpc/server.go b/manager/api/grpc/server.go index 9e17c1e6b..d57538b6d 100644 --- a/manager/api/grpc/server.go +++ b/manager/api/grpc/server.go @@ -56,22 +56,31 @@ func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error eg.Go(func() error { for { - req, err := stream.Recv() - if err != nil { - return err + select { + case <-ctx.Done(): + return ctx.Err() + default: + req, err := stream.Recv() + if err != nil { + return err + } + s.incoming <- req } - - s.incoming <- req } }) eg.Go(func() error { sendMessage := func(msg *manager.ServerStreamMessage) error { - switch m := msg.Message.(type) { - case *manager.ServerStreamMessage_RunReq: - return s.sendRunReqInChunks(stream, m.RunReq) + select { + case <-ctx.Done(): + return ctx.Err() default: - return stream.Send(msg) + switch m := msg.Message.(type) { + case *manager.ServerStreamMessage_RunReq: + return s.sendRunReqInChunks(stream, m.RunReq) + default: + return stream.Send(msg) + } } } diff --git a/manager/api/grpc/server_test.go b/manager/api/grpc/server_test.go new file mode 100644 index 000000000..6da977f9d --- /dev/null +++ b/manager/api/grpc/server_test.go @@ -0,0 +1,232 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package grpc + +import ( + "context" + "testing" + "time" + + "github.com/absmach/magistrala/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/pkg/manager" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +type mockServerStream struct { + mock.Mock + manager.ManagerService_ProcessServer +} + +func (m *mockServerStream) Send(msg *manager.ServerStreamMessage) error { + args := m.Called(msg) + return args.Error(0) +} + +func (m *mockServerStream) Recv() (*manager.ClientStreamMessage, error) { + args := m.Called() + return args.Get(0).(*manager.ClientStreamMessage), args.Error(1) +} + +func (m *mockServerStream) Context() context.Context { + args := m.Called() + return args.Get(0).(context.Context) +} + +type mockService struct { + mock.Mock +} + +func (m *mockService) Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) { + m.Called(ctx, ipAddress, sendMessage, authInfo) +} + +func TestNewServer(t *testing.T) { + incoming := make(chan *manager.ClientStreamMessage) + mockSvc := new(mockService) + + server := NewServer(incoming, mockSvc) + + assert.NotNil(t, server) + assert.IsType(t, &grpcServer{}, server) +} + +func TestGrpcServer_Process(t *testing.T) { + incoming := make(chan *manager.ClientStreamMessage, 1) + mockSvc := new(mockService) + server := NewServer(incoming, mockSvc).(*grpcServer) + + mockStream := new(mockServerStream) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{ + Addr: mockAddr{}, + AuthInfo: mockAuthInfo{}, + })) + + go func() { + for mes := range incoming { + assert.NotNil(t, mes) + } + }() + + mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil) + mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return() + + err := server.Process(mockStream) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + mockStream.AssertExpectations(t) + mockSvc.AssertExpectations(t) +} + +func TestGrpcServer_sendRunReqInChunks(t *testing.T) { + incoming := make(chan *manager.ClientStreamMessage) + mockSvc := new(mockService) + server := NewServer(incoming, mockSvc).(*grpcServer) + + mockStream := new(mockServerStream) + + runReq := &manager.ComputationRunReq{ + Id: "test-id", + } + + largePayload := make([]byte, bufferSize*2) + for i := range largePayload { + largePayload[i] = byte(i % 256) + } + runReq.Algorithm = &manager.Algorithm{} + runReq.Algorithm.UserKey = largePayload + + mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil).Times(4) + + err := server.sendRunReqInChunks(mockStream, runReq) + + assert.NoError(t, err) + mockStream.AssertExpectations(t) + + calls := mockStream.Calls + assert.Equal(t, 4, len(calls)) + + for i, call := range calls { + msg := call.Arguments[0].(*manager.ServerStreamMessage) + chunk := msg.GetRunReqChunks() + + assert.NotNil(t, chunk) + assert.Equal(t, "test-id", chunk.Id) + + if i < 3 { + assert.False(t, chunk.IsLast) + } else { + assert.Equal(t, 0, len(chunk.Data)) + assert.True(t, chunk.IsLast) + } + } +} + +type mockAddr struct{} + +func (mockAddr) Network() string { return "test network" } +func (mockAddr) String() string { return "test" } + +type mockAuthInfo struct{} + +func (mockAuthInfo) AuthType() string { return "test auth" } + +func TestGrpcServer_ProcessWithMockService(t *testing.T) { + incoming := make(chan *manager.ClientStreamMessage, 10) + mockSvc := new(mockService) + server := NewServer(incoming, mockSvc).(*grpcServer) + + go func() { + for mes := range incoming { + assert.NotNil(t, mes) + } + }() + + mockStream := new(mockServerStream) + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + peerCtx := peer.NewContext(ctx, &peer.Peer{ + Addr: mockAddr{}, + AuthInfo: mockAuthInfo{}, + }) + + mockStream.On("Context").Return(peerCtx) + mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil).Maybe() + + mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")). + Run(func(args mock.Arguments) { + sendFunc := args.Get(2).(SendFunc) + // Simulate sending a RunReq + runReq := &manager.ComputationRunReq{Id: "test-run-id"} + err := sendFunc(&manager.ServerStreamMessage{ + Message: &manager.ServerStreamMessage_RunReq{ + RunReq: runReq, + }, + }) + assert.NoError(t, err) + }). + Return() + + mockStream.On("Send", mock.MatchedBy(func(msg *manager.ServerStreamMessage) bool { + chunks := msg.GetRunReqChunks() + return chunks != nil && chunks.Id == "test-run-id" + })).Return(nil) + + go func() { + time.Sleep(150 * time.Millisecond) + cancel() + }() + + err := server.Process(mockStream) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + mockStream.AssertExpectations(t) + mockSvc.AssertExpectations(t) +} + +func TestGrpcServer_sendRunReqInChunksError(t *testing.T) { + incoming := make(chan *manager.ClientStreamMessage) + mockSvc := new(mockService) + server := NewServer(incoming, mockSvc).(*grpcServer) + + mockStream := new(mockServerStream) + + runReq := &manager.ComputationRunReq{ + Id: "test-id", + } + + // Simulate an error when sending + mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(errors.New("send error")).Once() + + err := server.sendRunReqInChunks(mockStream, runReq) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "send error") + mockStream.AssertExpectations(t) +} + +func TestGrpcServer_ProcessMissingPeerInfo(t *testing.T) { + incoming := make(chan *manager.ClientStreamMessage) + mockSvc := new(mockService) + server := NewServer(incoming, mockSvc).(*grpcServer) + + mockStream := new(mockServerStream) + ctx := context.Background() + + // Return a context without peer info + mockStream.On("Context").Return(ctx) + + err := server.Process(mockStream) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get peer info") + mockStream.AssertExpectations(t) +} diff --git a/manager/mocks/service.go b/manager/mocks/service.go new file mode 100644 index 000000000..cf9b10b2a --- /dev/null +++ b/manager/mocks/service.go @@ -0,0 +1,114 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + pkgmanager "github.com/ultravioletrs/cocos/pkg/manager" +) + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +// FetchBackendInfo provides a mock function with given fields: +func (_m *Service) FetchBackendInfo() ([]byte, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for FetchBackendInfo") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrieveAgentEventsLogs provides a mock function with given fields: +func (_m *Service) RetrieveAgentEventsLogs() { + _m.Called() +} + +// Run provides a mock function with given fields: ctx, c +func (_m *Service) Run(ctx context.Context, c *pkgmanager.ComputationRunReq) (string, error) { + ret := _m.Called(ctx, c) + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) (string, error)); ok { + return rf(ctx, c) + } + if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) string); ok { + r0 = rf(ctx, c) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, *pkgmanager.ComputationRunReq) error); ok { + r1 = rf(ctx, c) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Stop provides a mock function with given fields: ctx, computationID +func (_m *Service) Stop(ctx context.Context, computationID string) error { + ret := _m.Called(ctx, computationID) + + if len(ret) == 0 { + panic("no return value specified for Stop") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, computationID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/manager/qemu/config_test.go b/manager/qemu/config_test.go new file mode 100644 index 000000000..53fbb1d29 --- /dev/null +++ b/manager/qemu/config_test.go @@ -0,0 +1,232 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package qemu + +import ( + "reflect" + "testing" +) + +func TestConstructQemuArgs(t *testing.T) { + tests := []struct { + name string + config Config + expected []string + }{ + { + name: "Default configuration", + config: Config{ + QemuBinPath: "qemu-system-x86_64", + EnableKVM: true, + Machine: "q35", + CPU: "EPYC", + SMPCount: 4, + MaxCPUs: 64, + MemID: "ram1", + MemoryConfig: MemoryConfig{ + Size: "2048M", + Slots: 5, + Max: "30G", + }, + OVMFCodeConfig: OVMFCodeConfig{ + If: "pflash", + Format: "raw", + Unit: 0, + File: "/usr/share/OVMF/OVMF_CODE.fd", + ReadOnly: "on", + }, + OVMFVarsConfig: OVMFVarsConfig{ + If: "pflash", + Format: "raw", + Unit: 1, + File: "/usr/share/OVMF/OVMF_VARS.fd", + }, + NetDevConfig: NetDevConfig{ + ID: "vmnic", + HostFwdAgent: 7020, + GuestFwdAgent: 7002, + }, + VirtioNetPciConfig: VirtioNetPciConfig{ + DisableLegacy: "on", + IOMMUPlatform: true, + Addr: "0x2", + }, + VSockConfig: VSockConfig{ + ID: "vhost-vsock-pci0", + GuestCID: 3, + }, + DiskImgConfig: DiskImgConfig{ + KernelFile: "img/bzImage", + RootFsFile: "img/rootfs.cpio.gz", + }, + NoGraphic: true, + Monitor: "pty", + }, + expected: []string{ + "-enable-kvm", + "-machine", "q35", + "-cpu", "EPYC", + "-smp", "4,maxcpus=64", + "-m", "2048M,slots=5,maxmem=30G", + "-drive", "if=pflash,format=raw,unit=0,file=/usr/share/OVMF/OVMF_CODE.fd,readonly=on", + "-drive", "if=pflash,format=raw,unit=1,file=/usr/share/OVMF/OVMF_VARS.fd", + "-netdev", "user,id=vmnic,hostfwd=tcp::7020-:7002", + "-device", "virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,addr=0x2,romfile=", + "-device", "vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3", + "-kernel", "img/bzImage", + "-append", "\"quiet console=null rootfstype=ramfs\"", + "-initrd", "img/rootfs.cpio.gz", + "-nographic", + "-monitor", "pty", + }, + }, + { + name: "SEV-SNP enabled configuration", + config: Config{ + QemuBinPath: "qemu-system-x86_64", + EnableKVM: true, + EnableSEVSNP: true, + Machine: "q35", + CPU: "EPYC", + SMPCount: 4, + MaxCPUs: 64, + MemID: "ram1", + MemoryConfig: MemoryConfig{ + Size: "2048M", + Slots: 5, + Max: "30G", + }, + OVMFCodeConfig: OVMFCodeConfig{ + If: "pflash", + Format: "raw", + Unit: 0, + File: "/usr/share/OVMF/OVMF_CODE.fd", + ReadOnly: "on", + }, + OVMFVarsConfig: OVMFVarsConfig{ + If: "pflash", + Format: "raw", + Unit: 1, + File: "/usr/share/OVMF/OVMF_VARS.fd", + }, + NetDevConfig: NetDevConfig{ + ID: "vmnic", + HostFwdAgent: 7020, + GuestFwdAgent: 7002, + }, + VirtioNetPciConfig: VirtioNetPciConfig{ + DisableLegacy: "on", + IOMMUPlatform: true, + Addr: "0x2", + }, + VSockConfig: VSockConfig{ + ID: "vhost-vsock-pci0", + GuestCID: 3, + }, + DiskImgConfig: DiskImgConfig{ + KernelFile: "img/bzImage", + RootFsFile: "img/rootfs.cpio.gz", + }, + SevConfig: SevConfig{ + ID: "sev0", + CBitPos: 51, + ReducedPhysBits: 1, + }, + NoGraphic: true, + Monitor: "pty", + }, + expected: []string{ + "-enable-kvm", + "-machine", "q35", + "-cpu", "EPYC", + "-smp", "4,maxcpus=64", + "-m", "2048M,slots=5,maxmem=30G", + "-drive", "if=pflash,format=raw,unit=0,file=/usr/share/OVMF/OVMF_CODE.fd,readonly=on", + "-drive", "if=pflash,format=raw,unit=1,file=/usr/share/OVMF/OVMF_VARS.fd", + "-netdev", "user,id=vmnic,hostfwd=tcp::7020-:7002", + "-device", "virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,addr=0x2,romfile=", + "-device", "vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3", + "-object", "memory-backend-memfd-private,id=ram1,size=2048M,share=true", + "-machine", "memory-backend=ram1,kvm-type=protected", + "-kernel", "img/bzImage", + "-append", "\"quiet console=null rootfstype=ramfs\"", + "-initrd", "img/rootfs.cpio.gz", + "-object", "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1", + "-machine", "memory-encryption=sev0", + "-nographic", + "-monitor", "pty", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.config.ConstructQemuArgs() + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("ConstructQemuArgs() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestConstructQemuArgs_KernelHash(t *testing.T) { + config := Config{ + EnableSEVSNP: true, + KernelHash: true, + SevConfig: SevConfig{ + ID: "sev0", + CBitPos: 51, + ReducedPhysBits: 1, + }, + } + + result := config.ConstructQemuArgs() + + expected := "-object" + expectedValue := "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1,discard=none,kernel-hashes=on" + + found := false + for i, arg := range result { + if arg == expected && i+1 < len(result) { + if result[i+1] == expectedValue { + found = true + break + } + } + } + + if !found { + t.Errorf("ConstructQemuArgs() did not contain expected SEV-SNP configuration with kernel hashes enabled") + } +} + +func TestConstructQemuArgs_HostData(t *testing.T) { + config := Config{ + EnableSEVSNP: true, + SevConfig: SevConfig{ + ID: "sev0", + CBitPos: 51, + ReducedPhysBits: 1, + HostData: "test-host-data", + }, + } + + result := config.ConstructQemuArgs() + + expected := "-object" + expectedValue := "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1,host-data=test-host-data" + + found := false + for i, arg := range result { + if arg == expected && i+1 < len(result) { + if result[i+1] == expectedValue { + found = true + break + } + } + } + + if !found { + t.Errorf("ConstructQemuArgs() did not contain expected SEV-SNP configuration with host data") + } +} diff --git a/manager/qemu/persistence_test.go b/manager/qemu/persistence_test.go new file mode 100644 index 000000000..4c17e1ea3 --- /dev/null +++ b/manager/qemu/persistence_test.go @@ -0,0 +1,144 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package qemu + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "testing" +) + +func TestNewFilePersistence(t *testing.T) { + tempDir := t.TempDir() + + fp, err := NewFilePersistence(tempDir) + if err != nil { + t.Fatalf("NewFilePersistence failed: %v", err) + } + + if _, ok := fp.(*FilePersistence); !ok { + t.Fatalf("NewFilePersistence didn't return a FilePersistence") + } +} + +func TestSaveVM(t *testing.T) { + tempDir := t.TempDir() + fp, _ := NewFilePersistence(tempDir) + + state := VMState{ + ID: "test-vm", + Config: Config{}, + PID: 1234, + } + + err := fp.SaveVM(state) + if err != nil { + t.Fatalf("SaveVM failed: %v", err) + } + + // Check if file exists + if _, err := os.Stat(filepath.Join(tempDir, "test-vm.json")); os.IsNotExist(err) { + t.Fatalf("SaveVM didn't create a file") + } +} + +func TestLoadVMs(t *testing.T) { + tempDir := t.TempDir() + fp, _ := NewFilePersistence(tempDir) + + // Save two VMs + states := []VMState{ + {ID: "vm1", Config: Config{}, PID: 1234}, + {ID: "vm2", Config: Config{}, PID: 5678}, + } + + for _, state := range states { + if err := fp.SaveVM(state); err != nil { + t.Fatalf("SaveVM failed: %v", err) + } + } + + // Load VMs + loadedStates, err := fp.LoadVMs() + if err != nil { + t.Fatalf("LoadVMs failed: %v", err) + } + + if len(loadedStates) != len(states) { + t.Fatalf("LoadVMs returned %d states, expected %d", len(loadedStates), len(states)) + } + + // Check if loaded states match saved states + for i, state := range states { + if state.ID != loadedStates[i].ID || state.PID != loadedStates[i].PID { + t.Fatalf("Loaded state %v doesn't match saved state %v", loadedStates[i], state) + } + } +} + +func TestDeleteVM(t *testing.T) { + tempDir := t.TempDir() + fp, _ := NewFilePersistence(tempDir) + + state := VMState{ID: "test-vm", Config: Config{}, PID: 1234} + + // Save VM + if err := fp.SaveVM(state); err != nil { + t.Fatalf("SaveVM failed: %v", err) + } + + // Delete VM + if err := fp.DeleteVM(state.ID); err != nil { + t.Fatalf("DeleteVM failed: %v", err) + } + + // Check if file is deleted + if _, err := os.Stat(filepath.Join(tempDir, "test-vm.json")); !os.IsNotExist(err) { + t.Fatalf("DeleteVM didn't remove the file") + } +} + +func TestLoadVMsWithInvalidFile(t *testing.T) { + tempDir := t.TempDir() + fp, _ := NewFilePersistence(tempDir) + + invalidData := []byte("{invalid json") + if err := os.WriteFile(filepath.Join(tempDir, "invalid.json"), invalidData, 0o644); err != nil { + t.Fatalf("Failed to create invalid JSON file: %v", err) + } + + _, err := fp.LoadVMs() + if err == nil { + t.Fatalf("LoadVMs should have failed with invalid JSON") + } +} + +func TestConcurrentAccess(t *testing.T) { + tempDir := t.TempDir() + fp, _ := NewFilePersistence(tempDir) + + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines * 2) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + state := VMState{ID: fmt.Sprintf("vm-%d", id), Config: Config{}, PID: id} + if err := fp.SaveVM(state); err != nil { + t.Errorf("Concurrent SaveVM failed: %v", err) + } + }(i) + + go func() { + defer wg.Done() + if _, err := fp.LoadVMs(); err != nil { + t.Errorf("Concurrent LoadVMs failed: %v", err) + } + }() + } + + wg.Wait() +} diff --git a/manager/qemu/vm_test.go b/manager/qemu/vm_test.go index 5465f1408..4e4e2b75e 100644 --- a/manager/qemu/vm_test.go +++ b/manager/qemu/vm_test.go @@ -3,8 +3,10 @@ package qemu import ( + "os" "os/exec" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/ultravioletrs/cocos/pkg/manager" @@ -15,25 +17,110 @@ func TestNewVM(t *testing.T) { logsChan := make(chan *manager.ClientStreamMessage) computationId := "test-computation" - nvm := NewVM(config, logsChan, computationId) + vm := NewVM(config, logsChan, computationId) - assert.NotNil(t, nvm) - assert.IsType(t, &qemuVM{}, nvm) + assert.NotNil(t, vm) + assert.IsType(t, &qemuVM{}, vm) } -func TestVM_Stop(t *testing.T) { - // Setup - v := &qemuVM{ - cmd: exec.Command("sleep", "1"), +func TestStart(t *testing.T) { + // Create a temporary file for testing + tmpFile, err := os.CreateTemp("", "test-ovmf-vars") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + config := Config{ + OVMFVarsConfig: OVMFVarsConfig{ + File: tmpFile.Name(), + }, + QemuBinPath: "echo", // Use 'echo' as a dummy QEMU binary } + logsChan := make(chan *manager.ClientStreamMessage) + computationId := "test-computation" - err := v.cmd.Start() + vm := NewVM(config, logsChan, computationId).(*qemuVM) + + err = vm.Start() assert.NoError(t, err) + assert.NotNil(t, vm.cmd) + + // Clean up + _ = vm.Stop() +} - // Test - err = v.Stop() +func TestStop(t *testing.T) { + cmd := exec.Command("echo", "test") + err := cmd.Start() + assert.NoError(t, err) + + vm := &qemuVM{ + cmd: &exec.Cmd{ + Process: cmd.Process, + }, + } + + err = vm.Stop() + assert.NoError(t, err) +} + +func TestSetProcess(t *testing.T) { + vm := &qemuVM{ + config: Config{ + QemuBinPath: "echo", // Use 'echo' as a dummy QEMU binary + }, + } - // Assert + err := vm.SetProcess(os.Getpid()) // Use current process as a dummy assert.NoError(t, err) - assert.Error(t, v.cmd.Wait()) // Process should have been killed + assert.NotNil(t, vm.cmd) + assert.NotNil(t, vm.cmd.Process) +} + +func TestGetProcess(t *testing.T) { + expectedPid := 12345 + vm := &qemuVM{ + cmd: &exec.Cmd{ + Process: &os.Process{Pid: expectedPid}, + }, + } + + pid := vm.GetProcess() + assert.Equal(t, expectedPid, pid) +} + +func TestGetCID(t *testing.T) { + expectedCID := 42 + vm := &qemuVM{ + config: Config{ + VSockConfig: VSockConfig{ + GuestCID: expectedCID, + }, + }, + } + + cid := vm.GetCID() + assert.Equal(t, expectedCID, cid) +} + +func TestCheckVMProcessPeriodically(t *testing.T) { + logsChan := make(chan *manager.ClientStreamMessage, 1) + vm := &qemuVM{ + logsChan: logsChan, + computationId: "test-computation", + cmd: &exec.Cmd{ + Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process + }, + } + + go vm.checkVMProcessPeriodically() + + select { + case msg := <-logsChan: + assert.NotNil(t, msg.GetAgentEvent()) + assert.Equal(t, "test-computation", msg.GetAgentEvent().ComputationId) + assert.Equal(t, manager.VmRunning.String(), msg.GetAgentEvent().EventType) + assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status) + case <-time.After(2 * interval): + t.Fatal("Timeout waiting for VM stopped message") + } } diff --git a/manager/service.go b/manager/service.go index d0dac45d6..cd51dd66c 100644 --- a/manager/service.go +++ b/manager/service.go @@ -53,6 +53,8 @@ var ( // Service specifies an API that must be fulfilled by the domain service // implementation, and all of its decorators (e.g. logging & metrics). +// +//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" type Service interface { // Run create a computation. Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) diff --git a/manager/service_test.go b/manager/service_test.go index 2998134e0..dacde4970 100644 --- a/manager/service_test.go +++ b/manager/service_test.go @@ -6,11 +6,15 @@ import ( "context" "encoding/json" "log/slog" + "os" + "os/exec" "testing" + mglog "github.com/absmach/magistrala/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/manager/qemu" persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks" "github.com/ultravioletrs/cocos/manager/vm" @@ -252,3 +256,110 @@ func TestPublishEvent(t *testing.T) { }) } } + +func TestComputationHash(t *testing.T) { + tests := []struct { + name string + computation agent.Computation + wantErr bool + }{ + { + name: "Valid computation", + computation: agent.Computation{ + ID: "test-id", + Name: "test-name", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := computationHash(tt.computation) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, hash) + + hash2, _ := computationHash(tt.computation) + assert.Equal(t, hash, hash2) + } + }) + } +} + +func TestDecodeRange(t *testing.T) { + tests := []struct { + name string + input string + wantStart int + wantEnd int + wantErr bool + }{ + {"Valid range", "1-5", 1, 5, false}, + {"Invalid format", "1:5", 0, 0, true}, + {"Start greater than end", "5-1", 0, 0, true}, + {"Non-numeric input", "a-b", 0, 0, true}, + {"Single number", "5", 0, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start, end, err := decodeRange(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantStart, start) + assert.Equal(t, tt.wantEnd, end) + } + }) + } +} + +func TestRestoreVMs(t *testing.T) { + mockPersistence := new(persistenceMocks.Persistence) + vmf := new(mocks.Provider) + vmMock := new(mocks.VM) + vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock) + vmMock.On("SetProcess", mock.Anything).Return(nil) + ms := &managerService{ + persistence: mockPersistence, + vms: make(map[string]vm.VM), + eventsChan: make(chan *manager.ClientStreamMessage, 10), + vmFactory: vmf.Execute, + logger: mglog.NewMock(), + } + + cmd := exec.Command("echo", "test") + err := cmd.Start() + assert.NoError(t, err) + + mockPersistence.On("LoadVMs").Return([]qemu.VMState{ + {ID: "vm1", PID: cmd.Process.Pid}, + {ID: "vm2", PID: 1000}, + }, nil) + + mockPersistence.On("DeleteVM", mock.Anything).Return(nil) + + err = ms.restoreVMs() + assert.NoError(t, err) + + assert.Len(t, ms.vms, 1) + assert.Contains(t, ms.vms, "vm1") + + mockPersistence.AssertExpectations(t) +} + +func TestProcessExists(t *testing.T) { + ms := &managerService{} + + assert.True(t, ms.processExists(os.Getpid())) + + assert.False(t, ms.processExists(99999)) + + if os.Getuid() != 0 { // Skip this test if running as root. + assert.False(t, ms.processExists(1)) // PID 1 is usually the init process. + } +} From 643c132ff7ff7bc001ef7534406dd02e59e9eefc Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 8 Oct 2024 16:50:50 +0300 Subject: [PATCH 42/83] NOISSUE - Add pkg tests (#269) * add pkg tests Signed-off-by: Sammy Oina * rename function Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cli/cache.go | 2 +- pkg/clients/grpc/connect.go | 6 +- pkg/clients/grpc/connect_test.go | 192 +++++++++++++++++++++++++++++++ pkg/socket/socket.go | 56 --------- 4 files changed, 196 insertions(+), 60 deletions(-) create mode 100644 pkg/clients/grpc/connect_test.go delete mode 100644 pkg/socket/socket.go diff --git a/cli/cache.go b/cli/cache.go index 3c03f7c90..957362d6a 100644 --- a/cli/cache.go +++ b/cli/cache.go @@ -27,7 +27,7 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command { Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { attestationConfiguration := grpc.AttestationConfiguration{} - err := grpc.ReadManifest(args[0], &attestationConfiguration) + err := grpc.ReadBackendInfo(args[0], &attestationConfiguration) if err != nil { log.Fatalf("Error while reading manifest: %v", err) } diff --git a/pkg/clients/grpc/connect.go b/pkg/clients/grpc/connect.go index c3407cdc5..57b621533 100644 --- a/pkg/clients/grpc/connect.go +++ b/pkg/clients/grpc/connect.go @@ -68,7 +68,7 @@ type Config struct { URL string `env:"URL" envDefault:"localhost:7001"` Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"` AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` - Manifest string `env:"MANIFEST" envDefault:""` + BackendInfo string `env:"BACKEND_INFO" envDefault:""` } type AttestationConfiguration struct { @@ -142,7 +142,7 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) { tc := insecure.NewCredentials() if cfg.AttestedTLS { - err := ReadManifest(cfg.Manifest, &attestationConfiguration) + err := ReadBackendInfo(cfg.BackendInfo, &attestationConfiguration) if err != nil { return nil, secure, fmt.Errorf("failed to read Manifest %w", err) } @@ -193,7 +193,7 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) { return conn, secure, nil } -func ReadManifest(manifestPath string, attestationConfiguration *AttestationConfiguration) error { +func ReadBackendInfo(manifestPath string, attestationConfiguration *AttestationConfiguration) error { if manifestPath != "" { manifest, err := os.Open(manifestPath) if err != nil { diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go new file mode 100644 index 000000000..5922ad3a6 --- /dev/null +++ b/pkg/clients/grpc/connect_test.go @@ -0,0 +1,192 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package grpc + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewClient(t *testing.T) { + caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles() + require.NoError(t, err) + + t.Cleanup(func() { + os.Remove(caCertFile) + os.Remove(clientCertFile) + os.Remove(clientKeyFile) + }) + + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "Success without TLS", + cfg: Config{ + URL: "localhost:7001", + }, + wantErr: false, + }, + { + name: "Success with TLS", + cfg: Config{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + }, + wantErr: false, + }, + { + name: "Success with mTLS", + cfg: Config{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, client) + } else { + assert.NoError(t, err) + assert.NotNil(t, client) + assert.NoError(t, client.Close()) + } + }) + } +} + +func TestClientSecure(t *testing.T) { + tests := []struct { + name string + secure security + expected string + }{ + { + name: "Without TLS", + secure: withoutTLS, + expected: "without TLS", + }, + { + name: "With TLS", + secure: withTLS, + expected: "with TLS", + }, + { + name: "With mTLS", + secure: withmTLS, + expected: "with mTLS", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &client{secure: tt.secure} + assert.Equal(t, tt.expected, c.Secure()) + }) + } +} + +func createCertificatesFiles() (string, string, string, error) { + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + caTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})) + if err != nil { + return "", "", "", err + } + + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + clientTemplate := x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})) + if err != nil { + return "", "", "", err + } + + clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)})) + if err != nil { + return "", "", "", err + } + + return caCertFile, clientCertFile, clientKeyFile, nil +} + +func createTempFile(data []byte) (string, error) { + file, err := createTempFileHandle() + if err != nil { + return "", err + } + + _, err = file.Write(data) + if err != nil { + return "", err + } + + err = file.Close() + if err != nil { + return "", err + } + + return file.Name(), nil +} + +func createTempFileHandle() (*os.File, error) { + return os.CreateTemp("", "test") +} diff --git a/pkg/socket/socket.go b/pkg/socket/socket.go deleted file mode 100644 index 62d3eb2ce..000000000 --- a/pkg/socket/socket.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package socket - -import ( - "fmt" - "io" - "net" - "os" -) - -func StartUnixSocketServer(socketPath string) (net.Listener, error) { - // Remove any existing socket file - _ = os.Remove(socketPath) - - // Create a Unix domain socket listener - listener, err := net.Listen("unix", socketPath) - if err != nil { - return nil, fmt.Errorf("error creating socket listener: %v", err) - } - - fmt.Println("Unix domain socket server is listening on", socketPath) - - return listener, nil -} - -func AcceptConnection(listener net.Listener, dataChannel chan []byte, errorChannel chan error) { - conn, err := listener.Accept() - if err != nil { - errorChannel <- fmt.Errorf("error accepting connection:: %v", err) - } - - handleConnection(conn, dataChannel, errorChannel) -} - -func handleConnection(conn net.Conn, dataChannel chan []byte, errorChannel chan error) { - defer conn.Close() - - // Create a dynamic buffer to store incoming data - var buffer []byte - tmp := make([]byte, 1024) - - for { - // Read data into the temporary buffer - n, err := conn.Read(tmp) - if err != nil { - if err == io.EOF { - break - } - errorChannel <- err - } - buffer = append(buffer, tmp[:n]...) - } - - dataChannel <- buffer -} From 034547d66754265e77a3899c768bd6f8ec063209 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:02:17 +0300 Subject: [PATCH 43/83] NOISSUE - Add VM state machine and filter on qemu logs (#272) * add vm state machine and filter on qemu logs Signed-off-by: Sammy Oina * fix lint Signed-off-by: Sammy Oina * fix failing test Signed-off-by: Sammy Oina * fix logging test Signed-off-by: Sammy Oina * fix tests Signed-off-by: Sammy Oina * fix failing test Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- manager/agentEventsLogs.go | 2 +- manager/agentEventsLogs_test.go | 7 ++-- manager/qemu/vm.go | 22 ++++++++++-- manager/qemu/vm_test.go | 5 ++- manager/service.go | 8 +++++ manager/service_test.go | 2 ++ manager/vm/logging.go | 61 +++++++++++++++++++-------------- manager/vm/logging_test.go | 6 +++- manager/vm/mocks/vm.go | 40 ++++++++++++++++++++- manager/vm/state.go | 53 ++++++++++++++++++++++++++++ manager/vm/vm.go | 2 ++ 11 files changed, 175 insertions(+), 33 deletions(-) create mode 100644 manager/vm/state.go diff --git a/manager/agentEventsLogs.go b/manager/agentEventsLogs.go index 2178b7e5e..cf85b73bc 100644 --- a/manager/agentEventsLogs.go +++ b/manager/agentEventsLogs.go @@ -125,7 +125,7 @@ func (ms *managerService) reportBrokenConnection(cmpID string) { ms.eventsChan <- &manager.ClientStreamMessage{ Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ - EventType: manager.VmRunning.String(), + EventType: ms.vms[cmpID].State(), ComputationId: cmpID, Status: manager.Disconnected.String(), Timestamp: timestamppb.Now(), diff --git a/manager/agentEventsLogs_test.go b/manager/agentEventsLogs_test.go index dac0f82f4..b242e7ac5 100644 --- a/manager/agentEventsLogs_test.go +++ b/manager/agentEventsLogs_test.go @@ -125,9 +125,9 @@ func TestHandleConnection(t *testing.T) { msg := &manager.ClientStreamMessage{ Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ - EventType: manager.VmRunning.String(), + EventType: manager.VmProvision.String(), ComputationId: "comp1", - Status: manager.VmRunning.String(), + Status: manager.VmProvision.String(), Timestamp: timestamppb.Now(), Originator: "agent", }, @@ -153,6 +153,9 @@ func TestHandleConnection(t *testing.T) { func TestReportBrokenConnection(t *testing.T) { ms := &managerService{ eventsChan: make(chan *manager.ClientStreamMessage, 1), + vms: map[string]vm.VM{ + "comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"), + }, } ms.reportBrokenConnection("comp1") diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index fc7bf23c4..f21cdab1e 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -30,6 +30,7 @@ type qemuVM struct { cmd *exec.Cmd logsChan chan *manager.ClientStreamMessage computationId string + vm.StateMachine } func NewVM(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) vm.VM { @@ -37,6 +38,7 @@ func NewVM(config interface{}, logsChan chan *manager.ClientStreamMessage, compu config: config.(Config), logsChan: logsChan, computationId: computationId, + StateMachine: vm.NewStateMachine(), } } @@ -73,12 +75,28 @@ func (v *qemuVM) Start() (err error) { v.cmd = exec.Command(exe, args...) v.cmd.Stdout = &vm.Stdout{LogsChan: v.logsChan, ComputationId: v.computationId} - v.cmd.Stderr = &vm.Stderr{LogsChan: v.logsChan, ComputationId: v.computationId} + v.cmd.Stderr = &vm.Stderr{LogsChan: v.logsChan, ComputationId: v.computationId, StateMachine: v.StateMachine} return v.cmd.Start() } func (v *qemuVM) Stop() error { + defer func() { + err := v.StateMachine.Transition(manager.StopComputationRun) + if err != nil { + v.logsChan <- &manager.ClientStreamMessage{ + Message: &manager.ClientStreamMessage_AgentEvent{ + AgentEvent: &manager.AgentEvent{ + ComputationId: v.computationId, + EventType: v.StateMachine.State(), + Status: manager.Warning.String(), + Timestamp: timestamppb.Now(), + Originator: "manager", + }, + }, + } + } + }() err := v.cmd.Process.Signal(syscall.SIGTERM) if err != nil { return fmt.Errorf("failed to send SIGTERM: %v", err) @@ -146,7 +164,7 @@ func (v *qemuVM) checkVMProcessPeriodically() { Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ ComputationId: v.computationId, - EventType: manager.VmRunning.String(), + EventType: v.StateMachine.State(), Status: manager.Stopped.String(), Timestamp: timestamppb.Now(), Originator: "manager", diff --git a/manager/qemu/vm_test.go b/manager/qemu/vm_test.go index 4e4e2b75e..e14a37bf5 100644 --- a/manager/qemu/vm_test.go +++ b/manager/qemu/vm_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/manager/vm" "github.com/ultravioletrs/cocos/pkg/manager" ) @@ -57,6 +58,7 @@ func TestStop(t *testing.T) { cmd: &exec.Cmd{ Process: cmd.Process, }, + StateMachine: vm.NewStateMachine(), } err = vm.Stop() @@ -110,6 +112,7 @@ func TestCheckVMProcessPeriodically(t *testing.T) { cmd: &exec.Cmd{ Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process }, + StateMachine: vm.NewStateMachine(), } go vm.checkVMProcessPeriodically() @@ -118,7 +121,7 @@ func TestCheckVMProcessPeriodically(t *testing.T) { case msg := <-logsChan: assert.NotNil(t, msg.GetAgentEvent()) assert.Equal(t, "test-computation", msg.GetAgentEvent().ComputationId) - assert.Equal(t, manager.VmRunning.String(), msg.GetAgentEvent().EventType) + assert.Equal(t, manager.VmProvision.String(), msg.GetAgentEvent().EventType) assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status) case <-time.After(2 * interval): t.Fatal("Timeout waiting for VM stopped message") diff --git a/manager/service.go b/manager/service.go index cd51dd66c..8bfa4e4ff 100644 --- a/manager/service.go +++ b/manager/service.go @@ -187,6 +187,10 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) return "", err } + if err := ms.vms[c.Id].Transition(manager.VmRunning); err != nil { + ms.logger.Warn("Failed to transition VM state", "computation", c.Id, "error", err) + } + ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Completed.String(), json.RawMessage{}) return fmt.Sprint(ms.qemuCfg.HostFwdAgent), nil } @@ -327,6 +331,10 @@ func (ms *managerService) restoreVMs() error { continue } + if err := cvm.Transition(manager.VmRunning); err != nil { + ms.logger.Warn("Failed to transition VM state", "computation", state.ID, "error", err) + } + ms.vms[state.ID] = cvm ms.logger.Info("Successfully restored VM state", "id", state.ID, "computationId", state.ID, "pid", state.PID) } diff --git a/manager/service_test.go b/manager/service_test.go index dacde4970..4d6441ff3 100644 --- a/manager/service_test.go +++ b/manager/service_test.go @@ -86,6 +86,7 @@ func TestRun(t *testing.T) { vmMock.On("SendAgentConfig", mock.Anything).Return(nil) vmMock.On("GetProcess").Return(1234) + vmMock.On("Transition", mock.Anything).Return(nil) persistence.On("SaveVM", mock.Anything).Return(nil) @@ -324,6 +325,7 @@ func TestRestoreVMs(t *testing.T) { vmMock := new(mocks.VM) vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock) vmMock.On("SetProcess", mock.Anything).Return(nil) + vmMock.On("Transition", mock.Anything).Return(nil) ms := &managerService{ persistence: mockPersistence, vms: make(map[string]vm.VM), diff --git a/manager/vm/logging.go b/manager/vm/logging.go index 8fb4e0763..07f0ee477 100644 --- a/manager/vm/logging.go +++ b/manager/vm/logging.go @@ -7,6 +7,7 @@ import ( "errors" "io" "log/slog" + "strings" "github.com/ultravioletrs/cocos/pkg/manager" "google.golang.org/protobuf/types/known/timestamppb" @@ -58,18 +59,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) { return len(p) - inBuf.Len(), err } - msg := &manager.ClientStreamMessage{ - Message: &manager.ClientStreamMessage_AgentLog{ - AgentLog: &manager.AgentLog{ - Message: string(buf[:n]), - ComputationId: s.ComputationId, - Level: slog.LevelDebug.String(), - Timestamp: timestamppb.Now(), - }, - }, - } - - if err := safeSend(s.LogsChan, msg); err != nil { + if err := sendLog(s.LogsChan, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil { return len(p) - inBuf.Len(), err } } @@ -80,6 +70,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) { type Stderr struct { LogsChan chan *manager.ClientStreamMessage ComputationId string + StateMachine StateMachine } // Write implements io.Writer. @@ -97,18 +88,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { return len(p) - inBuf.Len(), err } - msg := &manager.ClientStreamMessage{ - Message: &manager.ClientStreamMessage_AgentLog{ - AgentLog: &manager.AgentLog{ - Message: string(buf[:n]), - ComputationId: s.ComputationId, - Level: slog.LevelError.String(), - Timestamp: timestamppb.Now(), - }, - }, - } - - if err := safeSend(s.LogsChan, msg); err != nil { + if err := sendLog(s.LogsChan, s.ComputationId, string(buf[:n]), ""); err != nil { return len(p) - inBuf.Len(), err } } @@ -118,7 +98,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) { Message: &manager.ClientStreamMessage_AgentEvent{ AgentEvent: &manager.AgentEvent{ ComputationId: s.ComputationId, - EventType: manager.VmRunning.String(), + EventType: s.StateMachine.State(), Timestamp: timestamppb.Now(), Originator: "manager", Status: manager.Warning.String(), @@ -132,3 +112,34 @@ func (s *Stderr) Write(p []byte) (n int, err error) { return len(p), nil } + +func sendLog(logsChan chan *manager.ClientStreamMessage, computationID, message, level string) error { + if len(message) < 3 { + return nil + } + + if level == "" { + if strings.Contains(strings.ToLower(message), "warning") { + level = slog.LevelWarn.String() + } else { + level = slog.LevelError.String() + } + } + + msg := &manager.ClientStreamMessage{ + Message: &manager.ClientStreamMessage_AgentLog{ + AgentLog: &manager.AgentLog{ + Message: message, + ComputationId: computationID, + Level: level, + Timestamp: timestamppb.Now(), + }, + }, + } + + if err := safeSend(logsChan, msg); err != nil { + return err + } + + return nil +} diff --git a/manager/vm/logging_test.go b/manager/vm/logging_test.go index 872221119..3cac55116 100644 --- a/manager/vm/logging_test.go +++ b/manager/vm/logging_test.go @@ -29,7 +29,7 @@ func TestStdoutWrite(t *testing.T) { }, { name: "Large write exceeding buffer size", - input: string(make([]byte, bufSize*2+1)), + input: string(make([]byte, bufSize*2+3)), expectedWrites: 3, }, } @@ -97,8 +97,12 @@ func TestStderrWrite(t *testing.T) { s := &Stderr{ LogsChan: logsChan, ComputationId: "test-computation", + StateMachine: NewStateMachine(), } + err := s.StateMachine.Transition(manager.VmRunning) + assert.NoError(t, err) + n, err := s.Write([]byte(tt.input)) assert.NoError(t, err) diff --git a/manager/vm/mocks/vm.go b/manager/vm/mocks/vm.go index f1d977e22..27952e88e 100644 --- a/manager/vm/mocks/vm.go +++ b/manager/vm/mocks/vm.go @@ -6,8 +6,10 @@ package mocks import ( - mock "github.com/stretchr/testify/mock" agent "github.com/ultravioletrs/cocos/agent" + manager "github.com/ultravioletrs/cocos/pkg/manager" + + mock "github.com/stretchr/testify/mock" ) // VM is an autogenerated mock type for the VM type @@ -105,6 +107,24 @@ func (_m *VM) Start() error { return r0 } +// State provides a mock function with given fields: +func (_m *VM) State() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for State") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // Stop provides a mock function with given fields: func (_m *VM) Stop() error { ret := _m.Called() @@ -123,6 +143,24 @@ func (_m *VM) Stop() error { return r0 } +// Transition provides a mock function with given fields: newState +func (_m *VM) Transition(newState manager.ManagerState) error { + ret := _m.Called(newState) + + if len(ret) == 0 { + panic("no return value specified for Transition") + } + + var r0 error + if rf, ok := ret.Get(0).(func(manager.ManagerState) error); ok { + r0 = rf(newState) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewVM creates a new instance of VM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewVM(t interface { diff --git a/manager/vm/state.go b/manager/vm/state.go new file mode 100644 index 000000000..86f149e76 --- /dev/null +++ b/manager/vm/state.go @@ -0,0 +1,53 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package vm + +import ( + "errors" + "sync" + + "github.com/ultravioletrs/cocos/pkg/manager" +) + +type sm struct { + sync.Mutex + state manager.ManagerState +} + +type StateMachine interface { + Transition(newState manager.ManagerState) error + State() string +} + +func NewStateMachine() StateMachine { + return &sm{state: manager.VmProvision} +} + +func (sm *sm) Transition(newState manager.ManagerState) error { + sm.Lock() + defer sm.Unlock() + switch sm.state { + case manager.VmProvision: + if newState == manager.VmRunning || newState == manager.StopComputationRun { + sm.state = newState + return nil + } + case manager.VmRunning: + if newState == manager.StopComputationRun { + sm.state = newState + return nil + } + case manager.StopComputationRun: + if newState == manager.VmRunning { + sm.state = newState + return nil + } + } + return errors.New("invalid state transition") +} + +func (sm *sm) State() string { + sm.Lock() + defer sm.Unlock() + return sm.state.String() +} diff --git a/manager/vm/vm.go b/manager/vm/vm.go index b6d463ee3..7431c2506 100644 --- a/manager/vm/vm.go +++ b/manager/vm/vm.go @@ -17,6 +17,8 @@ type VM interface { SetProcess(pid int) error GetProcess() int GetCID() int + Transition(newState manager.ManagerState) error + State() string } //go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" From 7ef25674c44c4dafe024da1041844b593c949b40 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:28:17 +0300 Subject: [PATCH 44/83] add cli tests (#274) Signed-off-by: Sammy Oina --- cli/attestation.go | 16 ++-- cli/attestation_test.go | 190 +++++++++++++++++++++++++++++++++++++++ cli/backend_info_test.go | 136 ++++++++++++++++++++++++++++ cli/cache_test.go | 59 ++++++++++++ cli/checksum.go | 7 +- cli/checksum_test.go | 102 +++++++++++++++++++++ cli/keys.go | 6 +- cli/keys_test.go | 92 +++++++++++++++++++ 8 files changed, 596 insertions(+), 12 deletions(-) create mode 100644 cli/attestation_test.go create mode 100644 cli/backend_info_test.go create mode 100644 cli/cache_test.go create mode 100644 cli/checksum_test.go create mode 100644 cli/keys_test.go diff --git a/cli/attestation.go b/cli/attestation.go index 37c5d6c63..4dd4e596d 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -154,26 +154,30 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { Example: "get ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - log.Println("Getting attestation") + cmd.Println("Getting attestation") reportData, err := hex.DecodeString(args[0]) if err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + cmd.Printf("attestation validation and verification failed with error: %s", err) + return } if len(reportData) != agent.ReportDataSize { - log.Fatalf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize) + cmd.Printf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize) + return } result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData)) if err != nil { - log.Fatalf("Error retrieving attestation: %v", err) + cmd.Printf("Error retrieving attestation: %v", err) + return } if err = os.WriteFile(attestationFilePath, result, 0o644); err != nil { - log.Fatalf("Error saving attestation result: %v", err) + cmd.Printf("Error saving attestation result: %v", err) + return } - log.Println("Attestation result retrieved and saved successfully!") + cmd.Println("Attestation result retrieved and saved successfully!") }, } } diff --git a/cli/attestation_test.go b/cli/attestation_test.go new file mode 100644 index 000000000..ba0b9fe97 --- /dev/null +++ b/cli/attestation_test.go @@ -0,0 +1,190 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "bytes" + "encoding/hex" + "fmt" + "os" + "testing" + + "github.com/google/go-sev-guest/proto/check" + "github.com/google/go-sev-guest/proto/sevsnp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/sdk/mocks" +) + +func TestNewAttestationCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewAttestationCmd() + + assert.Equal(t, "attestation [command]", cmd.Use) + assert.Equal(t, "Get and validate attestations", cmd.Short) +} + +func TestNewGetAttestationCmd(t *testing.T) { + mockSDK := new(mocks.SDK) + cli := &CLI{agentSDK: mockSDK} + cmd := cli.NewGetAttestationCmd() + var buf bytes.Buffer + + cmd.SetOutput(&buf) + + assert.Equal(t, "get", cmd.Use) + assert.Equal(t, "Retrieve attestation information from agent. Report data expected in hex enoded string of length 64 bytes.", cmd.Short) + + reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize) + mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData)).Return([]byte("mock attestation"), nil) + + cmd.SetArgs([]string{hex.EncodeToString(reportData)}) + err := cmd.Execute() + assert.NoError(t, err) + + assert.Contains(t, buf.String(), "Attestation result retrieved and saved successfully!") + + os.Remove(attestationFilePath) +} + +func TestNewValidateAttestationValidationCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewValidateAttestationValidationCmd() + + assert.Equal(t, "validate", cmd.Use) + assert.Equal(t, "Validate and verify attestation information. The report is provided as a file path.", cmd.Short) + + assert.Equal(t, fmt.Sprint(defaultMinimumTcb), cmd.Flag("minimum_tcb").Value.String()) + assert.Equal(t, fmt.Sprint(defaultMinimumLaunchTcb), cmd.Flag("minimum_lauch_tcb").Value.String()) + assert.Equal(t, fmt.Sprint(defaultGuestPolicy), cmd.Flag("guest_policy").Value.String()) + assert.Equal(t, fmt.Sprint(defaultMinimumGuestSvn), cmd.Flag("minimum_guest_svn").Value.String()) + assert.Equal(t, fmt.Sprint(defaultMinimumBuild), cmd.Flag("minimum_build").Value.String()) + assert.Equal(t, defaultCheckCrl, cmd.Flag("check_crl").Value.String() == "true") + assert.Equal(t, fmt.Sprint(defaultTimeout), cmd.Flag("timeout").Value.String()) + assert.Equal(t, fmt.Sprint(defaultMaxRetryDelay), cmd.Flag("max_retry_delay").Value.String()) +} + +func TestParseConfig(t *testing.T) { + cfgString = "" + err := parseConfig() + assert.NoError(t, err) + assert.NotNil(t, cfg.RootOfTrust) + assert.NotNil(t, cfg.Policy) + + cfgString = `{"rootOfTrust":{"product":"test_product"},"policy":{"minimumGuestSvn":1}}` + err = parseConfig() + assert.NoError(t, err) + assert.Equal(t, "test_product", cfg.RootOfTrust.Product) + assert.Equal(t, uint32(1), cfg.Policy.MinimumGuestSvn) + + cfgString = `{"invalid_json"` + err = parseConfig() + assert.Error(t, err) +} + +func TestParseHashes(t *testing.T) { + trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"} + trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"} + + cfg = check.Config{} + if cfg.Policy == nil { + cfg.Policy = &check.Policy{} + } + + err := parseHashes() + assert.NoError(t, err) + assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, 1) + assert.Len(t, cfg.Policy.TrustedIdKeyHashes, 1) + + trustedAuthorHashes = []string{"invalid_hash"} + err = parseHashes() + assert.Error(t, err) +} + +func TestParseFiles(t *testing.T) { + attestationFile = "test_attestation.bin" + authorKeyFile := "test_author_key.pem" + idKeyFile := "test_id_key.pem" + + err := os.WriteFile(attestationFile, []byte("test attestation"), 0o644) + assert.NoError(t, err) + err = os.WriteFile(authorKeyFile, []byte("test author key"), 0o644) + assert.NoError(t, err) + err = os.WriteFile(idKeyFile, []byte("test id key"), 0o644) + assert.NoError(t, err) + + trustedAuthorKeys = []string{authorKeyFile} + trustedIdKeys = []string{idKeyFile} + + err = parseFiles() + assert.NoError(t, err) + assert.Equal(t, []byte("test attestation"), attestation) + assert.Len(t, cfg.Policy.TrustedAuthorKeys, 1) + assert.Len(t, cfg.Policy.TrustedIdKeys, 1) + + os.Remove(attestationFile) + os.Remove(authorKeyFile) + os.Remove(idKeyFile) + + attestationFile = "non_existent_file.bin" + err = parseFiles() + assert.Error(t, err) +} + +func TestParseUints(t *testing.T) { + stepping = "10" + platformInfo = "0xFF" + + cfg = check.Config{} + if cfg.Policy == nil { + cfg.Policy = &check.Policy{ + Product: &sevsnp.SevProduct{}, + } + } + err := parseUints() + assert.NoError(t, err) + assert.Equal(t, uint32(10), cfg.Policy.Product.MachineStepping.Value) + assert.Equal(t, uint64(255), cfg.Policy.PlatformInfo.Value) + + stepping = "invalid" + err = parseUints() + assert.Error(t, err) + + stepping = "10" + platformInfo = "invalid" + err = parseUints() + assert.Error(t, err) +} + +func TestValidateInput(t *testing.T) { + cfg = check.Config{} + if cfg.Policy == nil { + cfg.Policy = &check.Policy{} + } + if cfg.RootOfTrust == nil { + cfg.RootOfTrust = &check.RootOfTrust{} + } + cfg.Policy.ReportData = make([]byte, 64) + cfg.Policy.HostData = make([]byte, 32) + cfg.Policy.FamilyId = make([]byte, 16) + cfg.Policy.ImageId = make([]byte, 16) + cfg.Policy.ReportId = make([]byte, 32) + cfg.Policy.ReportIdMa = make([]byte, 32) + cfg.Policy.Measurement = make([]byte, 48) + cfg.Policy.ChipId = make([]byte, 64) + + err := validateInput() + assert.NoError(t, err) + + cfg.Policy.ReportData = make([]byte, 32) + err = validateInput() + assert.Error(t, err) +} + +func TestGetBase(t *testing.T) { + assert.Equal(t, 16, getBase("0xFF")) + assert.Equal(t, 8, getBase("0o77")) + assert.Equal(t, 2, getBase("0b1010")) + assert.Equal(t, 10, getBase("123")) +} diff --git a/cli/backend_info_test.go b/cli/backend_info_test.go new file mode 100644 index 000000000..6bc2e6984 --- /dev/null +++ b/cli/backend_info_test.go @@ -0,0 +1,136 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "encoding/base64" + "encoding/json" + "os" + "testing" + + "github.com/google/go-sev-guest/proto/check" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestChangeAttestationConfiguration(t *testing.T) { + tmpfile, err := os.CreateTemp("", "backend_info.json") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + initialConfig := AttestationConfiguration{ + SNPPolicy: &check.Policy{ + Measurement: make([]byte, measurementLength), + HostData: make([]byte, hostDataLength), + }, + } + + initialJSON, err := json.Marshal(initialConfig) + require.NoError(t, err) + err = os.WriteFile(tmpfile.Name(), initialJSON, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + base64Data string + expectedLength int + field fieldType + expectError bool + errorType error + }{ + { + name: "Valid Measurement", + base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), + expectedLength: measurementLength, + field: measurementField, + expectError: false, + }, + { + name: "Valid Host Data", + base64Data: base64.StdEncoding.EncodeToString(make([]byte, hostDataLength)), + expectedLength: hostDataLength, + field: hostDataField, + expectError: false, + }, + { + name: "Invalid Base64", + base64Data: "Invalid Base64", + expectedLength: measurementLength, + field: measurementField, + expectError: true, + errorType: errDecode, + }, + { + name: "Invalid Data Length", + base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength-1)), + expectedLength: measurementLength, + field: measurementField, + expectError: true, + errorType: errDataLength, + }, + { + name: "Invalid Field Type", + base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), + expectedLength: measurementLength, + field: fieldType(999), + expectError: true, + errorType: errBackendField, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := changeAttestationConfiguration(tmpfile.Name(), tt.base64Data, tt.expectedLength, tt.field) + + if tt.expectError { + assert.Error(t, err) + assert.ErrorIs(t, err, tt.errorType) + } else { + assert.NoError(t, err) + + content, err := os.ReadFile(tmpfile.Name()) + require.NoError(t, err) + + var config AttestationConfiguration + err = json.Unmarshal(content, &config) + require.NoError(t, err) + + decodedData, _ := base64.StdEncoding.DecodeString(tt.base64Data) + if tt.field == measurementField { + assert.Equal(t, decodedData, config.SNPPolicy.Measurement) + } else if tt.field == hostDataField { + assert.Equal(t, decodedData, config.SNPPolicy.HostData) + } + } + }) + } +} + +func TestNewBackendCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewBackendCmd() + + assert.Equal(t, "backend [command]", cmd.Use) + assert.Equal(t, "Change backend information", cmd.Short) + assert.NotNil(t, cmd.Run) +} + +func TestNewAddMeasurementCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewAddMeasurementCmd() + + assert.Equal(t, "measurement", cmd.Use) + assert.Equal(t, "Add measurement to the backend info file. The value should be in base64. The second parameter is backend_info.json file", cmd.Short) + assert.Equal(t, "measurement ", cmd.Example) + assert.NotNil(t, cmd.Run) +} + +func TestNewAddHostDataCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewAddHostDataCmd() + + assert.Equal(t, "hostdata", cmd.Use) + assert.Equal(t, "Add host data to the backend info file. The value should be in base64. The second parameter is backend_info.json file", cmd.Short) + assert.Equal(t, "hostdata ", cmd.Example) + assert.NotNil(t, cmd.Run) +} diff --git a/cli/cache_test.go b/cli/cache_test.go new file mode 100644 index 000000000..f257cc378 --- /dev/null +++ b/cli/cache_test.go @@ -0,0 +1,59 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "bytes" + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewCABundleCmd(t *testing.T) { + cli := &CLI{} + tempDir, err := os.MkdirTemp("", "ca-bundle-test") + assert.NoError(t, err) + defer os.RemoveAll(tempDir) + + manifestContent := []byte(`{"root_of_trust": {"product": "Milan"}}`) + manifestPath := path.Join(tempDir, "manifest.json") + err = os.WriteFile(manifestPath, manifestContent, 0o644) + assert.NoError(t, err) + + cmd := cli.NewCABundleCmd(tempDir) + cmd.SetArgs([]string{manifestPath}) + output := &bytes.Buffer{} + cmd.SetOutput(output) + err = cmd.Execute() + + assert.NoError(t, err) + + expectedFilePath := path.Join(tempDir, "Milan", caBundleName) + _, err = os.Stat(expectedFilePath) + assert.NoError(t, err) + + content, err := os.ReadFile(expectedFilePath) + assert.NoError(t, err) + assert.NotNil(t, content) +} + +func TestSaveToFile(t *testing.T) { + tempDir, err := os.MkdirTemp("", "save-to-file-test") + assert.NoError(t, err) + defer os.RemoveAll(tempDir) + + filePath := path.Join(tempDir, "test-file.txt") + content := []byte("test content") + + err = saveToFile(filePath, content) + assert.NoError(t, err) + + savedContent, err := os.ReadFile(filePath) + assert.NoError(t, err) + assert.Equal(t, content, savedContent) + + _, err = os.Stat(filePath) + assert.NoError(t, err) +} diff --git a/cli/checksum.go b/cli/checksum.go index c9b2cb0cb..ab234ca92 100644 --- a/cli/checksum.go +++ b/cli/checksum.go @@ -3,8 +3,6 @@ package cli import ( - "log" - "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/internal" ) @@ -20,10 +18,11 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command { hash, err := internal.ChecksumHex(path) if err != nil { - log.Fatalf("Error computing hash: %v", err) + cmd.Printf("Error computing hash: %v", err) + return } - log.Println("Hash of file:", hash) + cmd.Println("Hash of file:", hash) }, } } diff --git a/cli/checksum_test.go b/cli/checksum_test.go new file mode 100644 index 000000000..93d0185c9 --- /dev/null +++ b/cli/checksum_test.go @@ -0,0 +1,102 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "bytes" + "os" + "strings" + "testing" + + "github.com/ultravioletrs/cocos/internal" +) + +func TestNewFileHashCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewFileHashCmd() + + if cmd.Use != "checksum" { + t.Errorf("Expected Use to be 'checksum', got %s", cmd.Use) + } + + if cmd.Short != "Compute the sha3-256 hash of a file" { + t.Errorf("Expected Short to be 'Compute the sha3-256 hash of a file', got %s", cmd.Short) + } + + if cmd.Example != "checksum " { + t.Errorf("Expected Example to be 'checksum ', got %s", cmd.Example) + } +} + +func TestNewFileHashCmdRun(t *testing.T) { + cli := &CLI{} + cmd := cli.NewFileHashCmd() + + content := []byte("test content") + tmpfile, err := os.CreateTemp("", "example") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + var output bytes.Buffer + cmd.SetOut(&output) + cmd.SetErr(&output) + + cmd.SetArgs([]string{tmpfile.Name()}) + err = cmd.Execute() + if err != nil { + t.Fatalf("Error executing command: %v", err) + } + + expectedHash, err := internal.ChecksumHex(tmpfile.Name()) + if err != nil { + t.Fatalf("Error computing expected hash: %v", err) + } + + if !strings.Contains(output.String(), expectedHash) { + t.Errorf("Expected output to contain hash %s, got %s", expectedHash, output.String()) + } +} + +func TestNewFileHashCmdInvalidArgs(t *testing.T) { + cli := &CLI{} + cmd := cli.NewFileHashCmd() + + err := cmd.Execute() + if err == nil { + t.Error("Expected error when executing without arguments, got nil") + } + + cmd.SetArgs([]string{"file1", "file2"}) + err = cmd.Execute() + if err == nil { + t.Error("Expected error when executing with too many arguments, got nil") + } +} + +func TestNewFileHashCmdNonExistentFile(t *testing.T) { + cli := &CLI{} + cmd := cli.NewFileHashCmd() + + var output bytes.Buffer + cmd.SetOut(&output) + cmd.SetErr(&output) + + cmd.SetArgs([]string{"non_existent_file.txt"}) + err := cmd.Execute() + if err != nil { + t.Fatalf("Error executing command: %v", err) + } + + if !strings.Contains(output.String(), "Error computing hash") { + t.Errorf("Expected output to contain 'Error computing hash', got %s", output.String()) + } +} diff --git a/cli/keys.go b/cli/keys.go index 418f1c50f..424aeeed7 100644 --- a/cli/keys.go +++ b/cli/keys.go @@ -25,6 +25,8 @@ const ( publicKeyType = "PUBLIC KEY" publicKeyFile = "public.pem" privateKeyFile = "private.pem" + ECDSA = "ecdsa" + ED25519 = "ed25519" ) var KeyType string @@ -39,7 +41,7 @@ func (cli *CLI) NewKeysCmd() *cobra.Command { Args: cobra.ExactArgs(0), Run: func(cmd *cobra.Command, args []string) { switch KeyType { - case "ecdsa": + case ECDSA: privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { log.Fatalf("Error generating keys: %v", err) @@ -52,7 +54,7 @@ func (cli *CLI) NewKeysCmd() *cobra.Command { generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType) - case "ed25519": + case ED25519: pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader) if err != nil { log.Fatalf("Error generating keys: %v", err) diff --git a/cli/keys_test.go b/cli/keys_test.go new file mode 100644 index 000000000..322314145 --- /dev/null +++ b/cli/keys_test.go @@ -0,0 +1,92 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" + "testing" +) + +func TestNewKeysCmd(t *testing.T) { + cli := &CLI{} + cmd := cli.NewKeysCmd() + + if cmd.Use != "keys" { + t.Errorf("Expected Use to be 'keys', got %s", cmd.Use) + } + + if cmd.Short != "Generate a new public/private key pair" { + t.Errorf("Unexpected Short description: %s", cmd.Short) + } +} + +func TestGenerateAndWriteKeys(t *testing.T) { + tests := []struct { + name string + keyType string + }{ + {"RSA", "rsa"}, + {"ECDSA", "ecdsa"}, + {"ED25519", "ed25519"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + KeyType = tt.keyType + cmd := (&CLI{}).NewKeysCmd() + cmd.Run(cmd, []string{}) + + if _, err := os.Stat(privateKeyFile); os.IsNotExist(err) { + t.Errorf("Private key file was not created") + } + if _, err := os.Stat(publicKeyFile); os.IsNotExist(err) { + t.Errorf("Public key file was not created") + } + + privKeyData, err := os.ReadFile(privateKeyFile) + if err != nil { + t.Fatalf("Failed to read private key file: %v", err) + } + privPem, _ := pem.Decode(privKeyData) + if privPem == nil { + t.Fatalf("Failed to decode private key PEM") + } + + var privKey interface{} + switch tt.keyType { + case "rsa": + privKey, err = x509.ParsePKCS1PrivateKey(privPem.Bytes) + case "ecdsa": + privKey, err = x509.ParseECPrivateKey(privPem.Bytes) + case "ed25519": + privKey, err = x509.ParsePKCS8PrivateKey(privPem.Bytes) + } + if err != nil { + t.Fatalf("Failed to parse private key: %v", err) + } + + switch tt.keyType { + case "rsa": + if _, ok := privKey.(*rsa.PrivateKey); !ok { + t.Errorf("Expected RSA private key, got %T", privKey) + } + case "ecdsa": + if _, ok := privKey.(*ecdsa.PrivateKey); !ok { + t.Errorf("Expected ECDSA private key, got %T", privKey) + } + case "ed25519": + if _, ok := privKey.(ed25519.PrivateKey); !ok { + t.Errorf("Expected ED25519 private key, got %T", privKey) + } + } + + os.Remove(privateKeyFile) + os.Remove(publicKeyFile) + }) + } +} From fb0fbaeb9a989025b39b09523df3f28935dfb395 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:11:37 +0300 Subject: [PATCH 45/83] COCOS-253 - Improve CLI error handling (#277) * decode errors Signed-off-by: Sammy Oina * standardise error formatting Signed-off-by: Sammy Oina * fix failing tests Signed-off-by: Sammy Oina * add errors tests Signed-off-by: Sammy Oina * pass lint Signed-off-by: Sammy Oina * add test cases Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cli/algorithm_test.go | 89 ++++++++++++++++++++++-------------- cli/algorithms.go | 23 +++++----- cli/attestation.go | 43 ++++++++++------- cli/backend_info.go | 7 +-- cli/cache.go | 18 ++++++-- cli/checksum.go | 2 +- cli/datasets.go | 43 ++++++++--------- cli/datasets_test.go | 73 +++++++++++++++++++---------- cli/errors.go | 46 +++++++++++++++++++ cli/errors_test.go | 104 ++++++++++++++++++++++++++++++++++++++++++ cli/keys.go | 56 +++++++++++++---------- cli/result.go | 20 ++++---- cli/result_test.go | 75 ++++++++++++++++++++---------- cli/sdk.go | 2 + cmd/cli/main.go | 85 +++++++++++++++++----------------- pkg/sdk/agent.go | 14 +----- pkg/sdk/agent_test.go | 21 ++------- 17 files changed, 478 insertions(+), 243 deletions(-) create mode 100644 cli/errors.go create mode 100644 cli/errors_test.go diff --git a/cli/algorithm_test.go b/cli/algorithm_test.go index fe853e195..5a82fb9de 100644 --- a/cli/algorithm_test.go +++ b/cli/algorithm_test.go @@ -9,7 +9,6 @@ import ( "crypto/x509" "encoding/pem" "errors" - "log" "os" "testing" @@ -20,14 +19,6 @@ import ( const algorithmFile = "test_algo_file.py" -func captureLogOutput(f func()) string { - var buf bytes.Buffer - log.SetOutput(&buf) - defer log.SetOutput(os.Stderr) - f() - return buf.String() -} - func generateRSAPrivateKeyFile(fileName string) error { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -65,13 +56,13 @@ func TestAlgorithmCmd_Success(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewAlgorithmCmd() - output := captureLogOutput(func() { - cmd.SetArgs([]string{algorithmFile, privateKeyFile}) - err = cmd.Execute() - require.NoError(t, err) - }) + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{algorithmFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Successfully uploaded algorithm") + require.Contains(t, buf.String(), "Successfully uploaded algorithm") t.Cleanup(func() { os.Remove(privateKeyFile) os.Remove(algorithmFile) @@ -84,14 +75,14 @@ func TestAlgorithmCmd_MissingAlgorithmFile(t *testing.T) { testCLI := New(mockSDK) cmd := testCLI.NewAlgorithmCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) - output := captureLogOutput(func() { - cmd.SetArgs([]string{"non_existent_algo_file.py", privateKeyFile}) - err := cmd.Execute() - require.NoError(t, err) - }) + cmd.SetArgs([]string{"non_existent_algo_file.py", privateKeyFile}) + err := cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Error reading algorithm file") + require.Contains(t, buf.String(), "Error reading algorithm file") } func TestAlgorithmCmd_MissingPrivateKeyFile(t *testing.T) { @@ -103,14 +94,13 @@ func TestAlgorithmCmd_MissingPrivateKeyFile(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewAlgorithmCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{algorithmFile, "non_existent_private_key.pem"}) + err = cmd.Execute() + require.NoError(t, err) - output := captureLogOutput(func() { - cmd.SetArgs([]string{algorithmFile, "non_existent_private_key.pem"}) - err = cmd.Execute() - require.NoError(t, err) - }) - - require.Contains(t, output, "Error reading private key file") + require.Contains(t, buf.String(), "Error reading private key file") t.Cleanup(func() { os.Remove(algorithmFile) }) @@ -128,17 +118,48 @@ func TestAlgorithmCmd_UploadFailure(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewAlgorithmCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) - output := captureLogOutput(func() { - cmd.SetArgs([]string{algorithmFile, privateKeyFile}) - err = cmd.Execute() - require.NoError(t, err) - }) + cmd.SetArgs([]string{algorithmFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Failed to upload algorithm") + require.Contains(t, buf.String(), "Failed to upload algorithm") t.Cleanup(func() { os.Remove(privateKeyFile) os.Remove(algorithmFile) }) } + +func TestAlgorithmCmd_InvalidPrivateKey(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644) + require.NoError(t, err) + + privKeyFile, err := os.Create(privateKeyFile) + require.NoError(t, err) + defer privKeyFile.Close() + + _, err = privKeyFile.WriteString("invalid private key") + require.NoError(t, err) + + cmd := testCLI.NewAlgorithmCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + + cmd.SetArgs([]string{algorithmFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + + require.Contains(t, buf.String(), "Error decoding private key") + + t.Cleanup(func() { + os.Remove(algorithmFile) + os.Remove(privateKeyFile) + }) +} diff --git a/cli/algorithms.go b/cli/algorithms.go index e0a1ab7cb..589188678 100644 --- a/cli/algorithms.go +++ b/cli/algorithms.go @@ -5,7 +5,6 @@ package cli import ( "context" "encoding/pem" - "log" "os" "github.com/fatih/color" @@ -32,12 +31,11 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { Run: func(cmd *cobra.Command, args []string) { algorithmFile := args[0] - log.Println("Uploading algorithm file:", algorithmFile) + cmd.Println("Uploading algorithm file:", algorithmFile) algorithm, err := os.ReadFile(algorithmFile) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading algorithm file: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error reading algorithm file: %v ❌ ", err) return } @@ -45,8 +43,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { if requirementsFile != "" { req, err = os.ReadFile(requirementsFile) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading requirments file: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error reading requirments file: %v ❌ ", err) return } } @@ -58,24 +55,26 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[1]) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading private key file: %v ❌ ", err.Error()) - log.Println(msg) + printError(cmd, "Error reading private key file: %v ❌ ", err) return } pemBlock, _ := pem.Decode(privKeyFile) - privKey := decodeKey(pemBlock) + privKey, err := decodeKey(pemBlock) + if err != nil { + printError(cmd, "Error decoding private key: %v ❌ ", err) + return + } ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil { - msg := color.New(color.FgRed).Sprintf("Failed to upload algorithm due to error: %v ❌ ", err.Error()) - log.Println(msg) + printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err) return } - log.Println(color.New(color.FgGreen).Sprint("Successfully uploaded algorithm! ✔ ")) + cmd.Println(color.New(color.FgGreen).Sprint("Successfully uploaded algorithm! ✔ ")) }, } diff --git a/cli/attestation.go b/cli/attestation.go index 4dd4e596d..fd46e9423 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -5,12 +5,12 @@ package cli import ( "encoding/hex" "fmt" - "log" "os" "strconv" "strings" "time" + "github.com/fatih/color" "github.com/google/go-sev-guest/abi" "github.com/google/go-sev-guest/proto/check" "github.com/google/go-sev-guest/proto/sevsnp" @@ -158,22 +158,23 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { reportData, err := hex.DecodeString(args[0]) if err != nil { - cmd.Printf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Error decoding report data: %v ❌ ", err) return } if len(reportData) != agent.ReportDataSize { - cmd.Printf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize) + msg := color.New(color.FgRed).Sprintf("report data must be a hex encoded string of length %d bytes ❌ ", agent.ReportDataSize) + cmd.Println(msg) return } result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData)) if err != nil { - cmd.Printf("Error retrieving attestation: %v", err) + printError(cmd, "Failed to get attestation due to error: %v ❌ ", err) return } if err = os.WriteFile(attestationFilePath, result, 0o644); err != nil { - cmd.Printf("Error saving attestation result: %v", err) + printError(cmd, "Error saving attestation result: %v ❌ ", err) return } @@ -189,37 +190,45 @@ func (cli *CLI) NewValidateAttestationValidationCmd() *cobra.Command { Example: "validate ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - log.Println("Checking attestation") + cmd.Println("Checking attestation") attestationFile = string(args[0]) if err := parseConfig(); err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Error parsing config: %v ❌ ", err) + return } if err := parseHashes(); err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Error parsing hashes: %v ❌ ", err) + return } if err := parseFiles(); err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Error parsing files: %v ❌ ", err) + return } // This format is the attestation report in AMD's specified ABI format, immediately // followed by the certificate table bytes. if len(attestation) < abi.ReportSize { - log.Fatalf("attestation contents too small (0x%x bytes). Want at least 0x%x bytes", len(attestation), abi.ReportSize) + msg := color.New(color.FgRed).Sprintf("attestation contents too small (0x%x bytes). Want at least 0x%x bytes ❌ ", len(attestation), abi.ReportSize) + cmd.Println(msg) + return } if err := parseUints(); err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Error parsing uints: %v ❌ ", err) + return } cfg.Policy.Vmpl = wrapperspb.UInt32(0) if err := validateInput(); err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Error validating input: %v ❌ ", err) + return } if err := verifyAndValidateAttestation(attestation); err != nil { - log.Fatalf("attestation validation and verification failed with error: %s", err) + printError(cmd, "Attestation validation and verification failed with error: %v ❌ ", err) + return } - log.Println("Attestation validation and verification is successful!") + cmd.Println("Attestation validation and verification is successful!") }, } cmd.Flags().StringVar( @@ -398,11 +407,13 @@ func (cli *CLI) NewValidateAttestationValidationCmd() *cobra.Command { ) if err := cmd.MarkFlagRequired("report_data"); err != nil { - log.Fatalf("Failed to mark flag as required: %s", err) + printError(cmd, "Failed to mark flag as required: %v ❌ ", err) + return nil } if err := cmd.MarkFlagRequired("product"); err != nil { - log.Fatalf("Failed to mark flag as required: %s", err) + printError(cmd, "Failed to mark flag as required: %v ❌ ", err) + return nil } return cmd diff --git a/cli/backend_info.go b/cli/backend_info.go index 481d1d0dc..628987d44 100644 --- a/cli/backend_info.go +++ b/cli/backend_info.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "log" "os" "github.com/absmach/magistrala/pkg/errors" @@ -83,7 +82,8 @@ func (cli *CLI) NewAddMeasurementCmd() *cobra.Command { Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { if err := changeAttestationConfiguration(args[1], args[0], measurementLength, measurementField); err != nil { - log.Fatalf("Error could not change measurement data %v", err) + printError(cmd, "Error could not change measurement data: %v ❌ ", err) + return } }, } @@ -97,7 +97,8 @@ func (cli *CLI) NewAddHostDataCmd() *cobra.Command { Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { if err := changeAttestationConfiguration(args[1], args[0], hostDataLength, hostDataField); err != nil { - log.Fatalf("Error could not change host data %v", err) + printError(cmd, "Error could not change host data: %v ❌ ", err) + return } }, } diff --git a/cli/cache.go b/cli/cache.go index 957362d6a..20bf52837 100644 --- a/cli/cache.go +++ b/cli/cache.go @@ -3,7 +3,7 @@ package cli import ( - "log" + "fmt" "os" "path" @@ -29,7 +29,8 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command { attestationConfiguration := grpc.AttestationConfiguration{} err := grpc.ReadBackendInfo(args[0], &attestationConfiguration) if err != nil { - log.Fatalf("Error while reading manifest: %v", err) + printError(cmd, "Error while reading manifest: %v ❌ ", err) + return } product := attestationConfiguration.RootOfTrust.Product @@ -39,17 +40,24 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command { bundle, err := getter.Get(caURL) if err != nil { - log.Fatalf("Error fetching ARK and ASK from AMD KDS for product: %s, error: %v", product, err) + message := fmt.Sprintf("Error fetching ARK and ASK from AMD KDS for product: %s", product) + message += ", error: %v ❌ " + printError(cmd, message, err) + return } err = os.MkdirAll(path.Join(fileSavePath, product), filePermisionKeys) if err != nil { - log.Fatalf("Error while creating directory for product name %s, error: %v", product, err) + message := fmt.Sprintf("Error while creating directory for product name %s", product) + message += ", error: %v ❌ " + printError(cmd, message, err) + return } bundleFilePath := path.Join(fileSavePath, product, caBundleName) if err = saveToFile(bundleFilePath, bundle); err != nil { - log.Fatalf("Error while saving ARK-ASK to file: %v", err) + printError(cmd, "Error while saving ARK-ASK to file: %v ❌ ", err) + return } }, } diff --git a/cli/checksum.go b/cli/checksum.go index ab234ca92..bcbadfcc6 100644 --- a/cli/checksum.go +++ b/cli/checksum.go @@ -18,7 +18,7 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command { hash, err := internal.ChecksumHex(path) if err != nil { - cmd.Printf("Error computing hash: %v", err) + printError(cmd, "Error computing hash: %v ❌ ", err) return } diff --git a/cli/datasets.go b/cli/datasets.go index 49c8c4279..49cbbde13 100644 --- a/cli/datasets.go +++ b/cli/datasets.go @@ -6,10 +6,10 @@ import ( "context" "crypto/x509" "encoding/pem" - "log" "os" "path" + "github.com/absmach/magistrala/pkg/errors" "github.com/fatih/color" "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/agent" @@ -28,12 +28,11 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { Run: func(cmd *cobra.Command, args []string) { datasetPath := args[0] - log.Println("Uploading dataset:", datasetPath) + cmd.Println("Uploading dataset:", datasetPath) f, err := os.Stat(datasetPath) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading dataset file: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error reading dataset file: %v ❌ ", err) return } @@ -42,15 +41,13 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { if f.IsDir() { dataset, err = internal.ZipDirectoryToMemory(datasetPath) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error zipping dataset directory: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error zipping dataset directory: %v ❌ ", err) return } } else { dataset, err = os.ReadFile(datasetPath) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading dataset file: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error reading dataset file: %v ❌ ", err) return } } @@ -62,23 +59,25 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[1]) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading private key file: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error reading private key file: %v ❌ ", err) return } pemBlock, _ := pem.Decode(privKeyFile) - privKey := decodeKey(pemBlock) + privKey, err := decodeKey(pemBlock) + if err != nil { + printError(cmd, "Error decoding private key: %v ❌ ", err) + return + } ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil { - msg := color.New(color.FgRed).Sprintf("Failed to upload dataset due to error: %v ❌ ", err.Error()) - log.Println(msg) + printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err) return } - log.Println(color.New(color.FgGreen).Sprint("Successfully uploaded dataset! ✔ ")) + cmd.Println(color.New(color.FgGreen).Sprint("Successfully uploaded dataset! ✔ ")) }, } @@ -86,26 +85,28 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { return cmd } -func decodeKey(b *pem.Block) interface{} { +func decodeKey(b *pem.Block) (interface{}, error) { + if b == nil { + return nil, errors.New("error decoding key") + } switch b.Type { case rsaKeyType: privKey, err := x509.ParsePKCS8PrivateKey(b.Bytes) if err != nil { privKey, err = x509.ParsePKCS1PrivateKey(b.Bytes) if err != nil { - log.Fatalf("Error parsing private key: %v", err) + return nil, err } } - return privKey + return privKey, nil case ecdsaKeyType: privKey, err := x509.ParseECPrivateKey(b.Bytes) if err != nil { - log.Fatalf("Error parsing private key: %v", err) + return nil, err } - return privKey + return privKey, nil default: - log.Fatalf("Error decoding key") - return nil + return nil, errors.New("error decoding key") } } diff --git a/cli/datasets_test.go b/cli/datasets_test.go index fade39c27..eec9ddc55 100644 --- a/cli/datasets_test.go +++ b/cli/datasets_test.go @@ -3,6 +3,7 @@ package cli import ( + "bytes" "errors" "os" "testing" @@ -38,14 +39,14 @@ func TestDatasetsCmd_Success(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewDatasetsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) - output := captureLogOutput(func() { - cmd.SetArgs([]string{datasetFile, privateKeyFile}) - err = cmd.Execute() - require.NoError(t, err) - }) + cmd.SetArgs([]string{datasetFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Successfully uploaded dataset") + require.Contains(t, buf.String(), "Successfully uploaded dataset") mockSDK.AssertCalled(t, "Data", mock.Anything, mock.Anything, mock.Anything) t.Cleanup(func() { @@ -60,14 +61,14 @@ func TestDatasetsCmd_MissingDatasetFile(t *testing.T) { testCLI := New(mockSDK) cmd := testCLI.NewDatasetsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) - output := captureLogOutput(func() { - cmd.SetArgs([]string{"non_existent_dataset.txt", privateKeyFile}) - err := cmd.Execute() - require.NoError(t, err) - }) + cmd.SetArgs([]string{"non_existent_dataset.txt", privateKeyFile}) + err := cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Error reading dataset file") + require.Contains(t, buf.String(), "Error reading dataset file") } func TestDatasetsCmd_MissingPrivateKeyFile(t *testing.T) { @@ -79,14 +80,14 @@ func TestDatasetsCmd_MissingPrivateKeyFile(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewDatasetsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) - output := captureLogOutput(func() { - cmd.SetArgs([]string{datasetFile, "non_existent_private_key.pem"}) - err = cmd.Execute() - require.NoError(t, err) - }) + cmd.SetArgs([]string{datasetFile, "non_existent_private_key.pem"}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Error reading private key file") + require.Contains(t, buf.String(), "Error reading private key file") t.Cleanup(func() { os.Remove(datasetFile) }) @@ -104,14 +105,40 @@ func TestDatasetsCmd_UploadFailure(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewDatasetsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) - output := captureLogOutput(func() { - cmd.SetArgs([]string{datasetFile, privateKeyFile}) - err = cmd.Execute() - require.NoError(t, err) + cmd.SetArgs([]string{datasetFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) + + require.Contains(t, buf.String(), "Failed to upload dataset due to error") + t.Cleanup(func() { + os.Remove(datasetFile) + os.Remove(privateKeyFile) }) +} + +func TestDatasetsCmd_InvalidPrivateKey(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + testCLI := New(mockSDK) + + datasetFile, err := createTempDatasetFile("test dataset content") + require.NoError(t, err) + + err = os.WriteFile(privateKeyFile, []byte("invalid private key"), 0o644) + require.NoError(t, err) + + cmd := testCLI.NewDatasetsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + + cmd.SetArgs([]string{datasetFile, privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Failed to upload dataset due to error") + require.Contains(t, buf.String(), "Error decoding private key") t.Cleanup(func() { os.Remove(datasetFile) os.Remove(privateKeyFile) diff --git a/cli/errors.go b/cli/errors.go new file mode 100644 index 000000000..0262c29a1 --- /dev/null +++ b/cli/errors.go @@ -0,0 +1,46 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "github.com/absmach/magistrala/pkg/errors" + "github.com/fatih/color" + "github.com/spf13/cobra" + "github.com/ultravioletrs/cocos/agent/auth" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + errAgentUnavailable = errors.New("agent is unavailable on the current address") + errDigitalSignatureVerificationFailed = errors.New("digital signature verification failed, check the provided public key") +) + +func decodeErros(err error) error { + statusErr, ok := status.FromError(err) + if ok { + switch statusErr.Code() { + case codes.PermissionDenied: + return errDigitalSignatureVerificationFailed + case codes.Unavailable: + return errAgentUnavailable + case codes.Unknown: + return err + } + } + switch { + case errors.Contains(err, auth.ErrSignatureVerificationFailed): + return auth.ErrSignatureVerificationFailed + + default: + return err + } +} + +func printError(cmd *cobra.Command, message string, err error) { + if !Verbose { + err = decodeErros(err) + } + msg := color.New(color.FgRed).Sprintf(message, err) + cmd.Println(msg) +} diff --git a/cli/errors_test.go b/cli/errors_test.go new file mode 100644 index 000000000..51e408b85 --- /dev/null +++ b/cli/errors_test.go @@ -0,0 +1,104 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package cli + +import ( + "bytes" + "errors" + "testing" + + mgerrors "github.com/absmach/magistrala/pkg/errors" + "github.com/fatih/color" + "github.com/spf13/cobra" + "github.com/ultravioletrs/cocos/agent/auth" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestDecodeErros(t *testing.T) { + tests := []struct { + name string + input error + expected error + }{ + { + name: "Permission Denied", + input: status.Error(codes.PermissionDenied, "permission denied"), + expected: errDigitalSignatureVerificationFailed, + }, + { + name: "Unavailable", + input: status.Error(codes.Unavailable, "service unavailable"), + expected: errAgentUnavailable, + }, + { + name: "Unknown", + input: status.Error(codes.Unknown, "unknown error"), + expected: status.Error(codes.Unknown, "unknown error"), + }, + { + name: "Signature Verification Failed", + input: mgerrors.Wrap(auth.ErrSignatureVerificationFailed, errors.New("wrapped error")), + expected: auth.ErrSignatureVerificationFailed, + }, + { + name: "Other Error", + input: errors.New("other error"), + expected: errors.New("other error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := decodeErros(tt.input) + if result.Error() != tt.expected.Error() { + t.Errorf("decodeErros(%v) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestPrintError(t *testing.T) { + // Save the original color.NoColor value and restore it after the test + origNoColor := color.NoColor + color.NoColor = true + defer func() { color.NoColor = origNoColor }() + + tests := []struct { + name string + message string + err error + verbose bool + expected string + }{ + { + name: "Non-verbose mode", + message: "Error: %s", + err: status.Error(codes.PermissionDenied, "permission denied"), + verbose: false, + expected: "Error: digital signature verification failed, check the provided public key\n", + }, + { + name: "Verbose mode", + message: "Error: %s", + err: status.Error(codes.PermissionDenied, "permission denied"), + verbose: true, + expected: "Error: rpc error: code = PermissionDenied desc = permission denied\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Verbose = tt.verbose + cmd := &cobra.Command{} + buf := new(bytes.Buffer) + cmd.SetOut(buf) + + printError(cmd, tt.message, tt.err) + + if got := buf.String(); got != tt.expected { + t.Errorf("printError() output = %q, want %q", got, tt.expected) + } + }) + } +} diff --git a/cli/keys.go b/cli/keys.go index 424aeeed7..5b12acd41 100644 --- a/cli/keys.go +++ b/cli/keys.go @@ -10,9 +10,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "log" "os" - "reflect" "github.com/spf13/cobra" ) @@ -44,48 +42,64 @@ func (cli *CLI) NewKeysCmd() *cobra.Command { case ECDSA: privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - log.Fatalf("Error generating keys: %v", err) + printError(cmd, "Error generating keys: %v ❌ ", err) + return } pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privEcdsaKey.PublicKey) if err != nil { - log.Fatalf("Error marshalling public key: %v", err) + printError(cmd, "Error marshalling public key: %v ❌ ", err) + return } - generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType) + if err := generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType); err != nil { + printError(cmd, "Error generating and writing keys: %v ❌ ", err) + return + } case ED25519: pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader) if err != nil { - log.Fatalf("Error generating keys: %v", err) + printError(cmd, "Error generating keys: %v ❌ ", err) + return } pubKey, err := x509.MarshalPKIXPublicKey(pubEd25519Key) if err != nil { - log.Fatalf("Error marshalling public key: %v", err) + printError(cmd, "Error marshalling public key: %v ❌ ", err) + return + } + if err := generateAndWriteKeys(privEd25519Key, pubKey, ed25519KeyType); err != nil { + printError(cmd, "Error generating and writing keys: %v ❌ ", err) + return } - generateAndWriteKeys(privEd25519Key, pubKey, ed25519KeyType) - // Default to RSA default: privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize) if err != nil { - log.Fatalf("Error generating keys: %v", err) + printError(cmd, "Error generating keys: %v ❌ ", err) + return } pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) if err != nil { - log.Fatalf("Error marshalling public key: %v", err) + printError(cmd, "Error marshalling public key: %v ❌ ", err) + return + } + if err := generateAndWriteKeys(privKey, pubKeyBytes, rsaKeyType); err != nil { + printError(cmd, "Error generating and writing keys: %v ❌ ", err) + return } - generateAndWriteKeys(privKey, pubKeyBytes, rsaKeyType) } + + cmd.Printf("Successfully generated public/private key pair of type: %s", KeyType) }, } } -func generateAndWriteKeys(privKey interface{}, pubKeyBytes []byte, keyType string) { +func generateAndWriteKeys(privKey interface{}, pubKeyBytes []byte, keyType string) error { privFile, err := os.Create(privateKeyFile) if err != nil { - log.Fatalf("Error creating private key file: %v", err) + return err } defer privFile.Close() @@ -99,22 +113,19 @@ func generateAndWriteKeys(privKey interface{}, pubKeyBytes []byte, keyType strin b, err = x509.MarshalPKCS8PrivateKey(privKey) } if err != nil { - log.Printf("Error marshalling private key: %v", err) - return + return err } if err := pem.Encode(privFile, &pem.Block{ Type: keyType, Bytes: b, }); err != nil { - log.Printf("Error encoding private key: %v", err) - return + return err } pubFile, err := os.Create(publicKeyFile) if err != nil { - log.Printf("Error creating public key file: %v", err) - return + return err } defer pubFile.Close() @@ -122,9 +133,8 @@ func generateAndWriteKeys(privKey interface{}, pubKeyBytes []byte, keyType strin Type: publicKeyType, Bytes: pubKeyBytes, }); err != nil { - log.Printf("Error encoding public key: %v", err) - return + return err } - log.Printf("Successfully generated public/private key pair of type: %s", reflect.TypeOf(privKey).String()) + return nil } diff --git a/cli/result.go b/cli/result.go index 728d726e3..d43fa59d1 100644 --- a/cli/result.go +++ b/cli/result.go @@ -4,7 +4,6 @@ package cli import ( "encoding/pem" - "log" "os" "github.com/fatih/color" @@ -20,12 +19,11 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { Example: "result ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - log.Println("⏳ Retrieving computation result file") + cmd.Println("⏳ Retrieving computation result file") privKeyFile, err := os.ReadFile(args[0]) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error reading private key file: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error reading private key file: %v ❌ ", err) return } @@ -33,21 +31,23 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { var result []byte - privKey := decodeKey(pemBlock) + privKey, err := decodeKey(pemBlock) + if err != nil { + printError(cmd, "Error decoding private key: %v ❌ ", err) + return + } result, err = cli.agentSDK.Result(cmd.Context(), privKey) if err != nil { - msg := color.New(color.FgRed).Sprintf("Error retrieving computation result: %v ❌ ", err) - log.Println(msg) + printError(cmd, "Error retrieving computation result: %v ❌ ", err) return } if err := os.WriteFile(resultFilePath, result, 0o644); err != nil { - msg := color.New(color.FgRed).Sprintf("Error saving computation result to %s: %v ❌ ", resultFilePath, err) - log.Println(msg) + printError(cmd, "Error saving computation result file: %v ❌ ", err) return } - log.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ ")) + cmd.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ ")) }, } } diff --git a/cli/result_test.go b/cli/result_test.go index 71cccf45d..6ef363bcc 100644 --- a/cli/result_test.go +++ b/cli/result_test.go @@ -3,6 +3,7 @@ package cli import ( + "bytes" "errors" "os" "testing" @@ -23,13 +24,13 @@ func TestResultsCmd_Success(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewResultsCmd() - output := captureLogOutput(func() { - cmd.SetArgs([]string{privateKeyFile}) - err = cmd.Execute() - require.NoError(t, err) - }) + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Computation result retrieved and saved successfully") + require.Contains(t, buf.String(), "Computation result retrieved and saved successfully") resultFile, err := os.ReadFile("results.zip") require.NoError(t, err) @@ -47,13 +48,13 @@ func TestResultsCmd_MissingPrivateKeyFile(t *testing.T) { testCLI := New(mockSDK) cmd := testCLI.NewResultsCmd() - output := captureLogOutput(func() { - cmd.SetArgs([]string{"non_existent_private_key.pem"}) - err := cmd.Execute() - require.NoError(t, err) - }) + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{"non_existent_private_key.pem"}) + err := cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Error reading private key file") + require.Contains(t, buf.String(), "Error reading private key file") } func TestResultsCmd_ResultFailure(t *testing.T) { @@ -65,13 +66,13 @@ func TestResultsCmd_ResultFailure(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewResultsCmd() - output := captureLogOutput(func() { - cmd.SetArgs([]string{privateKeyFile}) - err = cmd.Execute() - require.NoError(t, err) - }) + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "error retrieving computation result") + require.Contains(t, buf.String(), "error retrieving computation result") mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything) t.Cleanup(func() { os.Remove(privateKeyFile) @@ -91,13 +92,13 @@ func TestResultsCmd_SaveFailure(t *testing.T) { require.NoError(t, err) cmd := testCLI.NewResultsCmd() - output := captureLogOutput(func() { - cmd.SetArgs([]string{privateKeyFile}) - err := cmd.Execute() - require.NoError(t, err) - }) + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{privateKeyFile}) + err = cmd.Execute() + require.NoError(t, err) - require.Contains(t, output, "Error saving computation result to results.zip") + require.Contains(t, buf.String(), "Error saving computation result file") mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything) t.Cleanup(func() { @@ -105,3 +106,29 @@ func TestResultsCmd_SaveFailure(t *testing.T) { os.Remove(privateKeyFile) }) } + +func TestResultsCmd_InvalidPrivateKey(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + testCLI := New(mockSDK) + + invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem") + require.NoError(t, err) + err = invalidPrivateKey.Close() + require.NoError(t, err) + + t.Cleanup(func() { + err := os.Remove(invalidPrivateKey.Name()) + require.NoError(t, err) + }) + + cmd := testCLI.NewResultsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{invalidPrivateKey.Name()}) + err = cmd.Execute() + require.NoError(t, err) + + require.Contains(t, buf.String(), "Error decoding private key") + mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything) +} diff --git a/cli/sdk.go b/cli/sdk.go index 7edda7f8f..f961aa22f 100644 --- a/cli/sdk.go +++ b/cli/sdk.go @@ -4,6 +4,8 @@ package cli import "github.com/ultravioletrs/cocos/pkg/sdk" +var Verbose bool + type CLI struct { agentSDK sdk.SDK } diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 15f166f40..3d95006fb 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -4,13 +4,11 @@ package main import ( "fmt" - "log" "os" "os/signal" "path" "syscall" - mglog "github.com/absmach/magistrala/logger" "github.com/caarlos0/env/v11" "github.com/fatih/color" "github.com/spf13/cobra" @@ -35,81 +33,86 @@ type config struct { } func main() { + rootCmd := &cobra.Command{ + Use: "cocos-cli [command]", + Short: "CLI application for CoCos Service API", + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf("CLI application for CoCos Service API\n\n") + fmt.Printf("Usage:\n %s [command]\n\n", cmd.CommandPath()) + fmt.Printf("Available Commands:\n") + + // Filter out "completion" command + availableCommands := make([]*cobra.Command, 0) + for _, subCmd := range cmd.Commands() { + if subCmd.Name() != completion { + availableCommands = append(availableCommands, subCmd) + } + } + + for _, subCmd := range availableCommands { + fmt.Printf(" %-15s%s\n", subCmd.Name(), subCmd.Short) + } + + fmt.Printf("\nFlags:\n") + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + fmt.Printf(" -%s, --%s %s\n", flag.Shorthand, flag.Name, flag.Usage) + }) + fmt.Printf("\nUse \"%s [command] --help\" for more information about a command.\n", cmd.CommandPath()) + }, + } + signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) go func() { <-signalChan fmt.Println() - log.Println(color.New(color.FgRed).Sprint("Operation aborted by user!")) + rootCmd.Println(color.New(color.FgRed).Sprint("Operation aborted by user!")) os.Exit(2) }() var cfg config if err := env.Parse(&cfg); err != nil { - log.Fatalf("failed to load %s configuration : %s", svcName, err) + message := color.New(color.FgRed).Sprintf("failed to load %s configuration : %s", svcName, err) + rootCmd.Println(message) + return } homePath, err := os.UserHomeDir() if err != nil { - log.Fatalf("Error fetching user home directory: %v", err) + message := color.New(color.FgRed).Sprintf("failed to fetch user home directory: %s", err) + rootCmd.Println(message) + return } directoryCachePath := path.Join(homePath, cocosDirectory) if err := os.MkdirAll(directoryCachePath, filePermision); err != nil { - log.Fatalf("Error while creating directory %s, error: %v", directoryCachePath, err) - } - - logger, err := mglog.New(os.Stdout, cfg.LogLevel) - if err != nil { - log.Fatalf("Error creating logger: %s", err) + message := color.New(color.FgRed).Sprintf("failed to create directory %s : %s", directoryCachePath, err) + rootCmd.Println(message) + return } agentGRPCConfig := grpc.Config{} if err := env.ParseWithOptions(&agentGRPCConfig, env.Options{Prefix: envPrefixAgentGRPC}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) + message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err) + rootCmd.Println(message) return } agentGRPCClient, agentClient, err := agent.NewAgentClient(agentGRPCConfig) if err != nil { - logger.Error(err.Error()) + message := color.New(color.FgRed).Sprintf("failed to create %s gRPC client : %s", svcName, err) + rootCmd.Println(message) return } defer agentGRPCClient.Close() - agentSDK := sdk.NewAgentSDK(logger, agentClient) + agentSDK := sdk.NewAgentSDK(agentClient) cliSVC := cli.New(agentSDK) - rootCmd := &cobra.Command{ - Use: "cocos-cli [command]", - Short: "CLI application for CoCos Service API", - Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("CLI application for CoCos Service API\n\n") - fmt.Printf("Usage:\n %s [command]\n\n", cmd.CommandPath()) - fmt.Printf("Available Commands:\n") - - // Filter out "completion" command - availableCommands := make([]*cobra.Command, 0) - for _, subCmd := range cmd.Commands() { - if subCmd.Name() != completion { - availableCommands = append(availableCommands, subCmd) - } - } - - for _, subCmd := range availableCommands { - fmt.Printf(" %-15s%s\n", subCmd.Name(), subCmd.Short) - } - - fmt.Printf("\nFlags:\n") - cmd.Flags().VisitAll(func(flag *pflag.Flag) { - fmt.Printf(" -%s, --%s %s\n", flag.Shorthand, flag.Name, flag.Usage) - }) - fmt.Printf("\nUse \"%s [command] --help\" for more information about a command.\n", cmd.CommandPath()) - }, - } + rootCmd.PersistentFlags().BoolVarP(&cli.Verbose, "verbose", "v", false, "Enable verbose output") keysCmd := cliSVC.NewKeysCmd() attestationCmd := cliSVC.NewAttestationCmd() diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index 231f239f3..d354edb71 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -14,7 +14,6 @@ import ( "encoding/base64" "errors" "io" - "log/slog" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/auth" @@ -38,20 +37,17 @@ const ( type agentSDK struct { client agent.AgentServiceClient - logger *slog.Logger } -func NewAgentSDK(log *slog.Logger, agentClient agent.AgentServiceClient) SDK { +func NewAgentSDK(agentClient agent.AgentServiceClient) SDK { return &agentSDK{ client: agentClient, - logger: log, } } func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error { md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey) if err != nil { - sdk.logger.Error("Failed to generate metadata") return err } @@ -61,7 +57,6 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe stream, err := sdk.client.Algo(ctx) if err != nil { - sdk.logger.Error("Failed to call Algo RPC") return err } algoBuffer := bytes.NewBuffer(algorithm.Algorithm) @@ -69,7 +64,6 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe pb := progressbar.New() if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream); err != nil { - sdk.logger.Error("Failed to send Algorithm") return err } @@ -79,7 +73,6 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error { md, err := generateMetadata(string(auth.DataProviderRole), privKey) if err != nil { - sdk.logger.Error("Failed to generate metadata") return err } @@ -89,14 +82,12 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an stream, err := sdk.client.Data(ctx) if err != nil { - sdk.logger.Error("Failed to call Data RPC") return err } dataBuffer := bytes.NewBuffer(dataset.Dataset) pb := progressbar.New() if err := pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream); err != nil { - sdk.logger.Error("Failed to send Data") return err } @@ -108,14 +99,12 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) { md, err := generateMetadata(string(auth.ConsumerRole), privKey) if err != nil { - sdk.logger.Error("Failed to generate metadata") return nil, err } ctx = metadata.NewOutgoingContext(ctx, md) stream, err := sdk.client.Result(ctx, request) if err != nil { - sdk.logger.Error("Failed to call Result RPC") return nil, err } @@ -141,7 +130,6 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ( response, err := sdk.client.Attestation(ctx, request) if err != nil { - sdk.logger.Error("Failed to call Attestation RPC") return nil, err } diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go index 263c3a86f..faaca252c 100644 --- a/pkg/sdk/agent_test.go +++ b/pkg/sdk/agent_test.go @@ -15,7 +15,6 @@ import ( "os" "testing" - mglog "github.com/absmach/magistrala/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -34,9 +33,6 @@ var ( ) func TestAlgo(t *testing.T) { - logger, err := mglog.New(os.Stdout, "info") - require.NoError(t, err) - conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) if err != nil { t.Fatalf("Failed to dial bufnet: %v", err) @@ -45,7 +41,7 @@ func TestAlgo(t *testing.T) { client := agent.NewAgentServiceClient(conn) - sdk := sdk.NewAgentSDK(logger, client) + sdk := sdk.NewAgentSDK(client) algo, err := os.ReadFile(algoPath) require.NoError(t, err) @@ -124,9 +120,6 @@ func TestAlgo(t *testing.T) { } func TestData(t *testing.T) { - logger, err := mglog.New(os.Stdout, "info") - require.NoError(t, err) - conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) if err != nil { t.Fatalf("Failed to dial bufnet: %v", err) @@ -135,7 +128,7 @@ func TestData(t *testing.T) { client := agent.NewAgentServiceClient(conn) - sdk := sdk.NewAgentSDK(logger, client) + sdk := sdk.NewAgentSDK(client) data, err := os.ReadFile(dataPath) require.NoError(t, err) @@ -224,9 +217,6 @@ func TestData(t *testing.T) { } func TestResult(t *testing.T) { - logger, err := mglog.New(os.Stdout, "info") - require.NoError(t, err) - conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) if err != nil { t.Fatalf("Failed to dial bufnet: %v", err) @@ -235,7 +225,7 @@ func TestResult(t *testing.T) { client := agent.NewAgentServiceClient(conn) - sdk := sdk.NewAgentSDK(logger, client) + sdk := sdk.NewAgentSDK(client) resultConsumerKey, _ := generateKeys(t, "ecdsa") resultConsumer1Key, _ := generateKeys(t, "ed25519") @@ -319,9 +309,6 @@ func TestResult(t *testing.T) { } func TestAttestation(t *testing.T) { - logger, err := mglog.New(os.Stdout, "info") - require.NoError(t, err) - resultConsumerKey, _ := generateKeys(t, "rsa") resultConsumer1Key, _ := generateKeys(t, "ed25519") @@ -339,7 +326,7 @@ func TestAttestation(t *testing.T) { client := agent.NewAgentServiceClient(conn) - sdk := sdk.NewAgentSDK(logger, client) + sdk := sdk.NewAgentSDK(client) _, err = rand.Read(reportData) require.NoError(t, err) From db7f3c7a4bac2a240683ce3e2d51356e4205e565 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:19:12 +0300 Subject: [PATCH 46/83] COCOS-278 - Abstract state machine (#280) * abstract state machine Signed-off-by: Sammy Oina * perpetual results consumption Signed-off-by: Sammy Oina * async action Signed-off-by: Sammy Oina * fix failing tests Signed-off-by: Sammy Oina * fix failing test Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/agentevent_string.go | 29 +++++ agent/agentstate_string.go | 30 +++++ agent/service.go | 177 ++++++++++++++++++-------- agent/service_test.go | 32 ++--- agent/state.go | 150 ---------------------- agent/state_string.go | 31 ----- agent/state_test.go | 201 +++++++++++++++++++++++------- agent/statemachine/mocks/state.go | 86 +++++++++++++ agent/statemachine/state.go | 113 +++++++++++++++++ cli/result.go | 31 ++++- cli/result_test.go | 78 +++++++++++- 11 files changed, 654 insertions(+), 304 deletions(-) create mode 100644 agent/agentevent_string.go create mode 100644 agent/agentstate_string.go delete mode 100644 agent/state.go delete mode 100644 agent/state_string.go create mode 100644 agent/statemachine/mocks/state.go create mode 100644 agent/statemachine/state.go diff --git a/agent/agentevent_string.go b/agent/agentevent_string.go new file mode 100644 index 000000000..1cb344eb0 --- /dev/null +++ b/agent/agentevent_string.go @@ -0,0 +1,29 @@ +// Code generated by "stringer -type=AgentEvent"; DO NOT EDIT. + +package agent + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Start-0] + _ = x[ManifestReceived-1] + _ = x[AlgorithmReceived-2] + _ = x[DataReceived-3] + _ = x[RunComplete-4] + _ = x[ResultsConsumed-5] + _ = x[RunFailed-6] +} + +const _AgentEvent_name = "StartManifestReceivedAlgorithmReceivedDataReceivedRunCompleteResultsConsumedRunFailed" + +var _AgentEvent_index = [...]uint8{0, 5, 21, 38, 50, 61, 76, 85} + +func (i AgentEvent) String() string { + if i < 0 || i >= AgentEvent(len(_AgentEvent_index)-1) { + return "AgentEvent(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _AgentEvent_name[_AgentEvent_index[i]:_AgentEvent_index[i+1]] +} diff --git a/agent/agentstate_string.go b/agent/agentstate_string.go new file mode 100644 index 000000000..e620f39cc --- /dev/null +++ b/agent/agentstate_string.go @@ -0,0 +1,30 @@ +// Code generated by "stringer -type=AgentState"; DO NOT EDIT. + +package agent + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Idle-0] + _ = x[ReceivingManifest-1] + _ = x[ReceivingAlgorithm-2] + _ = x[ReceivingData-3] + _ = x[Running-4] + _ = x[ConsumingResults-5] + _ = x[Complete-6] + _ = x[Failed-7] +} + +const _AgentState_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailed" + +var _AgentState_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89} + +func (i AgentState) String() string { + if i < 0 || i >= AgentState(len(_AgentState_index)-1) { + return "AgentState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _AgentState_name[_AgentState_index[i]:_AgentState_index[i+1]] +} diff --git a/agent/service.go b/agent/service.go index 540d475e7..9f4e5e1a6 100644 --- a/agent/service.go +++ b/agent/service.go @@ -20,12 +20,52 @@ import ( "github.com/ultravioletrs/cocos/agent/algorithm/python" "github.com/ultravioletrs/cocos/agent/algorithm/wasm" "github.com/ultravioletrs/cocos/agent/events" + "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" "golang.org/x/crypto/sha3" ) var _ Service = (*agentService)(nil) +//go:generate stringer -type=AgentState +type AgentState int + +const ( + Idle AgentState = iota + ReceivingManifest + ReceivingAlgorithm + ReceivingData + Running + ConsumingResults + Complete + Failed +) + +//go:generate stringer -type=AgentEvent +type AgentEvent int + +const ( + Start AgentEvent = iota + ManifestReceived + AlgorithmReceived + DataReceived + RunComplete + ResultsConsumed + RunFailed +) + +//go:generate stringer -type=Status +type Status uint8 + +const ( + IdleState Status = iota + InProgress + Ready + Completed + Terminated + Warning +) + const ( // ReportDataSize is the size of the report data expected by the attestation service. ReportDataSize = 64 @@ -71,40 +111,69 @@ type Service interface { } type agentService struct { - computation Computation // Holds the current computation request details. - algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. - result []byte // Stores the result of the computation. - sm *StateMachine // Manages the state transitions of the agent service. - runError error // Stores any error encountered during the computation run. - eventSvc events.Service // Service for publishing events related to computation. - quoteProvider client.QuoteProvider // Provider for generating attestation quotes. + computation Computation // Holds the current computation request details. + algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. + result []byte // Stores the result of the computation. + sm statemachine.StateMachine // Manages the state transitions of the agent service. + runError error // Stores any error encountered during the computation run. + eventSvc events.Service // Service for publishing events related to computation. + quoteProvider client.QuoteProvider // Provider for generating attestation quotes. + logger *slog.Logger // Logger for the agent service. + resultsConsumed bool // Indicates if the results have been consumed. } var _ Service = (*agentService)(nil) // New instantiates the agent service implementation. func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation, quoteProvider client.QuoteProvider) Service { + sm := statemachine.NewStateMachine(Idle) svc := &agentService{ - sm: NewStateMachine(logger, cmp), + sm: sm, eventSvc: eventSvc, quoteProvider: quoteProvider, + logger: logger, + computation: cmp, } - svc.sm.StateFunctions[Idle] = svc.publishEvent(IdleState.String(), json.RawMessage{}) - svc.sm.StateFunctions[ReceivingManifest] = svc.publishEvent(InProgress.String(), json.RawMessage{}) - svc.sm.StateFunctions[ReceivingAlgorithm] = svc.publishEvent(InProgress.String(), json.RawMessage{}) - svc.sm.StateFunctions[ReceivingData] = svc.publishEvent(InProgress.String(), json.RawMessage{}) - svc.sm.StateFunctions[ConsumingResults] = svc.publishEvent(Ready.String(), json.RawMessage{}) - svc.sm.StateFunctions[Complete] = svc.publishEvent(Completed.String(), json.RawMessage{}) - svc.sm.StateFunctions[Running] = svc.runComputation - svc.sm.StateFunctions[Failed] = svc.publishEvent(Failed.String(), json.RawMessage{}) + transitions := []statemachine.Transition{ + {From: Idle, Event: Start, To: ReceivingManifest}, + {From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm}, + } + + if len(cmp.Datasets) == 0 { + transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running}) + } else { + transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData}) + transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running}) + } - go svc.sm.Start(ctx) - svc.sm.SendEvent(start) + transitions = append(transitions, []statemachine.Transition{ + {From: Running, Event: RunComplete, To: ConsumingResults}, + {From: Running, Event: RunFailed, To: Failed}, + {From: ConsumingResults, Event: ResultsConsumed, To: Complete}, + }...) - svc.computation = cmp + for _, t := range transitions { + sm.AddTransition(t) + } + + sm.SetAction(Idle, svc.publishEvent(IdleState.String())) + sm.SetAction(ReceivingManifest, svc.publishEvent(InProgress.String())) + sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String())) + sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String())) + sm.SetAction(Running, svc.runComputation) + sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String())) + sm.SetAction(Complete, svc.publishEvent(Completed.String())) + sm.SetAction(Failed, svc.publishEvent(Failed.String())) + + go func() { + if err := sm.Start(ctx); err != nil { + logger.Error(err.Error()) + } + }() + sm.SendEvent(Start) + defer sm.SendEvent(ManifestReceived) - svc.sm.SendEvent(manifestReceived) return svc } @@ -153,7 +222,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { switch algoType { case string(algorithm.AlgoTypeBin): - as.algorithm = binary.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args) + as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args) case string(algorithm.AlgoTypePython): var requirementsFile string if len(algo.Requirements) > 0 { @@ -171,11 +240,11 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { requirementsFile = fr.Name() } runtime := python.PythonRunTimeFromContext(ctx) - as.algorithm = python.NewAlgorithm(as.sm.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args) + as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args) case string(algorithm.AlgoTypeWasm): - as.algorithm = wasm.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args) + as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args) case string(algorithm.AlgoTypeDocker): - as.algorithm = docker.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name()) + as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name()) } if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil { @@ -183,7 +252,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { } if as.algorithm != nil { - as.sm.SendEvent(algorithmReceived) + as.sm.SendEvent(AlgorithmReceived) } return nil @@ -236,27 +305,30 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { } if len(as.computation.Datasets) == 0 { - defer as.sm.SendEvent(dataReceived) + defer as.sm.SendEvent(DataReceived) } return nil } func (as *agentService) Result(ctx context.Context) ([]byte, error) { - if as.sm.GetState() != ConsumingResults && as.sm.GetState() != Failed { + currentState := as.sm.GetState() + if currentState != ConsumingResults && currentState != Complete && currentState != Failed { return []byte{}, ErrResultsNotReady } - if len(as.computation.ResultConsumers) == 0 { - return []byte{}, ErrAllResultsConsumed - } + index, ok := IndexFromContext(ctx) if !ok { return []byte{}, ErrUndeclaredConsumer } - as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1) - if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults { - defer as.sm.SendEvent(resultsConsumed) + if index < 0 || index >= len(as.computation.ResultConsumers) { + return []byte{}, ErrUndeclaredConsumer + } + + if !as.resultsConsumed && currentState == ConsumingResults { + as.resultsConsumed = true + defer as.sm.SendEvent(ResultsConsumed) } return as.result, as.runError @@ -271,59 +343,58 @@ func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataS return rawQuote, nil } -func (as *agentService) runComputation() { - as.publishEvent(InProgress.String(), json.RawMessage{})() - as.sm.logger.Debug("computation run started") +func (as *agentService) runComputation(state statemachine.State) { + as.publishEvent(InProgress.String())(state) + as.logger.Debug("computation run started") defer func() { if as.runError != nil { - as.sm.SendEvent(runFailed) + as.sm.SendEvent(RunFailed) } else { - as.sm.SendEvent(runComplete) + as.sm.SendEvent(RunComplete) } }() if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil { as.runError = fmt.Errorf("error creating results directory: %s", err.Error()) - as.sm.logger.Warn(as.runError.Error()) - as.publishEvent(Failed.String(), json.RawMessage{})() + as.logger.Warn(as.runError.Error()) + as.publishEvent(Failed.String())(state) return } defer func() { if err := os.RemoveAll(algorithm.ResultsDir); err != nil { - as.sm.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error())) + as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error())) } if err := os.RemoveAll(algorithm.DatasetsDir); err != nil { - as.sm.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error())) + as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error())) } }() - as.publishEvent(InProgress.String(), json.RawMessage{})() + as.publishEvent(InProgress.String())(state) if err := as.algorithm.Run(); err != nil { as.runError = err - as.sm.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) - as.publishEvent(Failed.String(), json.RawMessage{})() + as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) + as.publishEvent(Failed.String())(state) return } results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir) if err != nil { as.runError = err - as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) - as.publishEvent(Failed.String(), json.RawMessage{})() + as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) + as.publishEvent(Failed.String())(state) return } - as.publishEvent(Completed.String(), json.RawMessage{})() + as.publishEvent(Completed.String())(state) as.result = results } -func (as *agentService) publishEvent(status string, details json.RawMessage) func() { - return func() { - st := as.sm.GetState().String() - if err := as.eventSvc.SendEvent(st, status, details); err != nil { - as.sm.logger.Warn(err.Error()) +func (as *agentService) publishEvent(status string) statemachine.Action { + return func(state statemachine.State) { + if err := as.eventSvc.SendEvent(state.String(), status, json.RawMessage{}); err != nil { + as.logger.Warn(err.Error()) } } } diff --git a/agent/service_test.go b/agent/service_test.go index 545dae480..f237b9df2 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -20,6 +20,8 @@ import ( "github.com/ultravioletrs/cocos/agent/events/mocks" "github.com/ultravioletrs/cocos/agent/quoteprovider" mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks" + "github.com/ultravioletrs/cocos/agent/statemachine" + smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" ) @@ -249,45 +251,36 @@ func TestResult(t *testing.T) { err error setup func(svc *agentService) ctxSetup func(ctx context.Context) context.Context + state statemachine.State }{ { name: "Test results not ready", err: ErrResultsNotReady, setup: func(svc *agentService) { }, - }, - { - name: "Test all results consumed", - err: ErrAllResultsConsumed, - setup: func(svc *agentService) { - svc.sm.SetState(ConsumingResults) - svc.computation.ResultConsumers = []ResultConsumer{} - }, - ctxSetup: func(ctx context.Context) context.Context { - return IndexToContext(ctx, 0) - }, + state: Running, }, { name: "Test undeclared consumer", err: ErrUndeclaredConsumer, setup: func(svc *agentService) { - svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}} }, ctxSetup: func(ctx context.Context) context.Context { return ctx }, + state: ConsumingResults, }, { name: "Test results consumed and event sent", err: nil, setup: func(svc *agentService) { - svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}} }, ctxSetup: func(ctx context.Context) context.Context { return IndexToContext(ctx, 0) }, + state: ConsumingResults, }, } @@ -301,14 +294,23 @@ func TestResult(t *testing.T) { ctx = tc.ctxSetup(ctx) } + sm := new(smmocks.StateMachine) + sm.On("Start", ctx).Return(nil) + sm.On("GetState").Return(tc.state) + sm.On("SendEvent", mock.Anything).Return() + svc := &agentService{ - sm: NewStateMachine(mglog.NewMock(), testComputation(t)), + sm: sm, eventSvc: events, quoteProvider: qp, computation: testComputation(t), } - go svc.sm.Start(ctx) + go func() { + if err := svc.sm.Start(ctx); err != nil { + t.Errorf("Error starting state machine: %v", err) + } + }() tc.setup(svc) _, err := svc.Result(ctx) t.Cleanup(func() { diff --git a/agent/state.go b/agent/state.go deleted file mode 100644 index f109735c3..000000000 --- a/agent/state.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package agent - -import ( - "context" - "fmt" - "log/slog" - "sync" -) - -//go:generate stringer -type=State -type State uint8 - -const ( - Idle State = iota - ReceivingManifest - ReceivingAlgorithm - ReceivingData - Running - ConsumingResults - Complete - Failed - AlgorithmRun -) - -//go:generate stringer -type=Status -type Status uint8 - -const ( - IdleState Status = iota - InProgress - Ready - Completed - Terminated - Warning -) - -type event uint8 - -const ( - start event = iota - manifestReceived - algorithmReceived - dataReceived - runComplete - resultsConsumed - runFailed -) - -// StateMachine represents the state machine. -type StateMachine struct { - mu sync.Mutex - State State - EventChan chan event - Transitions map[State]map[event]State - StateFunctions map[State]func() - logger *slog.Logger - wg *sync.WaitGroup -} - -// NewStateMachine creates a new StateMachine. -func NewStateMachine(logger *slog.Logger, cmp Computation) *StateMachine { - sm := &StateMachine{ - State: Idle, - EventChan: make(chan event), - Transitions: make(map[State]map[event]State), - StateFunctions: make(map[State]func()), - logger: logger, - wg: &sync.WaitGroup{}, - } - - sm.Transitions[Idle] = make(map[event]State) - sm.Transitions[Idle][start] = ReceivingManifest - - sm.Transitions[ReceivingManifest] = make(map[event]State) - sm.Transitions[ReceivingManifest][manifestReceived] = ReceivingAlgorithm - - sm.Transitions[ReceivingAlgorithm] = make(map[event]State) - switch len(cmp.Datasets) { - case 0: - sm.Transitions[ReceivingAlgorithm][algorithmReceived] = Running - default: - sm.Transitions[ReceivingAlgorithm][algorithmReceived] = ReceivingData - } - - sm.Transitions[ReceivingData] = make(map[event]State) - sm.Transitions[ReceivingData][dataReceived] = Running - - sm.Transitions[Running] = make(map[event]State) - sm.Transitions[Running][runComplete] = ConsumingResults - sm.Transitions[Running][runFailed] = Failed - - sm.Transitions[ConsumingResults] = make(map[event]State) - sm.Transitions[ConsumingResults][resultsConsumed] = Complete - - return sm -} - -// Start the state machine. -func (sm *StateMachine) Start(ctx context.Context) { - sm.wg.Add(1) - defer sm.wg.Done() - for { - select { - case event := <-sm.EventChan: - currentState := sm.GetState() - var nextState State - var stateFunc func() - var valid bool - - sm.mu.Lock() - nextState, valid = sm.Transitions[sm.State][event] - if valid { - sm.State = nextState - stateFunc = sm.StateFunctions[nextState] - } - sm.mu.Unlock() - - if valid { - sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState)) - if stateFunc != nil { - go stateFunc() - } - } else { - sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State)) - } - - case <-ctx.Done(): - return - } - } -} - -// SendEvent sends an event to the state machine. -func (sm *StateMachine) SendEvent(event event) { - sm.EventChan <- event -} - -func (sm *StateMachine) GetState() State { - sm.mu.Lock() - defer sm.mu.Unlock() - return sm.State -} - -func (sm *StateMachine) SetState(state State) { - sm.mu.Lock() - defer sm.mu.Unlock() - sm.State = state -} diff --git a/agent/state_string.go b/agent/state_string.go deleted file mode 100644 index b084ec0bc..000000000 --- a/agent/state_string.go +++ /dev/null @@ -1,31 +0,0 @@ -// Code generated by "stringer -type=State"; DO NOT EDIT. - -package agent - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[Idle-0] - _ = x[ReceivingManifest-1] - _ = x[ReceivingAlgorithm-2] - _ = x[ReceivingData-3] - _ = x[Running-4] - _ = x[ConsumingResults-5] - _ = x[Complete-6] - _ = x[Failed-7] - _ = x[AlgorithmRun-8] -} - -const _State_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailedAlgorithmRun" - -var _State_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89, 101} - -func (i State) String() string { - if i >= State(len(_State_index)-1) { - return "State(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _State_name[_State_index[i]:_State_index[i+1]] -} diff --git a/agent/state_test.go b/agent/state_test.go index 43e3fb92b..1bd269417 100644 --- a/agent/state_test.go +++ b/agent/state_test.go @@ -4,75 +4,182 @@ package agent import ( "context" - "fmt" + sync "sync" "testing" "time" - mglog "github.com/absmach/magistrala/logger" + "github.com/ultravioletrs/cocos/agent/statemachine" ) -var cmp = Computation{ - Datasets: []Dataset{ - { - Dataset: []byte("test"), - UserKey: []byte("test"), - }, - }, +type MockState int + +type MockEvent int + +func (s MockState) String() string { + return []string{"State1", "State2", "State3"}[s] } -func TestStateMachineTransitions(t *testing.T) { - cases := []struct { - fromState State - event event - expected State - cmp Computation - }{ - {Idle, start, ReceivingManifest, cmp}, - {ReceivingManifest, manifestReceived, ReceivingAlgorithm, cmp}, - {ReceivingAlgorithm, algorithmReceived, ReceivingData, cmp}, - {ReceivingAlgorithm, algorithmReceived, Running, Computation{}}, - {ReceivingData, dataReceived, Running, cmp}, - {Running, runComplete, ConsumingResults, cmp}, - {ConsumingResults, resultsConsumed, Complete, cmp}, +func (e MockEvent) String() string { + return []string{"Event1", "Event2", "Event3"}[e] +} + +const ( + State1 MockState = iota + State2 + State3 +) + +const ( + Event1 MockEvent = iota + Event2 + Event3 +) + +func TestNewStateMachine(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + if sm == nil { + t.Fatal("NewStateMachine returned nil") } + if sm.GetState() != State1 { + t.Errorf("Initial state not set correctly, got %v, want %v", sm.GetState(), State1) + } +} - for _, tc := range cases { - t.Run(fmt.Sprintf("Transition from %v to %v", tc.fromState, tc.expected), func(t *testing.T) { - sm := NewStateMachine(mglog.NewMock(), tc.cmp) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func TestAddTransition(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) - go sm.Start(ctx) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() - time.Sleep(50 * time.Millisecond) + go func() { + if err := sm.Start(ctx); err != context.Canceled { + t.Errorf("Start returned error: %v", err) + } + }() - sm.SetState(tc.fromState) - sm.SendEvent(tc.event) + sm.SendEvent(Event1) - time.Sleep(50 * time.Millisecond) + time.Sleep(50 * time.Millisecond) - if sm.GetState() != tc.expected { - t.Errorf("Expected state %v after the event, but got %v", tc.expected, sm.GetState()) - } - }) + if sm.GetState() != State2 { + t.Errorf("Transition not applied correctly, got state %v, want %v", sm.GetState(), State2) } } -func TestStateMachineInvalidTransition(t *testing.T) { - sm := NewStateMachine(mglog.NewMock(), cmp) - ctx, cancel := context.WithCancel(context.Background()) +func TestSetAction(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + + var wg sync.WaitGroup + wg.Add(1) + + sm.SetAction(State2, func(s statemachine.State) { + defer wg.Done() + }) + + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - go sm.Start(ctx) + go func() { + if err := sm.Start(ctx); err != context.Canceled { + t.Errorf("Start returned error: %v", err) + } + }() - time.Sleep(50 * time.Millisecond) + sm.SendEvent(Event1) - sm.SetState(Idle) - sm.SendEvent(dataReceived) + wg.Wait() - time.Sleep(50 * time.Millisecond) + if ctx.Err() != nil { + t.Error("Action was not called within the expected time") + } +} + +func TestInvalidTransition(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + errChan := make(chan error) + go func() { + errChan <- sm.Start(ctx) + }() + + sm.SendEvent(Event2) + + select { + case err := <-errChan: + if err == nil { + t.Errorf("Expected invalid transition error, got: %v", err) + } + case <-time.After(150 * time.Millisecond): + t.Error("Timeout waiting for invalid transition error") + } +} + +func TestMultipleTransitions(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State3}) + sm.AddTransition(statemachine.Transition{From: State3, Event: Event3, To: State1}) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + go func() { + if err := sm.Start(ctx); err != context.Canceled { + t.Errorf("Start returned error: %v", err) + } + }() + + transitions := []struct { + event MockEvent + want MockState + }{ + {Event1, State2}, + {Event2, State3}, + {Event3, State1}, + } + + for _, tt := range transitions { + sm.SendEvent(tt.event) + time.Sleep(50 * time.Millisecond) + + if sm.GetState() != tt.want { + t.Errorf("After event %v, got state %v, want %v", tt.event, sm.GetState(), tt.want) + } + } +} + +func TestConcurrency(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State1}) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + go func() { + if err := sm.Start(ctx); err == nil { + t.Errorf("Expected context error, got nil") + } + }() + + for i := 0; i < 100; i++ { + go func() { + sm.SendEvent(Event1) + sm.SendEvent(Event2) + }() + } + + time.Sleep(400 * time.Millisecond) - if sm.GetState() != Idle { - t.Errorf("State should not change on an invalid event, but got %v", sm.GetState()) + finalState := sm.GetState() + if finalState != State1 && finalState != State2 { + t.Errorf("Unexpected final state: %v", finalState) } } diff --git a/agent/statemachine/mocks/state.go b/agent/statemachine/mocks/state.go new file mode 100644 index 000000000..e3a3a6f8d --- /dev/null +++ b/agent/statemachine/mocks/state.go @@ -0,0 +1,86 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + context "context" + + agent "github.com/ultravioletrs/cocos/agent/statemachine" + + mock "github.com/stretchr/testify/mock" +) + +// StateMachine is an autogenerated mock type for the StateMachine type +type StateMachine struct { + mock.Mock +} + +// AddTransition provides a mock function with given fields: t +func (_m *StateMachine) AddTransition(t agent.Transition) { + _m.Called(t) +} + +// GetState provides a mock function with given fields: +func (_m *StateMachine) GetState() agent.State { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetState") + } + + var r0 agent.State + if rf, ok := ret.Get(0).(func() agent.State); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(agent.State) + } + } + + return r0 +} + +// SendEvent provides a mock function with given fields: event +func (_m *StateMachine) SendEvent(event agent.Event) { + _m.Called(event) +} + +// SetAction provides a mock function with given fields: state, action +func (_m *StateMachine) SetAction(state agent.State, action agent.Action) { + _m.Called(state, action) +} + +// Start provides a mock function with given fields: ctx +func (_m *StateMachine) Start(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewStateMachine(t interface { + mock.TestingT + Cleanup(func()) +}) *StateMachine { + mock := &StateMachine{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/statemachine/state.go b/agent/statemachine/state.go new file mode 100644 index 000000000..27b5d1e4c --- /dev/null +++ b/agent/statemachine/state.go @@ -0,0 +1,113 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package statemachine + +import ( + "context" + "fmt" + "sync" +) + +type State interface { + String() string +} + +type Event interface { + String() string +} + +type Action func(State) + +type Transition struct { + From State + Event Event + To State +} + +//go:generate mockery --name StateMachine --output=mocks --filename state.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" +type StateMachine interface { + AddTransition(t Transition) + SetAction(state State, action Action) + GetState() State + SendEvent(event Event) + Start(ctx context.Context) error +} + +type stateMachine struct { + mu sync.Mutex + currentState State + transitions map[State]map[Event]State + actions map[State]Action + eventChan chan Event +} + +func NewStateMachine(initialState State) StateMachine { + return &stateMachine{ + currentState: initialState, + transitions: make(map[State]map[Event]State), + actions: make(map[State]Action), + eventChan: make(chan Event), + } +} + +func (sm *stateMachine) AddTransition(t Transition) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if _, ok := sm.transitions[t.From]; !ok { + sm.transitions[t.From] = make(map[Event]State) + } + sm.transitions[t.From][t.Event] = t.To +} + +func (sm *stateMachine) SetAction(state State, action Action) { + sm.mu.Lock() + defer sm.mu.Unlock() + + sm.actions[state] = action +} + +func (sm *stateMachine) GetState() State { + sm.mu.Lock() + defer sm.mu.Unlock() + return sm.currentState +} + +func (sm *stateMachine) SendEvent(event Event) { + sm.eventChan <- event +} + +func (sm *stateMachine) Start(ctx context.Context) error { + for { + select { + case event := <-sm.eventChan: + if err := sm.handleEvent(event); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (sm *stateMachine) handleEvent(event Event) error { + sm.mu.Lock() + currentState := sm.currentState + nextState, valid := sm.transitions[currentState][event] + sm.mu.Unlock() + + if !valid { + return fmt.Errorf("invalid transition: %v -> %v", currentState, event) + } + + sm.mu.Lock() + sm.currentState = nextState + action := sm.actions[nextState] + sm.mu.Unlock() + + if action != nil { + go action(nextState) + } + + return nil +} diff --git a/cli/result.go b/cli/result.go index d43fa59d1..490172e30 100644 --- a/cli/result.go +++ b/cli/result.go @@ -4,13 +4,17 @@ package cli import ( "encoding/pem" + "fmt" "os" "github.com/fatih/color" "github.com/spf13/cobra" ) -const resultFilePath = "results.zip" +const ( + resultFilePrefix = "results" + resultFileExt = ".zip" +) func (cli *CLI) NewResultsCmd() *cobra.Command { return &cobra.Command{ @@ -42,12 +46,35 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { return } + resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt) + if err != nil { + printError(cmd, "Error generating unique file path: %v ❌ ", err) + return + } + if err := os.WriteFile(resultFilePath, result, 0o644); err != nil { printError(cmd, "Error saving computation result file: %v ❌ ", err) return } - cmd.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ ")) + cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath)) }, } } + +func getUniqueFilePath(prefix, ext string) (string, error) { + for i := 0; ; i++ { + var filename string + if i == 0 { + filename = prefix + ext + } else { + filename = fmt.Sprintf("%s_%d%s", prefix, i, ext) + } + + if _, err := os.Stat(filename); os.IsNotExist(err) { + return filename, nil + } else if err != nil { + return "", err + } + } +} diff --git a/cli/result_test.go b/cli/result_test.go index 6ef363bcc..fcd0f9989 100644 --- a/cli/result_test.go +++ b/cli/result_test.go @@ -1,11 +1,14 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 + package cli import ( "bytes" "errors" + "fmt" "os" + "path/filepath" "testing" "github.com/stretchr/testify/mock" @@ -32,12 +35,50 @@ func TestResultsCmd_Success(t *testing.T) { require.Contains(t, buf.String(), "Computation result retrieved and saved successfully") - resultFile, err := os.ReadFile("results.zip") + files, err := filepath.Glob("results*.zip") + require.NoError(t, err) + require.Len(t, files, 1) + + resultFile, err := os.ReadFile(files[0]) require.NoError(t, err) require.Equal(t, compResult, string(resultFile)) t.Cleanup(func() { - os.Remove("results.zip") + for _, file := range files { + os.Remove(file) + } + os.Remove(privateKeyFile) + }) +} + +func TestResultsCmd_MultipleExecutions(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + testCLI := New(mockSDK) + + err := generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewResultsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{privateKeyFile}) + + for i := 0; i < 3; i++ { + err = cmd.Execute() + require.NoError(t, err) + require.Contains(t, buf.String(), "Computation result retrieved and saved successfully") + buf.Reset() + } + + files, err := filepath.Glob("results*.zip") + require.NoError(t, err) + require.Len(t, files, 3) + + t.Cleanup(func() { + for _, file := range files { + os.Remove(file) + } os.Remove(privateKeyFile) }) } @@ -87,8 +128,8 @@ func TestResultsCmd_SaveFailure(t *testing.T) { err := generateRSAPrivateKeyFile(privateKeyFile) require.NoError(t, err) - // Simulate failure in saving the result file by making a directory with the same name as the result file - err = os.Mkdir("results.zip", 0o755) + // Simulate failure in saving the result file by making all files read-only + err = os.Chmod(".", 0o555) require.NoError(t, err) cmd := testCLI.NewResultsCmd() @@ -102,8 +143,10 @@ func TestResultsCmd_SaveFailure(t *testing.T) { mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything) t.Cleanup(func() { - os.Remove("results.zip") - os.Remove(privateKeyFile) + err := os.Chmod(".", 0o755) + require.NoError(t, err) + err = os.Remove(privateKeyFile) + require.NoError(t, err) }) } @@ -132,3 +175,26 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) { require.Contains(t, buf.String(), "Error decoding private key") mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything) } + +func TestGetUniqueFilePath(t *testing.T) { + prefix := "test" + ext := ".txt" + + path, err := getUniqueFilePath(prefix, ext) + require.NoError(t, err) + require.Equal(t, "test.txt", path) + + _, err = os.Create("test.txt") + require.NoError(t, err) + defer os.Remove("test.txt") + for i := 1; i < 3; i++ { + fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext) + _, err := os.Create(fileName) + require.NoError(t, err) + defer os.Remove(fileName) + } + + path, err = getUniqueFilePath(prefix, ext) + require.NoError(t, err) + require.Equal(t, "test_3.txt", path) +} From 18aa8ba785c888eeb5d98eeba659d1b7f452ea53 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 9 Oct 2024 21:01:11 +0300 Subject: [PATCH 47/83] NOISSUE - Add internal tests (#266) * add internal tests Signed-off-by: Sammy Oina * fix linter Signed-off-by: Sammy Oina * fix race conditions Signed-off-by: Sammy Oina * remove all races Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cmd/manager/start_VM.sh | 62 ------- cmd/manager/xml/dom.xml | 62 ------- cmd/manager/xml/pool.xml | 6 - cmd/manager/xml/vol.xml | 15 -- internal/cmd.go | 64 ------- internal/file_test.go | 136 +++++++++++++++ internal/libvirt/connect.go | 45 ----- internal/libvirt/libvirt.go | 167 ------------------ internal/logger/protohandler.go | 2 + internal/logger/protohandler_test.go | 83 +++++++++ internal/server/doc.go | 2 +- internal/server/grpc/grpc.go | 36 ++-- internal/server/grpc/grpc_test.go | 250 +++++++++++++++++++++++++++ internal/server/mocks/server.go | 60 +++++++ internal/server/server.go | 1 + internal/server/server_test.go | 138 +++++++++++++++ internal/vsock/client.go | 17 +- internal/vsock/client_test.go | 205 ++++++++++++++++++++++ internal/zip_test.go | 106 ++++++++++++ manager/agentEventsLogs_test.go | 2 - pkg/sdk/agent_test.go | 9 +- 21 files changed, 1013 insertions(+), 455 deletions(-) delete mode 100755 cmd/manager/start_VM.sh delete mode 100644 cmd/manager/xml/dom.xml delete mode 100644 cmd/manager/xml/pool.xml delete mode 100644 cmd/manager/xml/vol.xml delete mode 100644 internal/cmd.go create mode 100644 internal/file_test.go delete mode 100644 internal/libvirt/connect.go delete mode 100644 internal/libvirt/libvirt.go create mode 100644 internal/logger/protohandler_test.go create mode 100644 internal/server/grpc/grpc_test.go create mode 100644 internal/server/mocks/server.go create mode 100644 internal/server/server_test.go create mode 100644 internal/vsock/client_test.go create mode 100644 internal/zip_test.go diff --git a/cmd/manager/start_VM.sh b/cmd/manager/start_VM.sh deleted file mode 100755 index de88725d5..000000000 --- a/cmd/manager/start_VM.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash - -# Set your default values for sudo and sev -sudo_option=false -sev_option=false - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - key="$1" - - case $key in - --sudo) - sudo_option=true - shift - ;; - --sev) - sev_option=true - shift - ;; - *) - echo "Unknown option: $key" - exit 1 - ;; - esac -done - -build_qemu_command() { - local qemu_command="/usr/bin/qemu-system-x86_64 -enable-kvm -machine q35 -cpu EPYC -smp 4,maxcpus=64 -m 2048M,slots=5,maxmem=30G -drive if=pflash,format=raw,unit=0,file=$MANAGER_QEMU_OVMF_CODE_FILE,readonly=on -drive if=pflash,format=raw,unit=1,file=img/OVMF_VARS.fd -device virtio-scsi-pci,id=scsi,disable-legacy=on,iommu_platform=true -drive file=img/focal-server-cloudimg-amd64.img,if=none,id=disk0,format=qcow2 -device scsi-hd,drive=disk0 -netdev user,id=vmnic,hostfwd=tcp::2222-:22,hostfwd=tcp::9301-:9031,hostfwd=tcp::7020-:7002 -device virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,romfile= -nographic -monitor pty" - - if [ "$sev_option" = true ]; then - qemu_command="$qemu_command -object sev-guest,id=sev0,cbitpos=51,reduced-phys-bits=1 -machine memory-encryption=sev0" - fi - - echo "$qemu_command" -} - -if [ ! -f "img/OVMF_VARS.fd" ]; then - cp "$MANAGER_QEMU_OVMF_VARS_FILE" "img/OVMF_VARS.fd" - echo "Copied $MANAGER_QEMU_OVMF_VARS_FILE to img/OVMF_VARS.fd" -else - echo "img/OVMF_VARS.fd already exists. No need to copy." -fi - -echo "Launching VM ..." - -qemu_command=$(build_qemu_command) -echo "$qemu_command" - -echo "Mapping CTRL-C to CTRL-]" -stty intr ^] - -if [ "$sudo_option" = true ]; then - # Split the command and arguments into an array; << operator is known as a "here string" - IFS=" " read -r -a qemu_command_array <<< "$qemu_command" - # Treat each element in the array as a separate word, preserving spaces within each element - sudo "${qemu_command_array[@]}" -else - $qemu_command -fi - -# Restore the mapping -stty intr ^c diff --git a/cmd/manager/xml/dom.xml b/cmd/manager/xml/dom.xml deleted file mode 100644 index 2cccad454..000000000 --- a/cmd/manager/xml/dom.xml +++ /dev/null @@ -1,62 +0,0 @@ - - QEmu-alpine-standard-x86_64 - c7a5fdbd-cdaf-9455-926a-d65c16db1809 - - - - - - 4194304 - 4194304 - 1 - - hvm - - /usr/share/OVMF/OVMF_CODE.fd - ./img/OVMF_VARS.fd - - - - - - - - - - - - - - destroy - restart - destroy - - - - - - /usr/bin/qemu-system-x86_64 - - - - -
- - - - - - -