diff --git a/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx index 5e5816cda8d..f257156af96 100644 --- a/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx @@ -1,8 +1,5 @@ import { useState } from "react"; -// We dont support all vectorDBs yet for reranking due to complexities of how each provider -// returns information. We need to normalize the response data so Reranker can be used for each provider. -const supportedVectorDBs = ["lancedb"]; const hint = { default: { title: "Default", @@ -20,8 +17,7 @@ export default function VectorSearchMode({ workspace, setHasChanges }) { const [selection, setSelection] = useState( workspace?.vectorSearchMode ?? "default" ); - if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB)) - return null; + if (!workspace?.vectorDB) return null; return (
diff --git a/server/utils/EmbeddingRerankers/rerank.js b/server/utils/EmbeddingRerankers/rerank.js new file mode 100644 index 00000000000..8bbe588d54e --- /dev/null +++ b/server/utils/EmbeddingRerankers/rerank.js @@ -0,0 +1,29 @@ +const { getRerankerProvider } = require("../helpers"); + +async function rerank(query, documents, topN = 4) { + const reranker = getRerankerProvider(); + return await reranker.rerank(query, documents, { topK: topN }); +} + +/** + * For reranking, we want to work with a larger number of results than the topN. + * This is because the reranker can only rerank the results it it given and we dont auto-expand the results. + * We want to give the reranker a larger number of results to work with. + * + * However, we cannot make this boundless as reranking is expensive and time consuming. + * So we limit the number of results to a maximum of 50 and a minimum of 10. + * This is a good balance between the number of results to rerank and the cost of reranking + * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware. + * + * Benchmarks: + * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec + */ + +function getSearchLimit(totalEmbeddings = 0) { + return Math.max(10, Math.min(50, Math.ceil(totalEmbeddings * 0.1))); +} + +module.exports = { + rerank, + getSearchLimit, +}; diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 819a464c6d0..494943cf97d 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -75,6 +75,11 @@ * @property {Function} embedChunks - Embeds multiple chunks of text. */ +/** + * @typedef {Object} BaseRerankerProvider + * @property {function(string, {text: string}[], {topK: number}): Promise} rerank - Reranks a list of documents. + */ + /** * Gets the systems current vector database provider. * @param {('pinecone' | 'chroma' | 'chromacloud' | 'lancedb' | 'weaviate' | 'qdrant' | 'milvus' | 'zilliz' | 'astra') | null} getExactly - If provided, this will return an explit provider. @@ -471,6 +476,29 @@ function toChunks(arr, size) { ); } +/** + * Returns the Reranker provider. + * @returns {BaseRerankerProvider} + */ +function getRerankerProvider() { + const rerankerSelection = process.env.RERANKING_PROVIDER ?? "native"; + switch (rerankerSelection) { + case "native": + const { + NativeEmbeddingReranker, + } = require("../EmbeddingRerankers/native"); + return new NativeEmbeddingReranker(); + default: + console.log( + `[RERANKING] Reranker provider ${rerankerSelection} is not supported. Using native reranker as fallback.` + ); + const { + NativeEmbeddingReranker: Native, + } = require("../EmbeddingRerankers/native"); + return new Native(); + } +} + module.exports = { getEmbeddingEngineSelection, maximumChunkLength, @@ -479,4 +507,5 @@ module.exports = { getBaseLLMProviderModel, getLLMProvider, toChunks, + getRerankerProvider, }; diff --git a/server/utils/vectorDbProviders/astra/index.js b/server/utils/vectorDbProviders/astra/index.js index b34a8d83afa..783e6340733 100644 --- a/server/utils/vectorDbProviders/astra/index.js +++ b/server/utils/vectorDbProviders/astra/index.js @@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { v4: uuidv4 } = require("uuid"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); const sanitizeNamespace = (namespace) => { // If namespace already starts with ns_, don't add it again @@ -301,6 +302,7 @@ const AstraDB = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -319,17 +321,27 @@ const AstraDB = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client, - namespace: sanitizedNamespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace: sanitizedNamespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace: sanitizedNamespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); - const sources = sourceDocuments.map((metadata, i) => { - return { ...metadata, text: contextTexts[i] }; + const sources = sourceDocuments.map((doc, i) => { + return { metadata: doc, text: contextTexts[i] }; }); return { contextTexts, @@ -373,11 +385,55 @@ const AstraDB = { return; } result.contextTexts.push(response.metadata.text); - result.sourceDocuments.push(response); + result.sourceDocuments.push({ + ...response.metadata, + score: response.$similarity, + }); result.scores.push(response.$similarity); }); return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, allNamespaces: async function (client) { try { let header = new Headers(); @@ -432,12 +488,11 @@ const AstraDB = { curateSources: function (sources = []) { const documents = []; for (const source of sources) { - if (Object.keys(source).length > 0) { - const metadata = source.hasOwnProperty("metadata") - ? source.metadata - : source; + const { metadata = {} } = source; + if (Object.keys(metadata).length > 0) { documents.push({ ...metadata, + ...(source.text ? { text: source.text } : {}), }); } } diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js index bc12818fd18..a8aa09018b0 100644 --- a/server/utils/vectorDbProviders/chroma/index.js +++ b/server/utils/vectorDbProviders/chroma/index.js @@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { parseAuthHeader } = require("../../http"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); const COLLECTION_REGEX = new RegExp( /^(?!\d+\.\d+\.\d+\.\d+$)(?!.*\.\.)(?=^[a-zA-Z0-9][a-zA-Z0-9_-]{1,61}[a-zA-Z0-9]$).{3,63}$/ ); @@ -150,6 +151,51 @@ const Chroma = { return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments, contextTexts } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + const documentsForReranking = sourceDocuments.map((metadata, i) => ({ + ...metadata, + text: contextTexts[i], + })); + + const rerankedResults = await rerank(query, documentsForReranking, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { vector: _, rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, namespace: async function (client, namespace = null) { if (!namespace) throw new Error("No namespace value provided."); const collection = await client @@ -348,12 +394,14 @@ const Chroma = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); const { client } = await this.connect(); - if (!(await this.namespaceExists(client, this.normalize(namespace)))) { + const collectionName = this.normalize(namespace); + if (!(await this.namespaceExists(client, collectionName))) { return { contextTexts: [], sources: [], @@ -362,16 +410,26 @@ const Chroma = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments, scores } = - await this.similarityResponse({ - client, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const result = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); + const { contextTexts, sourceDocuments, scores } = result; const sources = sourceDocuments.map((metadata, i) => ({ metadata: { ...metadata, diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js index 563095fe5db..adda0cefc44 100644 --- a/server/utils/vectorDbProviders/lance/index.js +++ b/server/utils/vectorDbProviders/lance/index.js @@ -5,7 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings"); const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { v4: uuidv4 } = require("uuid"); const { sourceIdentifier } = require("../../chats"); -const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); /** * LancedDB Client connection object @@ -79,67 +79,43 @@ const LanceDb = { similarityThreshold = 0.25, filterIdentifiers = [], }) { - const reranker = new NativeEmbeddingReranker(); - const collection = await client.openTable(namespace); const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const vectorSearchResults = await client + .openTable(namespace) + .then((tbl) => + tbl + .vectorSearch(queryVector) + .distanceType("cosine") + .limit(searchLimit) + .toArray() + ); + + const rerankedResults = await rerank(query, vectorSearchResults, topN); const result = { contextTexts: [], sourceDocuments: [], scores: [], }; - /** - * For reranking, we want to work with a larger number of results than the topN. - * This is because the reranker can only rerank the results it it given and we dont auto-expand the results. - * We want to give the reranker a larger number of results to work with. - * - * However, we cannot make this boundless as reranking is expensive and time consuming. - * So we limit the number of results to a maximum of 50 and a minimum of 10. - * This is a good balance between the number of results to rerank and the cost of reranking - * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware. - * - * Benchmarks: - * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec - */ - const searchLimit = Math.max( - 10, - Math.min(50, Math.ceil(totalEmbeddings * 0.1)) - ); - const vectorSearchResults = await collection - .vectorSearch(queryVector) - .distanceType("cosine") - .limit(searchLimit) - .toArray(); - - await reranker - .rerank(query, vectorSearchResults, { topK: topN }) - .then((rerankResults) => { - rerankResults.forEach((item) => { - if (this.distanceToSimilarity(item._distance) < similarityThreshold) - return; - const { vector: _, ...rest } = item; - if (filterIdentifiers.includes(sourceIdentifier(rest))) { - console.log( - "LanceDB: A source was filtered from context as it's parent document is pinned." - ); - return; - } - const score = - item?.rerank_score || this.distanceToSimilarity(item._distance); + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { vector: _, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) { + console.log( + "LanceDB: A source was filtered from context as it's parent document is pinned." + ); + return; + } + const score = item.rerank_score; - result.contextTexts.push(rest.text); - result.sourceDocuments.push({ - ...rest, - score, - }); - result.scores.push(score); - }); - }) - .catch((e) => { - console.error(e); - console.error("LanceDB::rerankedSimilarityResponse", e.message); + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score, }); - + result.scores.push(score); + }); return result; }, diff --git a/server/utils/vectorDbProviders/milvus/index.js b/server/utils/vectorDbProviders/milvus/index.js index 2ddaad567bb..b1f6800c23f 100644 --- a/server/utils/vectorDbProviders/milvus/index.js +++ b/server/utils/vectorDbProviders/milvus/index.js @@ -10,6 +10,7 @@ const { v4: uuidv4 } = require("uuid"); const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); const Milvus = { name: "Milvus", @@ -299,6 +300,7 @@ const Milvus = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -313,14 +315,24 @@ const Milvus = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); const sources = sourceDocuments.map((doc, i) => { return { metadata: doc, text: contextTexts[i] }; @@ -368,6 +380,47 @@ const Milvus = { }); return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, "namespace-stats": async function (reqBody = {}) { const { namespace = null } = reqBody; if (!namespace) throw new Error("namespace required"); diff --git a/server/utils/vectorDbProviders/pgvector/index.js b/server/utils/vectorDbProviders/pgvector/index.js index 990498eb5cc..67245fd0a30 100644 --- a/server/utils/vectorDbProviders/pgvector/index.js +++ b/server/utils/vectorDbProviders/pgvector/index.js @@ -3,6 +3,7 @@ const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { TextSplitter } = require("../../TextSplitter"); const { v4: uuidv4 } = require("uuid"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); /* Embedding Table Schema (table name defined by user) @@ -207,29 +208,31 @@ const PGVector = { }, PGVector.connectionTimeout); }); - const connectionPromise = new Promise(async (resolve) => { - let pgClient = null; - try { - pgClient = this.client(connectionString); - await pgClient.connect(); - const result = await pgClient.query(this.getTablesSql); - - if (result.rows.length !== 0 && !!tableName) { - const tableExists = result.rows.some( - (row) => row.tablename === tableName - ); - if (tableExists) - await this.validateExistingEmbeddingTableSchema( - pgClient, - tableName + const connectionPromise = new Promise((resolve) => { + (async () => { + let pgClient = null; + try { + pgClient = this.client(connectionString); + await pgClient.connect(); + const result = await pgClient.query(this.getTablesSql); + + if (result.rows.length !== 0 && !!tableName) { + const tableExists = result.rows.some( + (row) => row.tablename === tableName ); + if (tableExists) + await this.validateExistingEmbeddingTableSchema( + pgClient, + tableName + ); + } + resolve({ error: null, success: true }); + } catch (err) { + resolve({ error: err.message, success: false }); + } finally { + if (pgClient) await pgClient.end(); } - resolve({ error: null, success: true }); - } catch (err) { - resolve({ error: err.message, success: false }); - } finally { - if (pgClient) await pgClient.end(); - } + })(); }); // Race the connection attempt against the timeout @@ -401,6 +404,48 @@ const PGVector = { return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, + normalizeVector: function (vector) { const magnitude = Math.sqrt( vector.reduce((sum, val) => sum + val * val, 0) @@ -707,6 +752,7 @@ const PGVector = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { let connection = null; if (!namespace || !input || !LLMConnector) @@ -727,16 +773,25 @@ const PGVector = { } const queryVector = await LLMConnector.embedTextInput(input); - const result = await this.similarityResponse({ - client: connection, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client: connection, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client: connection, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); - const { contextTexts, sourceDocuments } = result; const sources = sourceDocuments.map((metadata, i) => { return { metadata: { ...metadata, text: contextTexts[i] } }; }); diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js index c5c55acb58c..9335c90dbcd 100644 --- a/server/utils/vectorDbProviders/pinecone/index.js +++ b/server/utils/vectorDbProviders/pinecone/index.js @@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { v4: uuidv4 } = require("uuid"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); const PineconeDB = { name: "Pinecone", @@ -76,6 +77,47 @@ const PineconeDB = { return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, namespace: async function (index, namespace = null) { if (!namespace) throw new Error("No namespace value provided."); const { namespaces } = await index.describeIndexStats(); @@ -247,6 +289,7 @@ const PineconeDB = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -258,14 +301,24 @@ const PineconeDB = { ); const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client: pineconeIndex, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client: pineconeIndex, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client: pineconeIndex, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); const sources = sourceDocuments.map((doc, i) => { return { metadata: doc, text: contextTexts[i] }; diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js index 95e347b274e..e8831ea854e 100644 --- a/server/utils/vectorDbProviders/qdrant/index.js +++ b/server/utils/vectorDbProviders/qdrant/index.js @@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { v4: uuidv4 } = require("uuid"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); const QDrant = { name: "QDrant", @@ -80,12 +81,54 @@ const QDrant = { result.sourceDocuments.push({ ...(response?.payload || {}), id: response.id, + score: response.score, }); result.scores.push(response.score); }); return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, namespace: async function (client, namespace = null) { if (!namespace) throw new Error("No namespace value provided."); const collection = await client.getCollection(namespace).catch(() => null); @@ -334,6 +377,7 @@ const QDrant = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -348,17 +392,27 @@ const QDrant = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); - const sources = sourceDocuments.map((metadata, i) => { - return { ...metadata, text: contextTexts[i] }; + const sources = sourceDocuments.map((doc, i) => { + return { metadata: doc, text: contextTexts[i] }; }); return { contextTexts, @@ -400,12 +454,11 @@ const QDrant = { curateSources: function (sources = []) { const documents = []; for (const source of sources) { - if (Object.keys(source).length > 0) { - const metadata = source.hasOwnProperty("metadata") - ? source.metadata - : source; + const { metadata = {} } = source; + if (Object.keys(metadata).length > 0) { documents.push({ ...metadata, + ...(source.text ? { text: source.text } : {}), }); } } diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js index 2385c5e8ef1..6a475380f7d 100644 --- a/server/utils/vectorDbProviders/weaviate/index.js +++ b/server/utils/vectorDbProviders/weaviate/index.js @@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { camelCase } = require("../../helpers/camelcase"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); const Weaviate = { name: "Weaviate", @@ -115,12 +116,53 @@ const Weaviate = { return; } result.contextTexts.push(rest.text); - result.sourceDocuments.push({ ...rest, id }); + result.sourceDocuments.push({ ...rest, id, score: certainty }); result.scores.push(certainty); }); return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, allNamespaces: async function (client) { try { const { classes = [] } = await client.schema.getter().do(); @@ -368,6 +410,7 @@ const Weaviate = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -382,17 +425,27 @@ const Weaviate = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); - - const sources = sourceDocuments.map((metadata, i) => { - return { ...metadata, text: contextTexts[i] }; + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); + + const sources = sourceDocuments.map((doc, i) => { + return { metadata: doc, text: contextTexts[i] }; }); return { contextTexts, @@ -431,11 +484,12 @@ const Weaviate = { curateSources: function (sources = []) { const documents = []; for (const source of sources) { - if (Object.keys(source).length > 0) { - const metadata = source.hasOwnProperty("metadata") - ? source.metadata - : source; - documents.push({ ...metadata }); + const { metadata = {} } = source; + if (Object.keys(metadata).length > 0) { + documents.push({ + ...metadata, + ...(source.text ? { text: source.text } : {}), + }); } } diff --git a/server/utils/vectorDbProviders/zilliz/index.js b/server/utils/vectorDbProviders/zilliz/index.js index ab866f4edd5..189d2eb85d6 100644 --- a/server/utils/vectorDbProviders/zilliz/index.js +++ b/server/utils/vectorDbProviders/zilliz/index.js @@ -10,6 +10,7 @@ const { v4: uuidv4 } = require("uuid"); const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { sourceIdentifier } = require("../../chats"); +const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank"); // Zilliz is basically a copy of Milvus DB class with a different constructor // to connect to the cloud @@ -157,30 +158,38 @@ const Zilliz = { vectorDimension = chunks[0][0].values.length || null; await this.getOrCreateCollection(client, namespace, vectorDimension); - for (const chunk of chunks) { - // Before sending to Pinecone and saving the records to our db - // we need to assign the id of each chunk that is stored in the cached file. - const newChunks = chunk.map((chunk) => { - const id = uuidv4(); - documentVectors.push({ docId, vectorId: id }); - return { id, vector: chunk.values, metadata: chunk.metadata }; - }); - const insertResult = await client.insert({ - collection_name: this.normalize(namespace), - data: newChunks, - }); + try { + for (const chunk of chunks) { + // Before sending to Zilliz and saving the records to our db + // we need to assign the id of each chunk that is stored in the cached file. + const newChunks = chunk.map((chunk) => { + const id = uuidv4(); + documentVectors.push({ docId, vectorId: id }); + return { id, vector: chunk.values, metadata: chunk.metadata }; + }); + const insertResult = await client.insert({ + collection_name: this.normalize(namespace), + data: newChunks, + }); - if (insertResult?.status.error_code !== "Success") { - throw new Error( - `Error embedding into Zilliz! Reason:${insertResult?.status.reason}` - ); + if (insertResult?.status.error_code !== "Success") { + throw new Error( + `Error embedding into Milvus! Reason:${insertResult?.status.reason}` + ); + } } + await DocumentVectors.bulkInsert(documentVectors); + await client.flushSync({ + collection_names: [this.normalize(namespace)], + }); + return { vectorized: true, error: null }; + } catch (insertError) { + console.error( + "Error inserting cached chunks:", + insertError.message + ); + return { vectorized: false, error: insertError.message }; } - await DocumentVectors.bulkInsert(documentVectors); - await client.flushSync({ - collection_names: [this.normalize(namespace)], - }); - return { vectorized: true, error: null }; } } @@ -239,7 +248,7 @@ const Zilliz = { data: chunk.map((item) => ({ id: item.id, vector: item.values, - metadata: chunk.metadata, + metadata: item.metadata, })), }); @@ -292,6 +301,7 @@ const Zilliz = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -306,18 +316,29 @@ const Zilliz = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const { contextTexts, sourceDocuments } = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); const sources = sourceDocuments.map((doc, i) => { return { metadata: doc, text: contextTexts[i] }; }); + return { contextTexts, sources: this.curateSources(sources), @@ -350,6 +371,7 @@ const Zilliz = { ); return; } + result.contextTexts.push(match.metadata.text); result.sourceDocuments.push({ ...match.metadata, @@ -359,6 +381,47 @@ const Zilliz = { }); return result; }, + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const totalEmbeddings = await this.namespaceCount(namespace); + const searchLimit = getSearchLimit(totalEmbeddings); + const { sourceDocuments } = await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN: searchLimit, + filterIdentifiers, + }); + + const rerankedResults = await rerank(query, sourceDocuments, topN); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + rerankedResults.forEach((item) => { + if (item.rerank_score < similarityThreshold) return; + const { rerank_score, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) return; + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score: rerank_score, + }); + result.scores.push(rerank_score); + }); + return result; + }, "namespace-stats": async function (reqBody = {}) { const { namespace = null } = reqBody; if (!namespace) throw new Error("namespace required"); @@ -394,7 +457,6 @@ const Zilliz = { }); } } - return documents; }, };