这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from fastapi.security.api_key import APIKey
from starlette.responses import StreamingResponse

from app.lib.agents import Agent as AgentDefinition
from app.lib.agents.base import AgentBase
from app.lib.agents.factory import AgentFactory
from app.lib.auth.api import get_api_key
from app.lib.auth.prisma import JWTBearer, decodeJWT
from app.lib.models.agent import Agent, PredictAgent
Expand Down Expand Up @@ -167,14 +168,15 @@ def event_stream(data_queue: Queue) -> str:
yield f"data: {data}\n\n"

def conversation_run_thread(input: dict) -> None:
agent_definition = AgentDefinition(
agent_base = AgentBase(
agent=agent,
has_streaming=has_streaming,
on_llm_new_token=on_llm_new_token,
on_llm_end=on_llm_end,
on_chain_end=on_chain_end,
)
agent_executor = agent_definition.get_agent()
agent_strategy = AgentFactory.create_agent(agent_base)
agent_executor = agent_strategy.get_agent()
agent_executor.run(input)

data_queue = Queue()
Expand All @@ -186,8 +188,9 @@ def conversation_run_thread(input: dict) -> None:
return response

else:
agent_definition = AgentDefinition(agent=agent, has_streaming=has_streaming)
agent_executor = agent_definition.get_agent()
agent_base = AgentBase(agent=agent, has_streaming=has_streaming)
agent_strategy = AgentFactory.create_agent(agent_base)
agent_executor = agent_strategy.get_agent()
output = agent_executor.run(input)
prisma.agentmemory.create(
{"author": "AI", "message": output, "agentId": agentId}
Expand Down
Empty file added app/lib/agents/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions app/lib/agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any

import requests
import yaml
from langchain.agents import (
AgentExecutor,
LLMSingleActionAgent,
)
from langchain.agents.agent_toolkits.openapi import planner
from langchain.agents.agent_toolkits.openapi.spec import reduce_openapi_spec
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.conversational_retrieval.prompts import (
CONDENSE_QUESTION_PROMPT,
QA_PROMPT,
)
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.requests import RequestsWrapper

from app.lib.agents.strategy import AgentStrategy
from app.lib.parsers import CustomOutputParser


class DefaultAgent(AgentStrategy):
def __init__(self, agent_base):
self.agent_base = agent_base

def get_agent(self) -> Any:
llm = self.agent_base._get_llm()
memory = self.agent_base._get_memory()
agent = LLMChain(
llm=llm, memory=memory, verbose=True, prompt=self.agent_base._get_prompt()
)

return agent


class DocumentAgent(AgentStrategy):
def __init__(self, agent_base):
self.agent_base = agent_base

def get_agent(self) -> Any:
llm = self.agent_base._get_llm()
memory = self.agent_base._get_memory()
document = self.agent_base._get_document()
question_generator = LLMChain(
llm=OpenAI(temperature=0), prompt=CONDENSE_QUESTION_PROMPT
)
doc_chain = load_qa_chain(
llm, chain_type="stuff", prompt=QA_PROMPT, verbose=True
)
agent = ConversationalRetrievalChain(
retriever=document.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=question_generator,
memory=memory,
get_chat_history=lambda h: h,
)

return agent


class OpenApiDocumentAgent(AgentStrategy):
def __init__(self, agent_base):
self.agent_base = agent_base

def get_agent(self) -> Any:
llm = self.agent_base._get_llm()
document = self.agent_base._get_document()
requests_wrapper = (
RequestsWrapper(
headers={document.authorization["key"]: document.authorization["value"]}
)
if document.authorization
else RequestsWrapper()
)
yaml_response = requests.get(document.url)
content = yaml_response.content
raw_odds_api_spec = yaml.load(content, Loader=yaml.Loader)
odds_api_spec = reduce_openapi_spec(raw_odds_api_spec)
agent = planner.create_openapi_agent(odds_api_spec, requests_wrapper, llm)

return agent


class ToolAgent(AgentStrategy):
def __init__(self, agent_base):
self.agent_base = agent_base

def get_agent(self) -> Any:
llm = self.agent_base._get_llm()
memory = self.agent_base._get_memory()
tools = self.agent_base._get_tool()
output_parser = CustomOutputParser()
tool_names = [tool.name for tool in tools]
llm_chain = LLMChain(llm=llm, prompt=self.agent_base._get_prompt())
agent_config = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names,
)
agent = AgentExecutor.from_agent_and_tools(
agent=agent_config, tools=tools, verbose=True, memory=memory
)

return agent
80 changes: 2 additions & 78 deletions app/lib/agents.py → app/lib/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
from typing import Any

import requests
import yaml
from decouple import config
from langchain.agents import (
AgentExecutor,
LLMSingleActionAgent,
)
from langchain.agents.agent_toolkits.openapi import planner
from langchain.agents.agent_toolkits.openapi.spec import reduce_openapi_spec
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.conversational_retrieval.prompts import (
CONDENSE_QUESTION_PROMPT,
QA_PROMPT,
)
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import Cohere, OpenAI
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.requests import RequestsWrapper
from langchain.vectorstores.pinecone import Pinecone

from app.lib.callbacks import StreamingCallbackHandler
from app.lib.parsers import CustomOutputParser
from app.lib.prisma import prisma
from app.lib.prompts import (
CustomPromptTemplate,
Expand All @@ -34,7 +18,7 @@
from app.lib.tools import get_search_tool


class Agent:
class AgentBase:
def __init__(
self,
agent: dict,
Expand Down Expand Up @@ -246,64 +230,4 @@ def _get_document(self) -> Any:
return None

def get_agent(self) -> Any:
llm = self._get_llm()
memory = self._get_memory()
document = self._get_document()
tools = self._get_tool()

if self.document:
if self.document.type != "OPENAPI":
question_generator = LLMChain(
llm=OpenAI(temperature=0), prompt=CONDENSE_QUESTION_PROMPT
)
doc_chain = load_qa_chain(
llm, chain_type="stuff", prompt=QA_PROMPT, verbose=True
)
agent = ConversationalRetrievalChain(
retriever=document.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=question_generator,
memory=memory,
get_chat_history=lambda h: h,
)

elif self.document.type == "OPENAPI":
requests_wrapper = (
RequestsWrapper(
headers={
self.document.authorization[
"key"
]: self.document.authorization["value"]
}
)
if self.document.authorization
else RequestsWrapper()
)
yaml_response = requests.get(self.document.url)
content = yaml_response.content
raw_odds_api_spec = yaml.load(content, Loader=yaml.Loader)
odds_api_spec = reduce_openapi_spec(raw_odds_api_spec)
agent = planner.create_openapi_agent(
odds_api_spec, requests_wrapper, llm
)

elif self.tool:
output_parser = CustomOutputParser()
tool_names = [tool.name for tool in tools]
llm_chain = LLMChain(llm=llm, prompt=self._get_prompt())
agent_config = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names,
)
agent = AgentExecutor.from_agent_and_tools(
agent=agent_config, tools=tools, verbose=True, memory=memory
)

else:
agent = LLMChain(
llm=llm, memory=memory, verbose=True, prompt=self._get_prompt()
)

return agent
pass
20 changes: 20 additions & 0 deletions app/lib/agents/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from app.lib.agents.agent import (
DefaultAgent,
DocumentAgent,
OpenApiDocumentAgent,
ToolAgent,
)


class AgentFactory:
@staticmethod
def create_agent(agent_base):
if agent_base.document:
if agent_base.document.type != "OPENAPI":
return DocumentAgent(agent_base)
else:
return OpenApiDocumentAgent(agent_base)
elif agent_base.tool:
return ToolAgent(agent_base)
else:
return DefaultAgent(agent_base)
6 changes: 6 additions & 0 deletions app/lib/agents/strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any


class AgentStrategy:
def get_agent(self) -> Any:
pass