diff --git a/python/README.md b/python/README.md index c750724..6dec8fb 100644 --- a/python/README.md +++ b/python/README.md @@ -13,8 +13,9 @@ ```python from reag.client import ReagClient, Document - - async with ReagClient() as client: +async with ReagClient( + model="ollama/deepseek-r1:14b", + api_base="http://localhost:11434") as client: docs = [ Document( name="Superagent", @@ -36,7 +37,7 @@ Initialize the client by providing required configuration options: ```typescript client = new ReagClient( - model: "o3-mini", // LiteLLM model name + model: "gpt-4o-mini", // LiteLLM model name system: Optional[str] // Optional system prompt batchSize: Optional[Number] // Optional batch size schema: Optional[BaseModel] // Optional Pydantic schema diff --git a/python/pyproject.toml b/python/pyproject.toml index 8cae2f9..4d449f4 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -29,6 +29,7 @@ python = ">=3.9,<3.13" pydantic = "^2.0.0" httpx = "^0.25.0" litellm = "^1.60.0" +ollama = "0.3.1" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" @@ -51,4 +52,4 @@ multi_line_output = 3 [tool.pytest.ini_options] pythonpath = [ "src" -] \ No newline at end of file +] diff --git a/python/src/reag/client.py b/python/src/reag/client.py index 10e6900..5337919 100644 --- a/python/src/reag/client.py +++ b/python/src/reag/client.py @@ -1,7 +1,7 @@ import httpx import asyncio import json - +import re from typing import List, Optional, TypeVar, Dict, Union from pydantic import BaseModel from litellm import acompletion @@ -42,11 +42,13 @@ def __init__( system: str = None, batch_size: int = DEFAULT_BATCH_SIZE, schema: Optional[BaseModel] = None, + api_base: Optional[str] = None, # Added for Ollama support ): self.model = model self.system = system or REAG_SYSTEM_PROMPT self.batch_size = batch_size self.schema = schema or ResponseSchemaMessage + self.api_base = api_base # New attribute for API base URL self._http_client = None async def __aenter__(self): @@ -127,6 +129,31 @@ def _filter_documents_by_metadata( return filtered_docs + def _extract_think_content(self, text: str) -> tuple[str, str, bool]: + """Extract content from think tags and parse the bulleted response format.""" + # Extract think content + think_match = re.search(r'(.*?)', text, flags=re.DOTALL) + reasoning = think_match.group(1).strip() if think_match else "" + + # Remove think tags and get remaining text + remaining_text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + + # Initialize default values + content = "" + is_irrelevant = True + + # Extract is_irrelevant value + irrelevant_match = re.search(r'\*\*isIrrelevant:\*\*\s*(true|false)', remaining_text, re.IGNORECASE) + if irrelevant_match: + is_irrelevant = irrelevant_match.group(1).lower() == 'true' + + # Extract content value + content_match = re.search(r'\*\*Answer:\*\*\s*(.*?)(?:\n|$)', remaining_text, re.DOTALL) + if content_match: + content = content_match.group(1).strip() + + return content, reasoning, is_irrelevant + async def query( self, prompt: str, documents: List[Document], options: Optional[Dict] = None ) -> List[QueryResult]: @@ -179,29 +206,40 @@ def format_doc(doc: Document) -> str: message_content = response.choices[0].message.content try: - # Ensure it's parsed as a dict - data = ( - json.loads(message_content) - if isinstance(message_content, str) - else message_content - ) + if self.model.startswith("ollama/"): + content, reasoning, is_irrelevant = self._extract_think_content(message_content) + results.append( + QueryResult( + content=content, + reasoning=reasoning, + is_irrelevant=is_irrelevant, + document=document, + ) + ) + else: + # Ensure it's parsed as a dict + data = ( + json.loads(message_content) + if isinstance(message_content, str) + else message_content + ) - if data["source"].get("is_irrelevant", True): - continue + if data["source"].get("is_irrelevant", True): + continue - results.append( - QueryResult( - content=data["source"].get("content", ""), - reasoning=data["source"].get("reasoning", ""), - is_irrelevant=data["source"].get("is_irrelevant", False), - document=document, + results.append( + QueryResult( + content=data["source"].get("content", ""), + reasoning=data["source"].get("reasoning", ""), + is_irrelevant=data["source"].get("is_irrelevant", False), + document=document, + ) ) - ) except json.JSONDecodeError: print("Error: Could not parse response:", message_content) continue - return results # Moved outside the batch loop to return all results + return results except Exception as e: raise Exception(f"Query failed: {str(e)}")