这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def read_agents(token=Depends(JWTBearer())):
include={
"user": True,
"document": True,
"prompt": True,
},
order={"createdAt": "desc"},
)
Expand Down Expand Up @@ -123,7 +124,7 @@ async def run_agent(
input["chat_history"] = []
has_streaming = body.has_streaming
agent = prisma.agent.find_unique(
where={"id": agentId}, include={"user": True, "document": True}
where={"id": agentId}, include={"user": True, "document": True, "prompt": True}
)

prisma.agentmemory.create(
Expand Down
19 changes: 17 additions & 2 deletions app/lib/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import Cohere, OpenAI
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.pinecone import Pinecone

from app.lib.callbacks import StreamingCallbackHandler
Expand All @@ -32,7 +33,7 @@ def __init__(
self.has_memory = agent.hasMemory
self.type = agent.type
self.llm = agent.llm
self.prompt = default_chat_prompt
self.prompt = agent.prompt
self.has_streaming = has_streaming
self.on_llm_new_token = on_llm_new_token
self.on_llm_end = on_llm_end
Expand Down Expand Up @@ -60,6 +61,18 @@ def _get_api_key(self) -> str:
else config("COHERE_API_KEY")
)

def _get_prompt(self) -> Any:
if self.prompt:
prompt = PromptTemplate(
input_variables=self.prompt.input_variables,
template=self.prompt.template,
)

return prompt

else:
return default_chat_prompt

def _get_llm(self) -> Any:
if self.llm["provider"] == "openai-chat":
return (
Expand Down Expand Up @@ -174,6 +187,8 @@ def get_agent(self) -> Any:
)

else:
agent = LLMChain(llm=llm, memory=memory, verbose=True, prompt=self.prompt)
agent = LLMChain(
llm=llm, memory=memory, verbose=True, prompt=self._get_prompt()
)

return agent