这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions libs/superagent/app/agents/llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from typing import Optional

from langchain.schema.messages import AIMessage
from langchain.schema.output import ChatGeneration, LLMResult
from litellm import acompletion

from app.agents.base import AgentBase
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from app.utils.llm import LLM_REVERSE_MAPPING
from app.utils.prisma import prisma
from prisma.enums import AgentType, LLMProvider
Expand Down Expand Up @@ -68,8 +67,9 @@ async def get_agent(self):

class CustomAgentExecutor:
async def ainvoke(self, input, *_, **kwargs):
function_calling_res = {"output": ""}
function_calling_res = {}

print("agent_config.tools", agent_config, input)
if len(agent_config.tools) > 0:
function_calling = await FunctionCalling(
enable_streaming=False,
Expand Down Expand Up @@ -113,28 +113,24 @@ async def ainvoke(self, input, *_, **kwargs):

output = ""
if enable_streaming:
streaming = kwargs["config"]["callbacks"][0]
await streaming.on_llm_start()
streaming_callback = None
for callback in kwargs["config"]["callbacks"]:
if isinstance(callback, CustomAsyncIteratorCallbackHandler):
streaming_callback = callback

if not streaming_callback:
raise Exception("Streaming Callback not found")
await streaming_callback.on_llm_start()

async for chunk in res:
token = chunk.choices[0].delta.content
if token:
output += token
await streaming.on_llm_new_token(token)

await streaming.on_llm_end(
response=LLMResult(
generations=[
[
ChatGeneration(
message=AIMessage(
content=output,
)
)
]
],
)
)
await streaming_callback.on_llm_new_token(token)

streaming_callback.done.set()
else:
output = res.choices[0].message.content

return {
**function_calling_res,
Expand Down
2 changes: 2 additions & 0 deletions libs/superagent/app/api/workflow_configs/saml_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class LLMAgentTool(BaseAgentToolModel, LLMAgent):
SAML_OSS_LLM_PROVIDERS = [
LLMProvider.PERPLEXITY.value,
LLMProvider.TOGETHER_AI.value,
LLMProvider.ANTHROPIC.value,
]


Expand All @@ -143,6 +144,7 @@ class Workflow(BaseModel):
# ~~OSS LLM providers~~
perplexity: Optional[LLMAgent]
together_ai: Optional[LLMAgent]
anthropic: Optional[LLMAgent]
llm: Optional[LLMAgent] = Field(
description="Deprecated! Use LLM providers instead. e.g. `perplexity` or `together_ai`"
)
Expand Down
33 changes: 15 additions & 18 deletions libs/superagent/app/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def always_verbose(self) -> bool:
return True

def __init__(self) -> None:
self.queue = asyncio.Queue()
self.queue = asyncio.Queue(maxsize=5)
self.done = asyncio.Event()

async def on_chat_model_start(
Expand All @@ -44,9 +44,15 @@ async def on_llm_start(self) -> None:

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # noqa
if token is not None and token != "":
self.queue.put_nowait(token)

async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: # noqa
has_put = False
while not has_put:
try:
await self.queue.put(token)
has_put = True
except asyncio.QueueFull:
continue

async def on_llm_end(self, response, **kwargs: Any) -> None: # noqa
# TODO:
# This should be removed when Langchain has merged
# https://github.com/langchain-ai/langchain/pull/9536
Expand All @@ -60,36 +66,27 @@ async def on_llm_error(self, *args: Any, **kwargs: Any) -> None: # noqa

async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
# Wait for the next token in the queue,
# but stop waiting if the done event is set
done, other = await asyncio.wait(
done, pending = await asyncio.wait(
[
# NOTE: If you add other tasks here, update the code below,
# which assumes each set has exactly one task each
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
timeout=self.TIMEOUT_SECONDS,
)
# if we the timeout has been reached
if not done or not other:
if not done:
logger.warning(f"{self.TIMEOUT_SECONDS} seconds of timeout reached")
self.done.set()
break

# Cancel the other task
if other:
other.pop().cancel()
for future in pending:
future.cancel()

# Extract the value of the first completed task
token_or_done = cast(Union[str, Literal[True]], done.pop().result())

# If the extracted value is the boolean True, the done event was set
if token_or_done is True:
break
continue

# Otherwise, the extracted value is a token, which we yield
yield token_or_done


Expand Down
8 changes: 4 additions & 4 deletions libs/superagent/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterEnum
ALTER TYPE "LLMProvider" ADD VALUE 'ANTHROPIC';
1 change: 1 addition & 0 deletions libs/superagent/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum LLMProvider {
HUGGINGFACE
PERPLEXITY
TOGETHER_AI
ANTHROPIC
}

enum LLMModel {
Expand Down
2 changes: 1 addition & 1 deletion libs/superagent/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ openai = "^1.1.1"
langchain-experimental = "^0.0.37"
pydub = "^0.25.1"
algoliasearch = "^3.0.0"
litellm = "^1.14.1"
litellm = "^1.29.4"
weaviate-client = "^3.25.3"
qdrant-client = "^1.6.9"
vecs = "^0.4.2"
Expand Down
1 change: 1 addition & 0 deletions libs/ui/app/workflows/[id]/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ export default function Chat({
resetState()
},
async onmessage(event) {
console.log(event)
if (event.id) currentEventId = event.id

if (event.event === "function_call") {
Expand Down
13 changes: 13 additions & 0 deletions libs/ui/config/site.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,19 @@ export const siteConfig = {
},
],
},
{
disabled: false,
formDescription: "Please enter your Anthropic API key.",
provider: "ANTHROPIC",
name: "Anthropic",
metadata: [
{
key: "apiKey",
type: "input",
label: "Anthropic API Key",
},
],
},
{
disabled: true,
formDescription: "Please enter your HF API key.",
Expand Down