diff --git a/python/src/reag/client.py b/python/src/reag/client.py index 7d174b6..10e6900 100644 --- a/python/src/reag/client.py +++ b/python/src/reag/client.py @@ -156,24 +156,27 @@ def format_doc(doc: Document) -> str: results = [] for batch in batches: + tasks = [] + # Create tasks for parallel processing within the batch for document in batch: - system = ( - f"{self.system}\n\n# Available source\n\n{format_doc(document)}" + system = f"{self.system}\n\n# Available source\n\n{format_doc(document)}" + tasks.append( + acompletion( + model=self.model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": prompt}, + ], + response_format=self.schema, + ) ) - # Use litellm for model completion with the Pydantic schema - response = await acompletion( - model=self.model, - messages=[ - {"role": "system", "content": system}, - {"role": "user", "content": prompt}, - ], - response_format=self.schema, - ) + # Process all documents in the batch concurrently + batch_responses = await asyncio.gather(*tasks) - message_content = response.choices[ - 0 - ].message.content # Might be a JSON string + # Process the responses + for document, response in zip(batch, batch_responses): + message_content = response.choices[0].message.content try: # Ensure it's parsed as a dict @@ -190,17 +193,15 @@ def format_doc(doc: Document) -> str: QueryResult( content=data["source"].get("content", ""), reasoning=data["source"].get("reasoning", ""), - is_irrelevant=data["source"].get( - "is_irrelevant", False - ), + is_irrelevant=data["source"].get("is_irrelevant", False), document=document, ) ) except json.JSONDecodeError: print("Error: Could not parse response:", message_content) - continue # Skip this iteration if parsing fails + continue - return results + return results # Moved outside the batch loop to return all results except Exception as e: raise Exception(f"Query failed: {str(e)}") diff --git a/python/tests/test_client.py b/python/tests/test_client.py index 10b57be..de91999 100644 --- a/python/tests/test_client.py +++ b/python/tests/test_client.py @@ -1,5 +1,5 @@ import pytest -from src.reag.client import ReagClient, Document, QueryResult +from reag.client import ReagClient, Document, QueryResult @pytest.mark.asyncio