diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx index 79c72fffa6..105487bf3b 100644 --- a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx @@ -1,4 +1,10 @@ +import System from "@/models/system"; +import { useEffect, useState } from "react"; + export default function GeminiLLMOptions({ settings }) { + const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey); + const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey); + return (
@@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) { required={true} autoComplete="off" spellCheck={false} + onChange={(e) => setInputValue(e.target.value)} + onBlur={() => setGeminiApiKey(inputValue)} />
{!settings?.credentialsOnly && ( <> -
- - -
+
); } + +function GeminiModelSelection({ apiKey, settings }) { + const [groupedModels, setGroupedModels] = useState({}); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function findCustomModels() { + setLoading(true); + const { models } = await System.customModels("gemini", apiKey); + + if (models?.length > 0) { + const modelsByOrganization = models.reduce((acc, model) => { + acc[model.experimental ? "Experimental" : "Stable"] = + acc[model.experimental ? "Experimental" : "Stable"] || []; + acc[model.experimental ? "Experimental" : "Stable"].push(model); + return acc; + }, {}); + setGroupedModels(modelsByOrganization); + } + setLoading(false); + } + findCustomModels(); + }, [apiKey]); + + if (loading) { + return ( +
+ + +
+ ); + } + + return ( +
+ + +
+ ); +} diff --git a/server/utils/AiProviders/gemini/defaultModals.js b/server/utils/AiProviders/gemini/defaultModals.js new file mode 100644 index 0000000000..303a0aafff --- /dev/null +++ b/server/utils/AiProviders/gemini/defaultModals.js @@ -0,0 +1,46 @@ +const { MODEL_MAP } = require("../modelMap"); + +const stableModels = [ + "gemini-pro", + "gemini-1.0-pro", + "gemini-1.5-pro-latest", + "gemini-1.5-flash-latest", +]; + +const experimentalModels = [ + "gemini-1.5-pro-exp-0801", + "gemini-1.5-pro-exp-0827", + "gemini-1.5-flash-exp-0827", + "gemini-1.5-flash-8b-exp-0827", + "gemini-exp-1114", + "gemini-exp-1121", + "gemini-exp-1206", + "learnlm-1.5-pro-experimental", + "gemini-2.0-flash-exp", +]; + +// There are some models that are only available in the v1beta API +// and some models that are only available in the v1 API +// generally, v1beta models have `exp` in the name, but not always +// so we check for both against a static list as well. +const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"]; + +const defaultGeminiModels = [ + ...stableModels.map((model) => ({ + id: model, + name: model, + contextWindow: MODEL_MAP.gemini[model], + experimental: false, + })), + ...experimentalModels.map((model) => ({ + id: model, + name: model, + contextWindow: MODEL_MAP.gemini[model], + experimental: true, + })), +]; + +module.exports = { + defaultGeminiModels, + v1BetaModels, +}; diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index f658b3c5f1..3554a51fcf 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -7,6 +7,7 @@ const { clientAbortedHandler, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); +const { defaultGeminiModels, v1BetaModels } = require("./defaultModals"); class GeminiLLM { constructor(embedder = null, modelPreference = null) { @@ -21,22 +22,17 @@ class GeminiLLM { this.gemini = genAI.getGenerativeModel( { model: this.model }, { - // Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API. - apiVersion: [ - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ].includes(this.model) - ? "v1beta" - : "v1", + apiVersion: + /** + * There are some models that are only available in the v1beta API + * and some models that are only available in the v1 API + * generally, v1beta models have `exp` in the name, but not always + * so we check for both against a static list as well. + * @see {v1BetaModels} + */ + this.model.includes("exp") || v1BetaModels.includes(this.model) + ? "v1beta" + : "v1", } ); this.limits = { @@ -48,6 +44,11 @@ class GeminiLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; // not used for Gemini this.safetyThreshold = this.#fetchSafetyThreshold(); + this.#log(`Initialized with model: ${this.model}`); + } + + #log(text, ...args) { + console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args); } #appendContext(contextTexts = []) { @@ -109,25 +110,63 @@ class GeminiLLM { return MODEL_MAP.gemini[this.model] ?? 30_720; } - isValidChatCompletionModel(modelName = "") { - const validModels = [ - "gemini-pro", - "gemini-1.0-pro", - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ]; - return validModels.includes(modelName); + /** + * Fetches Gemini models from the Google Generative AI API + * @param {string} apiKey - The API key to use for the request + * @param {number} limit - The maximum number of models to fetch + * @param {string} pageToken - The page token to use for pagination + * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models + */ + static async fetchModels(apiKey, limit = 1_000, pageToken = null) { + const url = new URL( + "https://generativelanguage.googleapis.com/v1beta/models" + ); + url.searchParams.set("pageSize", limit); + url.searchParams.set("key", apiKey); + if (pageToken) url.searchParams.set("pageToken", pageToken); + + return fetch(url.toString(), { + method: "GET", + headers: { "Content-Type": "application/json" }, + }) + .then((res) => res.json()) + .then((data) => { + if (data.error) throw new Error(data.error.message); + return data.models ?? []; + }) + .then((models) => + models + .filter( + (model) => !model.displayName.toLowerCase().includes("tuning") + ) + .filter((model) => + model.supportedGenerationMethods.includes("generateContent") + ) // Only generateContent is supported + .map((model) => { + return { + id: model.name.split("/").pop(), + name: model.displayName, + contextWindow: model.inputTokenLimit, + experimental: model.name.includes("exp"), + }; + }) + ) + .catch((e) => { + console.error(`Gemini:getGeminiModels`, e.message); + return defaultGeminiModels; + }); } + /** + * Checks if a model is valid for chat completion (unused) + * @deprecated + * @param {string} modelName - The name of the model to check + * @returns {Promise} A promise that resolves to a boolean indicating if the model is valid + */ + async isValidChatCompletionModel(modelName = "") { + const models = await this.fetchModels(true); + return models.some((model) => model.id === modelName); + } /** * Generates appropriate content array for a message + attachments. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} @@ -218,11 +257,6 @@ class GeminiLLM { } async getChatCompletion(messages = [], _opts = {}) { - if (!this.isValidChatCompletionModel(this.model)) - throw new Error( - `Gemini chat: ${this.model} is not valid for chat completion!` - ); - const prompt = messages.find( (chat) => chat.role === "USER_PROMPT" )?.content; @@ -256,11 +290,6 @@ class GeminiLLM { } async streamGetChatCompletion(messages = [], _opts = {}) { - if (!this.isValidChatCompletionModel(this.model)) - throw new Error( - `Gemini chat: ${this.model} is not valid for chat completion!` - ); - const prompt = messages.find( (chat) => chat.role === "USER_PROMPT" )?.content; diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index a763635fb5..7adb276261 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs"); const { fetchNovitaModels } = require("../AiProviders/novita"); const { parseLMStudioBasePath } = require("../AiProviders/lmStudio"); const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim"); +const { GeminiLLM } = require("../AiProviders/gemini"); const SUPPORT_CUSTOM_MODELS = [ "openai", @@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [ "apipie", "novita", "xai", + "gemini", ]; async function getCustomModels(provider = "", apiKey = null, basePath = null) { @@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { return await getXAIModels(apiKey); case "nvidia-nim": return await getNvidiaNimModels(basePath); + case "gemini": + return await getGeminiModels(apiKey); default: return { models: [], error: "Invalid provider for custom models" }; } @@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) { } } +async function getGeminiModels(_apiKey = null) { + const apiKey = + _apiKey === true + ? process.env.GEMINI_API_KEY + : _apiKey || process.env.GEMINI_API_KEY || null; + const models = await GeminiLLM.fetchModels(apiKey); + // Api Key was successful so lets save it for future uses + if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey; + return { models, error: null }; +} + module.exports = { getCustomModels, }; diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 948703dca2..da30b6ee0d 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -52,7 +52,7 @@ const KEY_MAPPING = { }, GeminiLLMModelPref: { envKey: "GEMINI_LLM_MODEL_PREF", - checks: [isNotEmpty, validGeminiModel], + checks: [isNotEmpty], }, GeminiSafetySetting: { envKey: "GEMINI_SAFETY_SETTING", @@ -724,27 +724,6 @@ function supportedTranscriptionProvider(input = "") { : `${input} is not a valid transcription model provider.`; } -function validGeminiModel(input = "") { - const validModels = [ - "gemini-pro", - "gemini-1.0-pro", - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ]; - return validModels.includes(input) - ? null - : `Invalid Model type. Must be one of ${validModels.join(", ")}.`; -} - function validGeminiSafetySetting(input = "") { const validModes = [ "BLOCK_NONE",