θΏ™ζ˜―indexlocζδΎ›ηš„ζœεŠ‘οΌŒδΈθ¦θΎ“ε…₯任何密码
Skip to content

Add caching to Gemini /models #2969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion server/storage/models/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ downloaded/*
openrouter
apipie
novita
mixedbread-ai*
mixedbread-ai*
gemini
107 changes: 97 additions & 10 deletions server/utils/AiProviders/gemini/index.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const fs = require("fs");
const path = require("path");
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const {
LLMPerformanceMonitor,
Expand All @@ -7,7 +9,13 @@ const {
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModels");
const { safeJsonParse } = require("../../http");
const cacheFolder = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, "models", "gemini")
: path.resolve(__dirname, `../../../storage/models/gemini`)
);

class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
Expand Down Expand Up @@ -44,13 +52,33 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold();
this.#log(`Initialized with model: ${this.model}`);

if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
this.cacheModelPath = path.resolve(cacheFolder, "models.json");
this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
this.#log(
`Initialized with model: ${this.model} (${this.promptWindowLimit()})`
);
}

#log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
}

// This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
// from the current date. If it is, then we will refetch the API so that all the models are up
// to date.
static cacheIsStale() {
const MAX_STALE = 6.048e8; // 1 Week in MS
if (!fs.existsSync(path.resolve(cacheFolder, ".cached_at"))) return true;
const now = Number(new Date());
const timestampMs = Number(
fs.readFileSync(path.resolve(cacheFolder, ".cached_at"))
);
return now - timestampMs > MAX_STALE;
}

#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
Expand Down Expand Up @@ -103,11 +131,40 @@ class GeminiLLM {
}

static promptWindowLimit(modelName) {
return MODEL_MAP.gemini[modelName] ?? 30_720;
try {
const cacheModelPath = path.resolve(cacheFolder, "models.json");
if (!fs.existsSync(cacheModelPath))
return MODEL_MAP.gemini[modelName] ?? 30_720;

const models = safeJsonParse(fs.readFileSync(cacheModelPath));
const model = models.find((model) => model.id === modelName);
if (!model)
throw new Error(
"Model not found in cache - falling back to default model."
);
return model.contextWindow;
} catch (e) {
console.error(`GeminiLLM:promptWindowLimit`, e.message);
return MODEL_MAP.gemini[modelName] ?? 30_720;
}
}

promptWindowLimit() {
return MODEL_MAP.gemini[this.model] ?? 30_720;
try {
if (!fs.existsSync(this.cacheModelPath))
return MODEL_MAP.gemini[this.model] ?? 30_720;

const models = safeJsonParse(fs.readFileSync(this.cacheModelPath));
const model = models.find((model) => model.id === this.model);
if (!model)
throw new Error(
"Model not found in cache - falling back to default model."
);
return model.contextWindow;
} catch (e) {
console.error(`GeminiLLM:promptWindowLimit`, e.message);
return MODEL_MAP.gemini[this.model] ?? 30_720;
}
}

/**
Expand All @@ -118,14 +175,25 @@ class GeminiLLM {
* @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
*/
static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
if (!apiKey) return [];
if (fs.existsSync(cacheFolder) && !this.cacheIsStale()) {
console.log(
`\x1b[32m[GeminiLLM]\x1b[0m Using cached models API response.`
);
return safeJsonParse(
fs.readFileSync(path.resolve(cacheFolder, "models.json"))
);
}

const url = new URL(
"https://generativelanguage.googleapis.com/v1beta/models"
);
url.searchParams.set("pageSize", limit);
url.searchParams.set("key", apiKey);
if (pageToken) url.searchParams.set("pageToken", pageToken);
let success = false;

return fetch(url.toString(), {
const models = await fetch(url.toString(), {
method: "GET",
headers: { "Content-Type": "application/json" },
})
Expand All @@ -134,8 +202,9 @@ class GeminiLLM {
if (data.error) throw new Error(data.error.message);
return data.models ?? [];
})
.then((models) =>
models
.then((models) => {
success = true;
return models
.filter(
(model) => !model.displayName.toLowerCase().includes("tuning")
)
Expand All @@ -149,12 +218,30 @@ class GeminiLLM {
contextWindow: model.inputTokenLimit,
experimental: model.name.includes("exp"),
};
})
)
});
})
.catch((e) => {
console.error(`Gemini:getGeminiModels`, e.message);
success = false;
return defaultGeminiModels;
});

if (success) {
console.log(
`\x1b[32m[GeminiLLM]\x1b[0m Writing cached models API response to disk.`
);
if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
fs.writeFileSync(
path.resolve(cacheFolder, "models.json"),
JSON.stringify(models)
);
fs.writeFileSync(
path.resolve(cacheFolder, ".cached_at"),
new Date().getTime().toString()
);
}
return models;
}

/**
Expand All @@ -164,7 +251,7 @@ class GeminiLLM {
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
*/
async isValidChatCompletionModel(modelName = "") {
const models = await this.fetchModels(true);
const models = await this.fetchModels(process.env.GEMINI_API_KEY);
return models.some((model) => model.id === modelName);
}
/**
Expand Down