+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions lib/sycamore/sycamore/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
model_name: Union[AnthropicModels, str],
default_mode: LLMMode = LLMMode.ASYNC,
cache: Optional[Cache] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):

# We import this here so we can share utility code with the Bedrock
Expand All @@ -128,13 +129,18 @@ def __init__(

self._client = AnthropicClient()
self._async_client = AsyncAnthropicClient()
super().__init__(self.model.value, default_mode, cache)
super().__init__(self.model.value, default_mode, cache, default_llm_kwargs=default_llm_kwargs)

def __reduce__(self):
def deserializer(kwargs):
return Anthropic(**kwargs)

kwargs = {"model_name": self.model_name, "cache": self._cache, "default_mode": self._default_mode}
kwargs = {
"model_name": self.model_name,
"cache": self._cache,
"default_mode": self._default_mode,
"default_llm_kwargs": self._default_llm_kwargs,
}
return deserializer, (kwargs,)

def default_mode(self) -> LLMMode:
Expand Down Expand Up @@ -165,6 +171,8 @@ def _metadata_from_response(self, kwargs, response, starttime) -> dict:
return ret

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)

ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret
Expand All @@ -186,6 +194,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None)
async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
from anthropic import RateLimitError, APIConnectionError

llm_kwargs = self._merge_llm_kwargs(llm_kwargs)

ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret["output"]
Expand Down Expand Up @@ -215,6 +225,8 @@ def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request

llm_kwargs = self._merge_llm_kwargs(llm_kwargs)

cache_hits = [self._llm_cache_get(p, llm_kwargs) for p in prompts]

calls = []
Expand Down
7 changes: 5 additions & 2 deletions lib/sycamore/sycamore/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
self,
model_name: Union[BedrockModels, str],
cache: Optional[Cache] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):
import boto3

Expand All @@ -37,13 +38,13 @@ def __init__(
self.model = BedrockModel(name=model_name)

self._client = boto3.client(service_name="bedrock-runtime")
super().__init__(self.model.name, default_mode=LLMMode.SYNC, cache=cache)
super().__init__(self.model.name, default_mode=LLMMode.SYNC, cache=cache, default_llm_kwargs=default_llm_kwargs)

def __reduce__(self):
def deserializer(kwargs):
return Bedrock(**kwargs)

kwargs = {"model_name": self.model_name, "cache": self._cache}
kwargs = {"model_name": self.model_name, "cache": self._cache, "default_llm_kwargs": self._default_llm_kwargs}
return deserializer, (kwargs,)

def is_chat_mode(self) -> bool:
Expand All @@ -56,6 +57,8 @@ def format_image(self, image: Image.Image) -> dict[str, Any]:
raise NotImplementedError("Images not supported for non-Anthropic Bedrock models.")

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)

ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
print(f"cache return {ret}")
Expand Down
16 changes: 13 additions & 3 deletions lib/sycamore/sycamore/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
default_mode: LLMMode = LLMMode.ASYNC,
cache: Optional[Cache] = None,
api_key: Optional[str] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):
from google.genai import Client

Expand All @@ -42,13 +43,18 @@ def __init__(
self.model = GeminiModel(name=model_name)
api_key = api_key if api_key else os.getenv("GEMINI_API_KEY")
self._client = Client(api_key=api_key)
super().__init__(self.model.name, default_mode, cache)
super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs)

def __reduce__(self):
def deserializer(kwargs):
return Gemini(**kwargs)

kwargs = {"model_name": self.model_name, "cache": self._cache, "default_mode": self._default_mode}
kwargs = {
"model_name": self.model_name,
"cache": self._cache,
"default_mode": self._default_mode,
"default_llm_kwargs": self._default_llm_kwargs,
}
return deserializer, (kwargs,)

def default_mode(self) -> LLMMode:
Expand Down Expand Up @@ -106,7 +112,7 @@ def _metadata_from_response(self, kwargs, response, starttime) -> dict:

