diff --git a/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx b/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx
index c5ec337d0a8..67f7d291b6f 100644
--- a/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx
+++ b/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx
@@ -32,22 +32,26 @@ export default function OpenAiOptions({ settings }) {
}
function OpenAIModelSelection({ apiKey, settings }) {
- const [customModels, setCustomModels] = useState([]);
+ const [groupedModels, setGroupedModels] = useState({});
const [loading, setLoading] = useState(true);
useEffect(() => {
async function findCustomModels() {
- if (!apiKey) {
- setCustomModels([]);
- setLoading(false);
- return;
- }
setLoading(true);
const { models } = await System.customModels(
"openai",
typeof apiKey === "boolean" ? null : apiKey
);
- setCustomModels(models || []);
+
+ if (models?.length > 0) {
+ const modelsByOrganization = models.reduce((acc, model) => {
+ acc[model.organization] = acc[model.organization] || [];
+ acc[model.organization].push(model);
+ return acc;
+ }, {});
+ setGroupedModels(modelsByOrganization);
+ }
+
setLoading(false);
}
findCustomModels();
@@ -82,41 +86,21 @@ function OpenAIModelSelection({ apiKey, settings }) {
required={true}
className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
-
- {customModels.length > 0 && (
-
+ ))}
);
diff --git a/frontend/src/hooks/useGetProvidersModels.js b/frontend/src/hooks/useGetProvidersModels.js
index 95df82a3a98..513bfdbe19c 100644
--- a/frontend/src/hooks/useGetProvidersModels.js
+++ b/frontend/src/hooks/useGetProvidersModels.js
@@ -4,14 +4,7 @@ import { useEffect, useState } from "react";
// Providers which cannot use this feature for workspace<>model selection
export const DISABLED_PROVIDERS = ["azure", "lmstudio", "native"];
const PROVIDER_DEFAULT_MODELS = {
- openai: [
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-1106",
- "gpt-4",
- "gpt-4-turbo-preview",
- "gpt-4-1106-preview",
- "gpt-4-32k",
- ],
+ openai: [],
gemini: ["gemini-pro"],
anthropic: [
"claude-instant-1.2",
@@ -41,6 +34,7 @@ function groupModels(models) {
}, {});
}
+const groupedProviders = ["togetherai", "openai"];
export default function useGetProviderModels(provider = null) {
const [defaultModels, setDefaultModels] = useState([]);
const [customModels, setCustomModels] = useState([]);
@@ -50,9 +44,12 @@ export default function useGetProviderModels(provider = null) {
async function fetchProviderModels() {
if (!provider) return;
const { models = [] } = await System.customModels(provider);
- if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider))
+ if (
+ PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider) &&
+ !groupedProviders.includes(provider)
+ )
setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
- provider === "togetherai"
+ groupedProviders.includes(provider)
? setCustomModels(groupModels(models))
: setCustomModels(models);
setLoading(false);
diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js
index d4dc14dc63b..e09e11c2141 100644
--- a/server/utils/AiProviders/openAi/index.js
+++ b/server/utils/AiProviders/openAi/index.js
@@ -46,32 +46,28 @@ class OpenAiLLM {
promptWindowLimit() {
switch (this.model) {
case "gpt-3.5-turbo":
- return 4096;
case "gpt-3.5-turbo-1106":
- return 16385;
- case "gpt-4":
- return 8192;
+ return 16_385;
+ case "gpt-4-turbo":
case "gpt-4-1106-preview":
- return 128000;
case "gpt-4-turbo-preview":
- return 128000;
+ return 128_000;
+ case "gpt-4":
+ return 8_192;
case "gpt-4-32k":
- return 32000;
+ return 32_000;
default:
- return 4096; // assume a fine-tune 3.5
+ return 4_096; // assume a fine-tune 3.5?
}
}
+ // Short circuit if name has 'gpt' since we now fetch models from OpenAI API
+ // via the user API key, so the model must be relevant and real.
+ // and if somehow it is not, chat will fail but that is caught.
+ // we don't want to hit the OpenAI api every chat because it will get spammed
+ // and introduce latency for no reason.
async isValidChatCompletionModel(modelName = "") {
- const validModels = [
- "gpt-4",
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-1106",
- "gpt-4-1106-preview",
- "gpt-4-turbo-preview",
- "gpt-4-32k",
- ];
- const isPreset = validModels.some((model) => modelName === model);
+ const isPreset = modelName.toLowerCase().includes("gpt");
if (isPreset) return true;
const model = await this.openai
diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js
index 5dfa30e319e..d0d162c4a54 100644
--- a/server/utils/helpers/customModels.js
+++ b/server/utils/helpers/customModels.js
@@ -47,21 +47,85 @@ async function openAiModels(apiKey = null) {
apiKey: apiKey || process.env.OPEN_AI_KEY,
});
const openai = new OpenAIApi(config);
- const models = (
- await openai
- .listModels()
- .then((res) => res.data.data)
- .catch((e) => {
- console.error(`OpenAI:listModels`, e.message);
- return [];
- })
- ).filter(
- (model) => !model.owned_by.includes("openai") && model.owned_by !== "system"
- );
+ const allModels = await openai
+ .listModels()
+ .then((res) => res.data.data)
+ .catch((e) => {
+ console.error(`OpenAI:listModels`, e.message);
+ return [
+ {
+ name: "gpt-3.5-turbo",
+ id: "gpt-3.5-turbo",
+ object: "model",
+ created: 1677610602,
+ owned_by: "openai",
+ organization: "OpenAi",
+ },
+ {
+ name: "gpt-4",
+ id: "gpt-4",
+ object: "model",
+ created: 1687882411,
+ owned_by: "openai",
+ organization: "OpenAi",
+ },
+ {
+ name: "gpt-4-turbo",
+ id: "gpt-4-turbo",
+ object: "model",
+ created: 1712361441,
+ owned_by: "system",
+ organization: "OpenAi",
+ },
+ {
+ name: "gpt-4-32k",
+ id: "gpt-4-32k",
+ object: "model",
+ created: 1687979321,
+ owned_by: "openai",
+ organization: "OpenAi",
+ },
+ {
+ name: "gpt-3.5-turbo-16k",
+ id: "gpt-3.5-turbo-16k",
+ object: "model",
+ created: 1683758102,
+ owned_by: "openai-internal",
+ organization: "OpenAi",
+ },
+ ];
+ });
+
+ const gpts = allModels
+ .filter((model) => model.id.startsWith("gpt"))
+ .filter(
+ (model) => !model.id.includes("vision") && !model.id.includes("instruct")
+ )
+ .map((model) => {
+ return {
+ ...model,
+ name: model.id,
+ organization: "OpenAi",
+ };
+ });
+
+ const customModels = allModels
+ .filter(
+ (model) =>
+ !model.owned_by.includes("openai") && model.owned_by !== "system"
+ )
+ .map((model) => {
+ return {
+ ...model,
+ name: model.id,
+ organization: "Your Fine-Tunes",
+ };
+ });
// Api Key was successful so lets save it for future uses
- if (models.length > 0 && !!apiKey) process.env.OPEN_AI_KEY = apiKey;
- return { models, error: null };
+ if ((gpts.length > 0 || customModels.length > 0) && !!apiKey)
+ process.env.OPEN_AI_KEY = apiKey;
+ return { models: [...gpts, ...customModels], error: null };
}
async function localAIModels(basePath = null, apiKey = null) {