这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
170b98f
chore: add common get_tools function
elisalimli Apr 18, 2024
45de98b
chore: improve logging in send_message
elisalimli Apr 18, 2024
3225575
refactor: agent classes
elisalimli Apr 18, 2024
9b25c8d
chore: remove console.log in chat.tsx
elisalimli Apr 18, 2024
3ee2776
chore: fix logging
elisalimli Apr 18, 2024
ed8c502
add 'MISTRAL' to LLMProvider enum
elisalimli Apr 18, 2024
1417896
add Mistral to SAML
elisalimli Apr 18, 2024
ff85920
feat(ui): add Mistral integration support
elisalimli Apr 18, 2024
114fbd5
feat(db): add database schemas
elisalimli Apr 18, 2024
d4337d8
feat: add groq to SAML
elisalimli Apr 18, 2024
bc183b5
feat(ui): groq integration
elisalimli Apr 18, 2024
7948610
Merge branch 'feat/mistral' into feat/native-function-calling
elisalimli Apr 19, 2024
f22d816
Merge branch 'feat/groq' into feat/native-function-calling
elisalimli Apr 19, 2024
97e0958
feat(db): add cohere migration
elisalimli Apr 20, 2024
28f9451
feat(saml): add cohere to SAML
elisalimli Apr 20, 2024
c357e75
feat(ui): add cohere integration
elisalimli Apr 20, 2024
183d6da
feat(ui): add cohere integration
elisalimli Apr 20, 2024
bcb1a05
Merge branch 'feat/cohere' into feat/native-function-calling
elisalimli Apr 20, 2024
f99d835
fix: indentation issue in CustomAsyncIteratorCallbackHandler
elisalimli Apr 22, 2024
95206ac
refactor: remove hypen from function name regex
elisalimli Apr 22, 2024
13d26c9
feat: add native function calling
elisalimli Apr 22, 2024
430dddb
refactor: update SAML configuration to use 'browser tool' instead of …
elisalimli Apr 22, 2024
a7c60b4
deps: upgrade litellm from version 1.35.8 to 1.35.21
elisalimli Apr 24, 2024
fcd791d
fix: handle claude 3 haiku output
elisalimli Apr 24, 2024
bf6cd15
fix: passing tool error to LLM instead of moving on
elisalimli Apr 24, 2024
3dec179
refactor: LLMAgent's get_agent method to use native function calling …
elisalimli Apr 29, 2024
a0f4560
feat: add return_direct support in LLMAgent
elisalimli Apr 29, 2024
50b96de
Merge pull request #973 from superagent-ai/feat/native-function-calling
elisalimli Apr 29, 2024
325b00a
refactor
elisalimli Apr 29, 2024
d19296c
refactor LLMagent
elisalimli Apr 29, 2024
cb09d29
Merge branch 'main' into refactor/agent-classes
elisalimli Apr 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 125 additions & 54 deletions libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,177 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

from app.models.request import LLMParams
from langchain.agents import AgentExecutor
from pydantic import BaseModel

from app.models.request import LLMParams as LLMParamsRequest
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from prisma.enums import AgentType
from prisma.models import Agent
from prisma.enums import AgentType, LLMProvider
from prisma.models import LLM, Agent


class LLMParams(BaseModel):
temperature: Optional[float] = 0.1
max_tokens: Optional[int]
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_region_name: Optional[str] = None


DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
"the best of your ability."
)
class LLMData(BaseModel):
llm: LLM
params: LLMParams
model: str


class AgentBase:
class AgentBase(ABC):
_input: str
_messages: list = []
prompt: Any
tools: Any
session_id: str
enable_streaming: bool
output_schema: str
callbacks: List[CustomAsyncIteratorCallbackHandler]
agent_data: Agent
llm_data: LLMData

def __init__(
self,
agent_id: str,
session_id: str = None,
session_id: str,
enable_streaming: bool = False,
output_schema: str = None,
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
llm_data: LLMData = None,
agent_data: Agent = None,
):
self.agent_id = agent_id
self.session_id = session_id
self.enable_streaming = enable_streaming
self.output_schema = output_schema
self.callbacks = callbacks
self.llm_params = llm_params
self.agent_config = agent_config
self.llm_data = llm_data
self.agent_data = agent_data

async def _get_tools(
self,
) -> List:
raise NotImplementedError
@property
def input(self):
return self._input

async def _get_llm(
self,
) -> Any:
raise NotImplementedError
@input.setter
def input(self, value: str):
self._input = value

async def _get_prompt(
@property
def messages(self):
return self._messages

@messages.setter
def messages(self, value: list):
self._messages = value

@property
@abstractmethod
def prompt(self) -> Any:
...

@property
@abstractmethod
def tools(self) -> Any:
...

@abstractmethod
def get_agent(self) -> AgentExecutor:
...


class AgentFactory:
def __init__(
self,
) -> str:
raise NotImplementedError
session_id: str = None,
enable_streaming: bool = False,
output_schema: str = None,
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParamsRequest] = {},
agent_data: Agent = None,
):
self.session_id = session_id
self.enable_streaming = enable_streaming
self.output_schema = output_schema
self.callbacks = callbacks
self.api_llm_params = llm_params
self.agent_data = agent_data

@property
def llm_data(self):
llm = self.agent_data.llms[0].llm
params = self.api_llm_params.dict() if self.api_llm_params else {}

options = {
**(self.agent_data.metadata or {}),
**(llm.options or {}),
**(params),
}

async def _get_memory(self) -> List:
raise NotImplementedError
params = LLMParams(
temperature=options.get("temperature"),
max_tokens=options.get("max_tokens"),
aws_access_key_id=(
options.get("aws_access_key_id")
if llm.provider == LLMProvider.BEDROCK
else None
),
aws_secret_access_key=(
options.get("aws_secret_access_key")
if llm.provider == LLMProvider.BEDROCK
else None
),
aws_region_name=(
options.get("aws_region_name")
if llm.provider == LLMProvider.BEDROCK
else None
),
)

return LLMData(
llm=llm,
params=LLMParams.parse_obj(options),
model=self.agent_data.llmModel or self.agent_data.metadata.get("model"),
)

async def get_agent(self):
if self.agent_config.type == AgentType.OPENAI_ASSISTANT:
if self.agent_data.type == AgentType.OPENAI_ASSISTANT:
from app.agents.openai import OpenAiAssistant

agent = OpenAiAssistant(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
llm_data=self.llm_data,
agent_data=self.agent_data,
)

elif self.agent_config.type == AgentType.LLM:
elif self.agent_data.type == AgentType.LLM:
from app.agents.llm import LLMAgent

agent = LLMAgent(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
llm_data=self.llm_data,
agent_data=self.agent_data,
)

else:
from app.agents.langchain import LangchainAgent

agent = LangchainAgent(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
llm_data=self.llm_data,
agent_data=self.agent_data,
)

return await agent.get_agent()

def get_input(self, input: str, agent_type: AgentType):
agent_input = {
"input": input,
}

if agent_type == AgentType.OPENAI_ASSISTANT:
agent_input = {
"content": input,
}

if agent_type == AgentType.LLM:
agent_input = input

return agent_input
Loading