reason = response.candidates[0].finish_reason
if reason != FinishReason.STOP:
logger.warn(f"Gemini model stopped for unexpected reason {reason}. Full response:\n{response}")
logger.warning(f"Gemini model stopped for unexpected reason {reason}. Full response:\n{response}")
ret = {
"output": output,
"wall_latency": wall_latency,
Expand All @@ -117,6 +123,8 @@ def _metadata_from_response(self, kwargs, response, starttime) -> dict:
return ret

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)

ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret
Expand All @@ -137,6 +145,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None)
return d["output"]

async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)

ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret["output"]
Expand Down
32 changes: 29 additions & 3 deletions lib/sycamore/sycamore/llms/llms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from abc import ABC, abstractmethod
import copy
from enum import Enum
import logging
import pickle
import base64
from PIL import Image
Expand All @@ -23,15 +25,32 @@ class LLMMode(Enum):
class LLM(ABC):
"""Abstract representation of an LLM instance. and should be subclassed to implement specific LLM providers."""

def __init__(self, model_name, default_mode: LLMMode, cache: Optional[Cache] = None):
def __init__(
self,
model_name,
default_mode: LLMMode,
cache: Optional[Cache] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):
self._model_name = model_name
self._cache = cache
self._default_mode = default_mode
self._default_llm_kwargs = default_llm_kwargs or {}

def default_mode(self) -> LLMMode:
"""Returns the default execution mode for the llm"""
return self._default_mode

