diff --git a/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx
new file mode 100644
index 00000000000..5e5816cda8d
--- /dev/null
+++ b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx
@@ -0,0 +1,51 @@
+import { useState } from "react";
+
+// We dont support all vectorDBs yet for reranking due to complexities of how each provider
+// returns information. We need to normalize the response data so Reranker can be used for each provider.
+const supportedVectorDBs = ["lancedb"];
+const hint = {
+ default: {
+ title: "Default",
+ description:
+ "This is the fastest performance, but may not return the most relevant results leading to model hallucinations.",
+ },
+ rerank: {
+ title: "Accuracy Optimized",
+ description:
+ "LLM responses may take longer to generate, but your responses will be more accurate and relevant.",
+ },
+};
+
+export default function VectorSearchMode({ workspace, setHasChanges }) {
+ const [selection, setSelection] = useState(
+ workspace?.vectorSearchMode ?? "default"
+ );
+ if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
+ return null;
+
+ return (
+
+
+
+
+
+
+ {hint[selection]?.description}
+
+
+ );
+}
diff --git a/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx b/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx
index 97d63291cda..7d7d44e8f44 100644
--- a/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx
@@ -7,6 +7,7 @@ import MaxContextSnippets from "./MaxContextSnippets";
import DocumentSimilarityThreshold from "./DocumentSimilarityThreshold";
import ResetDatabase from "./ResetDatabase";
import VectorCount from "./VectorCount";
+import VectorSearchMode from "./VectorSearchMode";
export default function VectorDatabase({ workspace }) {
const [hasChanges, setHasChanges] = useState(false);
@@ -43,6 +44,7 @@ export default function VectorDatabase({ workspace }) {
+ {
+ if (
+ !value ||
+ typeof value !== "string" ||
+ !["default", "rerank"].includes(value)
+ )
+ return "default";
+ return value;
+ },
},
/**
diff --git a/server/prisma/migrations/20250102204948_init/migration.sql b/server/prisma/migrations/20250102204948_init/migration.sql
new file mode 100644
index 00000000000..788409bfa1a
--- /dev/null
+++ b/server/prisma/migrations/20250102204948_init/migration.sql
@@ -0,0 +1,2 @@
+-- AlterTable
+ALTER TABLE "workspaces" ADD COLUMN "vectorSearchMode" TEXT DEFAULT 'default';
diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma
index 143646e6579..37c82d4ddfc 100644
--- a/server/prisma/schema.prisma
+++ b/server/prisma/schema.prisma
@@ -137,6 +137,7 @@ model workspaces {
agentProvider String?
agentModel String?
queryRefusalResponse String?
+ vectorSearchMode String? @default("default")
workspace_users workspace_users[]
documents workspace_documents[]
workspace_suggested_messages workspace_suggested_messages[]
diff --git a/server/storage/models/.gitignore b/server/storage/models/.gitignore
index e669b51b277..7a8f66d8ff4 100644
--- a/server/storage/models/.gitignore
+++ b/server/storage/models/.gitignore
@@ -3,4 +3,5 @@ downloaded/*
!downloaded/.placeholder
openrouter
apipie
-novita
\ No newline at end of file
+novita
+mixedbread-ai*
\ No newline at end of file
diff --git a/server/utils/EmbeddingRerankers/native/index.js b/server/utils/EmbeddingRerankers/native/index.js
new file mode 100644
index 00000000000..f3d468cd7b2
--- /dev/null
+++ b/server/utils/EmbeddingRerankers/native/index.js
@@ -0,0 +1,153 @@
+const path = require("path");
+const fs = require("fs");
+
+class NativeEmbeddingReranker {
+ static #model = null;
+ static #tokenizer = null;
+ static #transformers = null;
+
+ constructor() {
+ // An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
+ // Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
+ this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
+ this.cacheDir = path.resolve(
+ process.env.STORAGE_DIR
+ ? path.resolve(process.env.STORAGE_DIR, `models`)
+ : path.resolve(__dirname, `../../../storage/models`)
+ );
+ this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
+ // Make directory when it does not exist in existing installations
+ if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
+ this.log("Initialized");
+ }
+
+ log(text, ...args) {
+ console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
+ }
+
+ /**
+ * This function will preload the reranker suite and tokenizer.
+ * This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
+ * to avoid having to wait for the models to download on the first rerank call.
+ */
+ async preload() {
+ try {
+ this.log(`Preloading reranker suite...`);
+ await this.initClient();
+ this.log(
+ `Preloaded reranker suite. Reranking is available as a service now.`
+ );
+ return;
+ } catch (e) {
+ console.error(e);
+ this.log(
+ `Failed to preload reranker suite. Reranking will be available on the first rerank call.`
+ );
+ return;
+ }
+ }
+
+ async initClient() {
+ if (NativeEmbeddingReranker.#transformers) {
+ this.log(`Reranker suite already initialized - reusing.`);
+ return;
+ }
+
+ await import("@xenova/transformers").then(
+ async ({ AutoModelForSequenceClassification, AutoTokenizer }) => {
+ this.log(`Loading reranker suite...`);
+ NativeEmbeddingReranker.#transformers = {
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ };
+ await this.#getPreTrainedModel();
+ await this.#getPreTrainedTokenizer();
+ }
+ );
+ return;
+ }
+
+ async #getPreTrainedModel() {
+ if (NativeEmbeddingReranker.#model) {
+ this.log(`Loading model from singleton...`);
+ return NativeEmbeddingReranker.#model;
+ }
+
+ const model =
+ await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
+ this.model,
+ {
+ progress_callback: (p) =>
+ p.status === "progress" &&
+ this.log(`Loading model ${this.model}... ${p?.progress}%`),
+ cache_dir: this.cacheDir,
+ }
+ );
+ this.log(`Loaded model ${this.model}`);
+ NativeEmbeddingReranker.#model = model;
+ return model;
+ }
+
+ async #getPreTrainedTokenizer() {
+ if (NativeEmbeddingReranker.#tokenizer) {
+ this.log(`Loading tokenizer from singleton...`);
+ return NativeEmbeddingReranker.#tokenizer;
+ }
+
+ const tokenizer =
+ await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
+ this.model,
+ {
+ progress_callback: (p) =>
+ p.status === "progress" &&
+ this.log(`Loading tokenizer ${this.model}... ${p?.progress}%`),
+ cache_dir: this.cacheDir,
+ }
+ );
+ this.log(`Loaded tokenizer ${this.model}`);
+ NativeEmbeddingReranker.#tokenizer = tokenizer;
+ return tokenizer;
+ }
+
+ /**
+ * Reranks a list of documents based on the query.
+ * @param {string} query - The query to rerank the documents against.
+ * @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
+ * @param {Object} options - The options for the reranking.
+ * @param {number} options.topK - The number of top documents to return.
+ * @returns {Promise} - The reranked list of documents.
+ */
+ async rerank(query, documents, options = { topK: 4 }) {
+ await this.initClient();
+ const model = NativeEmbeddingReranker.#model;
+ const tokenizer = NativeEmbeddingReranker.#tokenizer;
+
+ const start = Date.now();
+ this.log(`Reranking ${documents.length} documents...`);
+ const inputs = tokenizer(new Array(documents.length).fill(query), {
+ text_pair: documents.map((doc) => doc.text),
+ padding: true,
+ truncation: true,
+ });
+ const { logits } = await model(inputs);
+ const reranked = logits
+ .sigmoid()
+ .tolist()
+ .map(([score], i) => ({
+ rerank_corpus_id: i,
+ rerank_score: score,
+ ...documents[i],
+ }))
+ .sort((a, b) => b.rerank_score - a.rerank_score)
+ .slice(0, options.topK);
+
+ this.log(
+ `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
+ );
+ return reranked;
+ }
+}
+
+module.exports = {
+ NativeEmbeddingReranker,
+};
diff --git a/server/utils/agents/aibitat/plugins/memory.js b/server/utils/agents/aibitat/plugins/memory.js
index 4f43d0ec460..df52843015f 100644
--- a/server/utils/agents/aibitat/plugins/memory.js
+++ b/server/utils/agents/aibitat/plugins/memory.js
@@ -95,6 +95,7 @@ const memory = {
input: query,
LLMConnector,
topN: workspace?.topN ?? 4,
+ rerank: workspace?.vectorSearchMode === "rerank",
});
if (contextTexts.length === 0) {
diff --git a/server/utils/chats/apiChatHandler.js b/server/utils/chats/apiChatHandler.js
index 7ba45fed62d..11421ea128e 100644
--- a/server/utils/chats/apiChatHandler.js
+++ b/server/utils/chats/apiChatHandler.js
@@ -180,6 +180,7 @@ async function chatSync({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
+ rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
@@ -480,6 +481,7 @@ async function streamChat({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
+ rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js
index 7196d161e2b..550a460f874 100644
--- a/server/utils/chats/embed.js
+++ b/server/utils/chats/embed.js
@@ -93,6 +93,7 @@ async function streamChatWithForEmbed(
similarityThreshold: embed.workspace?.similarityThreshold,
topN: embed.workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
+ rerank: embed.workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
diff --git a/server/utils/chats/openaiCompatible.js b/server/utils/chats/openaiCompatible.js
index a76347bf71e..fcae9782767 100644
--- a/server/utils/chats/openaiCompatible.js
+++ b/server/utils/chats/openaiCompatible.js
@@ -89,6 +89,7 @@ async function chatSync({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
+ rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
@@ -304,6 +305,7 @@ async function streamChat({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
+ rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js
index 35b0c191e6b..bd81f130898 100644
--- a/server/utils/chats/stream.js
+++ b/server/utils/chats/stream.js
@@ -139,6 +139,7 @@ async function streamChatWithWorkspace(
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
+ rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js
index fa47f9cf78e..544bd36ff10 100644
--- a/server/utils/helpers/index.js
+++ b/server/utils/helpers/index.js
@@ -56,6 +56,7 @@
* @property {Function} totalVectors - Returns the total number of vectors in the database.
* @property {Function} namespaceCount - Returns the count of vectors in a given namespace.
* @property {Function} similarityResponse - Performs a similarity search on a given namespace.
+ * @property {Function} rerankedSimilarityResponse - Performs a similarity search on a given namespace with reranking (if supported by provider).
* @property {Function} namespace - Retrieves the specified namespace collection.
* @property {Function} hasNamespace - Checks if a namespace exists.
* @property {Function} namespaceExists - Verifies if a namespace exists in the client.
diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js
index 78a32b80c78..e3f285478b7 100644
--- a/server/utils/vectorDbProviders/lance/index.js
+++ b/server/utils/vectorDbProviders/lance/index.js
@@ -5,6 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { sourceIdentifier } = require("../../chats");
+const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
/**
* LancedDB Client connection object
@@ -57,6 +58,91 @@ const LanceDb = {
const table = await client.openTable(_namespace);
return (await table.countRows()) || 0;
},
+ /**
+ * Performs a SimilaritySearch + Reranking on a namespace.
+ * @param {Object} params - The parameters for the rerankedSimilarityResponse.
+ * @param {Object} params.client - The vectorDB client.
+ * @param {string} params.namespace - The namespace to search in.
+ * @param {string} params.query - The query to search for (plain text).
+ * @param {number[]} params.queryVector - The vector of the query.
+ * @param {number} params.similarityThreshold - The threshold for similarity.
+ * @param {number} params.topN - the number of results to return from this process.
+ * @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out.
+ * @returns
+ */
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const reranker = new NativeEmbeddingReranker();
+ const collection = await client.openTable(namespace);
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ /**
+ * For reranking, we want to work with a larger number of results than the topN.
+ * This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
+ * We want to give the reranker a larger number of results to work with.
+ *
+ * However, we cannot make this boundless as reranking is expensive and time consuming.
+ * So we limit the number of results to a maximum of 50 and a minimum of 10.
+ * This is a good balance between the number of results to rerank and the cost of reranking
+ * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
+ *
+ * Benchmarks:
+ * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
+ */
+ const searchLimit = Math.max(
+ 10,
+ Math.min(50, Math.ceil(totalEmbeddings * 0.1))
+ );
+ const vectorSearchResults = await collection
+ .vectorSearch(queryVector)
+ .distanceType("cosine")
+ .limit(searchLimit)
+ .toArray();
+
+ await reranker
+ .rerank(query, vectorSearchResults, { topK: topN })
+ .then((rerankResults) => {
+ rerankResults.forEach((item) => {
+ if (this.distanceToSimilarity(item._distance) < similarityThreshold)
+ return;
+ const { vector: _, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) {
+ console.log(
+ "LanceDB: A source was filtered from context as it's parent document is pinned."
+ );
+ return;
+ }
+ const score =
+ item?.rerank_score || this.distanceToSimilarity(item._distance);
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score,
+ });
+ result.scores.push(score);
+ });
+ })
+ .catch((e) => {
+ console.error(e);
+ console.error("LanceDB::rerankedSimilarityResponse", e.message);
+ });
+
+ return result;
+ },
+
/**
* Performs a SimilaritySearch on a give LanceDB namespace.
* @param {Object} params
@@ -300,6 +386,7 @@ const LanceDb = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -314,15 +401,26 @@ const LanceDb = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const result = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
+ const { contextTexts, sourceDocuments } = result;
const sources = sourceDocuments.map((metadata, i) => {
return { metadata: { ...metadata, text: contextTexts[i] } };
});