diff --git a/libs/superagent/app/agents/langchain.py b/libs/superagent/app/agents/langchain.py index 805c7af62..73db7a71c 100644 --- a/libs/superagent/app/agents/langchain.py +++ b/libs/superagent/app/agents/langchain.py @@ -1,7 +1,6 @@ import datetime import json import re -from typing import Any, List from decouple import config from langchain.agents import AgentType, initialize_agent @@ -63,9 +62,9 @@ def recursive_json_loads(data): class LangchainAgent(AgentBase): async def _get_tools( self, - agent_datasources: List[AgentDatasource], - agent_tools: List[AgentTool], - ) -> List: + agent_datasources: list[AgentDatasource], + agent_tools: list[AgentTool], + ) -> list: tools = [] for agent_datasource in agent_datasources: tool_type = ( @@ -103,10 +102,15 @@ async def _get_tools( for agent_tool in agent_tools: agent_tool_metadata = json.loads(agent_tool.tool.metadata or "{}") - # user id is added to the metadata for superrag tool agent_tool_metadata = { - "user_id": self.agent_config.apiUserId, **agent_tool_metadata, + "params": { + **(agent_tool_metadata.get("params", {}) or {}), + # user id is added to the metadata for superrag tool + "user_id": self.agent_config.apiUserId, + # session id is added to the metadata for agent as tool + "session_id": self.session_id, + }, } tool_info = TOOL_TYPE_MAPPING.get(agent_tool.tool.type) @@ -132,17 +136,12 @@ async def _get_tools( description=agent_tool.tool.description, metadata=metadata, args_schema=tool_info["schema"], - session_id=( - f"{self.agent_id}-{self.session_id}" - if self.session_id - else f"{self.agent_id}" - ), return_direct=agent_tool.tool.returnDirect, ) tools.append(tool) return tools - async def _get_llm(self, llm: LLM, model: str) -> Any: + async def _get_llm(self, llm: LLM, model: str): llm_params = { "temperature": 0, **(self.llm_params.dict() if self.llm_params else {}), @@ -193,16 +192,18 @@ async def _get_prompt(self, agent: Agent) -> str: content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}" return SystemMessage(content=content) - async def _get_memory(self) -> List: + async def _get_memory( + self, + ) -> None | MotorheadMemory | ConversationBufferWindowMemory: + # if memory is already set, in the main agent base class, return it + if not self.session_id: + raise ValueError("Session ID is required to initialize memory") + memory_type = config("MEMORY", "motorhead") if memory_type == "redis": memory = ConversationBufferWindowMemory( chat_memory=RedisChatMessageHistory( - session_id=( - f"{self.agent_id}-{self.session_id}" - if self.session_id - else f"{self.agent_id}" - ), + session_id=self.session_id, url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), key_prefix="superagent:", ), @@ -213,11 +214,7 @@ async def _get_memory(self) -> List: ) else: memory = MotorheadMemory( - session_id=( - f"{self.agent_id}-{self.session_id}" - if self.session_id - else f"{self.agent_id}" - ), + session_id=self.session_id, memory_key="chat_history", url=config("MEMORY_API_URL"), return_messages=True, diff --git a/libs/superagent/app/api/agents.py b/libs/superagent/app/api/agents.py index 56e62fd7c..36ead8837 100644 --- a/libs/superagent/app/api/agents.py +++ b/libs/superagent/app/api/agents.py @@ -412,6 +412,10 @@ async def invoke( langfuse_public_key = config("LANGFUSE_PUBLIC_KEY", "") langfuse_host = config("LANGFUSE_HOST", "https://cloud.langfuse.com") langfuse_handler = None + + session_id = body.sessionId or "" + session_id = f"agt_{agent_id}_{session_id}" + if langfuse_public_key and langfuse_secret_key: langfuse = Langfuse( public_key=langfuse_public_key, @@ -419,7 +423,6 @@ async def invoke( host=langfuse_host, sdk_integration="Superagent", ) - session_id = f"{agent_id}-{body.sessionId}" if body.sessionId else f"{agent_id}" trace = langfuse.trace( id=session_id, name="Assistant", @@ -437,7 +440,7 @@ async def invoke( agentops_handler = AsyncLangchainCallbackHandler( api_key=agentops_api_key, org_key=agentops_org_key, - tags=[agent_id, str(body.sessionId)], + tags=[agent_id, session_id], ) agent_config = await prisma.agent.find_unique_or_raise( @@ -526,7 +529,7 @@ async def send_message( { "user_id": api_user.id, "agent": agent_config, - "session_id": body.sessionId, + "session_id": session_id, **result, **vars(cost_callback), } @@ -558,12 +561,12 @@ async def send_message( streaming_callback.done.set() logger.info("Invoking agent...") - session_id = body.sessionId input = body.input enable_streaming = body.enableStreaming output_schema = body.outputSchema cost_callback = CostCalcAsyncHandler(model=model) streaming_callback = CustomAsyncIteratorCallbackHandler() + agent_base = AgentBase( agent_id=agent_id, session_id=session_id, @@ -613,7 +616,7 @@ async def send_message( { "user_id": api_user.id, "agent": agent_config, - "session_id": body.sessionId, + "session_id": session_id, **output, **vars(cost_callback), } diff --git a/libs/superagent/app/api/workflows.py b/libs/superagent/app/api/workflows.py index 9e8573a72..88faf6aaf 100644 --- a/libs/superagent/app/api/workflows.py +++ b/libs/superagent/app/api/workflows.py @@ -184,6 +184,8 @@ async def invoke( where={"id": workflow_id}, include={"steps": {"include": {"agent": True}, "order_by": {"order": "asc"}}}, ) + session_id = body.sessionId or "" + session_id = f"wf_{workflow_id}_{session_id}" workflow_steps = [] for workflow_step in workflow_data.steps: @@ -201,7 +203,7 @@ async def invoke( "agent_name": workflow_step.agent.name, } session_tracker_handler = get_session_tracker_handler( - workflow_data.id, workflow_step.agent.id, body.sessionId, api_user.id + workflow_data.id, workflow_step.agent.id, session_id, api_user.id ) if session_tracker_handler: @@ -216,7 +218,6 @@ async def invoke( callbacks.append(v) workflow_callbacks.append(callbacks) - session_id = body.sessionId input = body.input enable_streaming = body.enableStreaming @@ -268,6 +269,7 @@ async def send_message() -> AsyncIterable[str]: raise exception workflow_result = task.result() + for index, workflow_step in enumerate(workflow_steps): workflow_step_result = workflow_result.get("steps")[index] diff --git a/libs/superagent/app/tools/__init__.py b/libs/superagent/app/tools/__init__.py index 13d8470b0..66050fa90 100644 --- a/libs/superagent/app/tools/__init__.py +++ b/libs/superagent/app/tools/__init__.py @@ -114,10 +114,7 @@ def create_tool( args_schema: Any, metadata: Optional[Dict[str, Any]], return_direct: Optional[bool], - session_id: str = None, ) -> Any: - if metadata: - metadata["sessionId"] = session_id return tool_class( name=name, description=description, diff --git a/libs/superagent/app/tools/agent.py b/libs/superagent/app/tools/agent.py index 4286d5811..7f37dcff6 100644 --- a/libs/superagent/app/tools/agent.py +++ b/libs/superagent/app/tools/agent.py @@ -12,7 +12,9 @@ class Agent(BaseTool): description = "useful for answering questions." def _run(self, input: str) -> str: - agent_id = self.metadata.get("agentId") + agent_id = self.metadata["agentId"] + params = self.metadata["params"] + session_id = params.get("session_id") agent_config = prisma.agent.find_unique_or_raise( where={"id": agent_id}, @@ -29,6 +31,7 @@ def _run(self, input: str) -> str: agent_id, enable_streaming=False, agent_config=agent_config, + session_id=session_id, ) agent = agent_base.get_agent() @@ -44,7 +47,9 @@ def _run(self, input: str) -> str: return result.get("output") async def _arun(self, input: str) -> str: - agent_id = self.metadata.get("agentId") + agent_id = self.metadata["agentId"] + params = self.metadata["params"] + session_id = params.get("session_id") agent_config = await prisma.agent.find_unique_or_raise( where={"id": agent_id}, @@ -61,6 +66,7 @@ async def _arun(self, input: str) -> str: agent_id, enable_streaming=False, agent_config=agent_config, + session_id=session_id, ) agent = await agent_base.get_agent() diff --git a/libs/superagent/app/tools/superrag.py b/libs/superagent/app/tools/superrag.py index b0cadf0e1..d312b3560 100644 --- a/libs/superagent/app/tools/superrag.py +++ b/libs/superagent/app/tools/superrag.py @@ -28,7 +28,9 @@ def _run( question: str, ) -> str: """Use the tool.""" - pass + raise NotImplementedError( + "Sync run not implemented for SuperRag tool. Use async run." + ) async def _arun( self, @@ -40,7 +42,8 @@ async def _arun( vector_database = self.metadata.get("vector_database") interpreter_mode = self.metadata.get("interpreter_mode") - api_user_id = self.metadata.get("user_id") + params = self.metadata.get("params") + user_id = params.get("user_id") # with lower case e.g. pinecone, qdrant database_provider = vector_database.get("type").lower() @@ -48,7 +51,7 @@ async def _arun( provider = await prisma.vectordb.find_first( where={ "provider": VECTOR_DB_MAPPING.get(database_provider), - "apiUserId": api_user_id, + "apiUserId": user_id, } ) diff --git a/libs/superagent/app/utils/callbacks.py b/libs/superagent/app/utils/callbacks.py index 46ec6d43e..64dfe5d7a 100644 --- a/libs/superagent/app/utils/callbacks.py +++ b/libs/superagent/app/utils/callbacks.py @@ -136,7 +136,7 @@ def _calculate_cost_per_token( def get_session_tracker_handler( workflow_id, agent_id, - req_session_id, + session_id, user_id, ): langfuse_secret_key = config("LANGFUSE_SECRET_KEY", "") @@ -150,11 +150,8 @@ def get_session_tracker_handler( host=langfuse_host, sdk_integration="Superagent", ) - trace_id = ( - f"{workflow_id}-{req_session_id}" if req_session_id else f"{workflow_id}" - ) trace = langfuse.trace( - id=trace_id, + id=session_id, name="Workflow", tags=[agent_id], metadata={"agentId": agent_id, "workflowId": workflow_id},