diff --git a/server/package.json b/server/package.json index 8b765d46fc..99d018c7c3 100644 --- a/server/package.json +++ b/server/package.json @@ -22,7 +22,6 @@ "@anthropic-ai/sdk": "^0.39.0", "@aws-sdk/client-bedrock-runtime": "^3.775.0", "@datastax/astra-db-ts": "^0.1.3", - "@google/generative-ai": "^0.7.1", "@ladjs/graceful": "^3.2.2", "@lancedb/lancedb": "0.15.0", "@langchain/anthropic": "0.1.16", @@ -100,4 +99,4 @@ "nodemon": "^2.0.22", "prettier": "^3.0.3" } -} +} \ No newline at end of file diff --git a/server/utils/EmbeddingEngines/gemini/index.js b/server/utils/EmbeddingEngines/gemini/index.js index 4c60501a88..a7d50c9066 100644 --- a/server/utils/EmbeddingEngines/gemini/index.js +++ b/server/utils/EmbeddingEngines/gemini/index.js @@ -1,21 +1,22 @@ +const { toChunks } = require("../../helpers"); + class GeminiEmbedder { constructor() { if (!process.env.GEMINI_EMBEDDING_API_KEY) throw new Error("No Gemini API key was set."); - // TODO: Deprecate this and use OpenAI interface instead - after which, remove the @google/generative-ai dependency - const { GoogleGenerativeAI } = require("@google/generative-ai"); - const genAI = new GoogleGenerativeAI(process.env.GEMINI_EMBEDDING_API_KEY); + const { OpenAI: OpenAIApi } = require("openai"); this.model = process.env.EMBEDDING_MODEL_PREF || "text-embedding-004"; - this.gemini = genAI.getGenerativeModel({ model: this.model }); + this.openai = new OpenAIApi({ + apiKey: process.env.GEMINI_EMBEDDING_API_KEY, + // Even models that are v1 in gemini API can be used with v1beta/openai/ endpoint and nobody knows why. + baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/", + }); - // This property is disabled as it causes issues when sending multiple chunks at once - // since when given 4 chunks at once, the gemini api returns 1 embedding for all 4 chunks - // instead of 4 embeddings - no idea why this is the case, but it is not how the results are - // expected to be returned. - // this.maxConcurrentChunks = 1; + this.maxConcurrentChunks = 4; // https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding + // TODO: May need to make this dynamic based on the model this.embeddingMaxChunkLength = 2_048; this.log(`Initialized with ${this.model}`); } @@ -30,8 +31,10 @@ class GeminiEmbedder { * @returns {Promise>} The embedding values */ async embedTextInput(textInput) { - const result = await this.gemini.embedContent(textInput); - return result.embedding.values || []; + const result = await this.embedChunks( + Array.isArray(textInput) ? textInput : [textInput] + ); + return result?.[0] || []; } /** @@ -40,14 +43,66 @@ class GeminiEmbedder { * @returns {Promise>>} The embedding values */ async embedChunks(textChunks = []) { - let embeddings = []; - for (const chunk of textChunks) { - const results = await this.gemini.embedContent(chunk); - if (!results.embedding || !results.embedding.values) - throw new Error("No embedding values returned from gemini"); - embeddings.push(results.embedding.values); + this.log(`Embedding ${textChunks.length} chunks...`); + + // Because there is a hard POST limit on how many chunks can be sent at once to OpenAI (~8mb) + // we concurrently execute each max batch of text chunks possible. + // Refer to constructor maxConcurrentChunks for more info. + const embeddingRequests = []; + for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) { + embeddingRequests.push( + new Promise((resolve) => { + this.openai.embeddings + .create({ + model: this.model, + input: chunk, + }) + .then((result) => { + resolve({ data: result?.data, error: null }); + }) + .catch((e) => { + e.type = + e?.response?.data?.error?.code || + e?.response?.status || + "failed_to_embed"; + e.message = e?.response?.data?.error?.message || e.message; + resolve({ data: [], error: e }); + }); + }) + ); } - return embeddings; + + const { data = [], error = null } = await Promise.all( + embeddingRequests + ).then((results) => { + // If any errors were returned from OpenAI abort the entire sequence because the embeddings + // will be incomplete. + const errors = results + .filter((res) => !!res.error) + .map((res) => res.error) + .flat(); + if (errors.length > 0) { + let uniqueErrors = new Set(); + errors.map((error) => + uniqueErrors.add(`[${error.type}]: ${error.message}`) + ); + + return { + data: [], + error: Array.from(uniqueErrors).join(", "), + }; + } + return { + data: results.map((res) => res?.data || []).flat(), + error: null, + }; + }); + + if (!!error) throw new Error(`OpenAI Failed to embed: ${error}`); + return data.length > 0 && + data.every((embd) => embd.hasOwnProperty("embedding")) + ? data.map((embd) => embd.embedding) + : null; } } diff --git a/server/yarn.lock b/server/yarn.lock index ddc7f595d4..0f04a98179 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -1194,11 +1194,6 @@ resolved "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.3.tgz" integrity sha512-Cm4uJX1sKarpm1mje/MiOIinM7zdUUrQp/5/qGPAgznbdd/B9zup5ehT6c1qGqycFcSopTA1J1HpqHS5kJR8hQ== -"@google/generative-ai@^0.7.1": - version "0.7.1" - resolved "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.7.1.tgz" - integrity sha512-WTjMLLYL/xfA5BW6xAycRPiAX7FNHKAxrid/ayqC1QMam0KAK0NbMeS9Lubw80gVg5xFMLE+H7pw4wdNzTOlxw== - "@graphql-typed-document-node/core@^3.1.1": version "3.2.0" resolved "https://registry.npmjs.org/@graphql-typed-document-node/core/-/core-3.2.0.tgz"