-
Notifications
You must be signed in to change notification settings - Fork 65
[llm unify 1/n] Add consolidated prompt classes #1120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c2a8cfa
add prompt base classes and ElementListPrompt
HenryL27 21a115a
override .instead in ElementListPrompt to store net-new keys in self.…
HenryL27 f94da80
add ElementPrompt and StaticPrompt
HenryL27 b73c162
add unit tests for prompts
HenryL27 17b2163
forgot to commit this
HenryL27 5d145d5
address pr comments; flatten properties with flatten_data
HenryL27 7fa2ff1
support multiple user prompts
HenryL27 abf9b0b
rename instead to set
HenryL27 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Union, Optional, Callable | ||
import copy | ||
|
||
import pydantic | ||
from PIL import Image | ||
from sycamore.data.document import Document, Element | ||
from sycamore.connectors.common import flatten_data | ||
|
||
|
||
@dataclass | ||
class RenderedMessage: | ||
"""Represents a message per the LLM messages interface - i.e. a role and a content string | ||
|
||
Args: | ||
role: the role of this message. e.g. for OpenAI should be one of "user", "system", "assistant" | ||
content: the content of this message | ||
images: optional list of images to include in this message. | ||
""" | ||
|
||
role: str | ||
content: str | ||
images: Optional[list[Image.Image]] = None | ||
|
||
|
||
@dataclass | ||
class RenderedPrompt: | ||
"""Represents a prompt to be sent to the LLM per the LLM messages interface | ||
|
||
Args: | ||
messages: the list of messages to be sent to the LLM | ||
response_format: optional output schema, speicified as pydict/json or | ||
a pydantic model. Can only be used with modern OpenAI models. | ||
""" | ||
|
||
messages: list[RenderedMessage] | ||
response_format: Union[None, dict[str, Any], pydantic.BaseModel] = None | ||
|
||
|
||
class SycamorePrompt: | ||
"""Base class/API for all Sycamore LLM Prompt objects. Sycamore Prompts | ||
convert sycamore objects (``Document``s, ``Element``s) into ``RenderedPrompts`` | ||
""" | ||
|
||
def render_document(self, doc: Document) -> RenderedPrompt: | ||
"""Render this prompt, given this document as context. | ||
Used in llm_map | ||
|
||
Args: | ||
doc: The document to use to populate the prompt | ||
|
||
Returns: | ||
A fully rendered prompt that can be sent to an LLM for inference | ||
""" | ||
raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}") | ||
|
||
def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: | ||
"""Render this prompt, given this element and its parent document as context. | ||
Used in llm_map_elements | ||
|
||
Args: | ||
elt: The element to use to populate the prompt | ||
|
||
Returns: | ||
A fully rendered prompt that can be sent to an LLM for inference | ||
""" | ||
raise NotImplementedError(f"render_element is not implemented for {self.__class__.__name__}") | ||
|
||
def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: | ||
"""Render this prompt, given a list of documents as context. | ||
Used in llm_reduce | ||
|
||
Args: | ||
docs: The list of documents to use to populate the prompt | ||
|
||
Returns: | ||
A fully rendered prompt that can be sent to an LLM for inference""" | ||
raise NotImplementedError(f"render_multiple_documents is not implemented for {self.__class__.__name__}") | ||
|
||
def set(self, **kwargs) -> "SycamorePrompt": | ||
"""Create a new prompt with some fields changed. | ||
|
||
Args: | ||
|
||
**kwargs: any keyword arguments will get set as fields in the | ||
resulting prompt | ||
|
||
Returns: | ||
|
||
A new SycamorePrompt with updated fields. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
p = StaticPrompt(system="hello", user="world") | ||
p.render_document(Document()) | ||
# [ | ||
# {"role": "system", "content": "hello"}, | ||
# {"role": "user", "content": "world"} | ||
# ] | ||
p2 = p.set(user="bob") | ||
p2.render_document(Document()) | ||
# [ | ||
# {"role": "system", "content": "hello"}, | ||
# {"role": "user", "content": "bob"} | ||
# ] | ||
""" | ||
new = copy.deepcopy(self) | ||
for k, v in kwargs.items(): | ||
if hasattr(new, "kwargs") and k not in new.__dict__: | ||
getattr(new, "kwargs")[k] = v | ||
else: | ||
new.__dict__[k] = v | ||
return new | ||
|
||
|
||
def _build_format_str( | ||
system: Optional[str], user: Union[None, str, list[str]], format_args: dict[str, Any] | ||
) -> list[RenderedMessage]: | ||
messages = [] | ||
if system is not None: | ||
messages.append(RenderedMessage(role="system", content=system.format(**format_args))) | ||
if isinstance(user, list): | ||
messages.extend([RenderedMessage(role="user", content=u.format(**format_args)) for u in user]) | ||
elif isinstance(user, str): | ||
messages.append(RenderedMessage(role="user", content=user.format(**format_args))) | ||
return messages | ||
|
||
|
||
class ElementListPrompt(SycamorePrompt): | ||
"""A prompt with utilities for constructing a list of elements to include | ||
in the rendered prompt. | ||
|
||
Args: | ||
|
||
system: The system prompt string. Use {} to reference names that should | ||
be interpolated. Defaults to None | ||
user: The user prompt string. Use {} to reference names that should be | ||
interpolated. Defaults to None | ||
element_select: Function to choose the elements (and their order) to include | ||
in the prompt. If None, defaults to the first ``num_elements`` elements. | ||
element_list_constructor: Function to turn a list of elements into a | ||
string that can be accessed with the interpolation key "{elements}". | ||
Defaults to "ELEMENT 0: {elts[0].text_representation}\\n | ||
ELEMENT 1: {elts[1].text_representation}\\n | ||
..." | ||
num_elements: Sets the number of elements to take if ``element_select`` is | ||
unset. Default is 35. | ||
**kwargs: other keyword arguments are stored and can be used as interpolation keys. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
prompt = ElementListPrompt( | ||
system = "Hello {name}. This is a prompt about {doc_property_path}" | ||
user = "What do you make of these tables?\\nTables:\\n{elements}" | ||
element_select = lambda elts: list(reversed(e for e in elts if e.type == "table")) | ||
name = "David Rothschild" | ||
) | ||
prompt.render_document(doc) | ||
# [ | ||
# {"role": "system", "content": "Hello David Rothschild. This is a prompt about data/mypdf.pdf"}, | ||
# {"role": "user", "content": "What do you make of these tables?\\nTables:\\n | ||
# ELEMENT 0: <last table csv>\\nELEMENT 1: <second-last table csv>..."} | ||
# ] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
system: Optional[str] = None, | ||
user: Union[None, str, list[str]] = None, | ||
element_select: Optional[Callable[[list[Element]], list[Element]]] = None, | ||
element_list_constructor: Optional[Callable[[list[Element]], str]] = None, | ||
num_elements: int = 35, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.system = system | ||
self.user = user | ||
self.element_select = element_select or (lambda elts: elts[:num_elements]) | ||
self.element_list_constructor = element_list_constructor or ( | ||
lambda elts: "\n".join(f"ELEMENT {i}: {elts[i].text_representation}" for i in range(len(elts))) | ||
) | ||
self.kwargs = kwargs | ||
|
||
def _render_element_list_to_string(self, doc: Document): | ||
elts = self.element_select(doc.elements) | ||
return self.element_list_constructor(elts) | ||
|
||
def render_document(self, doc: Document) -> RenderedPrompt: | ||
"""Render this prompt, given this document as context, using python's | ||
``str.format()`` method. The keys passed into ``format()`` are as follows: | ||
|
||
- self.kwargs: the additional kwargs specified when creating this prompt. | ||
- doc_text: doc.text_representation | ||
- doc_property_<property_name>: each property name in doc.properties is | ||
prefixed with 'doc_property_'. So if ``doc.properties = {'k1': 0, 'k2': 3}``, | ||
you get ``doc_property_k1 = 0, doc_property_k2 = 3``. | ||
- elements: the element list constructed from doc.elements using ``self.element_select``, | ||
``self.element_order``, and ``self.element_list_constructor``. | ||
|
||
Args: | ||
doc: The document to use as context for rendering this prompt | ||
|
||
Returns: | ||
A two-message RenderedPrompt containing ``self.system.format()`` and ``self.user.format()`` | ||
using the format keys as specified above. | ||
""" | ||
format_args = self.kwargs | ||
format_args["doc_text"] = doc.text_representation | ||
flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") | ||
format_args.update(flat_props) | ||
format_args["elements"] = self._render_element_list_to_string(doc) | ||
|
||
messages = _build_format_str(self.system, self.user, format_args) | ||
result = RenderedPrompt(messages=messages) | ||
return result | ||
|
||
|
||
class ElementPrompt(SycamorePrompt): | ||
"""A prompt for rendering an element with utilities for capturing information | ||
from the element's parent document, with a system and user prompt. | ||
|
||
Args: | ||
system: The system prompt string. Use {} to reference names to be interpolated. | ||
Defaults to None | ||
user: The user prompt string. Use {} to reference names to be interpolated. | ||
Defaults to None | ||
include_element_image: Whether to include an image of the element in the rendered user | ||
message. Only works if the parent document is a PDF. Defaults to False (no image) | ||
capture_parent_context: Function to gather context from the element's parent document. | ||
Should return {"key": value} dictionary, which will be made available as interpolation | ||
keys. Defaults to returning {} | ||
**kwargs: other keyword arguments are stored and can be used as interpolation keys | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
prompt = ElementPrompt( | ||
system = "You know everything there is to know about {custom_kwarg}, {name}", | ||
user = "Summarize the information on page {elt_property_page}. \\nTEXT: {elt_text}", | ||
capture_parent_context = lambda doc, elt: {"custom_kwarg": doc.properties["path"]}, | ||
name = "Frank Sinatra", | ||
) | ||
prompt.render_element(doc.elements[0], doc) | ||
# [ | ||
# {"role": "system", "content": "You know everything there is to know | ||
# about /path/to/doc.pdf, Frank Sinatra"}, | ||
# {"role": "user", "content": "Summarize the information on page 1. \\nTEXT: <element text>"} | ||
# ] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
system: Optional[str] = None, | ||
user: Union[None, str, list[str]] = None, | ||
include_element_image: bool = False, | ||
capture_parent_context: Optional[Callable[[Document, Element], dict[str, Any]]] = None, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.system = system | ||
self.user = user | ||
self.include_element_image = include_element_image | ||
self.capture_parent_context = capture_parent_context or (lambda doc, elt: {}) | ||
self.kwargs = kwargs | ||
|
||
def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: | ||
"""Render this prompt for this element; also take the parent document | ||
if there is context in that to account for as well. Rendering is done | ||
using pythons ``str.format()`` method. The keys passed into ``format`` | ||
are as follows: | ||
|
||
- self.kwargs: the additional kwargs specified when creating this prompt. | ||
- self.capture_parent_content(doc, elt): key-value pairs returned by the | ||
context-capturing function. | ||
- elt_text: elt.text_representation (the text representation of the element) | ||
- elt_property_<property name>: each property name in elt.properties is | ||
prefixed with 'elt_property_'. So if ``elt.properties = {'k1': 0, 'k2': 3}``, | ||
you get ``elt_property_k1 = 0, elt_property_k2 = 3``. | ||
|
||
Args: | ||
elt: The element used as context for rendering this prompt. | ||
doc: The element's parent document; used to add additional context. | ||
|
||
Returns: | ||
A two-message rendered prompt containing ``self.system.format()`` and | ||
``self.user.format()`` using the format keys as specified above. | ||
If self.include_element_image is true, crop out the image from the page | ||
of the PDF it's on and attach it to the last message (user message if there | ||
is one, o/w system message). | ||
""" | ||
format_args = self.kwargs | ||
format_args.update(self.capture_parent_context(doc, elt)) | ||
format_args["elt_text"] = elt.text_representation | ||
flat_props = flatten_data(elt.properties, prefix="elt_property", separator="_") | ||
format_args.update(flat_props) | ||
|
||
messages = _build_format_str(self.system, self.user, format_args) | ||
result = RenderedPrompt(messages=messages) | ||
if self.include_element_image and len(result.messages) > 0: | ||
from sycamore.utils.pdf_utils import get_element_image | ||
|
||
result.messages[-1].images = [get_element_image(elt, doc)] | ||
return result | ||
|
||
|
||
class StaticPrompt(SycamorePrompt): | ||
"""A prompt that always renders the same regardless of the Document or Elements | ||
passed in as context. | ||
|
||
Args: | ||
|
||
system: the system prompt string. Use {} to reference names to be interpolated. | ||
Interpolated names only come from kwargs. | ||
user: the user prompt string. Use {} to reference names to be interpolated. | ||
Interpolated names only come from kwargs. | ||
**kwargs: keyword arguments to interpolate. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
prompt = StaticPrompt(system="static", user = "prompt - {number}", number=7) | ||
prompt.render_document(Document()) | ||
# [ | ||
# { "role": "system", "content": "static" }, | ||
# { "role": "user", "content": "prompt - 7" }, | ||
# ] | ||
""" | ||
|
||
def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs): | ||
super().__init__() | ||
self.system = system | ||
self.user = user | ||
self.kwargs = kwargs | ||
|
||
def render_generic(self) -> RenderedPrompt: | ||
messages = _build_format_str(self.system, self.user, self.kwargs) | ||
result = RenderedPrompt(messages=messages) | ||
return result | ||
|
||
def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: | ||
return self.render_generic() | ||
|
||
def render_document(self, doc: Document) -> RenderedPrompt: | ||
return self.render_generic() | ||
|
||
def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: | ||
return self.render_generic() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we're supporting system and user prompts (i.e. the messages api), shouldn't we be supporting a list then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I think we can limit to only one system prompt (I don't remember what the providers do but that seems sensible) though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that makes sense