+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/sycamore/sycamore/connectors/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Callable, Iterator, Union, Iterable, Tuple, Any, Dict
from sycamore.data import Document
from sycamore.data import Document, Element
import json
import string
import random
Expand Down Expand Up @@ -69,7 +69,10 @@ def compare_docs(doc1: Document, doc2: Document):
assert math.isclose(num1, num2, rel_tol=1e-5, abs_tol=1e-5)
except (ValueError, TypeError):
# If conversion to float fails, do direct comparison
assert item1 == item2
if isinstance(item1, Element):
check_dictionary_compatibility(item1.data, item2.data)
else:
assert item1 == item2
elif isinstance(filtered_doc1[key], dict) and isinstance(filtered_doc2.get(key), dict):
assert check_dictionary_compatibility(filtered_doc1[key], filtered_doc2.get(key))
else:
Expand Down
31 changes: 23 additions & 8 deletions lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
}
)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
doc.properties["score"] = data["_score"]
doc.properties["score"] = (
data["_score"] if doc.properties.get("score") is None else doc.properties["score"]
)
result.append(doc)
else:
assert (
Expand Down Expand Up @@ -326,7 +328,6 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
break

for hit in hits:
# print(f"Element index: {hit['_source']['properties']['_element_index']}")
if (
"parent_id" in hit["_source"]
and hit["_source"]["parent_id"] is not None
Expand All @@ -346,7 +347,6 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
client.close()

ret = [doc["_source"] for doc in results]
# logging.info(f"Sample: {ret[:5]}")
return ret

@timetrace("OpenSearchReader")
Expand Down Expand Up @@ -402,14 +402,14 @@ def _to_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
return [{"doc": doc.serialize()} for doc in docs]

def map_reduce_parent_id(self, group: pd.DataFrame) -> pd.DataFrame:
# logger.info(f"Applying on {group} ({type(group)}) ...")
parent_ids = set()
for row in group["parent_id"]:
# logging.info(f"Row: {row}: {type(row)}")
parent_ids.add(row)
if row not in parent_ids:
parent_ids.add(row)

logger.info(f"Parent IDs: {parent_ids}")
return pd.DataFrame([{"_source": {"doc_id": parent_id, "parent_id": parent_id}} for parent_id in parent_ids])
# logger.info(f"Parent IDs: {parent_ids}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you delete

return pd.DataFrame([{"_source": {"doc_id": parent_id}} for parent_id in parent_ids])

def reconstruct(self, doc: dict[str, Any]) -> dict[str, Any]:
# logging.info(f"Applying on {doc} ({type(doc)}) ...")
Expand All @@ -419,8 +419,21 @@ def reconstruct(self, doc: dict[str, Any]) -> dict[str, Any]:
raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n")

os_client = client._client
records = OpenSearchReaderQueryResponse([doc], os_client)
# doc["_source"]["properties"] = json.loads(doc["_source"]["properties"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

doc_id = doc["_source"]["doc_id"]
assert isinstance(
self._query_params, OpenSearchReaderQueryParams
), f"Wrong kind of query parameters found: {self._query_params}"

parent_doc = os_client.get(
index=self._query_params.index_name, id=doc_id
) # , _source_includes=["properties"])["_source"]["properties"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments

records = OpenSearchReaderQueryResponse([parent_doc], os_client)
docs = records.to_docs(query_params=self._query_params)

# properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DOCUMENT_RECONSTRUCTION_PARENT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

# docs[0].update(parent_doc["_source"]) # json.loads(doc["_source"]["properties"])
# docs[0]["properties"] = json.loads(doc["_source"]["properties"])
return {"doc": docs[0].serialize()}

def execute(self, **kwargs) -> "Dataset":
Expand Down Expand Up @@ -570,6 +583,8 @@ def _execute_pit(self, **kwargs) -> "Dataset":
ds.flat_map(self._to_parent_doc, **self.resource_args)
.groupby("parent_id")
.map_groups(self.map_reduce_parent_id)
# TODO use map_batches to improve 'get_all_elements_for_doc_ids' performance
# by making fewer requests to OpenSearch
.map(self.reconstruct)
)
else:
Expand Down
10 changes: 7 additions & 3 deletions lib/sycamore/sycamore/tests/integration/connectors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from sycamore.connectors.common import compare_docs


def compare_connector_docs(gt_docs: list[Document], returned_docs: list[Document], parent_offset: int = 0):
def compare_connector_docs(
gt_docs: list[Document], returned_docs: list[Document], parent_offset: int = 0, doc_reconstruct: bool = False
):
assert len(gt_docs) == (len(returned_docs) + parent_offset)
for doc in gt_docs:
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY

if not doc_reconstruct:
for doc in gt_docs:
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY

gt_dict = {doc.doc_id: doc for doc in gt_docs}
returned_dict = {doc.doc_id: doc for doc in returned_docs}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
from sycamore import EXEC_LOCAL, ExecMode
from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.data import Document
from sycamore.data.document import DocumentPropertyTypes
from sycamore.llms import OpenAI, OpenAIModels
from sycamore.tests.integration.connectors.common import compare_connector_docs
from sycamore.tests.config import TEST_DIR
from sycamore.transforms.partition import UnstructuredPdfPartitioner
from sycamore.transforms.partition import ArynPartitioner

from sycamore.transforms.extract_entity import OpenAIEntityExtractor

OS_ADMIN_PASSWORD = os.getenv("OS_ADMIN_PASSWORD", "admin")
TEST_CACHE_DIR = "/tmp/test_cache_dir"
ARYN_API_KEY = os.getenv("ARYN_API_KEY")


@pytest.fixture(scope="class")
Expand Down Expand Up @@ -49,7 +54,7 @@ def setup_index_large(os_client):

(
context.read.binary(path, binary_format="pdf")
.partition(partitioner=UnstructuredPdfPartitioner())
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY))
.explode()
.write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
Expand All @@ -73,6 +78,19 @@ def get_doc_count(os_client, index_name: str, query: Optional[Dict[str, Any]] =
return res["count"]


"""
class MockLLM(LLM):
def __init__(self):
super().__init__(model_name="mock_model")

def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
return str(uuid.uuid4())

def is_chat_mode(self):
return True
"""


class TestOpenSearchRead:
INDEX_SETTINGS = {
"body": {
Expand Down Expand Up @@ -112,7 +130,7 @@ def test_ingest_and_read(self, setup_index, os_client, exec_mode):
context = sycamore.init(exec_mode=exec_mode)
original_docs = (
context.read.binary(path, binary_format="pdf")
.partition(partitioner=UnstructuredPdfPartitioner())
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY))
.explode()
.write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
Expand Down Expand Up @@ -162,6 +180,74 @@ def test_ingest_and_read(self, setup_index, os_client, exec_mode):

os_client.indices.delete(setup_index, ignore_unavailable=True)

def test_doc_reconstruct(self, setup_index, os_client):
with tempfile.TemporaryDirectory() as materialized_dir:
self._test_doc_reconstruct(setup_index, os_client, materialized_dir)

def _test_doc_reconstruct(self, setup_index, os_client, materialized_dir):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
"""

print(f"Using materialized dir: {materialized_dir}")

def doc_reconstructor(doc_id: str) -> Document:
import pickle

data = pickle.load(open(f"{materialized_dir}/{setup_index}-{doc_id}", "rb"))
return Document(**data)

def doc_to_name(doc: Document, bin: bytes) -> str:
return f"{setup_index}-{doc.doc_id}"

path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf")
context = sycamore.init(exec_mode=ExecMode.RAY)
llm = OpenAI(OpenAIModels.GPT_4O_MINI)
extractor = OpenAIEntityExtractor("title", llm=llm)
original_docs = (
context.read.binary(path, binary_format="pdf")
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY))
.extract_entity(extractor)
.materialize(path={"root": materialized_dir, "name": doc_to_name})
.explode()
.write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index,
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
execute=False,
)
.take_all()
)

os_client.indices.refresh(setup_index)

expected_count = len(original_docs)
actual_count = get_doc_count(os_client, setup_index)
# refresh should have made all ingested docs immediately available for search
assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}"

retrieved_docs_reconstructed = context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index,
reconstruct_document=True,
).take_all()

assert len(retrieved_docs_reconstructed) == 1
retrieved_sorted = sorted(retrieved_docs_reconstructed, key=lambda d: d.doc_id)

def remove_reconstruct_doc_property(doc: Document):
for element in doc.data["elements"]:
element["properties"].pop(DocumentPropertyTypes.SOURCE)

for doc in retrieved_sorted:
remove_reconstruct_doc_property(doc)

from_materialized = [doc_reconstructor(doc.doc_id) for doc in retrieved_sorted]
compare_connector_docs(from_materialized, retrieved_sorted)

# Clean up
os_client.indices.delete(setup_index, ignore_unavailable=True)

def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
Expand Down Expand Up @@ -559,7 +645,7 @@ def _test_bulk_load(self, setup_index_large, os_client):
while doc_count < 20000:
(
context.read.binary(path, binary_format="pdf")
.partition(partitioner=UnstructuredPdfPartitioner())
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY))
# .materialize(path={"root": TEST_CACHE_DIR, "name": self.doc_to_name})
.explode()
.write.opensearch(
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载