-
Notifications
You must be signed in to change notification settings - Fork 65
Fix SummarizeData so that downstream .materialize operations will work. #1030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,12 +127,18 @@ def _get_text_for_summarize_data( | |
|
||
# consolidates relevant properties to give to LLM | ||
if isinstance(result, DocSet): | ||
for i, doc in enumerate(result.take(NUM_DOCS_GENERATE, **kwargs)): | ||
done = False | ||
# For query result caching in the executor, we need to consume the documents | ||
# so that the materialized data is complete, even if they are not all included | ||
# in the input prompt to the LLM. | ||
for di, doc in enumerate(result.take_all()): | ||
if isinstance(doc, MetadataDocument): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You shouldn't need MetadataDocument check. take_all removes those. |
||
continue | ||
if done: | ||
continue | ||
props_dict = doc.properties.get("entity", {}) | ||
props_dict.update({p: doc.properties[p] for p in set(doc.properties) - set(BASE_PROPS)}) | ||
doc_text = f"Document {i}:\n" | ||
doc_text = f"Document {di}:\n" | ||
for k, v in props_dict.items(): | ||
doc_text += f"{k}: {v}\n" | ||
|
||
|
@@ -153,9 +159,10 @@ def _get_text_for_summarize_data( | |
if total_token_count > max_tokens: | ||
log.warn( | ||
"Unable to add all text from to the LLM summary request due to token limit." | ||
f" Sending text from {i + 1} docs." | ||
f" Sending text from {di + 1} docs." | ||
) | ||
break | ||
done = True | ||
continue | ||
text += doc_text + "\n" | ||
else: | ||
text += str(result_data) + "\n" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
import os | ||
import tempfile | ||
from unittest.mock import patch | ||
from unittest.mock import patch, Mock | ||
|
||
import pytest | ||
|
||
import sycamore | ||
from sycamore.llms import LLM | ||
from sycamore.query.execution.sycamore_executor import SycamoreExecutor | ||
from sycamore.query.logical_plan import LogicalPlan | ||
from sycamore.query.operators.count import Count | ||
from sycamore.query.operators.query_database import QueryDatabase | ||
from sycamore.query.operators.summarize_data import SummarizeData | ||
from sycamore.query.result import SycamoreQueryResult | ||
|
||
|
||
|
@@ -102,3 +104,55 @@ def test_run_plan_with_caching(test_count_docs_query_plan, mock_sycamore_docsetr | |
# No new directories should have been created. | ||
existing_dirs = [os.path.join(temp_dir, x) for x in os.listdir(temp_dir)] | ||
assert set(existing_dirs) == set(cache_dirs) | ||
|
||
|
||
def test_run_summarize_data_plan(mock_sycamore_docsetreader): | ||
|
||
plan = LogicalPlan( | ||
query="Test query", | ||
result_node=1, | ||
nodes={ | ||
0: QueryDatabase(node_id=0, description="Load data", index="test_index"), | ||
1: SummarizeData(node_id=1, description="Summarize data", question="Summarize this data", inputs=[0]), | ||
}, | ||
) | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
with ( | ||
patch("sycamore.reader.DocSetReader", new=mock_sycamore_docsetreader), | ||
patch("sycamore.tests.unit.query.conftest.MOCK_SCAN_NUM_DOCUMENTS", new=1000), | ||
): | ||
context = sycamore.init( | ||
params={ | ||
"default": {"llm": Mock(spec=LLM)}, | ||
"opensearch": { | ||
"os_client_args": { | ||
"hosts": [{"host": "localhost", "port": 9200}], | ||
"http_compress": True, | ||
"http_auth": ("admin", "admin"), | ||
"use_ssl": True, | ||
"verify_certs": False, | ||
"ssl_assert_hostname": False, | ||
"ssl_show_warn": False, | ||
"timeout": 120, | ||
} | ||
}, | ||
} | ||
) | ||
|
||
# First run should populate cache. | ||
executor = SycamoreExecutor(context, cache_dir=temp_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this test fast? If it is, then fine to leave in unit tests, but if it's more than 5-10s, I'd like to get it moved into integration tests. I initially thought this would use ray (which basically guarantees it's slow), but I'm no longer sure. |
||
result = executor.execute(plan, query_id="test_query_id") | ||
assert result.plan == plan | ||
assert result.query_id == "test_query_id" | ||
|
||
# Check that a directory was created for each node. | ||
cache_dirs = [os.path.join(temp_dir, node.cache_key()) for node in plan.nodes.values()] | ||
for cache_dir in cache_dirs: | ||
assert os.path.exists(cache_dir) | ||
|
||
# Check that the materialized data is complete. | ||
assert os.path.exists(os.path.join(cache_dirs[0], "materialize.success")) | ||
assert os.path.exists(os.path.join(cache_dirs[0], "materialize.clean")) | ||
# 1000 docs + 2 'materialize' files | ||
assert len(os.listdir(cache_dirs[0])) == 1000 + 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works at small scale, but will blow up memory at large scale.
Approving as this is NTSB only, but I suggest a TODO or something.