diff --git a/frontend/src/components/Modals/ManageWorkspace/Documents/WorkspaceDirectory/index.jsx b/frontend/src/components/Modals/ManageWorkspace/Documents/WorkspaceDirectory/index.jsx index bca6632c8bc..8d7a5ce0f7f 100644 --- a/frontend/src/components/Modals/ManageWorkspace/Documents/WorkspaceDirectory/index.jsx +++ b/frontend/src/components/Modals/ManageWorkspace/Documents/WorkspaceDirectory/index.jsx @@ -235,7 +235,7 @@ function WorkspaceDirectory({ }`}

diff --git a/frontend/src/index.css b/frontend/src/index.css index c75b98d2846..fa88feab6b9 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -117,7 +117,7 @@ --theme-chat-input-border: #cccccc; --theme-action-menu-bg: #eaeaea; --theme-action-menu-item-hover: rgba(0, 0, 0, 0.1); - --theme-settings-input-bg: #EDF2FA; + --theme-settings-input-bg: #edf2fa; --theme-settings-input-placeholder: rgba(0, 0, 0, 0.5); --theme-settings-input-active: rgb(0 0 0 / 0.2); --theme-settings-input-text: #0e0f0f; diff --git a/server/package.json b/server/package.json index db153eb2145..8b765d46fc9 100644 --- a/server/package.json +++ b/server/package.json @@ -21,7 +21,6 @@ "dependencies": { "@anthropic-ai/sdk": "^0.39.0", "@aws-sdk/client-bedrock-runtime": "^3.775.0", - "@azure/openai": "1.0.0-beta.10", "@datastax/astra-db-ts": "^0.1.3", "@google/generative-ai": "^0.7.1", "@ladjs/graceful": "^3.2.2", @@ -67,7 +66,7 @@ "multer": "^1.4.5-lts.1", "mysql2": "^3.9.8", "ollama": "^0.5.10", - "openai": "4.38.5", + "openai": "4.95.1", "pg": "^8.11.5", "pinecone-client": "^1.1.0", "pluralize": "^8.0.0", diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index f15d6ecfdae..cc6712c94d2 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -1,29 +1,26 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { - LLMPerformanceMonitor, -} = require("../../helpers/chat/LLMPerformanceMonitor"); -const { - writeResponseChunk, - clientAbortedHandler, formatChatHistory, + handleDefaultStreamResponseV2, } = require("../../helpers/chat/responses"); +const { + LLMPerformanceMonitor, +} = require("../../helpers/chat/LLMPerformanceMonitor"); class AzureOpenAiLLM { constructor(embedder = null, modelPreference = null) { - const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); + const { AzureOpenAI } = require("openai"); if (!process.env.AZURE_OPENAI_ENDPOINT) throw new Error("No Azure API endpoint was set."); if (!process.env.AZURE_OPENAI_KEY) throw new Error("No Azure API key was set."); this.apiVersion = "2024-12-01-preview"; - this.openai = new OpenAIClient( - process.env.AZURE_OPENAI_ENDPOINT, - new AzureKeyCredential(process.env.AZURE_OPENAI_KEY), - { - apiVersion: this.apiVersion, - } - ); + this.openai = new AzureOpenAI({ + apiKey: process.env.AZURE_OPENAI_KEY, + apiVersion: this.apiVersion, + endpoint: process.env.AZURE_OPENAI_ENDPOINT, + }); this.model = modelPreference ?? process.env.OPEN_MODEL_PREF; this.isOTypeModel = process.env.AZURE_OPENAI_MODEL_TYPE === "reasoning" || false; @@ -139,7 +136,9 @@ class AzureOpenAiLLM { ); const result = await LLMPerformanceMonitor.measureAsyncFunction( - this.openai.getChatCompletions(this.model, messages, { + this.openai.chat.completions.create({ + messages, + model: this.model, ...(this.isOTypeModel ? {} : { temperature }), }) ); @@ -153,10 +152,10 @@ class AzureOpenAiLLM { return { textResponse: result.output.choices[0].message.content, metrics: { - prompt_tokens: result.output.usage.promptTokens || 0, - completion_tokens: result.output.usage.completionTokens || 0, - total_tokens: result.output.usage.totalTokens || 0, - outputTps: result.output.usage.completionTokens / result.duration, + prompt_tokens: result.output.usage.prompt_tokens || 0, + completion_tokens: result.output.usage.completion_tokens || 0, + total_tokens: result.output.usage.total_tokens || 0, + outputTps: result.output.usage.completion_tokens / result.duration, duration: result.duration, }, }; @@ -169,9 +168,12 @@ class AzureOpenAiLLM { ); const measuredStreamRequest = await LLMPerformanceMonitor.measureStream( - await this.openai.streamChatCompletions(this.model, messages, { + await this.openai.chat.completions.create({ + messages, + model: this.model, ...(this.isOTypeModel ? {} : { temperature }), n: 1, + stream: true, }), messages ); @@ -179,64 +181,8 @@ class AzureOpenAiLLM { return measuredStreamRequest; } - /** - * Handles the stream response from the AzureOpenAI API. - * Azure does not return the usage metrics in the stream response, but 1msg = 1token - * so we can estimate the completion tokens by counting the number of messages. - * @param {Object} response - the response object - * @param {import('../../helpers/chat/LLMPerformanceMonitor').MonitoredStream} stream - the stream response from the AzureOpenAI API w/tracking - * @param {Object} responseProps - the response properties - * @returns {Promise} - */ handleStream(response, stream, responseProps) { - const { uuid = uuidv4(), sources = [] } = responseProps; - - return new Promise(async (resolve) => { - let fullText = ""; - let usage = { - completion_tokens: 0, - }; - - // Establish listener to early-abort a streaming response - // in case things go sideways or the user does not like the response. - // We preserve the generated text but continue as if chat was completed - // to preserve previously generated content. - const handleAbort = () => { - stream?.endMeasurement(usage); - clientAbortedHandler(resolve, fullText); - }; - response.on("close", handleAbort); - - for await (const event of stream) { - for (const choice of event.choices) { - const delta = choice.delta?.content; - if (!delta) continue; - fullText += delta; - usage.completion_tokens++; - - writeResponseChunk(response, { - uuid, - sources: [], - type: "textResponseChunk", - textResponse: delta, - close: false, - error: false, - }); - } - } - - writeResponseChunk(response, { - uuid, - sources, - type: "textResponseChunk", - textResponse: "", - close: true, - error: false, - }); - response.removeListener("close", handleAbort); - stream?.endMeasurement(usage); - resolve(fullText); - }); + return handleDefaultStreamResponseV2(response, stream, responseProps); } // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index 587dad81f77..2e5996c939a 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -183,6 +183,7 @@ class OpenAiLLM { messages // runPromptTokenCalculation: true - We manually count the tokens because OpenAI does not provide them in the stream // since we are not using the OpenAI API version that supports this `stream_options` param. + // TODO: implement this once we upgrade to the OpenAI API version that supports this param. ); return measuredStreamRequest; diff --git a/server/utils/EmbeddingEngines/azureOpenAi/index.js b/server/utils/EmbeddingEngines/azureOpenAi/index.js index 62f69660a89..57907f45fd0 100644 --- a/server/utils/EmbeddingEngines/azureOpenAi/index.js +++ b/server/utils/EmbeddingEngines/azureOpenAi/index.js @@ -2,16 +2,22 @@ const { toChunks } = require("../../helpers"); class AzureOpenAiEmbedder { constructor() { - const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); + const { AzureOpenAI } = require("openai"); if (!process.env.AZURE_OPENAI_ENDPOINT) throw new Error("No Azure API endpoint was set."); if (!process.env.AZURE_OPENAI_KEY) throw new Error("No Azure API key was set."); - const openai = new OpenAIClient( - process.env.AZURE_OPENAI_ENDPOINT, - new AzureKeyCredential(process.env.AZURE_OPENAI_KEY) - ); + this.apiVersion = "2024-12-01-preview"; + const openai = new AzureOpenAI({ + apiKey: process.env.AZURE_OPENAI_KEY, + endpoint: process.env.AZURE_OPENAI_ENDPOINT, + apiVersion: this.apiVersion, + }); + + // We cannot assume the model fallback since the model is based on the deployment name + // and not the model name - so this will throw on embedding if the model is not defined. + this.model = process.env.EMBEDDING_MODEL_PREF; this.openai = openai; // Limit of how many strings we can process in a single pass to stay with resource or network limits @@ -22,6 +28,10 @@ class AzureOpenAiEmbedder { this.embeddingMaxChunkLength = 2048; } + log(text, ...args) { + console.log(`\x1b[36m[AzureOpenAiEmbedder]\x1b[0m ${text}`, ...args); + } + async embedTextInput(textInput) { const result = await this.embedChunks( Array.isArray(textInput) ? textInput : [textInput] @@ -30,13 +40,9 @@ class AzureOpenAiEmbedder { } async embedChunks(textChunks = []) { - const textEmbeddingModel = - process.env.EMBEDDING_MODEL_PREF || "text-embedding-ada-002"; - if (!textEmbeddingModel) - throw new Error( - "No EMBEDDING_MODEL_PREF ENV defined. This must the name of a deployment on your Azure account for an embedding model." - ); + if (!this.model) throw new Error("No Embedding Model preference defined."); + this.log(`Embedding ${textChunks.length} chunks...`); // Because there is a limit on how many chunks can be sent at once to Azure OpenAI // we concurrently execute each max batch of text chunks possible. // Refer to constructor maxConcurrentChunks for more info. @@ -44,8 +50,11 @@ class AzureOpenAiEmbedder { for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) { embeddingRequests.push( new Promise((resolve) => { - this.openai - .getEmbeddings(textEmbeddingModel, chunk) + this.openai.embeddings + .create({ + model: this.model, + input: chunk, + }) .then((res) => { resolve({ data: res.data, error: null }); }) diff --git a/server/utils/EmbeddingEngines/openAi/index.js b/server/utils/EmbeddingEngines/openAi/index.js index db3c92f96ca..9976ef54d14 100644 --- a/server/utils/EmbeddingEngines/openAi/index.js +++ b/server/utils/EmbeddingEngines/openAi/index.js @@ -16,6 +16,10 @@ class OpenAiEmbedder { this.embeddingMaxChunkLength = 8_191; } + log(text, ...args) { + console.log(`\x1b[36m[OpenAiEmbedder]\x1b[0m ${text}`, ...args); + } + async embedTextInput(textInput) { const result = await this.embedChunks( Array.isArray(textInput) ? textInput : [textInput] @@ -24,6 +28,8 @@ class OpenAiEmbedder { } async embedChunks(textChunks = []) { + this.log(`Embedding ${textChunks.length} chunks...`); + // Because there is a hard POST limit on how many chunks can be sent at once to OpenAI (~8mb) // we concurrently execute each max batch of text chunks possible. // Refer to constructor maxConcurrentChunks for more info. diff --git a/server/utils/agents/aibitat/providers/azure.js b/server/utils/agents/aibitat/providers/azure.js index a0e94e31264..48078f2db82 100644 --- a/server/utils/agents/aibitat/providers/azure.js +++ b/server/utils/agents/aibitat/providers/azure.js @@ -1,96 +1,94 @@ -const { OpenAIClient, AzureKeyCredential } = require("@azure/openai"); +const { AzureOpenAI } = require("openai"); const Provider = require("./ai-provider.js"); -const InheritMultiple = require("./helpers/classes.js"); -const UnTooled = require("./helpers/untooled.js"); +const { RetryError } = require("../error.js"); /** * The agent provider for the Azure OpenAI API. */ -class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { +class AzureOpenAiProvider extends Provider { model; - constructor(_config = {}) { - super(); - const client = new OpenAIClient( - process.env.AZURE_OPENAI_ENDPOINT, - new AzureKeyCredential(process.env.AZURE_OPENAI_KEY) - ); - this._client = client; - this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo"; + constructor(config = { model: null }) { + const client = new AzureOpenAI({ + apiKey: process.env.AZURE_OPENAI_KEY, + endpoint: process.env.AZURE_OPENAI_ENDPOINT, + apiVersion: "2024-12-01-preview", + }); + super(client); + this.model = config.model ?? process.env.OPEN_MODEL_PREF; this.verbose = true; } - - get client() { - return this._client; - } - - async #handleFunctionCallChat({ messages = [] }) { - return await this.client - .getChatCompletions(this.model, messages, { - temperature: 0, - }) - .then((result) => { - if (!result.hasOwnProperty("choices")) - throw new Error("Azure OpenAI chat: No results!"); - if (result.choices.length === 0) - throw new Error("Azure OpenAI chat: No results length!"); - return result.choices[0].message.content; - }) - .catch((_) => { - return null; - }); - } - /** * Create a completion based on the received messages. * - * @param messages A list of messages to send to the API. + * @param messages A list of messages to send to the OpenAI API. * @param functions * @returns The completion. */ async complete(messages, functions = []) { try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; + const response = await this.client.chat.completions.create({ + model: this.model, + // stream: true, + messages, + ...(Array.isArray(functions) && functions?.length > 0 + ? { functions } + : {}), + }); + + // Right now, we only support one completion, + // so we just take the first one in the list + const completion = response.choices[0].message; + const cost = this.getCost(response.usage); + // treat function calls + if (completion.function_call) { + let functionArgs = {}; + try { + functionArgs = JSON.parse(completion.function_call.arguments); + } catch (error) { + // call the complete function again in case it gets a json error + return this.complete( + [ + ...messages, + { + role: "function", + name: completion.function_call.name, + function_call: completion.function_call, + content: error?.message, + }, + ], + functions + ); } - completion = { content: text }; - } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.getChatCompletions( - this.model, - this.cleanMsgs(messages), - { - temperature: 0.7, - } - ); - completion = response.choices[0].message; + + // console.log(completion, { functionArgs }) + return { + result: null, + functionCall: { + name: completion.function_call.name, + arguments: functionArgs, + }, + cost, + }; } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { result: completion.content, cost: 0 }; + return { + result: completion.content, + cost, + }; } catch (error) { + // If invalid Auth error we need to abort because no amount of waiting + // will make auth better. + if (error instanceof AzureOpenAI.AuthenticationError) throw error; + + if ( + error instanceof AzureOpenAI.RateLimitError || + error instanceof AzureOpenAI.InternalServerError || + error instanceof AzureOpenAI.APIError // Also will catch AuthenticationError!!! + ) { + throw new RetryError(error.message); + } + throw error; } } diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index c2765b22796..77666ec7dae 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -221,7 +221,7 @@ class AgentHandler { "mistralai/Mixtral-8x7B-Instruct-v0.1" ); case "azure": - return null; + return process.env.OPEN_MODEL_PREF; case "koboldcpp": return process.env.KOBOLD_CPP_MODEL_PREF ?? null; case "localai": diff --git a/server/yarn.lock b/server/yarn.lock index d263a84b1b7..ddc7f595d4b 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -961,18 +961,6 @@ "@smithy/types" "^4.2.0" tslib "^2.6.2" -"@azure-rest/core-client@^1.1.7": - version "1.4.0" - resolved "https://registry.npmjs.org/@azure-rest/core-client/-/core-client-1.4.0.tgz" - integrity sha512-ozTDPBVUDR5eOnMIwhggbnVmOrka4fXCs8n8mvUo4WLLc38kki6bAOByDoVZZPz/pZy2jMt2kwfpvy/UjALj6w== - dependencies: - "@azure/abort-controller" "^2.0.0" - "@azure/core-auth" "^1.3.0" - "@azure/core-rest-pipeline" "^1.5.0" - "@azure/core-tracing" "^1.0.1" - "@azure/core-util" "^1.0.0" - tslib "^2.6.2" - "@azure/abort-controller@^1.0.0": version "1.1.0" resolved "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-1.1.0.tgz" @@ -1049,27 +1037,6 @@ https-proxy-agent "^7.0.0" tslib "^2.6.2" -"@azure/core-rest-pipeline@^1.13.0", "@azure/core-rest-pipeline@^1.5.0": - version "1.15.2" - resolved "https://registry.npmjs.org/@azure/core-rest-pipeline/-/core-rest-pipeline-1.15.2.tgz" - integrity sha512-BmWfpjc/QXc2ipHOh6LbUzp3ONCaa6xzIssTU0DwH9bbYNXJlGUL6tujx5TrbVd/QQknmS+vlQJGrCq2oL1gZA== - dependencies: - "@azure/abort-controller" "^2.0.0" - "@azure/core-auth" "^1.4.0" - "@azure/core-tracing" "^1.0.1" - "@azure/core-util" "^1.3.0" - "@azure/logger" "^1.0.0" - http-proxy-agent "^7.0.0" - https-proxy-agent "^7.0.0" - tslib "^2.6.2" - -"@azure/core-sse@^2.0.0": - version "2.1.2" - resolved "https://registry.npmjs.org/@azure/core-sse/-/core-sse-2.1.2.tgz" - integrity sha512-yf+pFIu8yCzXu9RbH2+8kp9vITIKJLHgkLgFNA6hxiDHK3fxeP596cHUj4c8Cm8JlooaUnYdHmF84KCZt3jbmw== - dependencies: - tslib "^2.6.2" - "@azure/core-tracing@^1.0.0", "@azure/core-tracing@^1.0.1": version "1.1.2" resolved "https://registry.npmjs.org/@azure/core-tracing/-/core-tracing-1.1.2.tgz" @@ -1077,7 +1044,7 @@ dependencies: tslib "^2.6.2" -"@azure/core-util@^1.0.0", "@azure/core-util@^1.1.0", "@azure/core-util@^1.2.0", "@azure/core-util@^1.3.0", "@azure/core-util@^1.4.0", "@azure/core-util@^1.6.1", "@azure/core-util@^1.9.0": +"@azure/core-util@^1.0.0", "@azure/core-util@^1.1.0", "@azure/core-util@^1.2.0", "@azure/core-util@^1.6.1", "@azure/core-util@^1.9.0": version "1.9.0" resolved "https://registry.npmjs.org/@azure/core-util/-/core-util-1.9.0.tgz" integrity sha512-AfalUQ1ZppaKuxPPMsFEUdX6GZPB3d9paR9d/TTL7Ow2De8cJaC7ibi7kWVlFAVPCYo31OcnGymc0R89DX8Oaw== @@ -1122,7 +1089,7 @@ "@azure/logger" "^1.0.0" tslib "^2.2.0" -"@azure/logger@^1.0.0", "@azure/logger@^1.0.3": +"@azure/logger@^1.0.0": version "1.1.2" resolved "https://registry.npmjs.org/@azure/logger/-/logger-1.1.2.tgz" integrity sha512-l170uE7bsKpIU6B/giRc9i4NI0Mj+tANMMMxf7Zi/5cKzEqPayP7+X1WPrG7e+91JgY8N+7K7nF2WOi7iVhXvg== @@ -1150,19 +1117,6 @@ jsonwebtoken "^9.0.0" uuid "^8.3.0" -"@azure/openai@1.0.0-beta.10": - version "1.0.0-beta.10" - resolved "https://registry.npmjs.org/@azure/openai/-/openai-1.0.0-beta.10.tgz" - integrity sha512-6kixZSMOI5jk9TBwgXrVo5fKUPUudOXxjwCJvAGaQN6NT1Tp3IMrjGou+2iP9iX7GwND9lptxfvafHtK7RX/VA== - dependencies: - "@azure-rest/core-client" "^1.1.7" - "@azure/core-auth" "^1.4.0" - "@azure/core-rest-pipeline" "^1.13.0" - "@azure/core-sse" "^2.0.0" - "@azure/core-util" "^1.4.0" - "@azure/logger" "^1.0.3" - tslib "^2.4.0" - "@babel/runtime@^7.10.5": version "7.24.7" resolved "https://registry.npmjs.org/@babel/runtime/-/runtime-7.24.7.tgz" @@ -6410,7 +6364,20 @@ open@^8.0.0: is-docker "^2.1.1" is-wsl "^2.2.0" -openai@4.38.5, openai@^4.0.0: +openai@4.95.1: + version "4.95.1" + resolved "https://registry.yarnpkg.com/openai/-/openai-4.95.1.tgz#7157697c2b150a546b13eb860180c4a6058051da" + integrity sha512-IqJy+ymeW+k/Wq+2YVN3693OQMMcODRtHEYOlz263MdUwnN/Dwdl9c2EXSxLLtGEHkSHAfvzpDMHI5MaWJKXjQ== + dependencies: + "@types/node" "^18.11.18" + "@types/node-fetch" "^2.6.4" + abort-controller "^3.0.0" + agentkeepalive "^4.2.1" + form-data-encoder "1.7.2" + formdata-node "^4.3.2" + node-fetch "^2.6.7" + +openai@^4.0.0: version "4.38.5" resolved "https://registry.npmjs.org/openai/-/openai-4.38.5.tgz" integrity sha512-Ym5GJL98ZhLJJ7enBx53jjG3vwN/fsB+Ozh46nnRZZS9W1NiYqbwkJ+sXd3dkCIiWIgcyyOPL2Zr8SQAzbpj3g== @@ -7697,7 +7664,7 @@ truncate@^3.0.0: resolved "https://registry.npmjs.org/truncate/-/truncate-3.0.0.tgz" integrity sha512-C+0Xojw7wZPl6MDq5UjMTuxZvBPK04mtdFet7k+GSZPINcvLZFCXg+15kWIL4wAqDB7CksIsKiRLbQ1wa7rKdw== -tslib@^2.2.0, tslib@^2.4.0, tslib@^2.6.2: +tslib@^2.2.0, tslib@^2.6.2: version "2.6.2" resolved "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz" integrity sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==