diff --git a/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx
index 5e5816cda8d..f257156af96 100644
--- a/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx
@@ -1,8 +1,5 @@
import { useState } from "react";
-// We dont support all vectorDBs yet for reranking due to complexities of how each provider
-// returns information. We need to normalize the response data so Reranker can be used for each provider.
-const supportedVectorDBs = ["lancedb"];
const hint = {
default: {
title: "Default",
@@ -20,8 +17,7 @@ export default function VectorSearchMode({ workspace, setHasChanges }) {
const [selection, setSelection] = useState(
workspace?.vectorSearchMode ?? "default"
);
- if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
- return null;
+ if (!workspace?.vectorDB) return null;
return (
diff --git a/server/utils/EmbeddingRerankers/rerank.js b/server/utils/EmbeddingRerankers/rerank.js
new file mode 100644
index 00000000000..8bbe588d54e
--- /dev/null
+++ b/server/utils/EmbeddingRerankers/rerank.js
@@ -0,0 +1,29 @@
+const { getRerankerProvider } = require("../helpers");
+
+async function rerank(query, documents, topN = 4) {
+ const reranker = getRerankerProvider();
+ return await reranker.rerank(query, documents, { topK: topN });
+}
+
+/**
+ * For reranking, we want to work with a larger number of results than the topN.
+ * This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
+ * We want to give the reranker a larger number of results to work with.
+ *
+ * However, we cannot make this boundless as reranking is expensive and time consuming.
+ * So we limit the number of results to a maximum of 50 and a minimum of 10.
+ * This is a good balance between the number of results to rerank and the cost of reranking
+ * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
+ *
+ * Benchmarks:
+ * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
+ */
+
+function getSearchLimit(totalEmbeddings = 0) {
+ return Math.max(10, Math.min(50, Math.ceil(totalEmbeddings * 0.1)));
+}
+
+module.exports = {
+ rerank,
+ getSearchLimit,
+};
diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js
index 819a464c6d0..494943cf97d 100644
--- a/server/utils/helpers/index.js
+++ b/server/utils/helpers/index.js
@@ -75,6 +75,11 @@
* @property {Function} embedChunks - Embeds multiple chunks of text.
*/
+/**
+ * @typedef {Object} BaseRerankerProvider
+ * @property {function(string, {text: string}[], {topK: number}): Promise
} rerank - Reranks a list of documents.
+ */
+
/**
* Gets the systems current vector database provider.
* @param {('pinecone' | 'chroma' | 'chromacloud' | 'lancedb' | 'weaviate' | 'qdrant' | 'milvus' | 'zilliz' | 'astra') | null} getExactly - If provided, this will return an explit provider.
@@ -471,6 +476,29 @@ function toChunks(arr, size) {
);
}
+/**
+ * Returns the Reranker provider.
+ * @returns {BaseRerankerProvider}
+ */
+function getRerankerProvider() {
+ const rerankerSelection = process.env.RERANKING_PROVIDER ?? "native";
+ switch (rerankerSelection) {
+ case "native":
+ const {
+ NativeEmbeddingReranker,
+ } = require("../EmbeddingRerankers/native");
+ return new NativeEmbeddingReranker();
+ default:
+ console.log(
+ `[RERANKING] Reranker provider ${rerankerSelection} is not supported. Using native reranker as fallback.`
+ );
+ const {
+ NativeEmbeddingReranker: Native,
+ } = require("../EmbeddingRerankers/native");
+ return new Native();
+ }
+}
+
module.exports = {
getEmbeddingEngineSelection,
maximumChunkLength,
@@ -479,4 +507,5 @@ module.exports = {
getBaseLLMProviderModel,
getLLMProvider,
toChunks,
+ getRerankerProvider,
};
diff --git a/server/utils/vectorDbProviders/astra/index.js b/server/utils/vectorDbProviders/astra/index.js
index b34a8d83afa..783e6340733 100644
--- a/server/utils/vectorDbProviders/astra/index.js
+++ b/server/utils/vectorDbProviders/astra/index.js
@@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const sanitizeNamespace = (namespace) => {
// If namespace already starts with ns_, don't add it again
@@ -301,6 +302,7 @@ const AstraDB = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -319,17 +321,27 @@ const AstraDB = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client,
- namespace: sanitizedNamespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace: sanitizedNamespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace: sanitizedNamespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
- const sources = sourceDocuments.map((metadata, i) => {
- return { ...metadata, text: contextTexts[i] };
+ const sources = sourceDocuments.map((doc, i) => {
+ return { metadata: doc, text: contextTexts[i] };
});
return {
contextTexts,
@@ -373,11 +385,55 @@ const AstraDB = {
return;
}
result.contextTexts.push(response.metadata.text);
- result.sourceDocuments.push(response);
+ result.sourceDocuments.push({
+ ...response.metadata,
+ score: response.$similarity,
+ });
result.scores.push(response.$similarity);
});
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
allNamespaces: async function (client) {
try {
let header = new Headers();
@@ -432,12 +488,11 @@ const AstraDB = {
curateSources: function (sources = []) {
const documents = [];
for (const source of sources) {
- if (Object.keys(source).length > 0) {
- const metadata = source.hasOwnProperty("metadata")
- ? source.metadata
- : source;
+ const { metadata = {} } = source;
+ if (Object.keys(metadata).length > 0) {
documents.push({
...metadata,
+ ...(source.text ? { text: source.text } : {}),
});
}
}
diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js
index bc12818fd18..a8aa09018b0 100644
--- a/server/utils/vectorDbProviders/chroma/index.js
+++ b/server/utils/vectorDbProviders/chroma/index.js
@@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { parseAuthHeader } = require("../../http");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const COLLECTION_REGEX = new RegExp(
/^(?!\d+\.\d+\.\d+\.\d+$)(?!.*\.\.)(?=^[a-zA-Z0-9][a-zA-Z0-9_-]{1,61}[a-zA-Z0-9]$).{3,63}$/
);
@@ -150,6 +151,51 @@ const Chroma = {
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments, contextTexts } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+ const documentsForReranking = sourceDocuments.map((metadata, i) => ({
+ ...metadata,
+ text: contextTexts[i],
+ }));
+
+ const rerankedResults = await rerank(query, documentsForReranking, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { vector: _, rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
namespace: async function (client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided.");
const collection = await client
@@ -348,12 +394,14 @@ const Chroma = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
const { client } = await this.connect();
- if (!(await this.namespaceExists(client, this.normalize(namespace)))) {
+ const collectionName = this.normalize(namespace);
+ if (!(await this.namespaceExists(client, collectionName))) {
return {
contextTexts: [],
sources: [],
@@ -362,16 +410,26 @@ const Chroma = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments, scores } =
- await this.similarityResponse({
- client,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const result = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
+ const { contextTexts, sourceDocuments, scores } = result;
const sources = sourceDocuments.map((metadata, i) => ({
metadata: {
...metadata,
diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js
index 563095fe5db..adda0cefc44 100644
--- a/server/utils/vectorDbProviders/lance/index.js
+++ b/server/utils/vectorDbProviders/lance/index.js
@@ -5,7 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { sourceIdentifier } = require("../../chats");
-const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
/**
* LancedDB Client connection object
@@ -79,67 +79,43 @@ const LanceDb = {
similarityThreshold = 0.25,
filterIdentifiers = [],
}) {
- const reranker = new NativeEmbeddingReranker();
- const collection = await client.openTable(namespace);
const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const vectorSearchResults = await client
+ .openTable(namespace)
+ .then((tbl) =>
+ tbl
+ .vectorSearch(queryVector)
+ .distanceType("cosine")
+ .limit(searchLimit)
+ .toArray()
+ );
+
+ const rerankedResults = await rerank(query, vectorSearchResults, topN);
const result = {
contextTexts: [],
sourceDocuments: [],
scores: [],
};
- /**
- * For reranking, we want to work with a larger number of results than the topN.
- * This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
- * We want to give the reranker a larger number of results to work with.
- *
- * However, we cannot make this boundless as reranking is expensive and time consuming.
- * So we limit the number of results to a maximum of 50 and a minimum of 10.
- * This is a good balance between the number of results to rerank and the cost of reranking
- * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
- *
- * Benchmarks:
- * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
- */
- const searchLimit = Math.max(
- 10,
- Math.min(50, Math.ceil(totalEmbeddings * 0.1))
- );
- const vectorSearchResults = await collection
- .vectorSearch(queryVector)
- .distanceType("cosine")
- .limit(searchLimit)
- .toArray();
-
- await reranker
- .rerank(query, vectorSearchResults, { topK: topN })
- .then((rerankResults) => {
- rerankResults.forEach((item) => {
- if (this.distanceToSimilarity(item._distance) < similarityThreshold)
- return;
- const { vector: _, ...rest } = item;
- if (filterIdentifiers.includes(sourceIdentifier(rest))) {
- console.log(
- "LanceDB: A source was filtered from context as it's parent document is pinned."
- );
- return;
- }
- const score =
- item?.rerank_score || this.distanceToSimilarity(item._distance);
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { vector: _, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) {
+ console.log(
+ "LanceDB: A source was filtered from context as it's parent document is pinned."
+ );
+ return;
+ }
+ const score = item.rerank_score;
- result.contextTexts.push(rest.text);
- result.sourceDocuments.push({
- ...rest,
- score,
- });
- result.scores.push(score);
- });
- })
- .catch((e) => {
- console.error(e);
- console.error("LanceDB::rerankedSimilarityResponse", e.message);
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score,
});
-
+ result.scores.push(score);
+ });
return result;
},
diff --git a/server/utils/vectorDbProviders/milvus/index.js b/server/utils/vectorDbProviders/milvus/index.js
index 2ddaad567bb..b1f6800c23f 100644
--- a/server/utils/vectorDbProviders/milvus/index.js
+++ b/server/utils/vectorDbProviders/milvus/index.js
@@ -10,6 +10,7 @@ const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const Milvus = {
name: "Milvus",
@@ -299,6 +300,7 @@ const Milvus = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -313,14 +315,24 @@ const Milvus = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
const sources = sourceDocuments.map((doc, i) => {
return { metadata: doc, text: contextTexts[i] };
@@ -368,6 +380,47 @@ const Milvus = {
});
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
"namespace-stats": async function (reqBody = {}) {
const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required");
diff --git a/server/utils/vectorDbProviders/pgvector/index.js b/server/utils/vectorDbProviders/pgvector/index.js
index 990498eb5cc..67245fd0a30 100644
--- a/server/utils/vectorDbProviders/pgvector/index.js
+++ b/server/utils/vectorDbProviders/pgvector/index.js
@@ -3,6 +3,7 @@ const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { TextSplitter } = require("../../TextSplitter");
const { v4: uuidv4 } = require("uuid");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
/*
Embedding Table Schema (table name defined by user)
@@ -207,29 +208,31 @@ const PGVector = {
}, PGVector.connectionTimeout);
});
- const connectionPromise = new Promise(async (resolve) => {
- let pgClient = null;
- try {
- pgClient = this.client(connectionString);
- await pgClient.connect();
- const result = await pgClient.query(this.getTablesSql);
-
- if (result.rows.length !== 0 && !!tableName) {
- const tableExists = result.rows.some(
- (row) => row.tablename === tableName
- );
- if (tableExists)
- await this.validateExistingEmbeddingTableSchema(
- pgClient,
- tableName
+ const connectionPromise = new Promise((resolve) => {
+ (async () => {
+ let pgClient = null;
+ try {
+ pgClient = this.client(connectionString);
+ await pgClient.connect();
+ const result = await pgClient.query(this.getTablesSql);
+
+ if (result.rows.length !== 0 && !!tableName) {
+ const tableExists = result.rows.some(
+ (row) => row.tablename === tableName
);
+ if (tableExists)
+ await this.validateExistingEmbeddingTableSchema(
+ pgClient,
+ tableName
+ );
+ }
+ resolve({ error: null, success: true });
+ } catch (err) {
+ resolve({ error: err.message, success: false });
+ } finally {
+ if (pgClient) await pgClient.end();
}
- resolve({ error: null, success: true });
- } catch (err) {
- resolve({ error: err.message, success: false });
- } finally {
- if (pgClient) await pgClient.end();
- }
+ })();
});
// Race the connection attempt against the timeout
@@ -401,6 +404,48 @@ const PGVector = {
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
+
normalizeVector: function (vector) {
const magnitude = Math.sqrt(
vector.reduce((sum, val) => sum + val * val, 0)
@@ -707,6 +752,7 @@ const PGVector = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
let connection = null;
if (!namespace || !input || !LLMConnector)
@@ -727,16 +773,25 @@ const PGVector = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const result = await this.similarityResponse({
- client: connection,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client: connection,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client: connection,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
- const { contextTexts, sourceDocuments } = result;
const sources = sourceDocuments.map((metadata, i) => {
return { metadata: { ...metadata, text: contextTexts[i] } };
});
diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js
index c5c55acb58c..9335c90dbcd 100644
--- a/server/utils/vectorDbProviders/pinecone/index.js
+++ b/server/utils/vectorDbProviders/pinecone/index.js
@@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const PineconeDB = {
name: "Pinecone",
@@ -76,6 +77,47 @@ const PineconeDB = {
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
namespace: async function (index, namespace = null) {
if (!namespace) throw new Error("No namespace value provided.");
const { namespaces } = await index.describeIndexStats();
@@ -247,6 +289,7 @@ const PineconeDB = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -258,14 +301,24 @@ const PineconeDB = {
);
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client: pineconeIndex,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client: pineconeIndex,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client: pineconeIndex,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
const sources = sourceDocuments.map((doc, i) => {
return { metadata: doc, text: contextTexts[i] };
diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js
index 95e347b274e..e8831ea854e 100644
--- a/server/utils/vectorDbProviders/qdrant/index.js
+++ b/server/utils/vectorDbProviders/qdrant/index.js
@@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const QDrant = {
name: "QDrant",
@@ -80,12 +81,54 @@ const QDrant = {
result.sourceDocuments.push({
...(response?.payload || {}),
id: response.id,
+ score: response.score,
});
result.scores.push(response.score);
});
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
namespace: async function (client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided.");
const collection = await client.getCollection(namespace).catch(() => null);
@@ -334,6 +377,7 @@ const QDrant = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -348,17 +392,27 @@ const QDrant = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
- const sources = sourceDocuments.map((metadata, i) => {
- return { ...metadata, text: contextTexts[i] };
+ const sources = sourceDocuments.map((doc, i) => {
+ return { metadata: doc, text: contextTexts[i] };
});
return {
contextTexts,
@@ -400,12 +454,11 @@ const QDrant = {
curateSources: function (sources = []) {
const documents = [];
for (const source of sources) {
- if (Object.keys(source).length > 0) {
- const metadata = source.hasOwnProperty("metadata")
- ? source.metadata
- : source;
+ const { metadata = {} } = source;
+ if (Object.keys(metadata).length > 0) {
documents.push({
...metadata,
+ ...(source.text ? { text: source.text } : {}),
});
}
}
diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js
index 2385c5e8ef1..6a475380f7d 100644
--- a/server/utils/vectorDbProviders/weaviate/index.js
+++ b/server/utils/vectorDbProviders/weaviate/index.js
@@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { camelCase } = require("../../helpers/camelcase");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const Weaviate = {
name: "Weaviate",
@@ -115,12 +116,53 @@ const Weaviate = {
return;
}
result.contextTexts.push(rest.text);
- result.sourceDocuments.push({ ...rest, id });
+ result.sourceDocuments.push({ ...rest, id, score: certainty });
result.scores.push(certainty);
});
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
allNamespaces: async function (client) {
try {
const { classes = [] } = await client.schema.getter().do();
@@ -368,6 +410,7 @@ const Weaviate = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -382,17 +425,27 @@ const Weaviate = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
-
- const sources = sourceDocuments.map((metadata, i) => {
- return { ...metadata, text: contextTexts[i] };
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
+
+ const sources = sourceDocuments.map((doc, i) => {
+ return { metadata: doc, text: contextTexts[i] };
});
return {
contextTexts,
@@ -431,11 +484,12 @@ const Weaviate = {
curateSources: function (sources = []) {
const documents = [];
for (const source of sources) {
- if (Object.keys(source).length > 0) {
- const metadata = source.hasOwnProperty("metadata")
- ? source.metadata
- : source;
- documents.push({ ...metadata });
+ const { metadata = {} } = source;
+ if (Object.keys(metadata).length > 0) {
+ documents.push({
+ ...metadata,
+ ...(source.text ? { text: source.text } : {}),
+ });
}
}
diff --git a/server/utils/vectorDbProviders/zilliz/index.js b/server/utils/vectorDbProviders/zilliz/index.js
index ab866f4edd5..189d2eb85d6 100644
--- a/server/utils/vectorDbProviders/zilliz/index.js
+++ b/server/utils/vectorDbProviders/zilliz/index.js
@@ -10,6 +10,7 @@ const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
+const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
// Zilliz is basically a copy of Milvus DB class with a different constructor
// to connect to the cloud
@@ -157,30 +158,38 @@ const Zilliz = {
vectorDimension = chunks[0][0].values.length || null;
await this.getOrCreateCollection(client, namespace, vectorDimension);
- for (const chunk of chunks) {
- // Before sending to Pinecone and saving the records to our db
- // we need to assign the id of each chunk that is stored in the cached file.
- const newChunks = chunk.map((chunk) => {
- const id = uuidv4();
- documentVectors.push({ docId, vectorId: id });
- return { id, vector: chunk.values, metadata: chunk.metadata };
- });
- const insertResult = await client.insert({
- collection_name: this.normalize(namespace),
- data: newChunks,
- });
+ try {
+ for (const chunk of chunks) {
+ // Before sending to Zilliz and saving the records to our db
+ // we need to assign the id of each chunk that is stored in the cached file.
+ const newChunks = chunk.map((chunk) => {
+ const id = uuidv4();
+ documentVectors.push({ docId, vectorId: id });
+ return { id, vector: chunk.values, metadata: chunk.metadata };
+ });
+ const insertResult = await client.insert({
+ collection_name: this.normalize(namespace),
+ data: newChunks,
+ });
- if (insertResult?.status.error_code !== "Success") {
- throw new Error(
- `Error embedding into Zilliz! Reason:${insertResult?.status.reason}`
- );
+ if (insertResult?.status.error_code !== "Success") {
+ throw new Error(
+ `Error embedding into Milvus! Reason:${insertResult?.status.reason}`
+ );
+ }
}
+ await DocumentVectors.bulkInsert(documentVectors);
+ await client.flushSync({
+ collection_names: [this.normalize(namespace)],
+ });
+ return { vectorized: true, error: null };
+ } catch (insertError) {
+ console.error(
+ "Error inserting cached chunks:",
+ insertError.message
+ );
+ return { vectorized: false, error: insertError.message };
}
- await DocumentVectors.bulkInsert(documentVectors);
- await client.flushSync({
- collection_names: [this.normalize(namespace)],
- });
- return { vectorized: true, error: null };
}
}
@@ -239,7 +248,7 @@ const Zilliz = {
data: chunk.map((item) => ({
id: item.id,
vector: item.values,
- metadata: chunk.metadata,
+ metadata: item.metadata,
})),
});
@@ -292,6 +301,7 @@ const Zilliz = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
+ rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
@@ -306,18 +316,29 @@ const Zilliz = {
}
const queryVector = await LLMConnector.embedTextInput(input);
- const { contextTexts, sourceDocuments } = await this.similarityResponse({
- client,
- namespace,
- queryVector,
- similarityThreshold,
- topN,
- filterIdentifiers,
- });
+ const { contextTexts, sourceDocuments } = rerank
+ ? await this.rerankedSimilarityResponse({
+ client,
+ namespace,
+ query: input,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ })
+ : await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN,
+ filterIdentifiers,
+ });
const sources = sourceDocuments.map((doc, i) => {
return { metadata: doc, text: contextTexts[i] };
});
+
return {
contextTexts,
sources: this.curateSources(sources),
@@ -350,6 +371,7 @@ const Zilliz = {
);
return;
}
+
result.contextTexts.push(match.metadata.text);
result.sourceDocuments.push({
...match.metadata,
@@ -359,6 +381,47 @@ const Zilliz = {
});
return result;
},
+ rerankedSimilarityResponse: async function ({
+ client,
+ namespace,
+ query,
+ queryVector,
+ topN = 4,
+ similarityThreshold = 0.25,
+ filterIdentifiers = [],
+ }) {
+ const totalEmbeddings = await this.namespaceCount(namespace);
+ const searchLimit = getSearchLimit(totalEmbeddings);
+ const { sourceDocuments } = await this.similarityResponse({
+ client,
+ namespace,
+ queryVector,
+ similarityThreshold,
+ topN: searchLimit,
+ filterIdentifiers,
+ });
+
+ const rerankedResults = await rerank(query, sourceDocuments, topN);
+ const result = {
+ contextTexts: [],
+ sourceDocuments: [],
+ scores: [],
+ };
+
+ rerankedResults.forEach((item) => {
+ if (item.rerank_score < similarityThreshold) return;
+ const { rerank_score, ...rest } = item;
+ if (filterIdentifiers.includes(sourceIdentifier(rest))) return;
+
+ result.contextTexts.push(rest.text);
+ result.sourceDocuments.push({
+ ...rest,
+ score: rerank_score,
+ });
+ result.scores.push(rerank_score);
+ });
+ return result;
+ },
"namespace-stats": async function (reqBody = {}) {
const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required");
@@ -394,7 +457,6 @@ const Zilliz = {
});
}
}
-
return documents;
},
};