diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index dc8a290b0..de436185d 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -1,11 +1,10 @@ import logging from typing import Optional -from langchain.schema.messages import AIMessage -from langchain.schema.output import ChatGeneration, LLMResult from litellm import acompletion from app.agents.base import AgentBase +from app.utils.callbacks import CustomAsyncIteratorCallbackHandler from app.utils.llm import LLM_REVERSE_MAPPING from app.utils.prisma import prisma from prisma.enums import AgentType, LLMProvider @@ -68,8 +67,9 @@ async def get_agent(self): class CustomAgentExecutor: async def ainvoke(self, input, *_, **kwargs): - function_calling_res = {"output": ""} + function_calling_res = {} + print("agent_config.tools", agent_config, input) if len(agent_config.tools) > 0: function_calling = await FunctionCalling( enable_streaming=False, @@ -113,28 +113,24 @@ async def ainvoke(self, input, *_, **kwargs): output = "" if enable_streaming: - streaming = kwargs["config"]["callbacks"][0] - await streaming.on_llm_start() + streaming_callback = None + for callback in kwargs["config"]["callbacks"]: + if isinstance(callback, CustomAsyncIteratorCallbackHandler): + streaming_callback = callback + + if not streaming_callback: + raise Exception("Streaming Callback not found") + await streaming_callback.on_llm_start() async for chunk in res: token = chunk.choices[0].delta.content if token: output += token - await streaming.on_llm_new_token(token) - - await streaming.on_llm_end( - response=LLMResult( - generations=[ - [ - ChatGeneration( - message=AIMessage( - content=output, - ) - ) - ] - ], - ) - ) + await streaming_callback.on_llm_new_token(token) + + streaming_callback.done.set() + else: + output = res.choices[0].message.content return { **function_calling_res, diff --git a/libs/superagent/app/api/workflow_configs/saml_schema.py b/libs/superagent/app/api/workflow_configs/saml_schema.py index 8b3a3ad86..257c5691c 100644 --- a/libs/superagent/app/api/workflow_configs/saml_schema.py +++ b/libs/superagent/app/api/workflow_configs/saml_schema.py @@ -134,6 +134,7 @@ class LLMAgentTool(BaseAgentToolModel, LLMAgent): SAML_OSS_LLM_PROVIDERS = [ LLMProvider.PERPLEXITY.value, LLMProvider.TOGETHER_AI.value, + LLMProvider.ANTHROPIC.value, ] @@ -143,6 +144,7 @@ class Workflow(BaseModel): # ~~OSS LLM providers~~ perplexity: Optional[LLMAgent] together_ai: Optional[LLMAgent] + anthropic: Optional[LLMAgent] llm: Optional[LLMAgent] = Field( description="Deprecated! Use LLM providers instead. e.g. `perplexity` or `together_ai`" ) diff --git a/libs/superagent/app/utils/callbacks.py b/libs/superagent/app/utils/callbacks.py index d462f6f8c..01daf67b3 100644 --- a/libs/superagent/app/utils/callbacks.py +++ b/libs/superagent/app/utils/callbacks.py @@ -27,7 +27,7 @@ def always_verbose(self) -> bool: return True def __init__(self) -> None: - self.queue = asyncio.Queue() + self.queue = asyncio.Queue(maxsize=5) self.done = asyncio.Event() async def on_chat_model_start( @@ -44,9 +44,15 @@ async def on_llm_start(self) -> None: async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # noqa if token is not None and token != "": - self.queue.put_nowait(token) - - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: # noqa + has_put = False + while not has_put: + try: + await self.queue.put(token) + has_put = True + except asyncio.QueueFull: + continue + + async def on_llm_end(self, response, **kwargs: Any) -> None: # noqa # TODO: # This should be removed when Langchain has merged # https://github.com/langchain-ai/langchain/pull/9536 @@ -60,36 +66,27 @@ async def on_llm_error(self, *args: Any, **kwargs: Any) -> None: # noqa async def aiter(self) -> AsyncIterator[str]: while not self.queue.empty() or not self.done.is_set(): - # Wait for the next token in the queue, - # but stop waiting if the done event is set - done, other = await asyncio.wait( + done, pending = await asyncio.wait( [ - # NOTE: If you add other tasks here, update the code below, - # which assumes each set has exactly one task each asyncio.ensure_future(self.queue.get()), asyncio.ensure_future(self.done.wait()), ], return_when=asyncio.FIRST_COMPLETED, timeout=self.TIMEOUT_SECONDS, ) - # if we the timeout has been reached - if not done or not other: + if not done: logger.warning(f"{self.TIMEOUT_SECONDS} seconds of timeout reached") self.done.set() break - # Cancel the other task - if other: - other.pop().cancel() + for future in pending: + future.cancel() - # Extract the value of the first completed task token_or_done = cast(Union[str, Literal[True]], done.pop().result()) - # If the extracted value is the boolean True, the done event was set if token_or_done is True: - break + continue - # Otherwise, the extracted value is a token, which we yield yield token_or_done diff --git a/libs/superagent/poetry.lock b/libs/superagent/poetry.lock index afe1ac263..7d0812c6b 100644 --- a/libs/superagent/poetry.lock +++ b/libs/superagent/poetry.lock @@ -2259,13 +2259,13 @@ requests = ">=2,<3" [[package]] name = "litellm" -version = "1.27.6" +version = "1.29.4" description = "Library to easily interface with LLM API providers" optional = false python-versions = ">=3.8, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*" files = [ - {file = "litellm-1.27.6-py3-none-any.whl", hash = "sha256:a9b2831065d89b9ea02e8abe07a1e09e280ccebcd8e1a0fcbeba89561174c6fc"}, - {file = "litellm-1.27.6.tar.gz", hash = "sha256:6a9a80fa463549846b7d0af565df593f5c565d2572527998688ba797f6ba69bf"}, + {file = "litellm-1.29.4-py3-none-any.whl", hash = "sha256:014b03fd37864d12acb095511f42bb46b74bf77a0c7086eb5d7d3ea0a27cc238"}, + {file = "litellm-1.29.4.tar.gz", hash = "sha256:14a3e5c5aaa042b2a732374f56260afd7761625d8ee6ac38f6e1de1c5ee5f792"}, ] [package.dependencies] @@ -5927,4 +5927,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1, <3.12" -content-hash = "360ca47b78aa00dd95b3575f5b04d748ad7dfae34aa52e2e8c6eb2e92bfee1f1" +content-hash = "e82f732de8734cd7b820d427a7e56a793eabf9377b33faee7e8d2b613b946b89" diff --git a/libs/superagent/prisma/migrations/20240306005619_add_anthropic/migration.sql b/libs/superagent/prisma/migrations/20240306005619_add_anthropic/migration.sql new file mode 100644 index 000000000..79d6cce25 --- /dev/null +++ b/libs/superagent/prisma/migrations/20240306005619_add_anthropic/migration.sql @@ -0,0 +1,2 @@ +-- AlterEnum +ALTER TYPE "LLMProvider" ADD VALUE 'ANTHROPIC'; diff --git a/libs/superagent/prisma/schema.prisma b/libs/superagent/prisma/schema.prisma index ebfc77782..5306ea927 100644 --- a/libs/superagent/prisma/schema.prisma +++ b/libs/superagent/prisma/schema.prisma @@ -22,6 +22,7 @@ enum LLMProvider { HUGGINGFACE PERPLEXITY TOGETHER_AI + ANTHROPIC } enum LLMModel { diff --git a/libs/superagent/pyproject.toml b/libs/superagent/pyproject.toml index c4d110e83..e00a2b583 100644 --- a/libs/superagent/pyproject.toml +++ b/libs/superagent/pyproject.toml @@ -50,7 +50,7 @@ openai = "^1.1.1" langchain-experimental = "^0.0.37" pydub = "^0.25.1" algoliasearch = "^3.0.0" -litellm = "^1.14.1" +litellm = "^1.29.4" weaviate-client = "^3.25.3" qdrant-client = "^1.6.9" vecs = "^0.4.2" diff --git a/libs/ui/app/workflows/[id]/chat.tsx b/libs/ui/app/workflows/[id]/chat.tsx index bc93cb9be..e1c2ce47a 100644 --- a/libs/ui/app/workflows/[id]/chat.tsx +++ b/libs/ui/app/workflows/[id]/chat.tsx @@ -141,6 +141,7 @@ export default function Chat({ resetState() }, async onmessage(event) { + console.log(event) if (event.id) currentEventId = event.id if (event.event === "function_call") { diff --git a/libs/ui/config/site.ts b/libs/ui/config/site.ts index 3ef3731af..a173cbe0a 100644 --- a/libs/ui/config/site.ts +++ b/libs/ui/config/site.ts @@ -434,6 +434,19 @@ export const siteConfig = { }, ], }, + { + disabled: false, + formDescription: "Please enter your Anthropic API key.", + provider: "ANTHROPIC", + name: "Anthropic", + metadata: [ + { + key: "apiKey", + type: "input", + label: "Anthropic API Key", + }, + ], + }, { disabled: true, formDescription: "Please enter your HF API key.",