From 9394a57896087d6ec8df8680190bc5f8648e36dc Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Tue, 4 Feb 2025 14:51:43 -0800 Subject: [PATCH] handle specified prompt and use_elements=True in extract entity Signed-off-by: Henry Lindeman --- .../unit/transforms/test_extract_entity.py | 27 +++++++++++++++++++ .../sycamore/transforms/extract_entity.py | 13 ++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index 0c06d0bff..db420d6af 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -23,6 +23,12 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) else: usermessage = prompt.messages[1].content + if usermessage.startswith("Hi"): + return usermessage + + if usermessage.startswith("ho!"): + return "ho there! " + prompt.messages[2].content + if "s3://path" in usermessage: return "alt_title" @@ -55,11 +61,13 @@ class TestEntityExtraction: { "type": "title", "content": {"binary": None, "text": "text1"}, + "text_representation": "text1", "properties": {"coordinates": [(1, 2)], "page_number": 1, "entity": {"author": "Jack Black"}}, }, { "type": "table", "content": {"binary": None, "text": "text2"}, + "text_representation": "text2", "properties": {"page_name": "name", "coordinates": [(1, 2)], "coordinate_system": "pixel"}, }, ], @@ -113,6 +121,25 @@ def test_extract_entity_document_field_string(self, mocker): out_docs = llm_map.run([self.doc]) assert out_docs[0].properties.get("title") == "alt_title" + def test_extract_entity_with_elements_and_string_prompt(self, mocker): + llm = MockLLM() + extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=True, prompt="Hi ") + llm_map = extractor.as_llm_map(None) + outdocs = llm_map.run([self.doc]) + assert outdocs[0].properties.get("title").startswith("Hi") + assert "text1" in outdocs[0].properties.get("title") + assert "text2" in outdocs[0].properties.get("title") + + def test_extract_entity_with_elements_and_messages_prompt(self, mocker): + llm = MockLLM() + prompt_messages = [{"role": "system", "content": "Yo"}, {"role": "user", "content": "ho!"}] + extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=True, prompt=prompt_messages) + llm_map = extractor.as_llm_map(None) + outdocs = llm_map.run([self.doc]) + assert outdocs[0].properties.get("title").startswith("ho there!") + assert "text1" in outdocs[0].properties.get("title") + assert "text2" in outdocs[0].properties.get("title") + def test_extract_entity_with_similarity_sorting(self, mocker): doc_list = [ Document( diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 260db3c12..3cd1839df 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -137,7 +137,18 @@ def as_llm_map( llm = self._llm assert llm is not None, "Could not find an LLM to use" prompt: SycamorePrompt # grr mypy - if self._prompt_template is not None: + if self._prompt is not None: + if isinstance(self._prompt, str): + prompt = ElementListPrompt(user=self._prompt + "\n{elements}") + else: + system = None + if len(self._prompt) > 0 and self._prompt[0]["role"] == "system": + system = self._prompt[0]["content"] + user = [p["content"] for p in self._prompt[1:]] + ["{elements}"] + else: + user = [p["content"] for p in self._prompt] + ["{elements}"] + prompt = ElementListPrompt(system=system, user=user) + elif self._prompt_template is not None: prompt = EntityExtractorFewShotGuidancePrompt prompt = cast(ElementListPrompt, prompt.set(examples=self._prompt_template)) else: