From 7131ee2448c19d328dcb879a52b293b6d9aa0fb2 Mon Sep 17 00:00:00 2001 From: Ben Sowell Date: Thu, 15 May 2025 12:33:27 -0700 Subject: [PATCH 1/3] Add support for default llm kwargs for all models. These kwargs can be passed to the constructor of the LLM object and will then be used for each generate call. Arguments passed via the generate_* methods will override the defaults. The integ tests verify that the default arguments do in fact get applied by greatly restricting the max output tokens. --- lib/sycamore/sycamore/llms/anthropic.py | 16 ++++++++-- lib/sycamore/sycamore/llms/bedrock.py | 7 ++-- lib/sycamore/sycamore/llms/gemini.py | 16 ++++++++-- lib/sycamore/sycamore/llms/llms.py | 32 +++++++++++++++++-- lib/sycamore/sycamore/llms/openai.py | 7 +++- .../tests/integration/llms/test_anthropic.py | 11 +++++++ .../tests/integration/llms/test_bedrock.py | 10 ++++++ .../tests/integration/llms/test_gemini.py | 11 +++++++ .../tests/integration/llms/test_openai.py | 14 ++++++++ .../sycamore/tests/unit/llms/test_llms.py | 7 ++++ .../tests/unit/transforms/test_base_llm.py | 2 ++ 11 files changed, 122 insertions(+), 11 deletions(-) diff --git a/lib/sycamore/sycamore/llms/anthropic.py b/lib/sycamore/sycamore/llms/anthropic.py index 4d98a52c9..75b088e9d 100644 --- a/lib/sycamore/sycamore/llms/anthropic.py +++ b/lib/sycamore/sycamore/llms/anthropic.py @@ -109,6 +109,7 @@ def __init__( model_name: Union[AnthropicModels, str], default_mode: LLMMode = LLMMode.ASYNC, cache: Optional[Cache] = None, + default_llm_kwargs: Optional[dict[str, Any]] = None, ): # We import this here so we can share utility code with the Bedrock @@ -128,13 +129,18 @@ def __init__( self._client = AnthropicClient() self._async_client = AsyncAnthropicClient() - super().__init__(self.model.value, default_mode, cache) + super().__init__(self.model.value, default_mode, cache, default_llm_kwargs=default_llm_kwargs) def __reduce__(self): def deserializer(kwargs): return Anthropic(**kwargs) - kwargs = {"model_name": self.model_name, "cache": self._cache, "default_mode": self._default_mode} + kwargs = { + "model_name": self.model_name, + "cache": self._cache, + "default_mode": self._default_mode, + "default_llm_kwargs": self._default_llm_kwargs, + } return deserializer, (kwargs,) def default_mode(self) -> LLMMode: @@ -165,6 +171,8 @@ def _metadata_from_response(self, kwargs, response, starttime) -> dict: return ret def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): return ret @@ -186,6 +194,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: from anthropic import RateLimitError, APIConnectionError + self._merge_llm_kwargs(llm_kwargs) + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): return ret["output"] @@ -215,6 +225,8 @@ def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[ from anthropic.types.message_create_params import MessageCreateParamsNonStreaming from anthropic.types.messages.batch_create_params import Request + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) + cache_hits = [self._llm_cache_get(p, llm_kwargs) for p in prompts] calls = [] diff --git a/lib/sycamore/sycamore/llms/bedrock.py b/lib/sycamore/sycamore/llms/bedrock.py index a372c0cba..99f0bb079 100644 --- a/lib/sycamore/sycamore/llms/bedrock.py +++ b/lib/sycamore/sycamore/llms/bedrock.py @@ -26,6 +26,7 @@ def __init__( self, model_name: Union[BedrockModels, str], cache: Optional[Cache] = None, + default_llm_kwargs: Optional[dict[str, Any]] = None, ): import boto3 @@ -37,13 +38,13 @@ def __init__( self.model = BedrockModel(name=model_name) self._client = boto3.client(service_name="bedrock-runtime") - super().__init__(self.model.name, default_mode=LLMMode.SYNC, cache=cache) + super().__init__(self.model.name, default_mode=LLMMode.SYNC, cache=cache, default_llm_kwargs=default_llm_kwargs) def __reduce__(self): def deserializer(kwargs): return Bedrock(**kwargs) - kwargs = {"model_name": self.model_name, "cache": self._cache} + kwargs = {"model_name": self.model_name, "cache": self._cache, "default_llm_kwargs": self._default_llm_kwargs} return deserializer, (kwargs,) def is_chat_mode(self) -> bool: @@ -56,6 +57,8 @@ def format_image(self, image: Image.Image) -> dict[str, Any]: raise NotImplementedError("Images not supported for non-Anthropic Bedrock models.") def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): print(f"cache return {ret}") diff --git a/lib/sycamore/sycamore/llms/gemini.py b/lib/sycamore/sycamore/llms/gemini.py index 68cd30ace..5696b54ec 100644 --- a/lib/sycamore/sycamore/llms/gemini.py +++ b/lib/sycamore/sycamore/llms/gemini.py @@ -31,6 +31,7 @@ def __init__( default_mode: LLMMode = LLMMode.ASYNC, cache: Optional[Cache] = None, api_key: Optional[str] = None, + default_llm_kwargs: Optional[dict[str, Any]] = None, ): from google.genai import Client @@ -42,13 +43,18 @@ def __init__( self.model = GeminiModel(name=model_name) api_key = api_key if api_key else os.getenv("GEMINI_API_KEY") self._client = Client(api_key=api_key) - super().__init__(self.model.name, default_mode, cache) + super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs) def __reduce__(self): def deserializer(kwargs): return Gemini(**kwargs) - kwargs = {"model_name": self.model_name, "cache": self._cache, "default_mode": self._default_mode} + kwargs = { + "model_name": self.model_name, + "cache": self._cache, + "default_mode": self._default_mode, + "default_llm_kwargs": self._default_llm_kwargs, + } return deserializer, (kwargs,) def default_mode(self) -> LLMMode: @@ -106,7 +112,7 @@ def _metadata_from_response(self, kwargs, response, starttime) -> dict: reason = response.candidates[0].finish_reason if reason != FinishReason.STOP: - logger.warn(f"Gemini model stopped for unexpected reason {reason}. Full response:\n{response}") + logger.warning(f"Gemini model stopped for unexpected reason {reason}. Full response:\n{response}") ret = { "output": output, "wall_latency": wall_latency, @@ -117,6 +123,8 @@ def _metadata_from_response(self, kwargs, response, starttime) -> dict: return ret def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): return ret @@ -137,6 +145,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) return d["output"] async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) + ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): return ret["output"] diff --git a/lib/sycamore/sycamore/llms/llms.py b/lib/sycamore/sycamore/llms/llms.py index e694876b6..0401faf10 100644 --- a/lib/sycamore/sycamore/llms/llms.py +++ b/lib/sycamore/sycamore/llms/llms.py @@ -1,6 +1,8 @@ import inspect from abc import ABC, abstractmethod +import copy from enum import Enum +import logging import pickle import base64 from PIL import Image @@ -23,15 +25,32 @@ class LLMMode(Enum): class LLM(ABC): """Abstract representation of an LLM instance. and should be subclassed to implement specific LLM providers.""" - def __init__(self, model_name, default_mode: LLMMode, cache: Optional[Cache] = None): + def __init__( + self, + model_name, + default_mode: LLMMode, + cache: Optional[Cache] = None, + default_llm_kwargs: Optional[dict[str, Any]] = None, + ): self._model_name = model_name self._cache = cache self._default_mode = default_mode + self._default_llm_kwargs = default_llm_kwargs or {} def default_mode(self) -> LLMMode: """Returns the default execution mode for the llm""" return self._default_mode + def _merge_llm_kwargs(self, llm_kwargs: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """Merges the default LLM kwargs with any provided LLM kwargs. + + Prefers the passed in values if there is a conflict. + """ + new_kwargs = copy.copy(self._default_llm_kwargs) + new_kwargs.update(llm_kwargs or {}) + logging.debug(f"Merging LLM kwargs: {new_kwargs}") + return new_kwargs + @abstractmethod def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: """Generates a response from the LLM for the given prompt and LLM parameters.""" @@ -210,8 +229,15 @@ def add_llm_metadata(self, kwargs, output, wall_latency, in_tokens, out_tokens): class FakeLLM(LLM): """Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable""" - def __init__(self, *, return_value="trivial", cache: Optional[Cache] = None, default_mode: LLMMode = LLMMode.SYNC): - super().__init__("trivial", cache=cache, default_mode=default_mode) + def __init__( + self, + *, + return_value="trivial", + cache: Optional[Cache] = None, + default_mode: LLMMode = LLMMode.SYNC, + default_llm_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__("trivial", cache=cache, default_mode=default_mode, default_llm_kwargs=default_llm_kwargs) self._return_value = return_value def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index 3b47e8d64..091f76f5d 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -238,6 +238,7 @@ def __init__( params: Optional[OpenAIClientParameters] = None, default_mode: LLMMode = LLMMode.ASYNC, cache: Optional[Cache] = None, + default_llm_kwargs: Optional[dict[str, Any]] = None, **kwargs, ): if isinstance(model_name, OpenAIModels): @@ -252,7 +253,7 @@ def __init__( if self.model.name == OpenAIModels.TEXT_DAVINCI.value.name: logger.warning("text-davinci-003 is deprecated. Falling back to gpt-3.5-turbo-instruct") self.model = OpenAIModels.GPT_3_5_TURBO_INSTRUCT.value - super().__init__(self.model.name, default_mode, cache) + super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs) # This is somewhat complex to provide a degree of backward compatibility. if client_wrapper is None: @@ -279,6 +280,7 @@ def __reduce__(self): "model_name": self.model, "cache": self._cache, "default_mode": self._default_mode, + "default_llm_kwargs": self._default_llm_kwargs, } return openai_deserializer, (kwargs,) @@ -352,6 +354,7 @@ def _get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict return kwargs def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) llm_kwargs = self._convert_response_format(llm_kwargs) ret = self._llm_cache_get(prompt, llm_kwargs) if ret is not None: @@ -411,6 +414,7 @@ def _generate_using_openai_structured(self, prompt: RenderedPrompt, llm_kwargs: raise e async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) ret = self._llm_cache_get(prompt, llm_kwargs) if ret is not None: return ret @@ -488,6 +492,7 @@ async def _generate_awaitable_using_openai_structured( raise e def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None) -> list[str]: + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) cache_hits = [self._llm_cache_get(p, llm_kwargs) for p in prompts] calls = [] diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py index 567cecb77..72d34503f 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_anthropic.py @@ -151,3 +151,14 @@ def test_metadata(): assert "wall_latency" in res assert "in_tokens" in res assert "out_tokens" in res + + +def test_default_llm_kwargs(): + llm = Anthropic(AnthropicModels.CLAUDE_3_HAIKU, default_llm_kwargs={"max_tokens": 5}) + + res = llm.generate_metadata( + prompt=RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + ) + assert res["out_tokens"] <= 5 diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py index f1b71fd85..b685455c8 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_bedrock.py @@ -155,3 +155,13 @@ def test_metadata(): assert "server_latency" in res assert "in_tokens" in res assert "out_tokens" in res + + +def test_default_llm_kwargs(): + llm = Bedrock(BedrockModels.CLAUDE_3_HAIKU, default_llm_kwargs={"max_tokens": 5}) + res = llm.generate_metadata( + prompt=RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + ) + assert res["out_tokens"] <= 5 diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_gemini.py b/lib/sycamore/sycamore/tests/integration/llms/test_gemini.py index ed17c784e..12f801ca3 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_gemini.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_gemini.py @@ -151,3 +151,14 @@ def test_metadata(): assert "wall_latency" in res assert "in_tokens" in res assert "out_tokens" in res + + +def test_default_llm_kwargs(): + llm = Gemini(GeminiModels.GEMINI_2_FLASH_LITE, default_llm_kwargs={"max_output_tokens": 5}) + res = llm.generate_metadata( + prompt=RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ), + llm_kwargs={}, + ) + assert res["out_tokens"] <= 5 diff --git a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py index a5fd851af..38308e8b2 100644 --- a/lib/sycamore/sycamore/tests/integration/llms/test_openai.py +++ b/lib/sycamore/sycamore/tests/integration/llms/test_openai.py @@ -4,6 +4,7 @@ import pytest from typing import Any +from sycamore.functions.tokenizer import OpenAITokenizer from sycamore.llms.openai import OpenAI, OpenAIModels, OpenAIClientWrapper from sycamore.llms.openai import OpenAIModel, OpenAIClientType from sycamore.llms.prompts import RenderedPrompt, RenderedMessage, StaticPrompt @@ -226,6 +227,19 @@ def test_openai_defaults_guidance_instruct(): assert len(res) > 0 +def test_default_llm_kwargs(): + llm = OpenAI(OpenAIModels.GPT_4O_MINI, default_llm_kwargs={"max_tokens": 5}) + + res = llm.generate( + prompt=RenderedPrompt( + messages=[RenderedMessage(role="user", content="Write a limerick about large language models.")] + ) + ) + + num_tokens = len(OpenAITokenizer(OpenAIModels.GPT_4O_MINI.value.name).tokenize(res)) + assert num_tokens <= 5, f"Expected max_tokens to be 5, but got {num_tokens} tokens in the response: {res}" + + @pytest.fixture(scope="module") def azure_llm(): # Note this deployment name is different from the official model name, which has a '.' diff --git a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py index c35658644..5c168b05c 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py +++ b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py @@ -59,6 +59,13 @@ def test_default_llm_mode(): assert async_llm.default_mode() == LLMMode.ASYNC +def test_merge_llm_kwargs(): + llm = FakeLLM(default_llm_kwargs={"temperature": 0.5, "max_tokens": 100}) + llm_kwargs = {"thinking_config": {"token_budget": 1000}, "max_tokens": 500} + merged_kwargs = llm._merge_llm_kwargs(llm_kwargs) + assert merged_kwargs == {"temperature": 0.5, "max_tokens": 500, "thinking_config": {"token_budget": 1000}} + + @patch("boto3.client") def test_get_llm(mock_boto3_client): assert isinstance(get_llm("openai." + OpenAIModels.TEXT_DAVINCI.value.name)(), OpenAI) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py index 535a30264..d1cc6f965 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py @@ -11,11 +11,13 @@ class FakeLLM(LLM): def __init__(self, default_mode: LLMMode = LLMMode.SYNC): super().__init__(model_name="dummy", default_mode=default_mode) self.async_calls = 0 + self.used_llm_kwargs = {} def is_chat_mode(self) -> bool: return True def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + self.used_llm_kwargs = self._merge_llm_kwargs(llm_kwargs) return "".join(m.content for m in prompt.messages) async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: From b46e77cc36def98e5d39fc00655daadda2e910ec Mon Sep 17 00:00:00 2001 From: Ben Sowell Date: Thu, 15 May 2025 12:39:45 -0700 Subject: [PATCH 2/3] Fix bug with async anthropic llm. --- lib/sycamore/sycamore/llms/anthropic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/llms/anthropic.py b/lib/sycamore/sycamore/llms/anthropic.py index 75b088e9d..23806709d 100644 --- a/lib/sycamore/sycamore/llms/anthropic.py +++ b/lib/sycamore/sycamore/llms/anthropic.py @@ -194,7 +194,7 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: from anthropic import RateLimitError, APIConnectionError - self._merge_llm_kwargs(llm_kwargs) + llm_kwargs = self._merge_llm_kwargs(llm_kwargs) ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): From fda41a55e0169b8d1ed17ea8a4663b4c37d90273 Mon Sep 17 00:00:00 2001 From: Ben Sowell Date: Thu, 15 May 2025 16:14:51 -0700 Subject: [PATCH 3/3] Add type annotation for mypy. --- lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py index d1cc6f965..f60439fa8 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py @@ -4,14 +4,14 @@ from sycamore.llms.prompts.prompts import RenderedMessage from sycamore.transforms.base_llm import LLMMap, LLMMapElements import pytest -from typing import Optional +from typing import Any, Optional class FakeLLM(LLM): def __init__(self, default_mode: LLMMode = LLMMode.SYNC): super().__init__(model_name="dummy", default_mode=default_mode) self.async_calls = 0 - self.used_llm_kwargs = {} + self.used_llm_kwargs: dict[str, Any] = {} def is_chat_mode(self) -> bool: return True