From c2a8cfa0ecdf5a1db1a2a3971344305468401fcb Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 16 Jan 2025 16:40:07 -0800 Subject: [PATCH 1/8] add prompt base classes and ElementListPrompt Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 211 ++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 lib/sycamore/sycamore/llms/prompts/prompts.py diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py new file mode 100644 index 000000000..c7ec8ec37 --- /dev/null +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -0,0 +1,211 @@ +from dataclasses import dataclass +from typing import Any, Union, Optional, Callable +import copy + +import pydantic +from sycamore.data.document import Document, Element + + +@dataclass +class RenderedMessage: + """Represents a message per the LLM messages interface - i.e. a role and a content string + + Args: + role: the role of this message. Should be one of "user", "system", "assistant" + content: the content of this message. + """ + + role: str + content: str + + def to_dict(self): + return {"role": self.role, "content": self.content} + + +@dataclass +class RenderedPrompt: + """Represents a prompt to be sent to the LLM per the LLM messages interface + + Args: + messages: the list of messages to be sent to the LLM + response_format: optional output schema, speicified as pydict/json or + a pydantic model. Can only be used (iirc) with modern OpenAI models. + """ + + messages: list[RenderedMessage] + response_format: Union[None, dict[str, Any], pydantic.BaseModel] = None + + def to_dict(self): + res = {"messages": [m.to_dict() for m in self.messages]} + if self.response_format is not None: + res["response_format"] = self.output_structure # type: ignore + return res + + +class SycamorePrompt: + """Base class/API for all Sycamore LLM Prompt objects. Sycamore Prompts + convert sycamore objects (``Document``s, ``Element``s) into ``RenderedPrompts`` + """ + + def render_document(self, doc: Document) -> RenderedPrompt: + """Render this prompt, given this document as context. + Used in llm_map + + Args: + doc: The document to use to populate the prompt + + Returns: + A fully rendered prompt that can be sent to an LLM for inference + """ + raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") + + def render_element(self, elt: Element) -> RenderedPrompt: + """Render this prompt, given this element as context. + Used in llm_map_elements + + Args: + elt: The element to use to populate the prompt + + Returns: + A fully rendered prompt that can be sent to an LLM for inference + """ + raise NotImplementedError(f"render_element is not implemented for {self.__class__.__name__}") + + def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: + """Render this prompt, given a list of documents as context. + Used in llm_reduce + + Args: + docs: The list of documents to use to populate the prompt + + Returns: + A fully rendered prompt that can be sent to an LLM for inference""" + raise NotImplementedError(f"render_multiple_documents is not implemented for {self.__class__.__name__}") + + def instead(self, **kwargs) -> "SycamorePrompt": + """Create a new prompt with some fields changed. + + Args: + **kwargs: any keyword arguments will get set as fields in the + resulting prompt + + Returns: + A new SycamorePrompt with updated fields. + + Example: + .. code-block:: python + + p = StaticPrompt(system="hello", user="world") + p.render_document(Document()) + # [ + # {"role": "system", "content": "hello"}, + # {"role": "user", "content": "world"} + # ] + p2 = p.instead(user="bob") + p2.render_document(Document()) + # [ + # {"role": "system", "content": "hello"}, + # {"role": "user", "content": "bob"} + # ] + """ + new = copy.deepcopy(self) + new.__dict__.update(kwargs) + return new + + +class ElementListPrompt(SycamorePrompt): + """A prompt with utilities for constructing a list of elements to include + in the rendered prompt. + + Args: + system: The system prompt string. Use {} to reference names that should + be interpolated. Defaults to None + user: The user prompt string. Use {} to reference names that should be + interpolated. Defaults to None + element_select: Function to choose which set of elements to include in + the prompt. If None, defaults to the first ``num_elements`` elements. + element_order: Function to reorder the selected elements. Defaults to + a noop. + element_list_constructor: Function to turn a list of elements into a + string that can be accessed with the interpolation key "{elements}". + Defaults to "ELEMENT 0: {elts[0].text_representation}\n + ELEMENT 1: {elts[1].text_representation}\n + ..." + num_elements: Sets the number of elements to take if ``element_select`` is + unset. Default is 35. + **kwargs: other keyword arguments are stored and can be used as interpolation keys. + + Example: + .. code-block:: python + + prompt = ElementListPrompt( + system = "Hello {name}. This is a prompt about {doc_property_path}" + user = "What do you make of these tables?\nTables:\n{elements}" + element_select = lambda elts: [e for e in elts if e.type == "table"] + element_order = reversed + name = "David Rothschild" + ) + prompt.render_document(doc) + # [ + # {"role": "system", "content": "Hello David Rothschild. This is a prompt about data/mypdf.pdf"}, + # {"role": "user", "content": "What do you make of these tables?\nTables:\n + # ELEMENT 0: \nELEMENT 1: ..."} + # ] + """ + + def __init__( + self, + *, + system: Optional[str] = None, + user: Optional[str] = None, + element_select: Optional[Callable[[list[Element]], list[Element]]] = None, + element_order: Optional[Callable[[list[Element]], list[Element]]] = None, + element_list_constructor: Optional[Callable[[list[Element]], str]] = None, + num_elements: int = 35, + **kwargs, + ): + self.system = system + self.user = user + self.element_select = element_select or (lambda elts: elts[:num_elements]) + self.element_order = element_order or (lambda elts: elts) + self.element_list_constructor = element_list_constructor or ( + lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) + ) + self.kwargs = kwargs + super().__init__() + + def _render_element_list_to_string(self, doc: Document): + elts = self.element_select(doc.elements) + elts = self.element_order(elts) + return self.element_list_constructor(elts) + + def render_document(self, doc: Document) -> RenderedPrompt: + """Render this prompt, given this document as context, using python's + ``str.format()`` method. The keys passed into ``format()`` are as follows: + + - self.kwargs: the additional kwargs specified when creating this prompt. + - doc_text: doc.text_representation + - doc_property_: each property name in doc.properties is + prefixed with 'doc_property_'. So if ``doc.properties = {'k1': 0, 'k2': 3}``, + you get ``doc_property_k1 = 0, doc_property_k2 = 3``. + - elements: the element list constructed from doc.elements using ``self.element_select``, + ``self.element_order``, and ``self.element_list_constructor``. + + Args: + doc: The document to use as context for rendering this prompt + + Returns: + A two-message RenderedPrompt containing ``self.system.format()`` and ``self.user.format()`` + using the format keys as specified above. + """ + format_args = self.kwargs + format_args["doc_text"] = doc.text_representation + format_args.update({"doc_property_" + k: v for k, v in doc.properties.items()}) + format_args["elements"] = self._render_element_list_to_string(doc) + + result = RenderedPrompt(messages=[]) + if self.system is not None: + result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) + if self.user is not None: + result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + return result From 21a115a614133244facdb99bb6ef20dc10427550 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 16 Jan 2025 16:49:31 -0800 Subject: [PATCH 2/8] override .instead in ElementListPrompt to store net-new keys in self.kwargs Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index c7ec8ec37..f4185af51 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -209,3 +209,12 @@ def render_document(self, doc: Document) -> RenderedPrompt: if self.user is not None: result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) return result + + def instead(self, **kwargs) -> "SycamorePrompt": + new = copy.deepcopy(self) + for k in kwargs: + if k in new.__dict__: + new.__dict__[k] = kwargs[k] + else: + new.kwargs[k] = kwargs[k] + return new From f94da80f841104870e770a13de2f1f9912f0a078 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 17 Jan 2025 15:38:49 -0800 Subject: [PATCH 3/8] add ElementPrompt and StaticPrompt Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 184 +++++++++++++++--- 1 file changed, 161 insertions(+), 23 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index f4185af51..9feb8d27f 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -3,7 +3,9 @@ import copy import pydantic +from PIL import Image from sycamore.data.document import Document, Element +from sycamore.utils.pdf_utils import get_element_image @dataclass @@ -12,11 +14,13 @@ class RenderedMessage: Args: role: the role of this message. Should be one of "user", "system", "assistant" - content: the content of this message. + content: the content of this message, either a python string or a PIL image. + images: optional list of images to include in this message. """ role: str content: str + images: Optional[list[Image.Image]] = None def to_dict(self): return {"role": self.role, "content": self.content} @@ -59,8 +63,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: """ raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") - def render_element(self, elt: Element) -> RenderedPrompt: - """Render this prompt, given this element as context. + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + """Render this prompt, given this element and its parent document as context. Used in llm_map_elements Args: @@ -86,10 +90,12 @@ def instead(self, **kwargs) -> "SycamorePrompt": """Create a new prompt with some fields changed. Args: + **kwargs: any keyword arguments will get set as fields in the resulting prompt Returns: + A new SycamorePrompt with updated fields. Example: @@ -98,18 +104,22 @@ def instead(self, **kwargs) -> "SycamorePrompt": p = StaticPrompt(system="hello", user="world") p.render_document(Document()) # [ - # {"role": "system", "content": "hello"}, - # {"role": "user", "content": "world"} - # ] - p2 = p.instead(user="bob") - p2.render_document(Document()) + # {"role": "system", "content": "hello"}, + # {"role": "user", "content": "world"} + # ] + p2 = p.instead(user="bob") + p2.render_document(Document()) # [ # {"role": "system", "content": "hello"}, # {"role": "user", "content": "bob"} # ] """ new = copy.deepcopy(self) - new.__dict__.update(kwargs) + for k, v in kwargs.items(): + if hasattr(new, "kwargs") and k not in new.__dict__: + getattr(new, "kwargs")[k] = v + else: + new.__dict__[k] = v return new @@ -118,6 +128,7 @@ class ElementListPrompt(SycamorePrompt): in the rendered prompt. Args: + system: The system prompt string. Use {} to reference names that should be interpolated. Defaults to None user: The user prompt string. Use {} to reference names that should be @@ -128,8 +139,8 @@ class ElementListPrompt(SycamorePrompt): a noop. element_list_constructor: Function to turn a list of elements into a string that can be accessed with the interpolation key "{elements}". - Defaults to "ELEMENT 0: {elts[0].text_representation}\n - ELEMENT 1: {elts[1].text_representation}\n + Defaults to "ELEMENT 0: {elts[0].text_representation}\\n + ELEMENT 1: {elts[1].text_representation}\\n ..." num_elements: Sets the number of elements to take if ``element_select`` is unset. Default is 35. @@ -140,7 +151,7 @@ class ElementListPrompt(SycamorePrompt): prompt = ElementListPrompt( system = "Hello {name}. This is a prompt about {doc_property_path}" - user = "What do you make of these tables?\nTables:\n{elements}" + user = "What do you make of these tables?\\nTables:\\n{elements}" element_select = lambda elts: [e for e in elts if e.type == "table"] element_order = reversed name = "David Rothschild" @@ -148,8 +159,8 @@ class ElementListPrompt(SycamorePrompt): prompt.render_document(doc) # [ # {"role": "system", "content": "Hello David Rothschild. This is a prompt about data/mypdf.pdf"}, - # {"role": "user", "content": "What do you make of these tables?\nTables:\n - # ELEMENT 0: \nELEMENT 1: ..."} + # {"role": "user", "content": "What do you make of these tables?\\nTables:\\n + # ELEMENT 0: \\nELEMENT 1: ..."} # ] """ @@ -164,6 +175,7 @@ def __init__( num_elements: int = 35, **kwargs, ): + super().__init__() self.system = system self.user = user self.element_select = element_select or (lambda elts: elts[:num_elements]) @@ -172,7 +184,6 @@ def __init__( lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) ) self.kwargs = kwargs - super().__init__() def _render_element_list_to_string(self, doc: Document): elts = self.element_select(doc.elements) @@ -210,11 +221,138 @@ def render_document(self, doc: Document) -> RenderedPrompt: result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) return result - def instead(self, **kwargs) -> "SycamorePrompt": - new = copy.deepcopy(self) - for k in kwargs: - if k in new.__dict__: - new.__dict__[k] = kwargs[k] - else: - new.kwargs[k] = kwargs[k] - return new + +class ElementPrompt(SycamorePrompt): + """A prompt for rendering an element with utilities for capturing information + from the element's parent document, with a system and user prompt. + + Args: + system: The system prompt string. Use {} to reference names to be interpolated. + Defaults to None + user: The user prompt string. Use {} to reference names to be interpolated. + Defaults to None + include_element_image: Whether to include an image of the element in the rendered user + message. Only works if the parent document is a PDF. Defaults to False (no image) + capture_parent_context: Function to gather context from the element's parent document. + Should return {"key": value} dictionary, which will be made available as interpolation + keys. Defaults to returning {} + **kwargs: other keyword arguments are stored and can be used as interpolation keys + + Example: + .. code-block:: python + + prompt = ElementPrompt( + system = "You know everything there is to know about {custom_kwarg}, {name}", + user = "Summarize the information on page {elt_property_page}. \\nTEXT: {elt_text}", + capture_parent_context = lambda doc, elt: {"custom_kwarg": doc.properties["path"]}, + name = "Frank Sinatra", + ) + prompt.render_element(doc.elements[0], doc) + # [ + # {"role": "system", "content": "You know everything there is to know + # about /path/to/doc.pdf, Frank Sinatra"}, + # {"role": "user", "content": "Summarize the information on page 1. \\nTEXT: "} + # ] + """ + + def __init__( + self, + *, + system: Optional[str] = None, + user: Optional[str] = None, + include_element_image: bool = False, + capture_parent_context: Optional[Callable[[Document, Element], dict[str, Any]]] = None, + **kwargs, + ): + super().__init__() + self.system = system + self.user = user + self.include_element_image = include_element_image + self.capture_parent_context = capture_parent_context or (lambda doc, elt: {}) + self.kwargs = kwargs + + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + """Render this prompt for this element; also take the parent document + if there is context in that to account for as well. Rendering is done + using pythons ``str.format()`` method. The keys passed into ``format`` + are as follows: + + - self.kwargs: the additional kwargs specified when creating this prompt. + - self.capture_parent_content(doc, elt): key-value pairs returned by the + context-capturing function. + - elt_text: elt.text_representation (the text representation of the element) + - elt_property_: each property name in elt.properties is + prefixed with 'elt_property_'. So if ``elt.properties = {'k1': 0, 'k2': 3}``, + you get ``elt_property_k1 = 0, elt_property_k2 = 3``. + + Args: + elt: The element used as context for rendering this prompt. + doc: The element's parent document; used to add additional context. + + Returns: + A two-message rendered prompt containing ``self.system.format()`` and + ``self.user.format()`` using the format keys as specified above. + If self.include_element_image is true, crop out the image from the page + of the PDF it's on and attach it to the last message (user message if there + is one, o/w system message). + """ + format_args = self.kwargs + format_args.update(self.capture_parent_context(doc, elt)) + format_args["elt_text"] = elt.text_representation + format_args.update({"elt_property_" + k: v for k, v in elt.properties.items()}) + + result = RenderedPrompt(messages=[]) + if self.system is not None: + result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) + if self.user is not None: + result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + if self.include_element_image and len(result.messages) > 0: + result.messages[-1].images = [get_element_image(elt, doc)] + return result + + +class StaticPrompt(SycamorePrompt): + """A prompt that always renders the same regardless of the Document or Elements + passed in as context. + + Args: + + system: the system prompt string. Use {} to reference names to be interpolated. + Interpolated names only come from kwargs. + user: the user prompt string. Use {} to reference names to be interpolated. + Interpolated names only come from kwargs. + **kwargs: keyword arguments to interpolate. + + Example: + .. code-block:: python + + prompt = StaticPrompt(system="static", user = "prompt - {number}", number=7) + prompt.render_document(Document()) + # [ + # { "role": "system", "content": "static" }, + # { "role": "user", "content": "prompt - 7" }, + # ] + """ + + def __init__(self, *, system: Optional[str] = None, user: Optional[str] = None, **kwargs): + super().__init__() + self.system = system + self.user = user + self.kwargs = kwargs + + def render_generic(self) -> RenderedPrompt: + result = RenderedPrompt(messages=[]) + if self.system is not None: + result.messages.append(RenderedMessage(role="system", content=self.system.format(**self.kwargs))) + if self.user is not None: + result.messages.append(RenderedMessage(role="user", content=self.user.format(**self.kwargs))) + return result + + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + return self.render_generic() + + def render_document(self, doc: Document) -> RenderedPrompt: + return self.render_generic() + + def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: + return self.render_generic() From b73c1624951f39799d20548df332d37cbeb915b1 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 09:34:59 -0800 Subject: [PATCH 4/8] add unit tests for prompts Signed-off-by: Henry Lindeman --- .../tests/unit/llms/prompts/test_prompts.py | 238 ++++++++++++++++++ lib/sycamore/sycamore/utils/pdf_utils.py | 18 +- 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py new file mode 100644 index 000000000..6ce28c1dc --- /dev/null +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -0,0 +1,238 @@ +from sycamore.data.element import Element +from sycamore.llms.prompts.prompts import ( + RenderedPrompt, + RenderedMessage, + StaticPrompt, + SycamorePrompt, + ElementPrompt, + ElementListPrompt, +) +from sycamore.data import Document +from sycamore.tests.config import TEST_DIR +from pyarrow.fs import LocalFileSystem +import pytest + + +@pytest.fixture(scope="module") +def dummy_document(): + docpath = TEST_DIR / "resources/data/pdfs/ntsb-report.pdf" + local = LocalFileSystem() + path = str(docpath) + input_stream = local.open_input_stream(path) + document = Document() + document.binary_representation = input_stream.readall() + document.type = "pdf" + document.properties["path"] = path + document.properties["pages"] = 6 + document.elements = [ + Element( + text_representation="Element 1", + type="Text", + element_id="e1", + properties={"page_number": 1}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + Element( + text_representation="Element 2", + type="Text", + element_id="e2", + properties={"page_number": 2}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + Element( + text_representation="Element 3", + type="Text", + element_id="e3", + properties={"page_number": 3}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + Element( + text_representation="Element 4", + type="Text", + element_id="e4", + properties={"page_number": 3}, + bbox=(0.4, 0.1, 0.8, 0.4), + ), + Element( + text_representation="Element 5", + type="Text", + element_id="e5", + properties={"page_number": 3}, + bbox=(0.1, 0.4, 0.8, 0.8), + ), + Element( + text_representation="Element 6", + type="Text", + element_id="e6", + properties={"page_number": 4}, + bbox=(0.1, 0.1, 0.4, 0.4), + ), + ] + return document + + +class TestRenderedPrompt: + """RenderedPrompt and RenderedMessage are dataclasses, + no need to test them. Nothing to test :)""" + + pass + + +class TestSycamorePrompt: + def test_instead_is_cow(self): + sp = SycamorePrompt() + sp.__dict__["key"] = "value" + sp2 = sp.instead(key="other value") + assert sp.key == "value" + assert sp2.key == "other value" + + +class TestStaticPrompt: + def test_static_rd(self, dummy_document): + prompt = StaticPrompt(system="system {x}", user="computers") + with pytest.raises(KeyError): + prompt.render_document(dummy_document) + + prompt = prompt.instead(x=76) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="system 76"), + RenderedMessage(role="user", content="computers"), + ] + ) + assert prompt.render_document(dummy_document) == expected + assert prompt.render_element(dummy_document.elements[0], dummy_document) == expected + assert prompt.render_multiple_documents([dummy_document]) == expected + + +class TestElementPrompt: + def test_basic(self, dummy_document): + prompt = ElementPrompt( + system="You know everything there is to know about jazz, {name}", + user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}", + name="Frank Sinatra", + ) + expected = RenderedPrompt( + messages=[ + RenderedMessage( + role="system", content="You know everything there is to know about jazz, Frank Sinatra" + ), + RenderedMessage(role="user", content="Summarize the information on page 3.\nTEXT: Element 4"), + ] + ) + assert prompt.render_element(dummy_document.elements[3], dummy_document) == expected + with pytest.raises(NotImplementedError): + prompt.render_document(dummy_document) + with pytest.raises(NotImplementedError): + prompt.render_multiple_documents([dummy_document]) + + def test_get_parent_context(self, dummy_document): + prompt = ElementPrompt( + system="You know everything there is to know about {custom_property}, {name}", + user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}", + name="Frank Sinatra", + capture_parent_context=lambda doc, elt: {"custom_property": doc.properties["pages"]}, + ) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="You know everything there is to know about 6, Frank Sinatra"), + RenderedMessage(role="user", content="Summarize the information on page 3.\nTEXT: Element 4"), + ] + ) + assert prompt.render_element(dummy_document.elements[3], dummy_document) == expected + + def test_include_image(self, dummy_document): + prompt = ElementPrompt( + system="You know everything there is to know about {custom_property}, {name}", + user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}", + name="Frank Sinatra", + capture_parent_context=lambda doc, elt: {"custom_property": doc.properties["pages"]}, + include_element_image=True, + ) + rp = prompt.render_element(dummy_document.elements[3], dummy_document) + assert rp.messages[1].images is not None and len(rp.messages[1].images) == 1 + assert rp.messages[1].role == "user" + assert rp.messages[0].images is None + + prompt = prompt.instead(user=None) + rp2 = prompt.render_element(dummy_document.elements[1], dummy_document) + assert len(rp2.messages) == 1 + assert rp2.messages[0].role == "system" + assert rp2.messages[0].images is not None + assert len(rp2.messages[0].images) == 1 + + +class TestElementListPrompt: + def test_basic(self, dummy_document): + prompt = ElementListPrompt(system="sys", user="usr: {elements}") + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 1\nELEMENT 1: Element 2\n" + "ELEMENT 2: Element 3\nELEMENT 3: Element 4\nELEMENT 4: Element 5\nELEMENT 5: Element 6", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_limit_elements(self, dummy_document): + prompt = ElementListPrompt(system="sys", user="usr: {elements}", num_elements=3) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 1\nELEMENT 1: Element 2\nELEMENT 2: Element 3", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_select_odd_elements(self, dummy_document): + prompt = ElementListPrompt( + system="sys", + user="usr: {elements}", + element_select=lambda elts: [elts[i] for i in range(len(elts)) if i % 2 == 1], + ) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 2\nELEMENT 1: Element 4\nELEMENT 2: Element 6", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_order_elements(self, dummy_document): + prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_order=lambda e: list(reversed(e))) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: ELEMENT 0: Element 6\nELEMENT 1: Element 5\n" + "ELEMENT 2: Element 4\nELEMENT 3: Element 3\nELEMENT 4: Element 2\nELEMENT 5: Element 1", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected + + def test_construct_element_list(self, dummy_document): + def list_constructor(elts: list[Element]) -> str: + return "<>" + "<>".join(f"{i}-{e.type}" for i, e in enumerate(elts)) + "" + + prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_list_constructor=list_constructor) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="sys"), + RenderedMessage( + role="user", + content="usr: <>0-Text<>1-Text<>2-Text<>3-Text<>4-Text<>5-Text", + ), + ] + ) + assert prompt.render_document(dummy_document) == expected diff --git a/lib/sycamore/sycamore/utils/pdf_utils.py b/lib/sycamore/sycamore/utils/pdf_utils.py index 8665bee1e..92331b578 100644 --- a/lib/sycamore/sycamore/utils/pdf_utils.py +++ b/lib/sycamore/sycamore/utils/pdf_utils.py @@ -5,10 +5,11 @@ from PIL import Image from pypdf import PdfReader, PdfWriter +import pdf2image from sycamore import DocSet from sycamore.functions.document import DrawBoxes, split_and_convert_to_image -from sycamore.utils.image_utils import show_images +from sycamore.utils.image_utils import show_images, crop_to_bbox from sycamore.data import Document, Element import json @@ -180,3 +181,18 @@ def promote_title(elements: list[Element], title_candidate_elements=["Section-he if section_header: section_header.type = "Title" return elements + + +def get_element_image(element: Element, document: Document) -> Image.Image: + assert document.type == "pdf", "Cannot get picture of element from non-pdf" + assert document.binary_representation is not None, "Cannot get image since there is not binary representation" + assert element.bbox is not None, "Cannot get picture of element if it has no BBox" + assert element.properties.get("page_number") is not None and isinstance( + element.properties["page_number"], int + ), "Cannot get picture of element without known page number" + bits = BytesIO(document.binary_representation) + pagebits = BytesIO() + select_pdf_pages(bits, pagebits, [element.properties["page_number"]]) + images = pdf2image.convert_from_bytes(pagebits.getvalue()) + im = crop_to_bbox(images[0], element.bbox) + return im From 17b21635a41ac4a18b3384ba3cb057fc48e29802 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 10:33:34 -0800 Subject: [PATCH 5/8] forgot to commit this Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 9feb8d27f..40b020f76 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -5,7 +5,6 @@ import pydantic from PIL import Image from sycamore.data.document import Document, Element -from sycamore.utils.pdf_utils import get_element_image @dataclass @@ -22,9 +21,6 @@ class RenderedMessage: content: str images: Optional[list[Image.Image]] = None - def to_dict(self): - return {"role": self.role, "content": self.content} - @dataclass class RenderedPrompt: @@ -39,12 +35,6 @@ class RenderedPrompt: messages: list[RenderedMessage] response_format: Union[None, dict[str, Any], pydantic.BaseModel] = None - def to_dict(self): - res = {"messages": [m.to_dict() for m in self.messages]} - if self.response_format is not None: - res["response_format"] = self.output_structure # type: ignore - return res - class SycamorePrompt: """Base class/API for all Sycamore LLM Prompt objects. Sycamore Prompts @@ -307,6 +297,8 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: if self.user is not None: result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) if self.include_element_image and len(result.messages) > 0: + from sycamore.utils.pdf_utils import get_element_image + result.messages[-1].images = [get_element_image(elt, doc)] return result From 5d145d5f8772e6385e7bf1a987547161728e58d9 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 12:37:35 -0800 Subject: [PATCH 6/8] address pr comments; flatten properties with flatten_data Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 25 ++++++++----------- .../tests/unit/llms/prompts/test_prompts.py | 10 +++++++- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 40b020f76..24c2e6392 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -5,6 +5,7 @@ import pydantic from PIL import Image from sycamore.data.document import Document, Element +from sycamore.connectors.common import flatten_data @dataclass @@ -12,8 +13,8 @@ class RenderedMessage: """Represents a message per the LLM messages interface - i.e. a role and a content string Args: - role: the role of this message. Should be one of "user", "system", "assistant" - content: the content of this message, either a python string or a PIL image. + role: the role of this message. e.g. for OpenAI should be one of "user", "system", "assistant" + content: the content of this message images: optional list of images to include in this message. """ @@ -29,7 +30,7 @@ class RenderedPrompt: Args: messages: the list of messages to be sent to the LLM response_format: optional output schema, speicified as pydict/json or - a pydantic model. Can only be used (iirc) with modern OpenAI models. + a pydantic model. Can only be used with modern OpenAI models. """ messages: list[RenderedMessage] @@ -123,10 +124,8 @@ class ElementListPrompt(SycamorePrompt): be interpolated. Defaults to None user: The user prompt string. Use {} to reference names that should be interpolated. Defaults to None - element_select: Function to choose which set of elements to include in - the prompt. If None, defaults to the first ``num_elements`` elements. - element_order: Function to reorder the selected elements. Defaults to - a noop. + element_select: Function to choose the elements (and their order) to include + in the prompt. If None, defaults to the first ``num_elements`` elements. element_list_constructor: Function to turn a list of elements into a string that can be accessed with the interpolation key "{elements}". Defaults to "ELEMENT 0: {elts[0].text_representation}\\n @@ -142,8 +141,7 @@ class ElementListPrompt(SycamorePrompt): prompt = ElementListPrompt( system = "Hello {name}. This is a prompt about {doc_property_path}" user = "What do you make of these tables?\\nTables:\\n{elements}" - element_select = lambda elts: [e for e in elts if e.type == "table"] - element_order = reversed + element_select = lambda elts: list(reversed(e for e in elts if e.type == "table")) name = "David Rothschild" ) prompt.render_document(doc) @@ -160,7 +158,6 @@ def __init__( system: Optional[str] = None, user: Optional[str] = None, element_select: Optional[Callable[[list[Element]], list[Element]]] = None, - element_order: Optional[Callable[[list[Element]], list[Element]]] = None, element_list_constructor: Optional[Callable[[list[Element]], str]] = None, num_elements: int = 35, **kwargs, @@ -169,7 +166,6 @@ def __init__( self.system = system self.user = user self.element_select = element_select or (lambda elts: elts[:num_elements]) - self.element_order = element_order or (lambda elts: elts) self.element_list_constructor = element_list_constructor or ( lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) ) @@ -177,7 +173,6 @@ def __init__( def _render_element_list_to_string(self, doc: Document): elts = self.element_select(doc.elements) - elts = self.element_order(elts) return self.element_list_constructor(elts) def render_document(self, doc: Document) -> RenderedPrompt: @@ -201,7 +196,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: """ format_args = self.kwargs format_args["doc_text"] = doc.text_representation - format_args.update({"doc_property_" + k: v for k, v in doc.properties.items()}) + flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") + format_args.update(flat_props) format_args["elements"] = self._render_element_list_to_string(doc) result = RenderedPrompt(messages=[]) @@ -289,7 +285,8 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: format_args = self.kwargs format_args.update(self.capture_parent_context(doc, elt)) format_args["elt_text"] = elt.text_representation - format_args.update({"elt_property_" + k: v for k, v in elt.properties.items()}) + flat_props = flatten_data(elt.properties, prefix="elt_property", separator="_") + format_args.update(flat_props) result = RenderedPrompt(messages=[]) if self.system is not None: diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py index 6ce28c1dc..6c36d7510 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -208,7 +208,7 @@ def test_select_odd_elements(self, dummy_document): assert prompt.render_document(dummy_document) == expected def test_order_elements(self, dummy_document): - prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_order=lambda e: list(reversed(e))) + prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_select=lambda e: list(reversed(e))) expected = RenderedPrompt( messages=[ RenderedMessage(role="system", content="sys"), @@ -236,3 +236,11 @@ def list_constructor(elts: list[Element]) -> str: ] ) assert prompt.render_document(dummy_document) == expected + + def test_flattened_properties(self, dummy_document): + doc = dummy_document.copy() + doc.properties["entity"] = {"key": "value"} + + prompt = ElementListPrompt(system="sys {doc_property_entity_key}") + expected = RenderedPrompt(messages=[RenderedMessage(role="system", content="sys value")]) + assert prompt.render_document(doc) == expected From 7fa2ff1488a4db4be5b58c8e15a565b37c32854b Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 13:12:51 -0800 Subject: [PATCH 7/8] support multiple user prompts Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 40 ++++++++++--------- .../tests/unit/llms/prompts/test_prompts.py | 11 +++++ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 24c2e6392..40930bad3 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -114,6 +114,19 @@ def instead(self, **kwargs) -> "SycamorePrompt": return new +def _build_format_str( + system: Optional[str], user: Union[None, str, list[str]], format_args: dict[str, Any] +) -> list[RenderedMessage]: + messages = [] + if system is not None: + messages.append(RenderedMessage(role="system", content=system.format(**format_args))) + if isinstance(user, list): + messages.extend([RenderedMessage(role="user", content=u.format(**format_args)) for u in user]) + elif isinstance(user, str): + messages.append(RenderedMessage(role="user", content=user.format(**format_args))) + return messages + + class ElementListPrompt(SycamorePrompt): """A prompt with utilities for constructing a list of elements to include in the rendered prompt. @@ -156,7 +169,7 @@ def __init__( self, *, system: Optional[str] = None, - user: Optional[str] = None, + user: Union[None, str, list[str]] = None, element_select: Optional[Callable[[list[Element]], list[Element]]] = None, element_list_constructor: Optional[Callable[[list[Element]], str]] = None, num_elements: int = 35, @@ -200,11 +213,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: format_args.update(flat_props) format_args["elements"] = self._render_element_list_to_string(doc) - result = RenderedPrompt(messages=[]) - if self.system is not None: - result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) - if self.user is not None: - result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + messages = _build_format_str(self.system, self.user, format_args) + result = RenderedPrompt(messages=messages) return result @@ -245,7 +255,7 @@ def __init__( self, *, system: Optional[str] = None, - user: Optional[str] = None, + user: Union[None, str, list[str]] = None, include_element_image: bool = False, capture_parent_context: Optional[Callable[[Document, Element], dict[str, Any]]] = None, **kwargs, @@ -288,11 +298,8 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: flat_props = flatten_data(elt.properties, prefix="elt_property", separator="_") format_args.update(flat_props) - result = RenderedPrompt(messages=[]) - if self.system is not None: - result.messages.append(RenderedMessage(role="system", content=self.system.format(**format_args))) - if self.user is not None: - result.messages.append(RenderedMessage(role="user", content=self.user.format(**format_args))) + messages = _build_format_str(self.system, self.user, format_args) + result = RenderedPrompt(messages=messages) if self.include_element_image and len(result.messages) > 0: from sycamore.utils.pdf_utils import get_element_image @@ -323,18 +330,15 @@ class StaticPrompt(SycamorePrompt): # ] """ - def __init__(self, *, system: Optional[str] = None, user: Optional[str] = None, **kwargs): + def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs): super().__init__() self.system = system self.user = user self.kwargs = kwargs def render_generic(self) -> RenderedPrompt: - result = RenderedPrompt(messages=[]) - if self.system is not None: - result.messages.append(RenderedMessage(role="system", content=self.system.format(**self.kwargs))) - if self.user is not None: - result.messages.append(RenderedMessage(role="user", content=self.user.format(**self.kwargs))) + messages = _build_format_str(self.system, self.user, self.kwargs) + result = RenderedPrompt(messages=messages) return result def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py index 6c36d7510..111a6ee33 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -104,6 +104,17 @@ def test_static_rd(self, dummy_document): assert prompt.render_element(dummy_document.elements[0], dummy_document) == expected assert prompt.render_multiple_documents([dummy_document]) == expected + def test_static_with_multiple_user_prompts(self, dummy_document): + prompt = StaticPrompt(system="system {x}", user=["{x} user {y}", "{x} user {z}"], x=1, y=2, z=3) + expected = RenderedPrompt( + messages=[ + RenderedMessage(role="system", content="system 1"), + RenderedMessage(role="user", content="1 user 2"), + RenderedMessage(role="user", content="1 user 3"), + ] + ) + assert prompt.render_document(dummy_document) == expected + class TestElementPrompt: def test_basic(self, dummy_document): From abf9b0b7dcd32d37e1d849cd509f8e33bf710d0b Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 21 Jan 2025 16:16:30 -0800 Subject: [PATCH 8/8] rename instead to set Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/prompts/prompts.py | 4 ++-- .../sycamore/tests/unit/llms/prompts/test_prompts.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 40930bad3..0ea81112e 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -77,7 +77,7 @@ def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: A fully rendered prompt that can be sent to an LLM for inference""" raise NotImplementedError(f"render_multiple_documents is not implemented for {self.__class__.__name__}") - def instead(self, **kwargs) -> "SycamorePrompt": + def set(self, **kwargs) -> "SycamorePrompt": """Create a new prompt with some fields changed. Args: @@ -98,7 +98,7 @@ def instead(self, **kwargs) -> "SycamorePrompt": # {"role": "system", "content": "hello"}, # {"role": "user", "content": "world"} # ] - p2 = p.instead(user="bob") + p2 = p.set(user="bob") p2.render_document(Document()) # [ # {"role": "system", "content": "hello"}, diff --git a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py index 111a6ee33..76d4fefdf 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py +++ b/lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py @@ -79,10 +79,10 @@ class TestRenderedPrompt: class TestSycamorePrompt: - def test_instead_is_cow(self): + def test_set_is_cow(self): sp = SycamorePrompt() sp.__dict__["key"] = "value" - sp2 = sp.instead(key="other value") + sp2 = sp.set(key="other value") assert sp.key == "value" assert sp2.key == "other value" @@ -93,7 +93,7 @@ def test_static_rd(self, dummy_document): with pytest.raises(KeyError): prompt.render_document(dummy_document) - prompt = prompt.instead(x=76) + prompt = prompt.set(x=76) expected = RenderedPrompt( messages=[ RenderedMessage(role="system", content="system 76"), @@ -165,7 +165,7 @@ def test_include_image(self, dummy_document): assert rp.messages[1].role == "user" assert rp.messages[0].images is None - prompt = prompt.instead(user=None) + prompt = prompt.set(user=None) rp2 = prompt.render_element(dummy_document.elements[1], dummy_document) assert len(rp2.messages) == 1 assert rp2.messages[0].role == "system"