def _merge_llm_kwargs(self, llm_kwargs: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""Merges the default LLM kwargs with any provided LLM kwargs.

Prefers the passed in values if there is a conflict.
"""
new_kwargs = copy.copy(self._default_llm_kwargs)
new_kwargs.update(llm_kwargs or {})
logging.debug(f"Merging LLM kwargs: {new_kwargs}")
return new_kwargs

@abstractmethod
def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
"""Generates a response from the LLM for the given prompt and LLM parameters."""
Expand Down Expand Up @@ -210,8 +229,15 @@ def add_llm_metadata(self, kwargs, output, wall_latency, in_tokens, out_tokens):
class FakeLLM(LLM):
"""Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable"""

def __init__(self, *, return_value="trivial", cache: Optional[Cache] = None, default_mode: LLMMode = LLMMode.SYNC):
super().__init__("trivial", cache=cache, default_mode=default_mode)
def __init__(
self,
*,
return_value="trivial",
cache: Optional[Cache] = None,
default_mode: LLMMode = LLMMode.SYNC,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):
super().__init__("trivial", cache=cache, default_mode=default_mode, default_llm_kwargs=default_llm_kwargs)
self._return_value = return_value

def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
Expand Down
7 changes: 6 additions & 1 deletion lib/sycamore/sycamore/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
params: Optional[OpenAIClientParameters] = None,
default_mode: LLMMode = LLMMode.ASYNC,
cache: Optional[Cache] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
if isinstance(model_name, OpenAIModels):
Expand All @@ -252,7 +253,7 @@ def __init__(
if self.model.name == OpenAIModels.TEXT_DAVINCI.value.name:
logger.warning("text-davinci-003 is deprecated. Falling back to gpt-3.5-turbo-instruct")
self.model = OpenAIModels.GPT_3_5_TURBO_INSTRUCT.value
super().__init__(self.model.name, default_mode, cache)
super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs)

# This is somewhat complex to provide a degree of backward compatibility.
if client_wrapper is None:
Expand All @@ -279,6 +280,7 @@ def __reduce__(self):
"model_name": self.model,
"cache": self._cache,
"default_mode": self._default_mode,
"default_llm_kwargs": self._default_llm_kwargs,
}

return openai_deserializer, (kwargs,)
Expand Down Expand Up @@ -352,6 +354,7 @@ def _get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict
return kwargs

def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)
llm_kwargs = self._convert_response_format(llm_kwargs)
ret = self._llm_cache_get(prompt, llm_kwargs)
if ret is not None:
Expand Down Expand Up @@ -411,6 +414,7 @@ def _generate_using_openai_structured(self, prompt: RenderedPrompt, llm_kwargs:
raise e

async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)
ret = self._llm_cache_get(prompt, llm_kwargs)
if ret is not None:
return ret
Expand Down Expand Up @@ -488,6 +492,7 @@ async def _generate_awaitable_using_openai_structured(
raise e

def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None) -> list[str]:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)
cache_hits = [self._llm_cache_get(p, llm_kwargs) for p in prompts]

calls = []
Expand Down
11 changes: 11 additions & 0 deletions lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,14 @@ def test_metadata():
assert "wall_latency" in res
assert "in_tokens" in res
assert "out_tokens" in res


def test_default_llm_kwargs():
llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU, default_llm_kwargs={"max_tokens": 5})

res = llm.generate_metadata(
prompt=RenderedPrompt(
messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")]
)
)
assert res["out_tokens"] <= 5
10 changes: 10 additions & 0 deletions lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,13 @@ def test_metadata():
assert "server_latency" in res
assert "in_tokens" in res
assert "out_tokens" in res


def test_default_llm_kwargs():
llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU, default_llm_kwargs={"max_tokens": 5})
res = llm.generate_metadata(
prompt=RenderedPrompt(
messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")]
)
)
assert res["out_tokens"] <= 5
11 changes: 11 additions & 0 deletions lib/sycamore/sycamore/tests/integration/llms/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,14 @@ def test_metadata():
assert "wall_latency" in res
assert "in_tokens" in res
assert "out_tokens" in res


def test_default_llm_kwargs():
llm = Gemini(GeminiModels.GEMINI_2_FLASH_LITE, default_llm_kwargs={"max_output_tokens": 5})
res = llm.generate_metadata(
prompt=RenderedPrompt(
messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")]
),
llm_kwargs={},
)
assert res["out_tokens"] <= 5
14 changes: 14 additions & 0 deletions lib/sycamore/sycamore/tests/integration/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from typing import Any

from sycamore.functions.tokenizer import OpenAITokenizer
from sycamore.llms.openai import OpenAI, OpenAIModels, OpenAIClientWrapper
from sycamore.llms.openai import OpenAIModel, OpenAIClientType
from sycamore.llms.prompts import RenderedPrompt, RenderedMessage, StaticPrompt
Expand Down Expand Up @@ -226,6 +227,19 @@ def test_openai_defaults_guidance_instruct():
assert len(res) > 0


def test_default_llm_kwargs():
llm = OpenAI(OpenAIModels.GPT_4O_MINI, default_llm_kwargs={"max_tokens": 5})

res = llm.generate(
prompt=RenderedPrompt(
messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")]
)
)

num_tokens = len(OpenAITokenizer(OpenAIModels.GPT_4O_MINI.value.name).tokenize(res))
assert num_tokens <= 5, f"Expected max_tokens to be 5, but got {num_tokens} tokens in the response: {res}"


@pytest.fixture(scope="module")
def azure_llm():
# Note this deployment name is different from the official model name, which has a '.'
Expand Down
7 changes: 7 additions & 0 deletions lib/sycamore/sycamore/tests/unit/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def test_default_llm_mode():
assert async_llm.default_mode() == LLMMode.ASYNC


def test_merge_llm_kwargs():
llm = FakeLLM(default_llm_kwargs={"temperature": 0.5, "max_tokens": 100})
llm_kwargs = {"thinking_config": {"token_budget": 1000}, "max_tokens": 500}
merged_kwargs = llm._merge_llm_kwargs(llm_kwargs)
assert merged_kwargs == {"temperature": 0.5, "max_tokens": 500, "thinking_config": {"token_budget": 1000}}


@patch("boto3.client")
def test_get_llm(mock_boto3_client):
assert isinstance(get_llm("openai." + OpenAIModels.TEXT_DAVINCI.value.name)(), OpenAI)
Expand Down
4 changes: 3 additions & 1 deletion lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
from sycamore.llms.prompts.prompts import RenderedMessage
from sycamore.transforms.base_llm import LLMMap, LLMMapElements
import pytest
from typing import Optional
from typing import Any, Optional


class FakeLLM(LLM):
def __init__(self, default_mode: LLMMode = LLMMode.SYNC):
super().__init__(model_name="dummy", default_mode=default_mode)
self.async_calls = 0
self.used_llm_kwargs: dict[str, Any] = {}

def is_chat_mode(self) -> bool:
return True

def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
self.used_llm_kwargs = self._merge_llm_kwargs(llm_kwargs)
return "".join(m.content for m in prompt.messages)

async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载