diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
index 79c72fffa6..105487bf3b 100644
--- a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
+++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
@@ -1,4 +1,10 @@
+import System from "@/models/system";
+import { useEffect, useState } from "react";
+
export default function GeminiLLMOptions({ settings }) {
+ const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey);
+ const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey);
+
return (
@@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) {
required={true}
autoComplete="off"
spellCheck={false}
+ onChange={(e) => setInputValue(e.target.value)}
+ onBlur={() => setGeminiApiKey(inputValue)}
/>
{!settings?.credentialsOnly && (
<>
-
-
-
-
+
);
}
+
+function GeminiModelSelection({ apiKey, settings }) {
+ const [groupedModels, setGroupedModels] = useState({});
+ const [loading, setLoading] = useState(true);
+
+ useEffect(() => {
+ async function findCustomModels() {
+ setLoading(true);
+ const { models } = await System.customModels("gemini", apiKey);
+
+ if (models?.length > 0) {
+ const modelsByOrganization = models.reduce((acc, model) => {
+ acc[model.experimental ? "Experimental" : "Stable"] =
+ acc[model.experimental ? "Experimental" : "Stable"] || [];
+ acc[model.experimental ? "Experimental" : "Stable"].push(model);
+ return acc;
+ }, {});
+ setGroupedModels(modelsByOrganization);
+ }
+ setLoading(false);
+ }
+ findCustomModels();
+ }, [apiKey]);
+
+ if (loading) {
+ return (
+
+
+
+
+ );
+ }
+
+ return (
+
+
+
+
+ );
+}
diff --git a/server/utils/AiProviders/gemini/defaultModals.js b/server/utils/AiProviders/gemini/defaultModals.js
new file mode 100644
index 0000000000..303a0aafff
--- /dev/null
+++ b/server/utils/AiProviders/gemini/defaultModals.js
@@ -0,0 +1,46 @@
+const { MODEL_MAP } = require("../modelMap");
+
+const stableModels = [
+ "gemini-pro",
+ "gemini-1.0-pro",
+ "gemini-1.5-pro-latest",
+ "gemini-1.5-flash-latest",
+];
+
+const experimentalModels = [
+ "gemini-1.5-pro-exp-0801",
+ "gemini-1.5-pro-exp-0827",
+ "gemini-1.5-flash-exp-0827",
+ "gemini-1.5-flash-8b-exp-0827",
+ "gemini-exp-1114",
+ "gemini-exp-1121",
+ "gemini-exp-1206",
+ "learnlm-1.5-pro-experimental",
+ "gemini-2.0-flash-exp",
+];
+
+// There are some models that are only available in the v1beta API
+// and some models that are only available in the v1 API
+// generally, v1beta models have `exp` in the name, but not always
+// so we check for both against a static list as well.
+const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"];
+
+const defaultGeminiModels = [
+ ...stableModels.map((model) => ({
+ id: model,
+ name: model,
+ contextWindow: MODEL_MAP.gemini[model],
+ experimental: false,
+ })),
+ ...experimentalModels.map((model) => ({
+ id: model,
+ name: model,
+ contextWindow: MODEL_MAP.gemini[model],
+ experimental: true,
+ })),
+];
+
+module.exports = {
+ defaultGeminiModels,
+ v1BetaModels,
+};
diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js
index f658b3c5f1..3554a51fcf 100644
--- a/server/utils/AiProviders/gemini/index.js
+++ b/server/utils/AiProviders/gemini/index.js
@@ -7,6 +7,7 @@ const {
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
+const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");
class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
@@ -21,22 +22,17 @@ class GeminiLLM {
this.gemini = genAI.getGenerativeModel(
{ model: this.model },
{
- // Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API.
- apiVersion: [
- "gemini-1.5-pro-latest",
- "gemini-1.5-flash-latest",
- "gemini-1.5-pro-exp-0801",
- "gemini-1.5-pro-exp-0827",
- "gemini-1.5-flash-exp-0827",
- "gemini-1.5-flash-8b-exp-0827",
- "gemini-exp-1114",
- "gemini-exp-1121",
- "gemini-exp-1206",
- "learnlm-1.5-pro-experimental",
- "gemini-2.0-flash-exp",
- ].includes(this.model)
- ? "v1beta"
- : "v1",
+ apiVersion:
+ /**
+ * There are some models that are only available in the v1beta API
+ * and some models that are only available in the v1 API
+ * generally, v1beta models have `exp` in the name, but not always
+ * so we check for both against a static list as well.
+ * @see {v1BetaModels}
+ */
+ this.model.includes("exp") || v1BetaModels.includes(this.model)
+ ? "v1beta"
+ : "v1",
}
);
this.limits = {
@@ -48,6 +44,11 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold();
+ this.#log(`Initialized with model: ${this.model}`);
+ }
+
+ #log(text, ...args) {
+ console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
}
#appendContext(contextTexts = []) {
@@ -109,25 +110,63 @@ class GeminiLLM {
return MODEL_MAP.gemini[this.model] ?? 30_720;
}
- isValidChatCompletionModel(modelName = "") {
- const validModels = [
- "gemini-pro",
- "gemini-1.0-pro",
- "gemini-1.5-pro-latest",
- "gemini-1.5-flash-latest",
- "gemini-1.5-pro-exp-0801",
- "gemini-1.5-pro-exp-0827",
- "gemini-1.5-flash-exp-0827",
- "gemini-1.5-flash-8b-exp-0827",
- "gemini-exp-1114",
- "gemini-exp-1121",
- "gemini-exp-1206",
- "learnlm-1.5-pro-experimental",
- "gemini-2.0-flash-exp",
- ];
- return validModels.includes(modelName);
+ /**
+ * Fetches Gemini models from the Google Generative AI API
+ * @param {string} apiKey - The API key to use for the request
+ * @param {number} limit - The maximum number of models to fetch
+ * @param {string} pageToken - The page token to use for pagination
+ * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
+ */
+ static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
+ const url = new URL(
+ "https://generativelanguage.googleapis.com/v1beta/models"
+ );
+ url.searchParams.set("pageSize", limit);
+ url.searchParams.set("key", apiKey);
+ if (pageToken) url.searchParams.set("pageToken", pageToken);
+
+ return fetch(url.toString(), {
+ method: "GET",
+ headers: { "Content-Type": "application/json" },
+ })
+ .then((res) => res.json())
+ .then((data) => {
+ if (data.error) throw new Error(data.error.message);
+ return data.models ?? [];
+ })
+ .then((models) =>
+ models
+ .filter(
+ (model) => !model.displayName.toLowerCase().includes("tuning")
+ )
+ .filter((model) =>
+ model.supportedGenerationMethods.includes("generateContent")
+ ) // Only generateContent is supported
+ .map((model) => {
+ return {
+ id: model.name.split("/").pop(),
+ name: model.displayName,
+ contextWindow: model.inputTokenLimit,
+ experimental: model.name.includes("exp"),
+ };
+ })
+ )
+ .catch((e) => {
+ console.error(`Gemini:getGeminiModels`, e.message);
+ return defaultGeminiModels;
+ });
}
+ /**
+ * Checks if a model is valid for chat completion (unused)
+ * @deprecated
+ * @param {string} modelName - The name of the model to check
+ * @returns {Promise
} A promise that resolves to a boolean indicating if the model is valid
+ */
+ async isValidChatCompletionModel(modelName = "") {
+ const models = await this.fetchModels(true);
+ return models.some((model) => model.id === modelName);
+ }
/**
* Generates appropriate content array for a message + attachments.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
@@ -218,11 +257,6 @@ class GeminiLLM {
}
async getChatCompletion(messages = [], _opts = {}) {
- if (!this.isValidChatCompletionModel(this.model))
- throw new Error(
- `Gemini chat: ${this.model} is not valid for chat completion!`
- );
-
const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
@@ -256,11 +290,6 @@ class GeminiLLM {
}
async streamGetChatCompletion(messages = [], _opts = {}) {
- if (!this.isValidChatCompletionModel(this.model))
- throw new Error(
- `Gemini chat: ${this.model} is not valid for chat completion!`
- );
-
const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js
index a763635fb5..7adb276261 100644
--- a/server/utils/helpers/customModels.js
+++ b/server/utils/helpers/customModels.js
@@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
const { fetchNovitaModels } = require("../AiProviders/novita");
const { parseLMStudioBasePath } = require("../AiProviders/lmStudio");
const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim");
+const { GeminiLLM } = require("../AiProviders/gemini");
const SUPPORT_CUSTOM_MODELS = [
"openai",
@@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [
"apipie",
"novita",
"xai",
+ "gemini",
];
async function getCustomModels(provider = "", apiKey = null, basePath = null) {
@@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await getXAIModels(apiKey);
case "nvidia-nim":
return await getNvidiaNimModels(basePath);
+ case "gemini":
+ return await getGeminiModels(apiKey);
default:
return { models: [], error: "Invalid provider for custom models" };
}
@@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) {
}
}
+async function getGeminiModels(_apiKey = null) {
+ const apiKey =
+ _apiKey === true
+ ? process.env.GEMINI_API_KEY
+ : _apiKey || process.env.GEMINI_API_KEY || null;
+ const models = await GeminiLLM.fetchModels(apiKey);
+ // Api Key was successful so lets save it for future uses
+ if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey;
+ return { models, error: null };
+}
+
module.exports = {
getCustomModels,
};
diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js
index 948703dca2..da30b6ee0d 100644
--- a/server/utils/helpers/updateENV.js
+++ b/server/utils/helpers/updateENV.js
@@ -52,7 +52,7 @@ const KEY_MAPPING = {
},
GeminiLLMModelPref: {
envKey: "GEMINI_LLM_MODEL_PREF",
- checks: [isNotEmpty, validGeminiModel],
+ checks: [isNotEmpty],
},
GeminiSafetySetting: {
envKey: "GEMINI_SAFETY_SETTING",
@@ -724,27 +724,6 @@ function supportedTranscriptionProvider(input = "") {
: `${input} is not a valid transcription model provider.`;
}
-function validGeminiModel(input = "") {
- const validModels = [
- "gemini-pro",
- "gemini-1.0-pro",
- "gemini-1.5-pro-latest",
- "gemini-1.5-flash-latest",
- "gemini-1.5-pro-exp-0801",
- "gemini-1.5-pro-exp-0827",
- "gemini-1.5-flash-exp-0827",
- "gemini-1.5-flash-8b-exp-0827",
- "gemini-exp-1114",
- "gemini-exp-1121",
- "gemini-exp-1206",
- "learnlm-1.5-pro-experimental",
- "gemini-2.0-flash-exp",
- ];
- return validModels.includes(input)
- ? null
- : `Invalid Model type. Must be one of ${validModels.join(", ")}.`;
-}
-
function validGeminiSafetySetting(input = "") {
const validModes = [
"BLOCK_NONE",