θΏ™ζ˜―indexlocζδΎ›ηš„ζœεŠ‘οΌŒδΈθ¦θΎ“ε…₯任何密码
Skip to content

Reranker option for RAG #2929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { useState } from "react";

// We dont support all vectorDBs yet for reranking due to complexities of how each provider
// returns information. We need to normalize the response data so Reranker can be used for each provider.
const supportedVectorDBs = ["lancedb"];
const hint = {
default: {
title: "Default",
description:
"This is the fastest performance, but may not return the most relevant results leading to model hallucinations.",
},
rerank: {
title: "Accuracy Optimized",
description:
"LLM responses may take longer to generate, but your responses will be more accurate and relevant.",
},
};

export default function VectorSearchMode({ workspace, setHasChanges }) {
const [selection, setSelection] = useState(
workspace?.vectorSearchMode ?? "default"
);
if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
return null;

return (
<div>
<div className="flex flex-col">
<label htmlFor="name" className="block input-label">
Search Preference
</label>
</div>
<select
name="vectorSearchMode"
value={selection}
className="border-none bg-theme-settings-input-bg text-white text-sm mt-2 rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
onChange={(e) => {
setSelection(e.target.value);
setHasChanges(true);
}}
required={true}
>
<option value="default">Default</option>
<option value="rerank">Accuracy Optimized</option>
</select>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
{hint[selection]?.description}
</p>
</div>
);
}
2 changes: 2 additions & 0 deletions frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import MaxContextSnippets from "./MaxContextSnippets";
import DocumentSimilarityThreshold from "./DocumentSimilarityThreshold";
import ResetDatabase from "./ResetDatabase";
import VectorCount from "./VectorCount";
import VectorSearchMode from "./VectorSearchMode";

export default function VectorDatabase({ workspace }) {
const [hasChanges, setHasChanges] = useState(false);
Expand Down Expand Up @@ -43,6 +44,7 @@ export default function VectorDatabase({ workspace }) {
<VectorDBIdentifier workspace={workspace} />
<VectorCount reload={true} workspace={workspace} />
</div>
<VectorSearchMode workspace={workspace} setHasChanges={setHasChanges} />
<MaxContextSnippets workspace={workspace} setHasChanges={setHasChanges} />
<DocumentSimilarityThreshold
workspace={workspace}
Expand Down
3 changes: 3 additions & 0 deletions frontend/src/pages/WorkspaceSettings/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Members from "./Members";
import WorkspaceAgentConfiguration from "./AgentConfig";
import useUser from "@/hooks/useUser";
import { useTranslation } from "react-i18next";
import System from "@/models/system";

const TABS = {
"general-appearance": GeneralAppearance,
Expand Down Expand Up @@ -59,9 +60,11 @@ function ShowWorkspaceChat() {
return;
}

const _settings = await System.keys();
const suggestedMessages = await Workspace.getSuggestedMessages(slug);
setWorkspace({
..._workspace,
vectorDB: _settings?.VectorDB,
suggestedMessages,
});
setLoading(false);
Expand Down
1 change: 1 addition & 0 deletions server/endpoints/api/workspace/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ function apiWorkspaceEndpoints(app) {
LLMConnector: getLLMProvider(),
similarityThreshold: parseSimilarityThreshold(),
topN: parseTopN(),
rerank: workspace?.vectorSearchMode === "rerank",
});

response.status(200).json({
Expand Down
10 changes: 10 additions & 0 deletions server/models/workspace.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const Workspace = {
"agentProvider",
"agentModel",
"queryRefusalResponse",
"vectorSearchMode",
],

validations: {
Expand Down Expand Up @@ -99,6 +100,15 @@ const Workspace = {
if (!value || typeof value !== "string") return null;
return String(value);
},
vectorSearchMode: (value) => {
if (
!value ||
typeof value !== "string" ||
!["default", "rerank"].includes(value)
)
return "default";
return value;
},
},

/**
Expand Down
2 changes: 2 additions & 0 deletions server/prisma/migrations/20250102204948_init/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspaces" ADD COLUMN "vectorSearchMode" TEXT DEFAULT 'default';
1 change: 1 addition & 0 deletions server/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ model workspaces {
agentProvider String?
agentModel String?
queryRefusalResponse String?
vectorSearchMode String? @default("default")
workspace_users workspace_users[]
documents workspace_documents[]
workspace_suggested_messages workspace_suggested_messages[]
Expand Down
3 changes: 2 additions & 1 deletion server/storage/models/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ downloaded/*
!downloaded/.placeholder
openrouter
apipie
novita
novita
mixedbread-ai*
153 changes: 153 additions & 0 deletions server/utils/EmbeddingRerankers/native/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
const path = require("path");
const fs = require("fs");

class NativeEmbeddingReranker {
static #model = null;
static #tokenizer = null;
static #transformers = null;

constructor() {
// An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
// Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
this.cacheDir = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, `models`)
: path.resolve(__dirname, `../../../storage/models`)
);
this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
// Make directory when it does not exist in existing installations
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
this.log("Initialized");
}

log(text, ...args) {
console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
}

/**
* This function will preload the reranker suite and tokenizer.
* This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
* to avoid having to wait for the models to download on the first rerank call.
*/
async preload() {
try {
this.log(`Preloading reranker suite...`);
await this.initClient();
this.log(
`Preloaded reranker suite. Reranking is available as a service now.`
);
return;
} catch (e) {
console.error(e);
this.log(
`Failed to preload reranker suite. Reranking will be available on the first rerank call.`
);
return;
}
}

async initClient() {
if (NativeEmbeddingReranker.#transformers) {
this.log(`Reranker suite already initialized - reusing.`);
return;
}

await import("@xenova/transformers").then(
async ({ AutoModelForSequenceClassification, AutoTokenizer }) => {
this.log(`Loading reranker suite...`);
NativeEmbeddingReranker.#transformers = {
AutoModelForSequenceClassification,
AutoTokenizer,
};
await this.#getPreTrainedModel();
await this.#getPreTrainedTokenizer();
}
);
return;
}

async #getPreTrainedModel() {
if (NativeEmbeddingReranker.#model) {
this.log(`Loading model from singleton...`);
return NativeEmbeddingReranker.#model;
}

const model =
await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
this.model,
{
progress_callback: (p) =>
p.status === "progress" &&
this.log(`Loading model ${this.model}... ${p?.progress}%`),
cache_dir: this.cacheDir,
}
);
this.log(`Loaded model ${this.model}`);
NativeEmbeddingReranker.#model = model;
return model;
}

async #getPreTrainedTokenizer() {
if (NativeEmbeddingReranker.#tokenizer) {
this.log(`Loading tokenizer from singleton...`);
return NativeEmbeddingReranker.#tokenizer;
}

const tokenizer =
await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
this.model,
{
progress_callback: (p) =>
p.status === "progress" &&
this.log(`Loading tokenizer ${this.model}... ${p?.progress}%`),
cache_dir: this.cacheDir,
}
);
this.log(`Loaded tokenizer ${this.model}`);
NativeEmbeddingReranker.#tokenizer = tokenizer;
return tokenizer;
}

