diff --git a/libs/superagent/app/agents/base.py b/libs/superagent/app/agents/base.py index dbf42e06e..552ab281f 100644 --- a/libs/superagent/app/agents/base.py +++ b/libs/superagent/app/agents/base.py @@ -1,5 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional +from app.models.request import LLMParams from app.utils.streaming import CustomAsyncIteratorCallbackHandler from prisma.models import Agent, AgentDatasource, AgentLLM, AgentTool @@ -17,7 +18,7 @@ def __init__( enable_streaming: bool = False, output_schema: str = None, callback: CustomAsyncIteratorCallbackHandler = None, - llm_params: dict[any, any] = {}, + llm_params: Optional[LLMParams] = {}, agent_config: Agent = None, ): self.agent_id = agent_id diff --git a/libs/superagent/app/agents/langchain.py b/libs/superagent/app/agents/langchain.py index b78c5de4b..3a06c9d40 100644 --- a/libs/superagent/app/agents/langchain.py +++ b/libs/superagent/app/agents/langchain.py @@ -108,7 +108,7 @@ async def _get_tools( async def _get_llm(self, agent_llm: AgentLLM, model: str) -> Any: llm_params = { "temperature": 0, - **self.llm_params, + **(self.llm_params.dict() if self.llm_params else {}), } if agent_llm.llm.provider == "OPENAI": diff --git a/libs/superagent/app/api/agents.py b/libs/superagent/app/api/agents.py index 32859156d..86b3cbcfc 100644 --- a/libs/superagent/app/api/agents.py +++ b/libs/superagent/app/api/agents.py @@ -314,7 +314,7 @@ async def send_message( enable_streaming=enable_streaming, output_schema=output_schema, callback=callback, - llm_params=body.llm_params.dict() if body.llm_params else {}, + llm_params=body.llm_params, agent_config=agent_config, ).get_agent() diff --git a/libs/superagent/app/workflows/base.py b/libs/superagent/app/workflows/base.py index b3e51bbbe..bdaeb4976 100644 --- a/libs/superagent/app/workflows/base.py +++ b/libs/superagent/app/workflows/base.py @@ -2,6 +2,7 @@ from typing import Any, List from app.agents.base import AgentBase +from app.utils.prisma import prisma from app.utils.streaming import CustomAsyncIteratorCallbackHandler from prisma.models import Workflow @@ -26,11 +27,22 @@ async def arun(self, input: Any): stepIndex = 0 for step in self.workflow.steps: + agent_config = await prisma.agent.find_unique_or_raise( + where={"id": step.agentId}, + include={ + "llms": {"include": {"llm": True}}, + "datasources": { + "include": {"datasource": {"include": {"vectorDb": True}}} + }, + "tools": {"include": {"tool": True}}, + }, + ) agent = await AgentBase( agent_id=step.agentId, enable_streaming=True, callback=self.callbacks[stepIndex], session_id=self.session_id, + agent_config=agent_config, ).get_agent() task = asyncio.ensure_future(