这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import json
import re
from typing import Any, List

from decouple import config
from langchain.agents import AgentType, initialize_agent
Expand Down Expand Up @@ -63,9 +62,9 @@ def recursive_json_loads(data):
class LangchainAgent(AgentBase):
async def _get_tools(
self,
agent_datasources: List[AgentDatasource],
agent_tools: List[AgentTool],
) -> List:
agent_datasources: list[AgentDatasource],
agent_tools: list[AgentTool],
) -> list:
tools = []
for agent_datasource in agent_datasources:
tool_type = (
Expand Down Expand Up @@ -103,10 +102,15 @@ async def _get_tools(
for agent_tool in agent_tools:
agent_tool_metadata = json.loads(agent_tool.tool.metadata or "{}")

# user id is added to the metadata for superrag tool
agent_tool_metadata = {
"user_id": self.agent_config.apiUserId,
**agent_tool_metadata,
"params": {
**(agent_tool_metadata.get("params", {}) or {}),
# user id is added to the metadata for superrag tool
"user_id": self.agent_config.apiUserId,
# session id is added to the metadata for agent as tool
"session_id": self.session_id,
},
}

tool_info = TOOL_TYPE_MAPPING.get(agent_tool.tool.type)
Expand All @@ -132,17 +136,12 @@ async def _get_tools(
description=agent_tool.tool.description,
metadata=metadata,
args_schema=tool_info["schema"],
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
return_direct=agent_tool.tool.returnDirect,
)
tools.append(tool)
return tools

async def _get_llm(self, llm: LLM, model: str) -> Any:
async def _get_llm(self, llm: LLM, model: str):
llm_params = {
"temperature": 0,
**(self.llm_params.dict() if self.llm_params else {}),
Expand Down Expand Up @@ -193,16 +192,18 @@ async def _get_prompt(self, agent: Agent) -> str:
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
return SystemMessage(content=content)

async def _get_memory(self) -> List:
async def _get_memory(
self,
) -> None | MotorheadMemory | ConversationBufferWindowMemory:
# if memory is already set, in the main agent base class, return it
if not self.session_id:
raise ValueError("Session ID is required to initialize memory")

memory_type = config("MEMORY", "motorhead")
if memory_type == "redis":
memory = ConversationBufferWindowMemory(
chat_memory=RedisChatMessageHistory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
session_id=self.session_id,
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
key_prefix="superagent:",
),
Expand All @@ -213,11 +214,7 @@ async def _get_memory(self) -> List:
)
else:
memory = MotorheadMemory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
session_id=self.session_id,
memory_key="chat_history",
url=config("MEMORY_API_URL"),
return_messages=True,
Expand Down
13 changes: 8 additions & 5 deletions libs/superagent/app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,17 @@ async def invoke(
langfuse_public_key = config("LANGFUSE_PUBLIC_KEY", "")
langfuse_host = config("LANGFUSE_HOST", "https://cloud.langfuse.com")
langfuse_handler = None

session_id = body.sessionId or ""
session_id = f"agt_{agent_id}_{session_id}"

if langfuse_public_key and langfuse_secret_key:
langfuse = Langfuse(
public_key=langfuse_public_key,
secret_key=langfuse_secret_key,
host=langfuse_host,
sdk_integration="Superagent",
)
session_id = f"{agent_id}-{body.sessionId}" if body.sessionId else f"{agent_id}"
trace = langfuse.trace(
id=session_id,
name="Assistant",
Expand All @@ -437,7 +440,7 @@ async def invoke(
agentops_handler = AsyncLangchainCallbackHandler(
api_key=agentops_api_key,
org_key=agentops_org_key,
tags=[agent_id, str(body.sessionId)],
tags=[agent_id, session_id],
)

agent_config = await prisma.agent.find_unique_or_raise(
Expand Down Expand Up @@ -526,7 +529,7 @@ async def send_message(
{
"user_id": api_user.id,
"agent": agent_config,
"session_id": body.sessionId,
"session_id": session_id,
**result,
**vars(cost_callback),
}
Expand Down Expand Up @@ -558,12 +561,12 @@ async def send_message(
streaming_callback.done.set()

logger.info("Invoking agent...")
session_id = body.sessionId
input = body.input
enable_streaming = body.enableStreaming
output_schema = body.outputSchema
cost_callback = CostCalcAsyncHandler(model=model)
streaming_callback = CustomAsyncIteratorCallbackHandler()

agent_base = AgentBase(
agent_id=agent_id,
session_id=session_id,
Expand Down Expand Up @@ -613,7 +616,7 @@ async def send_message(
{
"user_id": api_user.id,
"agent": agent_config,
"session_id": body.sessionId,
"session_id": session_id,
**output,
**vars(cost_callback),
}
Expand Down
6 changes: 4 additions & 2 deletions libs/superagent/app/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ async def invoke(
where={"id": workflow_id},
include={"steps": {"include": {"agent": True}, "order_by": {"order": "asc"}}},
)
session_id = body.sessionId or ""
session_id = f"wf_{workflow_id}_{session_id}"

workflow_steps = []
for workflow_step in workflow_data.steps:
Expand All @@ -201,7 +203,7 @@ async def invoke(
"agent_name": workflow_step.agent.name,
}
session_tracker_handler = get_session_tracker_handler(
workflow_data.id, workflow_step.agent.id, body.sessionId, api_user.id
workflow_data.id, workflow_step.agent.id, session_id, api_user.id
)

if session_tracker_handler:
Expand All @@ -216,7 +218,6 @@ async def invoke(
callbacks.append(v)
workflow_callbacks.append(callbacks)

session_id = body.sessionId
input = body.input
enable_streaming = body.enableStreaming

Expand Down Expand Up @@ -268,6 +269,7 @@ async def send_message() -> AsyncIterable[str]:
raise exception

workflow_result = task.result()

for index, workflow_step in enumerate(workflow_steps):
workflow_step_result = workflow_result.get("steps")[index]

Expand Down
3 changes: 0 additions & 3 deletions libs/superagent/app/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def create_tool(
args_schema: Any,
metadata: Optional[Dict[str, Any]],
return_direct: Optional[bool],
session_id: str = None,
) -> Any:
if metadata:
metadata["sessionId"] = session_id
return tool_class(
name=name,
description=description,
Expand Down
10 changes: 8 additions & 2 deletions libs/superagent/app/tools/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class Agent(BaseTool):
description = "useful for answering questions."

def _run(self, input: str) -> str:
agent_id = self.metadata.get("agentId")
agent_id = self.metadata["agentId"]
params = self.metadata["params"]
session_id = params.get("session_id")

agent_config = prisma.agent.find_unique_or_raise(
where={"id": agent_id},
Expand All @@ -29,6 +31,7 @@ def _run(self, input: str) -> str:
agent_id,
enable_streaming=False,
agent_config=agent_config,
session_id=session_id,
)

agent = agent_base.get_agent()
Expand All @@ -44,7 +47,9 @@ def _run(self, input: str) -> str:
return result.get("output")

async def _arun(self, input: str) -> str:
agent_id = self.metadata.get("agentId")
agent_id = self.metadata["agentId"]
params = self.metadata["params"]
session_id = params.get("session_id")

agent_config = await prisma.agent.find_unique_or_raise(
where={"id": agent_id},
Expand All @@ -61,6 +66,7 @@ async def _arun(self, input: str) -> str:
agent_id,
enable_streaming=False,
agent_config=agent_config,
session_id=session_id,
)

agent = await agent_base.get_agent()
Expand Down
9 changes: 6 additions & 3 deletions libs/superagent/app/tools/superrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def _run(
question: str,
) -> str:
"""Use the tool."""
pass
raise NotImplementedError(
"Sync run not implemented for SuperRag tool. Use async run."
)

async def _arun(
self,
Expand All @@ -40,15 +42,16 @@ async def _arun(
vector_database = self.metadata.get("vector_database")
interpreter_mode = self.metadata.get("interpreter_mode")

api_user_id = self.metadata.get("user_id")
params = self.metadata.get("params")
user_id = params.get("user_id")

# with lower case e.g. pinecone, qdrant
database_provider = vector_database.get("type").lower()

provider = await prisma.vectordb.find_first(
where={
"provider": VECTOR_DB_MAPPING.get(database_provider),
"apiUserId": api_user_id,
"apiUserId": user_id,
}
)

Expand Down
7 changes: 2 additions & 5 deletions libs/superagent/app/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _calculate_cost_per_token(
def get_session_tracker_handler(
workflow_id,
agent_id,
req_session_id,
session_id,
user_id,
):
langfuse_secret_key = config("LANGFUSE_SECRET_KEY", "")
Expand All @@ -150,11 +150,8 @@ def get_session_tracker_handler(
host=langfuse_host,
sdk_integration="Superagent",
)
trace_id = (
f"{workflow_id}-{req_session_id}" if req_session_id else f"{workflow_id}"
)
trace = langfuse.trace(
id=trace_id,
id=session_id,
name="Workflow",
tags=[agent_id],
metadata={"agentId": agent_id, "workflowId": workflow_id},
Expand Down