/**
* Reranks a list of documents based on the query.
* @param {string} query - The query to rerank the documents against.
* @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
* @param {Object} options - The options for the reranking.
* @param {number} options.topK - The number of top documents to return.
* @returns {Promise<any[]>} - The reranked list of documents.
*/
async rerank(query, documents, options = { topK: 4 }) {
await this.initClient();
const model = NativeEmbeddingReranker.#model;
const tokenizer = NativeEmbeddingReranker.#tokenizer;

const start = Date.now();
this.log(`Reranking ${documents.length} documents...`);
const inputs = tokenizer(new Array(documents.length).fill(query), {
text_pair: documents.map((doc) => doc.text),
padding: true,
truncation: true,
});
const { logits } = await model(inputs);
const reranked = logits
.sigmoid()
.tolist()
.map(([score], i) => ({
rerank_corpus_id: i,
rerank_score: score,
...documents[i],
}))
.sort((a, b) => b.rerank_score - a.rerank_score)
.slice(0, options.topK);

this.log(
`Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
);
return reranked;
}
}

module.exports = {
NativeEmbeddingReranker,
};
1 change: 1 addition & 0 deletions server/utils/agents/aibitat/plugins/memory.js
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const memory = {
input: query,
LLMConnector,
topN: workspace?.topN ?? 4,
rerank: workspace?.vectorSearchMode === "rerank",
});

if (contextTexts.length === 0) {
Expand Down
2 changes: 2 additions & 0 deletions server/utils/chats/apiChatHandler.js
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ async function chatSync({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
Expand Down Expand Up @@ -480,6 +481,7 @@ async function streamChat({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
Expand Down
1 change: 1 addition & 0 deletions server/utils/chats/embed.js
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async function streamChatWithForEmbed(
similarityThreshold: embed.workspace?.similarityThreshold,
topN: embed.workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: embed.workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
Expand Down
2 changes: 2 additions & 0 deletions server/utils/chats/openaiCompatible.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async function chatSync({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
Expand Down Expand Up @@ -304,6 +305,7 @@ async function streamChat({
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
Expand Down
1 change: 1 addition & 0 deletions server/utils/chats/stream.js
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ async function streamChatWithWorkspace(
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
filterIdentifiers: pinnedDocIdentifiers,
rerank: workspace?.vectorSearchMode === "rerank",
})
: {
contextTexts: [],
Expand Down
1 change: 1 addition & 0 deletions server/utils/helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
* @property {Function} totalVectors - Returns the total number of vectors in the database.
* @property {Function} namespaceCount - Returns the count of vectors in a given namespace.
* @property {Function} similarityResponse - Performs a similarity search on a given namespace.
* @property {Function} rerankedSimilarityResponse - Performs a similarity search on a given namespace with reranking (if supported by provider).
* @property {Function} namespace - Retrieves the specified namespace collection.
* @property {Function} hasNamespace - Checks if a namespace exists.
* @property {Function} namespaceExists - Verifies if a namespace exists in the client.
Expand Down
Loading