θΏ™ζ˜―indexlocζδΎ›ηš„ζœεŠ‘οΌŒδΈθ¦θΎ“ε…₯任何密码
Skip to content

Add support for gemini authenticated models endpoint #2868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 85 additions & 45 deletions frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import System from "@/models/system";
import { useEffect, useState } from "react";

export default function GeminiLLMOptions({ settings }) {
const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey);
const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey);

return (
<div className="w-full flex flex-col">
<div className="w-full flex items-center gap-[36px] mt-1.5">
Expand All @@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) {
required={true}
autoComplete="off"
spellCheck={false}
onChange={(e) => setInputValue(e.target.value)}
onBlur={() => setGeminiApiKey(inputValue)}
/>
</div>

{!settings?.credentialsOnly && (
<>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"}
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Stable Models">
{[
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
<optgroup label="Experimental Models">
{[
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
<GeminiModelSelection apiKey={geminiApiKey} settings={settings} />
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Safety Setting
Expand All @@ -91,3 +55,79 @@ export default function GeminiLLMOptions({ settings }) {
</div>
);
}

function GeminiModelSelection({ apiKey, settings }) {
const [groupedModels, setGroupedModels] = useState({});
const [loading, setLoading] = useState(true);

useEffect(() => {
async function findCustomModels() {
setLoading(true);
const { models } = await System.customModels("gemini", apiKey);

if (models?.length > 0) {
const modelsByOrganization = models.reduce((acc, model) => {
acc[model.experimental ? "Experimental" : "Stable"] =
acc[model.experimental ? "Experimental" : "Stable"] || [];
acc[model.experimental ? "Experimental" : "Stable"].push(model);
return acc;
}, {});
setGroupedModels(modelsByOrganization);
}
setLoading(false);
}
findCustomModels();
}, [apiKey]);

if (loading) {
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
disabled={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<option disabled={true} selected={true}>
-- loading available models --
</option>
</select>
</div>
);
}

return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
{Object.keys(groupedModels)
.sort((a, b) => {
if (a === "Stable") return -1;
if (b === "Stable") return 1;
return a.localeCompare(b);
})
.map((organization) => (
<optgroup key={organization} label={organization}>
{groupedModels[organization].map((model) => (
<option
key={model.id}
value={model.id}
selected={settings?.GeminiLLMModelPref === model.id}
>
{model.name}
</option>
))}
</optgroup>
))}
</select>
</div>
);
}
46 changes: 46 additions & 0 deletions server/utils/AiProviders/gemini/defaultModals.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
const { MODEL_MAP } = require("../modelMap");

const stableModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
];

const experimentalModels = [
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];

// There are some models that are only available in the v1beta API
// and some models that are only available in the v1 API
// generally, v1beta models have `exp` in the name, but not always
// so we check for both against a static list as well.
const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"];

const defaultGeminiModels = [
...stableModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: false,
})),
...experimentalModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: true,
})),
];

module.exports = {
defaultGeminiModels,
v1BetaModels,
};
115 changes: 72 additions & 43 deletions server/utils/AiProviders/gemini/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const {
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");

class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
Expand All @@ -21,22 +22,17 @@ class GeminiLLM {
this.gemini = genAI.getGenerativeModel(
{ model: this.model },
{
// Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API.
apiVersion: [
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].includes(this.model)
? "v1beta"
: "v1",
apiVersion:
/**
* There are some models that are only available in the v1beta API
* and some models that are only available in the v1 API
* generally, v1beta models have `exp` in the name, but not always
* so we check for both against a static list as well.
* @see {v1BetaModels}
*/
this.model.includes("exp") || v1BetaModels.includes(this.model)
? "v1beta"
: "v1",
}
);
this.limits = {
Expand All @@ -48,6 +44,11 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold();
this.#log(`Initialized with model: ${this.model}`);
}

#log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
}

#appendContext(contextTexts = []) {
Expand Down Expand Up @@ -109,25 +110,63 @@ class GeminiLLM {
return MODEL_MAP.gemini[this.model] ?? 30_720;
}

isValidChatCompletionModel(modelName = "") {
const validModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
return validModels.includes(modelName);
/**
* Fetches Gemini models from the Google Generative AI API
* @param {string} apiKey - The API key to use for the request
* @param {number} limit - The maximum number of models to fetch
* @param {string} pageToken - The page token to use for pagination
* @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
*/
static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
const url = new URL(
"https://generativelanguage.googleapis.com/v1beta/models"
);
url.searchParams.set("pageSize", limit);
url.searchParams.set("key", apiKey);
if (pageToken) url.searchParams.set("pageToken", pageToken);

return fetch(url.toString(), {
method: "GET",
headers: { "Content-Type": "application/json" },
})
.then((res) => res.json())
.then((data) => {
if (data.error) throw new Error(data.error.message);
return data.models ?? [];
})
.then((models) =>
models
.filter(
(model) => !model.displayName.toLowerCase().includes("tuning")
)
.filter((model) =>
model.supportedGenerationMethods.includes("generateContent")
) // Only generateContent is supported
.map((model) => {
return {
id: model.name.split("/").pop(),
name: model.displayName,
contextWindow: model.inputTokenLimit,
experimental: model.name.includes("exp"),
};
})
)
.catch((e) => {
console.error(`Gemini:getGeminiModels`, e.message);
return defaultGeminiModels;
});
}

/**
* Checks if a model is valid for chat completion (unused)
* @deprecated
* @param {string} modelName - The name of the model to check
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
*/
async isValidChatCompletionModel(modelName = "") {
const models = await this.fetchModels(true);
return models.some((model) => model.id === modelName);
}
/**
* Generates appropriate content array for a message + attachments.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
Expand Down Expand Up @@ -218,11 +257,6 @@ class GeminiLLM {
}

async getChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);

const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
Expand Down Expand Up @@ -256,11 +290,6 @@ class GeminiLLM {
}

async streamGetChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);

const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
Expand Down
15 changes: 15 additions & 0 deletions server/utils/helpers/customModels.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
const { fetchNovitaModels } = require("../AiProviders/novita");
const { parseLMStudioBasePath } = require("../AiProviders/lmStudio");
const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim");
const { GeminiLLM } = require("../AiProviders/gemini");

const SUPPORT_CUSTOM_MODELS = [
"openai",
Expand All @@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [
"apipie",
"novita",
"xai",
"gemini",
];

async function getCustomModels(provider = "", apiKey = null, basePath = null) {
Expand Down Expand Up @@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await getXAIModels(apiKey);
case "nvidia-nim":
return await getNvidiaNimModels(basePath);
case "gemini":
return await getGeminiModels(apiKey);
default:
return { models: [], error: "Invalid provider for custom models" };
}
Expand Down Expand Up @@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) {
}
}

async function getGeminiModels(_apiKey = null) {
const apiKey =
_apiKey === true
? process.env.GEMINI_API_KEY
: _apiKey || process.env.GEMINI_API_KEY || null;
const models = await GeminiLLM.fetchModels(apiKey);
// Api Key was successful so lets save it for future uses
if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey;
return { models, error: null };
}

module.exports = {
getCustomModels,
};
Loading