diff --git a/python/README.md b/python/README.md index e8e09da..45d940c 100644 --- a/python/README.md +++ b/python/README.md @@ -38,7 +38,8 @@ Initialize the client by providing required configuration options: ```typescript client = new ReagClient( - model: "gpt-4o-mini", // LiteLLM model name + model: "gpt-4.1-nano", // LiteLLM model name + filtration_model: "minimax", // Filtration model name system: Optional[str] // Optional system prompt batchSize: Optional[Number] // Optional batch size schema: Optional[BaseModel] // Optional Pydantic schema diff --git a/python/src/reag/client.py b/python/src/reag/client.py index b7d4a3e..faf7840 100644 --- a/python/src/reag/client.py +++ b/python/src/reag/client.py @@ -39,12 +39,14 @@ class ReagClient: def __init__( self, model: str = "gpt-4o-mini", + filtration_model: str = "minimax", system: str = None, batch_size: int = DEFAULT_BATCH_SIZE, schema: Optional[BaseModel] = None, model_kwargs: Optional[Dict] = None, ): self.model = model + self.filtration_model = filtration_model self.system = system or REAG_SYSTEM_PROMPT self.batch_size = batch_size self.schema = schema or ResponseSchemaMessage @@ -199,46 +201,64 @@ def format_doc(doc: Document) -> str: ) ) - # Process all documents in the batch concurrently - batch_responses = await asyncio.gather(*tasks) + # Use litellm for model completion with the filtration model + filtration_response = await acompletion( + model=self.filtration_model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": prompt}, + ], + response_format=self.schema, + ) - # Process the responses - for document, response in zip(batch, batch_responses): - message_content = response.choices[0].message.content + filtration_message_content = filtration_response.choices[ + 0 + ].message.content # Might be a JSON string try: - if self.model.startswith("ollama/"): - content, reasoning, is_irrelevant = self._extract_think_content(message_content) - results.append( - QueryResult( - content=content, - reasoning=reasoning, - is_irrelevant=is_irrelevant, - document=document, - ) - ) - else: - # Ensure it's parsed as a dict - data = ( - json.loads(message_content) - if isinstance(message_content, str) - else message_content - ) + # Ensure it's parsed as a dict + filtration_data = ( + json.loads(filtration_message_content) + if isinstance(filtration_message_content, str) + else filtration_message_content + ) + + if filtration_data["source"].get("is_irrelevant", True): + continue + + # Use litellm for model completion with the reasoning model + reasoning_response = await acompletion( + model=self.model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": prompt}, + ], + response_format=self.schema, + ) + + reasoning_message_content = reasoning_response.choices[ + 0 + ].message.content # Might be a JSON string - if data["source"].get("is_irrelevant", True): - continue + reasoning_data = ( + json.loads(reasoning_message_content) + if isinstance(reasoning_message_content, str) + else reasoning_message_content + ) - results.append( - QueryResult( - content=data["source"].get("content", ""), - reasoning=data["source"].get("reasoning", ""), - is_irrelevant=data["source"].get("is_irrelevant", False), - document=document, - ) + results.append( + QueryResult( + content=reasoning_data["source"].get("content", ""), + reasoning=reasoning_data["source"].get("reasoning", ""), + is_irrelevant=reasoning_data["source"].get( + "is_irrelevant", False + ), + document=document, ) except json.JSONDecodeError: - print("Error: Could not parse response:", message_content) - continue + print("Error: Could not parse response:", filtration_message_content) + continue # Skip this iteration if parsing fails + return results diff --git a/typescript/README.md b/typescript/README.md index 955f916..f8f8792 100644 --- a/typescript/README.md +++ b/typescript/README.md @@ -21,6 +21,7 @@ import { openai } from "@ai-sdk/openai"; // Initialize the SDK with required options const client = new ReagClient({ model: openai("o3-mini", { structuredOutputs: true }), + filtrationModel: openai("minimax", { structuredOutputs: true }), // system: optional system prompt here or use the default }); @@ -62,6 +63,7 @@ Initialize the client by providing required configuration options: ```typescript const client = new ReagClient({ model: openai("o3-mini", { structuredOutputs: true }), + filtrationModel: openai("minimax", { structuredOutputs: true }), system?: string // Optional system prompt batchSize?: number // Optional batch size schema?: z.ZodSchema // Optional schema diff --git a/typescript/src/client.ts b/typescript/src/client.ts index 011ac1c..20def40 100644 --- a/typescript/src/client.ts +++ b/typescript/src/client.ts @@ -16,6 +16,11 @@ export interface ClientOptions { * See: https://sdk.vercel.ai/docs/foundations/providers-and-models */ model: LanguageModel; + /** + * The filtration model instance to use for document filtration. + * This should be an instance of a model that implements the Vercel AI SDK's LanguageModel interface. + */ + filtrationModel: LanguageModel; /** * The system prompt that provides context and instructions to the model. * This string sets the behavior and capabilities of the model for all queries. @@ -69,6 +74,7 @@ const DEFAULT_BATCH_SIZE = 20; */ export class ReagClient { private readonly model: LanguageModel; + private readonly filtrationModel: LanguageModel; private readonly system: string; private readonly batchSize: number; private readonly schema: z.ZodSchema; @@ -79,6 +85,7 @@ export class ReagClient { */ constructor(options: ClientOptions) { this.model = options.model; + this.filtrationModel = options.filtrationModel; this.system = options.system || REAG_SYSTEM_PROMPT; this.batchSize = options.batchSize || DEFAULT_BATCH_SIZE; this.schema = options.schema || RESPONSE_SCHEMA; @@ -171,7 +178,23 @@ export class ReagClient { const system = `${ this.system }\n\n# Available source\n\n${formatDoc(document)}`; - const response = await generateObject({ + + // Use the filtration model for document filtration + const filtrationResponse = await generateObject({ + model: this.filtrationModel, + system, + prompt, + schema: this.schema, + }); + + const filtrationData = filtrationResponse.object; + + if (filtrationData.isIrrelevant) { + return null; + } + + // Use the reasoning model for generating the final answer + const reasoningResponse = await generateObject({ model: this.model, system, prompt, @@ -179,12 +202,13 @@ export class ReagClient { }); return { - response, + response: reasoningResponse, document, }; }) ); - return batchResponses; + + return batchResponses.filter((response) => response !== null); }) );