+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion lib/sycamore/sycamore/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from sycamore.llms.llms import LLM
from sycamore.llms.openai import OpenAI, OpenAIClientType, OpenAIModels, OpenAIClientParameters, OpenAIClientWrapper
from sycamore.llms.bedrock import Bedrock, BedrockModels

__all__ = ["LLM", "OpenAI", "OpenAIClientType", "OpenAIModels", "OpenAIClientParameters", "OpenAIClientWrapper"]
__all__ = [
"LLM",
"OpenAI",
"OpenAIClientType",
"OpenAIModels",
"OpenAIClientParameters",
"OpenAIClientWrapper",
"Bedrock",
"BedrockModels",
]
100 changes: 100 additions & 0 deletions lib/sycamore/sycamore/llms/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from dataclasses import dataclass
from enum import Enum
import boto3
import json
from typing import Dict, Optional, Union


from sycamore.llms.llms import LLM
from sycamore.utils.cache import Cache

DEFAULT_MAX_TOKENS = 1000
DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"


@dataclass
class BedrockModel:
name: str
is_chat: bool = False


class BedrockModels(Enum):
"""Represents available Bedrock models."""

# Note that the models available on a given Bedrock account may vary.
CLAUDE_3_HAIKU = BedrockModel(name="anthropic.claude-3-haiku-20240307-v1:0", is_chat=True)
CLAUDE_3_SONNET = BedrockModel(name="anthropic.claude-3-sonnet-20240229-v1:0", is_chat=True)
CLAUDE_3_OPUS = BedrockModel(name="anthropic.claude-3-opus-20240229-v1:0", is_chat=True)
CLAUDE_3_5_SONNET = BedrockModel(name="anthropic.claude-3-5-sonnet-20240620-v1:0", is_chat=True)

@classmethod
def from_name(cls, name: str):
for m in iter(cls):
if m.value.name == name:
return m
return None


class Bedrock(LLM):
"""This is an LLM implementation that uses the AWS Bedrock API to generate text.

Args:
model_name: The name of the Bedrock model to use.
cache: A cache object to use for caching results.
"""

def __init__(
self,
model_name: Union[BedrockModels, str],
cache: Optional[Cache] = None,
):
if isinstance(model_name, BedrockModels):
self.model = model_name.value
elif isinstance(model_name, str):
self.model = BedrockModel(name=model_name)

self._client = boto3.client(service_name="bedrock-runtime")
super().__init__(self.model.name, cache)

def is_chat_mode(self) -> bool:
"""Returns True if the LLM is in chat mode, False otherwise."""
return True

def _get_generate_kwargs(self, prompt_kwargs: Dict, llm_kwargs: Optional[Dict] = None) -> Dict:
kwargs = {
"temperature": 0,
**(llm_kwargs or {}),
}
if self._model_name.startswith("anthropic."):
kwargs["anthropic_version"] = kwargs.get("anthropic_version", DEFAULT_ANTHROPIC_VERSION)
kwargs["max_tokens"] = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS)

if "prompt" in prompt_kwargs:
prompt = prompt_kwargs.get("prompt")
kwargs.update({"messages": [{"role": "user", "content": f"{prompt}"}]})
elif "messages" in prompt_kwargs:
kwargs.update({"messages": prompt_kwargs["messages"]})
else:
raise ValueError("Either prompt or messages must be present in prompt_kwargs.")
return kwargs

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
key, ret = self._cache_get(prompt_kwargs, llm_kwargs)
if ret is not None:
return ret

kwargs = self._get_generate_kwargs(prompt_kwargs, llm_kwargs)
body = json.dumps(kwargs)
response = self._client.invoke_model(
body=body, modelId=self.model.name, accept="application/json", contentType="application/json"
)
response_body = json.loads(response.get("body").read())
ret = response_body.get("content", {})[0].get("text", "")
value = {
"result": ret,
"prompt_kwargs": prompt_kwargs,
"llm_kwargs": llm_kwargs,
"model_name": self.model.name,
}
self._cache_set(key, value)
return ret
56 changes: 43 additions & 13 deletions lib/sycamore/sycamore/llms/llms.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,66 @@
from abc import ABC, abstractmethod
from typing import Optional
import pickle
from typing import Optional, Tuple

from sycamore.utils.cache import Cache


class LLM(ABC):
"""
Initializes a new LLM instance. This class is abstract and should be subclassed to implement specific LLM providers.
"""
"""Abstract representation of an LLM instance. and should be subclassed to implement specific LLM providers."""

def __init__(self, model_name, cache: Optional[Cache] = None):
self._model_name = model_name
self._cache = cache

"""
Generates a response from the LLM for the given prompt and LLM parameters.
"""

@abstractmethod
def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
"""Generates a response from the LLM for the given prompt and LLM parameters."""
pass

"""
Returns True if the LLM is in chat mode, False otherwise.
"""

@abstractmethod
def is_chat_mode(self) -> bool:
"""Returns True if the LLM is in chat mode, False otherwise."""
pass

async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
raise ValueError("No implementation for llm futures exists")
"""Generates a response from the LLM for the given prompt and LLM parameters asynchronously."""
raise NotImplementedError("This LLM does not support asynchronous generation.")

