From 0d6b1599209acf807f6630161004705e4aaf7f4c Mon Sep 17 00:00:00 2001 From: Jason Kneen Date: Mon, 3 Feb 2025 19:47:54 +0000 Subject: [PATCH] Separate doc filtration and reasoning LLMs Fixes #1 Separate the document filtration and reasoning LLMs into two distinct models. * **Python Changes:** - Add a new parameter `filtration_model` to the `ReagClient` constructor in `python/src/reag/client.py`. - Update the `query` method to use the `filtration_model` for document filtration and the `model` for generating the final answer. - Update `python/README.md` to reflect the changes in the `ReagClient` class and the new two-step process. * **TypeScript Changes:** - Add a new parameter `filtrationModel` to the `ReagClient` constructor in `typescript/src/client.ts`. - Update the `query` method to use the `filtrationModel` for document filtration and the `model` for generating the final answer. - Update `typescript/README.md` to reflect the changes in the `ReagClient` class and the new two-step process. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/superagent-ai/reag/issues/1?shareId=XXXX-XXXX-XXXX-XXXX). --- python/README.md | 1 + python/src/reag/client.py | 50 ++++++++++++++++++++++++++++----------- typescript/README.md | 2 ++ typescript/src/client.ts | 30 ++++++++++++++++++++--- 4 files changed, 66 insertions(+), 17 deletions(-) diff --git a/python/README.md b/python/README.md index c750724..eb17ba2 100644 --- a/python/README.md +++ b/python/README.md @@ -37,6 +37,7 @@ Initialize the client by providing required configuration options: ```typescript client = new ReagClient( model: "o3-mini", // LiteLLM model name + filtration_model: "minimax", // Filtration model name system: Optional[str] // Optional system prompt batchSize: Optional[Number] // Optional batch size schema: Optional[BaseModel] // Optional Pydantic schema diff --git a/python/src/reag/client.py b/python/src/reag/client.py index 7d174b6..3b8e82e 100644 --- a/python/src/reag/client.py +++ b/python/src/reag/client.py @@ -39,11 +39,13 @@ class ReagClient: def __init__( self, model: str = "gpt-4o-mini", + filtration_model: str = "minimax", system: str = None, batch_size: int = DEFAULT_BATCH_SIZE, schema: Optional[BaseModel] = None, ): self.model = model + self.filtration_model = filtration_model self.system = system or REAG_SYSTEM_PROMPT self.batch_size = batch_size self.schema = schema or ResponseSchemaMessage @@ -161,9 +163,9 @@ def format_doc(doc: Document) -> str: f"{self.system}\n\n# Available source\n\n{format_doc(document)}" ) - # Use litellm for model completion with the Pydantic schema - response = await acompletion( - model=self.model, + # Use litellm for model completion with the filtration model + filtration_response = await acompletion( + model=self.filtration_model, messages=[ {"role": "system", "content": system}, {"role": "user", "content": prompt}, @@ -171,36 +173,56 @@ def format_doc(doc: Document) -> str: response_format=self.schema, ) - message_content = response.choices[ + filtration_message_content = filtration_response.choices[ 0 ].message.content # Might be a JSON string try: # Ensure it's parsed as a dict - data = ( - json.loads(message_content) - if isinstance(message_content, str) - else message_content + filtration_data = ( + json.loads(filtration_message_content) + if isinstance(filtration_message_content, str) + else filtration_message_content ) - if data["source"].get("is_irrelevant", True): + if filtration_data["source"].get("is_irrelevant", True): continue + # Use litellm for model completion with the reasoning model + reasoning_response = await acompletion( + model=self.model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": prompt}, + ], + response_format=self.schema, + ) + + reasoning_message_content = reasoning_response.choices[ + 0 + ].message.content # Might be a JSON string + + reasoning_data = ( + json.loads(reasoning_message_content) + if isinstance(reasoning_message_content, str) + else reasoning_message_content + ) + results.append( QueryResult( - content=data["source"].get("content", ""), - reasoning=data["source"].get("reasoning", ""), - is_irrelevant=data["source"].get( + content=reasoning_data["source"].get("content", ""), + reasoning=reasoning_data["source"].get("reasoning", ""), + is_irrelevant=reasoning_data["source"].get( "is_irrelevant", False ), document=document, ) ) except json.JSONDecodeError: - print("Error: Could not parse response:", message_content) + print("Error: Could not parse response:", filtration_message_content) continue # Skip this iteration if parsing fails - return results + return results except Exception as e: raise Exception(f"Query failed: {str(e)}") diff --git a/typescript/README.md b/typescript/README.md index 955f916..f8f8792 100644 --- a/typescript/README.md +++ b/typescript/README.md @@ -21,6 +21,7 @@ import { openai } from "@ai-sdk/openai"; // Initialize the SDK with required options const client = new ReagClient({ model: openai("o3-mini", { structuredOutputs: true }), + filtrationModel: openai("minimax", { structuredOutputs: true }), // system: optional system prompt here or use the default }); @@ -62,6 +63,7 @@ Initialize the client by providing required configuration options: ```typescript const client = new ReagClient({ model: openai("o3-mini", { structuredOutputs: true }), + filtrationModel: openai("minimax", { structuredOutputs: true }), system?: string // Optional system prompt batchSize?: number // Optional batch size schema?: z.ZodSchema // Optional schema diff --git a/typescript/src/client.ts b/typescript/src/client.ts index 011ac1c..20def40 100644 --- a/typescript/src/client.ts +++ b/typescript/src/client.ts @@ -16,6 +16,11 @@ export interface ClientOptions { * See: https://sdk.vercel.ai/docs/foundations/providers-and-models */ model: LanguageModel; + /** + * The filtration model instance to use for document filtration. + * This should be an instance of a model that implements the Vercel AI SDK's LanguageModel interface. + */ + filtrationModel: LanguageModel; /** * The system prompt that provides context and instructions to the model. * This string sets the behavior and capabilities of the model for all queries. @@ -69,6 +74,7 @@ const DEFAULT_BATCH_SIZE = 20; */ export class ReagClient { private readonly model: LanguageModel; + private readonly filtrationModel: LanguageModel; private readonly system: string; private readonly batchSize: number; private readonly schema: z.ZodSchema; @@ -79,6 +85,7 @@ export class ReagClient { */ constructor(options: ClientOptions) { this.model = options.model; + this.filtrationModel = options.filtrationModel; this.system = options.system || REAG_SYSTEM_PROMPT; this.batchSize = options.batchSize || DEFAULT_BATCH_SIZE; this.schema = options.schema || RESPONSE_SCHEMA; @@ -171,7 +178,23 @@ export class ReagClient { const system = `${ this.system }\n\n# Available source\n\n${formatDoc(document)}`; - const response = await generateObject({ + + // Use the filtration model for document filtration + const filtrationResponse = await generateObject({ + model: this.filtrationModel, + system, + prompt, + schema: this.schema, + }); + + const filtrationData = filtrationResponse.object; + + if (filtrationData.isIrrelevant) { + return null; + } + + // Use the reasoning model for generating the final answer + const reasoningResponse = await generateObject({ model: this.model, system, prompt, @@ -179,12 +202,13 @@ export class ReagClient { }); return { - response, + response: reasoningResponse, document, }; }) ); - return batchResponses; + + return batchResponses.filter((response) => response !== null); }) );