diff --git a/.github/workflows/dev-build.yaml b/.github/workflows/dev-build.yaml index ada7979a2b7..ef1d99f5323 100644 --- a/.github/workflows/dev-build.yaml +++ b/.github/workflows/dev-build.yaml @@ -6,7 +6,7 @@ concurrency: on: push: - branches: ['improve-url-handler-collector'] # put your current branch to create a build. Core team only. + branches: ['gemini-migration-agents'] # put your current branch to create a build. Core team only. paths-ignore: - '**.md' - 'cloud-deployments/*' diff --git a/server/utils/agents/aibitat/providers/gemini.js b/server/utils/agents/aibitat/providers/gemini.js index 101d970c848..53d3113a32a 100644 --- a/server/utils/agents/aibitat/providers/gemini.js +++ b/server/utils/agents/aibitat/providers/gemini.js @@ -1,19 +1,14 @@ const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); -const InheritMultiple = require("./helpers/classes.js"); -const UnTooled = require("./helpers/untooled.js"); -const { - NO_SYSTEM_PROMPT_MODELS, -} = require("../../../AiProviders/gemini/index.js"); -const { APIError } = require("../error.js"); -const { v4 } = require("uuid"); +const { RetryError } = require("../error.js"); const { safeJsonParse } = require("../../../http"); +const { v4 } = require("uuid"); /** * The agent provider for the Gemini provider. * We wrap Gemini in UnTooled because its tool-calling is not supported via the dedicated OpenAI API. */ -class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { +class GeminiProvider extends Provider { model; constructor(config = {}) { @@ -35,6 +30,11 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsToolCalling() { + if (!this.model.startsWith("gemini")) return false; + return true; + } + get supportsAgentStreaming() { // Tool call streaming results in a 400/503 error for all non-gemini models // using the compatible v1beta/openai/ endpoint @@ -44,205 +44,250 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { ); return false; } - return true; } /** - * Format the messages to the format required by the Gemini API since some models do not support system prompts. - * @see {NO_SYSTEM_PROMPT_MODELS} - * @param {import("openai").OpenAI.ChatCompletionMessage[]} messages - * @returns {import("openai").OpenAI.ChatCompletionMessage[]} + * Gemini specifcally will throw an error if the tool call's function name + * starts with a non-alpha character. So we need to prefix the function names + * with a valid prefix to ensure they are always valid and then strip them back + * so they may properly be used in the tool call. + * + * So for all tools, we force the prefix to be gtc__ to avoid issues + * Agent flows are already prefixed with flow__ but since we strip the prefix + * anyway pre and post-reply, we do it anyway to ensure consistency across all tools. + * + * This specifically impacts the custom Agent Skills since they can be a short alphanumeric + * and cant definitely start with a number. eg: '12xdaya31bas' -> invalid in gemini tools. + * + * Even if the tool is never called, if it is in the `tools` array and this prefix + * patch is not applied, gemini will throw an error. + * + * This is undocumented by google, but it is the only way to ensure that tool calls + * are valid. + * + * @param {string} functionName - The name of the function to prefix. + * @param {'add' | 'strip'} action - The action to take. + * @returns {string} The prefixed function name. + * @returns {string} The prefix to use for tool call ids. + */ + prefixToolCall(functionName, action = "add") { + if (action === "add") return `gtc__${functionName}`; + // must start with gtc__ to be valid and we only strip the first instance + return functionName.startsWith("gtc__") + ? functionName.split("gtc__")[1] + : functionName; + } + + /** + * Format the messages to the Gemini API Responses format. + * - Gemini has some loosely documented format for tool calls and it can change at any time. + * - We need to map the function call to the correct id and Gemini will throw an error if it does not. + * @param {any[]} messages - The messages to format. + * @returns {OpenAI.OpenAI.Responses.ResponseInput[]} The formatted messages. */ - formatMessages(messages) { - if (!NO_SYSTEM_PROMPT_MODELS.includes(this.model)) return messages; + #formatMessages(messages) { + let formattedMessages = []; + messages.forEach((message) => { + if (message.role === "function") { + // If the message does not have an originalFunctionCall we cannot + // map it to a function call id and Gemini will throw an error. + // so if this does not carry over - log and skip + if (!message.hasOwnProperty("originalFunctionCall")) { + this.providerLog( + "[Gemini.#formatMessages]: message did not pass back the originalFunctionCall. We need this to map the function call to the correct id.", + { message: JSON.stringify(message, null, 2) } + ); + return; + } - // Replace the system message with a user/assistant message pair - const formattedMessages = []; - for (const message of messages) { - if (message.role === "system") { - formattedMessages.push({ - role: "user", - content: message.content, - }); - formattedMessages.push({ - role: "assistant", - content: "Okay, I'll follow your instructions.", - }); - continue; + formattedMessages.push( + { + role: "assistant", + tool_calls: [ + { + type: "function", + function: { + arguments: JSON.stringify( + message.originalFunctionCall.arguments + ), + name: message.originalFunctionCall.name, + }, + id: message.originalFunctionCall.id, + }, + ], + }, + { + role: "tool", + tool_call_id: message.originalFunctionCall.id, + content: message.content, + } + ); + return; } - formattedMessages.push(message); - } + + formattedMessages.push({ + role: message.role, + content: message.content, + }); + }); + return formattedMessages; } - /** - * Format the functions for the LLM. - * @param {any[]} functions - The functions to format. - * @returns {any[]} - The formatted functions. - */ - formatFunctions(functions = []) { - return functions.map((fn) => ({ + #formatFunctions(functions) { + return functions.map((func) => ({ type: "function", function: { - name: fn.name, - description: fn.description, - parameters: { - type: "object", - properties: fn.parameters.properties, - }, + name: this.prefixToolCall(func.name, "add"), + description: func.description, + parameters: func.parameters, }, })); } - async #handleFunctionCallChat({ messages = [] }) { - return await this.client.chat.completions - .create({ + async stream(messages, functions = [], eventHandler = null) { + if (!this.supportsToolCalling) + throw new Error(`Gemini: ${this.model} does not support tool calling.`); + this.providerLog("Gemini.stream - will process this chat completion."); + try { + const msgUUID = v4(); + /** @type {OpenAI.OpenAI.Chat.ChatCompletion} */ + const response = await this.client.chat.completions.create({ model: this.model, - messages: this.cleanMsgs(this.formatMessages(messages)), - }) - .then((result) => { - if (!result.hasOwnProperty("choices")) - throw new Error("Gemini chat: No results!"); - if (result.choices.length === 0) - throw new Error("Gemini chat: No results length!"); - return result.choices[0].message.content; - }) - .catch((_) => { - return null; + messages: this.#formatMessages(messages), + stream: true, + ...(Array.isArray(functions) && functions?.length > 0 + ? { tools: this.#formatFunctions(functions), tool_choice: "auto" } + : {}), }); - } - /** - * Streaming for Gemini only supports `tools` and not `functions`, so - * we need to apply some transformations to the messages and functions. - * - * @see {formatFunctions} - * @param {*} messages - * @param {*} functions - * @param {*} eventHandler - * @returns - */ - async stream(messages, functions = [], eventHandler = null) { - const msgUUID = v4(); - const stream = await this.client.chat.completions.create({ - model: this.model, - stream: true, - messages: this.cleanMsgs(this.formatMessages(messages)), - ...(Array.isArray(functions) && functions?.length > 0 - ? { - tools: this.formatFunctions(functions), - tool_choice: "auto", - } - : {}), - }); - - const result = { - functionCall: null, - textResponse: "", - }; + const completion = { + content: "", + /** @type {null|{name: string, call_id: string, arguments: string|object}} */ + functionCall: null, + }; - for await (const chunk of stream) { - if (!chunk?.choices?.[0]) continue; // Skip if no choices - const choice = chunk.choices[0]; + for await (const streamEvent of response) { + /** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */ + const chunk = streamEvent; + const { content, tool_calls } = chunk?.choices?.[0]?.delta || {}; - if (choice.delta?.content) { - result.textResponse += choice.delta.content; - eventHandler?.("reportStreamEvent", { - type: "textResponseChunk", - uuid: msgUUID, - content: choice.delta.content, - }); - } + if (content) { + completion.content += content; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content, + }); + } - if (choice.delta?.tool_calls && choice.delta.tool_calls.length > 0) { - const toolCall = choice.delta.tool_calls[0]; - if (result.functionCall) - result.functionCall.arguments += toolCall.function.arguments; - else { - result.functionCall = { - name: toolCall.function.name, + if (tool_calls) { + const toolCall = tool_calls[0]; + completion.functionCall = { + name: this.prefixToolCall(toolCall.function.name, "strip"), + call_id: toolCall.id, arguments: toolCall.function.arguments, }; + eventHandler?.("reportStreamEvent", { + type: "toolCallInvocation", + uuid: `${msgUUID}:tool_call_invocation`, + content: `Assembling Tool Call: ${completion.functionCall.name}(${completion.functionCall.arguments})`, + }); } + } + + if (completion.functionCall) { + completion.functionCall.arguments = safeJsonParse( + completion.functionCall.arguments, + {} + ); + return { + textResponse: completion.content, + functionCall: { + id: completion.functionCall.call_id, + name: completion.functionCall.name, + arguments: completion.functionCall.arguments, + }, + cost: this.getCost(), + }; + } - eventHandler?.("reportStreamEvent", { - uuid: `${msgUUID}:tool_call_invocation`, - type: "toolCallInvocation", - content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`, - }); + return { + textResponse: completion.content, + functionCall: null, + cost: this.getCost(), + }; + } catch (error) { + if (error instanceof OpenAI.AuthenticationError) throw error; + if ( + error instanceof OpenAI.RateLimitError || + error instanceof OpenAI.InternalServerError || + error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! + ) { + throw new RetryError(error.message); } - } - // If there are arguments, parse them as json so that the tools can use them - if (!!result.functionCall?.arguments) - result.functionCall.arguments = safeJsonParse( - result.functionCall.arguments, - {} - ); - return result; + throw error; + } } /** * Create a completion based on the received messages. * - * TODO: see stream() - tool_calls are now supported, so we can use that instead of Untooled - * - * @param messages A list of messages to send to the API. + * @param messages A list of messages to send to the Gemini API. * @param functions * @returns The completion. */ async complete(messages, functions = []) { + if (!this.supportsToolCalling) + throw new Error(`Gemini: ${this.model} does not support tool calling.`); + this.providerLog("Gemini.complete - will process this chat completion."); try { - let completion; - - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - this.cleanMsgs(this.formatMessages(messages)), - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + const response = await this.client.chat.completions.create({ + model: this.model, + stream: false, + messages: this.#formatMessages(messages), + ...(Array.isArray(functions) && functions?.length > 0 + ? { tools: this.#formatFunctions(functions), tool_choice: "auto" } + : {}), + }); - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(this.formatMessages(messages)), - }); - completion = response.choices[0].message; + /** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */ + const completion = response.choices[0].message; + const cost = this.getCost(response.usage); + if (completion?.tool_calls?.length > 0) { + const toolCall = completion.tool_calls[0]; + let functionArgs = safeJsonParse(toolCall.function.arguments, {}); + return { + textResponse: null, + functionCall: { + name: this.prefixToolCall(toolCall.function.name, "strip"), + arguments: functionArgs, + id: toolCall.id, + }, + cost, + }; } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); return { textResponse: completion.content, - cost: 0, + cost, }; } catch (error) { - throw new APIError( - error?.message - ? `${this.className} encountered an error while executing the request: ${error.message}` - : "There was an error with the Gemini provider executing the request" - ); + // If invalid Auth error we need to abort because no amount of waiting + // will make auth better. + if (error instanceof OpenAI.AuthenticationError) throw error; + + if ( + error instanceof OpenAI.RateLimitError || + error instanceof OpenAI.InternalServerError || + error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! + ) { + throw new RetryError(error.message); + } + + throw error; } }