def _get_cache_key(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
"""Return a cache key for the given prompt and LLM parameters."""
assert self._cache
combined = {"prompt_kwargs": prompt_kwargs, "llm_kwargs": llm_kwargs, "model_name": self._model_name}
data = pickle.dumps(combined)
return self._cache.get_hash_context(data).hexdigest()

def _cache_get(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> Tuple[Optional[str], Optional[str]]:
"""Get a cached result for the given prompt and LLM parameters. Returns the cache key
and the cached result if found, otherwise returns None for both."""
if (llm_kwargs or {}).get("temperature", 0) != 0 or not self._cache:
# Never cache when temperature setting is nonzero.
return (None, None)

key = self._get_cache_key(prompt_kwargs, llm_kwargs)
hit = self._cache.get(key)
if hit:
assert (
hit.get("prompt_kwargs") == prompt_kwargs
and hit.get("llm_kwargs") == llm_kwargs
and hit.get("model_name") == self._model_name
), f"""
Found LLM cache content mismatch:
key={key}
prompt_kwargs={prompt_kwargs}, cached={hit.get("prompt_kwargs")}
llm_kwargs={llm_kwargs}, cached={hit.get("llm_kwargs")}
model_name={self._model_name}, cached={hit.get("model_name")}"""
return (key, hit.get("result"))
return (key, None)

def _cache_set(self, key, result):
"""Set a cached result for the given key."""
if key is None or not self._cache:
return
self._cache.set(key, result)


class FakeLLM(LLM):
Expand Down
67 changes: 29 additions & 38 deletions lib/sycamore/sycamore/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import inspect
import logging
import os
import pickle
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, TypedDict, Union, cast, TYPE_CHECKING
from typing import Any, Dict, Optional, TypedDict, Union, cast, TYPE_CHECKING

from openai import AzureOpenAI as AzureOpenAIClient
from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient
Expand Down Expand Up @@ -299,10 +298,25 @@ def openai_deserializer(kwargs):


class OpenAI(LLM):
"""An LLM interface to OpenAI models.

Args:
model_name: The name of the OpenAI model to use. This can be an instance of OpenAIModels, an instance of
OpenAIModel, or a string. If a string is provided, it must be the name of the model.
api_key: The API key to use for the OpenAI client. If not provided, the key will be read from the
OPENAI_API_KEY environment variable.
client_wrapper: An instance of OpenAIClientWrapper to use for the OpenAI client. If not provided, a new
instance will be created using the provided parameters.
params: An instance of OpenAIClientParameters to use for the OpenAI client. If not provided, a new instance
will be created using the provided parameters.
cache: An instance of Cache to use for caching responses. If not provided, no caching will be used.
**kwargs: Additional parameters to pass to the OpenAI client.
"""

def __init__(
self,
model_name: Union[OpenAIModels, OpenAIModel, str],
api_key=None,
api_key: Optional[str] = None,
client_wrapper: Optional[OpenAIClientWrapper] = None,
params: Optional[OpenAIClientParameters] = None,
cache: Optional[Cache] = None,
Expand Down Expand Up @@ -349,42 +363,18 @@ def __reduce__(self):
def is_chat_mode(self):
return self.model.is_chat

def _get_cache_key(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
assert self._cache
combined = {"prompt_kwargs": prompt_kwargs, "llm_kwargs": llm_kwargs, "model_name": self.model.name}
data = pickle.dumps(combined)
return self._cache.get_hash_context(data).hexdigest()

def _cache_get(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None):
if (llm_kwargs or {}).get("temperature", 0) != 0 or not self._cache:
# Never cache when temperature setting is nonzero.
return (None, None)

response_format = (llm_kwargs or {}).get("response_format")
def _convert_response_format(self, llm_kwargs: Optional[Dict]) -> Optional[Dict]:
"""Convert the response_format parameter to the appropriate OpenAI format."""
if llm_kwargs is None:
return None
response_format = llm_kwargs.get("response_format")
if response_format is None:
return llm_kwargs
if inspect.isclass(response_format) and issubclass(response_format, pydantic.BaseModel):
assert llm_kwargs
llm_kwargs["response_format"] = type_to_response_format_param(response_format)

key = self._get_cache_key(prompt_kwargs, llm_kwargs)
hit = self._cache.get(key)
if hit:
assert (
hit.get("prompt_kwargs") == prompt_kwargs
and hit.get("llm_kwargs") == llm_kwargs
and hit.get("model_name") == self.model.name
), f"""
Found cache content mismatch:
key={key}
prompt_kwargs={prompt_kwargs}, cached={hit.get("prompt_kwargs")}
llm_kwargs={llm_kwargs}, cached={hit.get("llm_kwargs")}
model_name={self.model.name}, cached={hit.get("model_name")}"""
return (key, hit.get("result"))
return (key, None)

def _cache_set(self, key, result):
if key is None or not self._cache:
return
self._cache.set(key, result)
retval = llm_kwargs.copy()
retval["response_format"] = type_to_response_format_param(response_format)
return retval
return llm_kwargs

def _get_generate_kwargs(self, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> dict:
kwargs = {
Expand Down Expand Up @@ -412,6 +402,7 @@ def _determine_using_beta(self, response_format: Any) -> bool:
return False

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
llm_kwargs = self._convert_response_format(llm_kwargs)
key, ret = self._cache_get(prompt_kwargs, llm_kwargs)
if ret is not None:
return ret
Expand Down
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载