diff --git a/CLAUDE.md b/CLAUDE.md index e4b8f5a..a943211 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -351,6 +351,7 @@ Examples of state transition Context blocks: - Use descriptive By() statements to explain test steps - Ensure each test verifies both the state and any side effects - Assert on specific fields that should change during the transition +- Use positive assertions (prefer `Expect(x).To(Equal(y))` over `Expect(x).NotTo(Equal(z))`) for clarity and readability - Test event recording when events are part of the controller behavior - Verify controller return values (Requeue, RequeueAfter) - For tool calls or API interactions, use mock clients with verification diff --git a/kubechain/api/v1alpha1/taskrun_types.go b/kubechain/api/v1alpha1/taskrun_types.go index c4306e8..4ef2649 100644 --- a/kubechain/api/v1alpha1/taskrun_types.go +++ b/kubechain/api/v1alpha1/taskrun_types.go @@ -121,6 +121,10 @@ type TaskRunStatus struct { // SpanContext contains OpenTelemetry span context information // +optional SpanContext *SpanContext `json:"spanContext,omitempty"` + + // ToolCallRequestID uniquely identifies a set of tool calls from a single LLM response + // +optional + ToolCallRequestID string `json:"toolCallRequestId,omitempty"` } type TaskRunStatusStatus string diff --git a/kubechain/config/crd/bases/kubechain.humanlayer.dev_taskruns.yaml b/kubechain/config/crd/bases/kubechain.humanlayer.dev_taskruns.yaml index 4e467a6..bc75764 100644 --- a/kubechain/config/crd/bases/kubechain.humanlayer.dev_taskruns.yaml +++ b/kubechain/config/crd/bases/kubechain.humanlayer.dev_taskruns.yaml @@ -232,6 +232,10 @@ spec: description: StatusDetail provides additional details about the current status type: string + toolCallRequestId: + description: ToolCallRequestID uniquely identifies a set of tool calls + from a single LLM response + type: string userMsgPreview: description: UserMsgPreview stores the first 50 characters of the user's message diff --git a/kubechain/internal/controller/taskrun/taskrun_controller.go b/kubechain/internal/controller/taskrun/taskrun_controller.go index 52bf3b4..adf1b4a 100644 --- a/kubechain/internal/controller/taskrun/taskrun_controller.go +++ b/kubechain/internal/controller/taskrun/taskrun_controller.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/google/uuid" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -194,11 +195,12 @@ func (r *TaskRunReconciler) prepareForLLM(ctx context.Context, taskRun *kubechai func (r *TaskRunReconciler) processToolCalls(ctx context.Context, taskRun *kubechainv1alpha1.TaskRun) (ctrl.Result, error) { logger := log.FromContext(ctx) - // List all tool calls for this TaskRun + // List all tool calls for this ToolCallRequestID toolCallList := &kubechainv1alpha1.TaskRunToolCallList{} - logger.Info("Listing tool calls", "taskrun", taskRun.Name) + logger.Info("Listing tool calls", "taskrun", taskRun.Name, "requestId", taskRun.Status.ToolCallRequestID) + if err := r.List(ctx, toolCallList, client.InNamespace(taskRun.Namespace), - client.MatchingLabels{"kubechain.humanlayer.dev/taskruntoolcall": taskRun.Name}); err != nil { + client.MatchingLabels{"kubechain.humanlayer.dev/toolcallrequest": taskRun.Status.ToolCallRequestID}); err != nil { logger.Error(err, "Failed to list tool calls") return ctrl.Result{}, err } @@ -438,9 +440,14 @@ func (r *TaskRunReconciler) processLLMResponse(ctx context.Context, output *kube // End the parent span since we've reached a terminal state r.endTaskRunSpan(ctx, taskRun, codes.Ok, "TaskRun completed successfully with final answer") } else { + // Generate a unique ID for this set of tool calls + toolCallRequestId := uuid.New().String()[:7] // Using first 7 characters for brevity + logger.Info("Generated toolCallRequestId for tool calls", "id", toolCallRequestId) + // tool call branch: create TaskRunToolCall objects for each tool call returned by the LLM. statusUpdate.Status.Output = "" statusUpdate.Status.Phase = kubechainv1alpha1.TaskRunPhaseToolCallsPending + statusUpdate.Status.ToolCallRequestID = toolCallRequestId statusUpdate.Status.ContextWindow = append(statusUpdate.Status.ContextWindow, kubechainv1alpha1.Message{ Role: "assistant", ToolCalls: adapters.CastOpenAIToolCallsToKubechain(output.ToolCalls), @@ -466,15 +473,22 @@ func (r *TaskRunReconciler) processLLMResponse(ctx context.Context, output *kube func (r *TaskRunReconciler) createToolCalls(ctx context.Context, taskRun *kubechainv1alpha1.TaskRun, statusUpdate *kubechainv1alpha1.TaskRun, toolCalls []kubechainv1alpha1.ToolCall) (ctrl.Result, error) { logger := log.FromContext(ctx) - // For each tool call, create a new TaskRunToolCall. + if statusUpdate.Status.ToolCallRequestID == "" { + err := fmt.Errorf("no ToolCallRequestID found in statusUpdate, cannot create tool calls") + logger.Error(err, "Missing ToolCallRequestID") + return ctrl.Result{}, err + } + + // For each tool call, create a new TaskRunToolCall with a unique name using the ToolCallRequestID for i, tc := range toolCalls { - newName := fmt.Sprintf("%s-tc-%02d", statusUpdate.Name, i+1) + newName := fmt.Sprintf("%s-%s-tc-%02d", statusUpdate.Name, statusUpdate.Status.ToolCallRequestID, i+1) newTRTC := &kubechainv1alpha1.TaskRunToolCall{ ObjectMeta: metav1.ObjectMeta{ Name: newName, Namespace: statusUpdate.Namespace, Labels: map[string]string{ "kubechain.humanlayer.dev/taskruntoolcall": statusUpdate.Name, + "kubechain.humanlayer.dev/toolcallrequest": statusUpdate.Status.ToolCallRequestID, }, OwnerReferences: []metav1.OwnerReference{ { @@ -501,7 +515,7 @@ func (r *TaskRunReconciler) createToolCalls(ctx context.Context, taskRun *kubech logger.Error(err, "Failed to create TaskRunToolCall", "name", newName) return ctrl.Result{}, err } - logger.Info("Created TaskRunToolCall", "name", newName) + logger.Info("Created TaskRunToolCall", "name", newName, "requestId", statusUpdate.Status.ToolCallRequestID) r.recorder.Event(taskRun, corev1.EventTypeNormal, "ToolCallCreated", "Created TaskRunToolCall "+newName) } return ctrl.Result{RequeueAfter: time.Second * 5}, nil @@ -731,8 +745,25 @@ func (r *TaskRunReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct } // Step 9: Process LLM response - if result, err := r.processLLMResponse(ctx, output, &taskRun, statusUpdate); err != nil || !result.IsZero() { - return result, err + var llmResult ctrl.Result + llmResult, err = r.processLLMResponse(ctx, output, &taskRun, statusUpdate) + if err != nil { + logger.Error(err, "Failed to process LLM response") + statusUpdate.Status.Status = StatusError + statusUpdate.Status.Phase = kubechainv1alpha1.TaskRunPhaseFailed + statusUpdate.Status.StatusDetail = fmt.Sprintf("Failed to process LLM response: %v", err) + statusUpdate.Status.Error = err.Error() + r.recorder.Event(&taskRun, corev1.EventTypeWarning, "LLMResponseProcessingFailed", err.Error()) + + if updateErr := r.Status().Update(ctx, statusUpdate); updateErr != nil { + logger.Error(updateErr, "Failed to update TaskRun status after LLM response processing error") + return ctrl.Result{}, updateErr + } + return ctrl.Result{}, nil // Don't return the error to avoid requeuing + } + + if !llmResult.IsZero() { + return llmResult, nil } // Step 10: Update final status diff --git a/kubechain/internal/controller/taskrun/taskrun_controller_test.go b/kubechain/internal/controller/taskrun/taskrun_controller_test.go index 5ac421b..18c2177 100644 --- a/kubechain/internal/controller/taskrun/taskrun_controller_test.go +++ b/kubechain/internal/controller/taskrun/taskrun_controller_test.go @@ -254,7 +254,7 @@ var _ = Describe("TaskRun Controller", func() { Expect(k8sClient.Get(ctx, types.NamespacedName{Name: testTaskRun.name, Namespace: "default"}, taskRun)).To(Succeed()) Expect(taskRun.Status.Status).To(Equal(kubechain.TaskRunStatusStatusError)) // Phase shouldn't be Failed for general errors - Expect(taskRun.Status.Phase).ToNot(Equal(kubechain.TaskRunPhaseFailed)) + Expect(taskRun.Status.Phase).To(Equal(kubechain.TaskRunPhaseReadyForLLM)) Expect(taskRun.Status.Error).To(Equal("connection timeout")) ExpectRecorder(recorder).ToEmitEventContaining("LLMRequestFailed") }) @@ -375,7 +375,8 @@ var _ = Describe("TaskRun Controller", func() { defer teardown() taskRun := testTaskRun.SetupWithStatus(ctx, kubechain.TaskRunStatus{ - Phase: kubechain.TaskRunPhaseToolCallsPending, + Phase: kubechain.TaskRunPhaseToolCallsPending, + ToolCallRequestID: "test123", }) defer testTaskRun.Teardown(ctx) @@ -405,7 +406,8 @@ var _ = Describe("TaskRun Controller", func() { By("setting up the taskrun with a tool call pending") taskRun := testTaskRun.SetupWithStatus(ctx, kubechain.TaskRunStatus{ - Phase: kubechain.TaskRunPhaseToolCallsPending, + Phase: kubechain.TaskRunPhaseToolCallsPending, + ToolCallRequestID: "test123", ContextWindow: []kubechain.Message{ { Role: "system", diff --git a/kubechain/internal/controller/taskrun/utils_test.go b/kubechain/internal/controller/taskrun/utils_test.go index b054cf3..5d9b7e8 100644 --- a/kubechain/internal/controller/taskrun/utils_test.go +++ b/kubechain/internal/controller/taskrun/utils_test.go @@ -248,6 +248,7 @@ func (t *TestTaskRunToolCall) Setup(ctx context.Context) *kubechain.TaskRunToolC Namespace: "default", Labels: map[string]string{ "kubechain.humanlayer.dev/taskruntoolcall": testTaskRun.name, + "kubechain.humanlayer.dev/toolcallrequest": "test123", }, }, Spec: kubechain.TaskRunToolCallSpec{