diff --git a/server/__tests__/utils/agents/defaults.test.js b/server/__tests__/utils/agents/defaults.test.js new file mode 100644 index 00000000000..b3992a9067e --- /dev/null +++ b/server/__tests__/utils/agents/defaults.test.js @@ -0,0 +1,131 @@ +// Set required env vars before requiring modules +process.env.STORAGE_DIR = __dirname; +process.env.NODE_ENV = "test"; + +const { SystemPromptVariables } = require("../../../models/systemPromptVariables"); +const Provider = require("../../../utils/agents/aibitat/providers/ai-provider"); + +jest.mock("../../../models/systemPromptVariables"); +jest.mock("../../../models/systemSettings"); +jest.mock("../../../utils/agents/imported", () => ({ + activeImportedPlugins: jest.fn().mockReturnValue([]), +})); +jest.mock("../../../utils/agentFlows", () => ({ + AgentFlows: { + activeFlowPlugins: jest.fn().mockReturnValue([]), + }, +})); +jest.mock("../../../utils/MCP", () => { + return jest.fn().mockImplementation(() => ({ + activeMCPServers: jest.fn().mockResolvedValue([]), + })); +}); + +const { WORKSPACE_AGENT } = require("../../../utils/agents/defaults"); + +describe("WORKSPACE_AGENT.getDefinition", () => { + beforeEach(() => { + jest.clearAllMocks(); + // Mock SystemSettings to return empty arrays for agent skills + const { SystemSettings } = require("../../../models/systemSettings"); + SystemSettings.getValueOrFallback = jest.fn().mockResolvedValue("[]"); + }); + + it("should use provider default system prompt when workspace has no openAiPrompt", async () => { + const workspace = { + id: 1, + name: "Test Workspace", + openAiPrompt: null, + }; + const user = { id: 1 }; + const provider = "openai"; + const expectedPrompt = await Provider.systemPrompt({ provider, workspace, user }); + const definition = await WORKSPACE_AGENT.getDefinition( + provider, + workspace, + user + ); + expect(definition.role).toBe(expectedPrompt); + expect(SystemPromptVariables.expandSystemPromptVariables).not.toHaveBeenCalled(); + }); + + it("should use workspace system prompt with variable expansion when openAiPrompt exists", async () => { + const workspace = { + id: 1, + name: "Test Workspace", + openAiPrompt: "You are a helpful assistant for {workspace.name}. The current user is {user.name}.", + }; + const user = { id: 1 }; + const provider = "openai"; + + const expandedPrompt = "You are a helpful assistant for Test Workspace. The current user is John Doe."; + SystemPromptVariables.expandSystemPromptVariables.mockResolvedValue(expandedPrompt); + + const definition = await WORKSPACE_AGENT.getDefinition( + provider, + workspace, + user + ); + + expect(SystemPromptVariables.expandSystemPromptVariables).toHaveBeenCalledWith( + workspace.openAiPrompt, + user.id, + workspace.id + ); + expect(definition.role).toBe(expandedPrompt); + }); + + it("should handle workspace system prompt without user context", async () => { + const workspace = { + id: 1, + name: "Test Workspace", + openAiPrompt: "You are a helpful assistant. Today is {date}.", + }; + const user = null; + const provider = "lmstudio"; + const expandedPrompt = "You are a helpful assistant. Today is January 1, 2024."; + SystemPromptVariables.expandSystemPromptVariables.mockResolvedValue(expandedPrompt); + + const definition = await WORKSPACE_AGENT.getDefinition( + provider, + workspace, + user + ); + + expect(SystemPromptVariables.expandSystemPromptVariables).toHaveBeenCalledWith( + workspace.openAiPrompt, + null, + workspace.id + ); + expect(definition.role).toBe(expandedPrompt); + }); + + it("should return functions array in definition", async () => { + const workspace = { id: 1, openAiPrompt: null }; + const provider = "openai"; + + const definition = await WORKSPACE_AGENT.getDefinition( + provider, + workspace, + null + ); + + expect(definition).toHaveProperty("functions"); + expect(Array.isArray(definition.functions)).toBe(true); + }); + + it("should use LMStudio specific prompt when workspace has no openAiPrompt", async () => { + const workspace = { id: 1, openAiPrompt: null }; + const user = null; + const provider = "lmstudio"; + const definition = await WORKSPACE_AGENT.getDefinition( + provider, + workspace, + null + ); + + expect(definition.role).toBe(await Provider.systemPrompt({ provider, workspace, user })); + expect(definition.role).toContain("helpful ai assistant"); + }); +}); + diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 507015cb0cb..5c35ee2e03e 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -19,6 +19,9 @@ const { toValidNumber, safeJsonParse } = require("../../../http"); const { getLLMProviderClass } = require("../../../helpers"); const { parseLMStudioBasePath } = require("../../../AiProviders/lmStudio"); const { parseFoundryBasePath } = require("../../../AiProviders/foundry"); +const { + SystemPromptVariables, +} = require("../../../../models/systemPromptVariables"); const DEFAULT_WORKSPACE_PROMPT = "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; @@ -288,10 +291,7 @@ class Provider { return llm.promptWindowLimit(modelName); } - // For some providers we may want to override the system prompt to be more verbose. - // Currently we only do this for lmstudio, but we probably will want to expand this even more - // to any Untooled LLM. - static systemPrompt(provider = null) { + static defaultSystemPromptForProvider(provider = null) { switch (provider) { case "lmstudio": return "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions. Tools will be handled by another assistant and you will simply receive their responses to help answer the user prompt - always try to answer the user's prompt the best you can with the context available to you and your general knowledge."; @@ -300,6 +300,27 @@ class Provider { } } + /** + * Get the system prompt for a provider. + * @param {string} provider + * @param {import("@prisma/client").workspaces | null} workspace + * @param {import("@prisma/client").users | null} user + * @returns {Promise} + */ + static async systemPrompt({ + provider = null, + workspace = null, + user = null, + }) { + if (!workspace?.openAiPrompt) + return Provider.defaultSystemPromptForProvider(provider); + return await SystemPromptVariables.expandSystemPromptVariables( + workspace.openAiPrompt, + user?.id || null, + workspace.id + ); + } + /** * Whether the provider supports agent streaming. * Disabled by default and needs to be explicitly enabled in the provider diff --git a/server/utils/agents/defaults.js b/server/utils/agents/defaults.js index ee12974cf80..f55589639f7 100644 --- a/server/utils/agents/defaults.js +++ b/server/utils/agents/defaults.js @@ -5,6 +5,7 @@ const Provider = require("./aibitat/providers/ai-provider"); const ImportedPlugin = require("./imported"); const { AgentFlows } = require("../agentFlows"); const MCPCompatibilityLayer = require("../MCP"); +const { SystemPromptVariables } = require("../../models/systemPromptVariables"); // This is a list of skills that are built-in and default enabled. const DEFAULT_SKILLS = [ @@ -15,7 +16,7 @@ const DEFAULT_SKILLS = [ const USER_AGENT = { name: "USER", - getDefinition: async () => { + getDefinition: () => { return { interrupt: "ALWAYS", role: "I am the human monitor and oversee this chat. Any questions on action or decision making should be directed to me.", @@ -25,9 +26,16 @@ const USER_AGENT = { const WORKSPACE_AGENT = { name: "@agent", - getDefinition: async (provider = null) => { + /** + * Get the definition for the workspace agent with its role (prompt) and functions in Aibitat format + * @param {string} provider + * @param {import("@prisma/client").workspaces | null} workspace + * @param {import("@prisma/client").users | null} user + * @returns {Promise<{ role: string, functions: object[] }>} + */ + getDefinition: async (provider = null, workspace = null, user = null) => { return { - role: Provider.systemPrompt(provider), + role: await Provider.systemPrompt({ provider, workspace, user }), functions: [ ...(await agentSkillsFromSystemSettings()), ...ImportedPlugin.activeImportedPlugins(), diff --git a/server/utils/agents/ephemeral.js b/server/utils/agents/ephemeral.js index 9106af24d71..bbcc99cf278 100644 --- a/server/utils/agents/ephemeral.js +++ b/server/utils/agents/ephemeral.js @@ -4,6 +4,7 @@ const ImportedPlugin = require("./imported"); const MCPCompatibilityLayer = require("../MCP"); const { AgentFlows } = require("../agentFlows"); const { httpSocket } = require("./aibitat/plugins/http-socket.js"); +const { User } = require("../../models/user"); const { WorkspaceChats } = require("../../models/workspaceChats"); const { safeJsonParse } = require("../http"); const { @@ -26,7 +27,7 @@ class EphemeralAgentHandler extends AgentHandler { #invocationUUID = null; /** @type {import("@prisma/client").workspaces|null} the workspace to use for the agent */ #workspace = null; - /** @type {import("@prisma/client").users|null} the user id to use for the agent */ + /** @type {import("@prisma/client").users["id"]|null} the user id to use for the agent */ #userId = null; /** @type {import("@prisma/client").workspace_threads|null} the workspace thread id to use for the agent */ #threadId = null; @@ -69,6 +70,9 @@ class EphemeralAgentHandler extends AgentHandler { this.#workspace = workspace; this.#prompt = prompt; + // Note: userId for ephemeral agent is only available + // via the workspace-thread chat endpoints for the API + // since workspaces can belong to multiple users. this.#userId = userId; this.#threadId = threadId; this.#sessionId = sessionId; @@ -319,10 +323,14 @@ class EphemeralAgentHandler extends AgentHandler { async #loadAgents() { // Default User agent and workspace agent this.log(`Attaching user and default agent to Agent cluster.`); - this.aibitat.agent(USER_AGENT.name, await USER_AGENT.getDefinition()); + this.aibitat.agent(USER_AGENT.name, USER_AGENT.getDefinition()); + const user = this.#userId + ? await User.get({ id: Number(this.#userId) }) + : null; + this.aibitat.agent( WORKSPACE_AGENT.name, - await WORKSPACE_AGENT.getDefinition(this.provider) + await WORKSPACE_AGENT.getDefinition(this.provider, this.#workspace, user) ); this.#funcsToLoad = [ diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 98d3d774a09..4ed1ddea0a8 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -3,6 +3,7 @@ const AgentPlugins = require("./aibitat/plugins"); const { WorkspaceAgentInvocation, } = require("../../models/workspaceAgentInvocation"); +const { User } = require("../../models/user"); const { WorkspaceChats } = require("../../models/workspaceChats"); const { safeJsonParse } = require("../http"); const { USER_AGENT, WORKSPACE_AGENT } = require("./defaults"); @@ -523,15 +524,21 @@ class AgentHandler { async #loadAgents() { // Default User agent and workspace agent this.log(`Attaching user and default agent to Agent cluster.`); - this.aibitat.agent(USER_AGENT.name, await USER_AGENT.getDefinition()); - this.aibitat.agent( - WORKSPACE_AGENT.name, - await WORKSPACE_AGENT.getDefinition(this.provider) + const user = this.invocation.user_id + ? await User.get({ id: Number(this.invocation.user_id) }) + : null; + const userAgentDef = await USER_AGENT.getDefinition(); + const workspaceAgentDef = await WORKSPACE_AGENT.getDefinition( + this.provider, + this.invocation.workspace, + user ); + this.aibitat.agent(USER_AGENT.name, userAgentDef); + this.aibitat.agent(WORKSPACE_AGENT.name, workspaceAgentDef); this.#funcsToLoad = [ - ...((await USER_AGENT.getDefinition())?.functions || []), - ...((await WORKSPACE_AGENT.getDefinition())?.functions || []), + ...(userAgentDef?.functions || []), + ...(workspaceAgentDef?.functions || []), ]; }