From 04c91669a672aed9d5023770c5aac23bae76a535 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman Date: Sun, 26 Jan 2025 00:47:51 -0800 Subject: [PATCH 1/6] Initial dev --- .../sycamore/connectors/base_writer.py | 4 ++ .../opensearch/opensearch_writer.py | 24 +++++++++++- lib/sycamore/sycamore/writer.py | 37 +++++++++++++++++-- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/base_writer.py b/lib/sycamore/sycamore/connectors/base_writer.py index ba205fdf1..54d751ef4 100644 --- a/lib/sycamore/sycamore/connectors/base_writer.py +++ b/lib/sycamore/sycamore/connectors/base_writer.py @@ -33,6 +33,10 @@ def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"): def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams") -> "BaseDBWriter.TargetParams": pass + @abstractmethod + def reliability_assertor(self, target_params: "BaseDBWriter.TargetParams"): + pass + def close(self): pass diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py index 5ce35b04b..3a7e93e2c 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py @@ -42,6 +42,7 @@ class OpenSearchWriterClientParams(BaseDBWriter.ClientParams): @dataclass class OpenSearchWriterTargetParams(BaseDBWriter.TargetParams): index_name: str + reliability_rewriter: dict[str, Any] = field(default_factory=lambda: {"enabled": False, "num_chunks": 0}) settings: dict[str, Any] = field(default_factory=lambda: {"index.knn": True}) mappings: dict[str, Any] = field( default_factory=lambda: { @@ -196,7 +197,28 @@ def _string_values_to_python_types(obj: Any): assert isinstance(mappings, dict) settings = _string_values_to_python_types(response.get(index_name, {}).get("settings", {})) assert isinstance(settings, dict) - return OpenSearchWriterTargetParams(index_name=index_name, mappings=mappings, settings=settings) + reliability_rewriter = target_params.reliability_rewriter + assert isinstance(reliability_rewriter, dict) + return OpenSearchWriterTargetParams( + index_name=index_name, mappings=mappings, settings=settings, reliability_rewriter=reliability_rewriter + ) + + def reliability_assertor(self, target_params: BaseDBWriter.TargetParams): + assert isinstance( + target_params, OpenSearchWriterTargetParams + ), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}" + if not target_params.reliability_rewriter["enabled"]: + return + log.info("Flushing index...") + self._client.indices.flush(index=target_params.index_name) + log.info("Done flushing index.") + indices = self._client.cat.indices(index=target_params.index_name, format="json") + assert len(indices) == 1, f"Expected 1 index, found {len(indices)}" + num_docs = int(indices[0]["docs.count"]) + assert ( + num_docs == target_params.reliability_rewriter["num_chunks"] + ), f"Expected {target_params.reliability_rewriter['num_chunks']} docs, found {num_docs}" + log.info(f"{num_docs} chunks written in index {target_params.index_name}") @dataclass diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index 85b9d6c7c..b1aac664b 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -20,6 +20,7 @@ from neo4j import Auth from neo4j.auth_management import AuthManager +from sycamore.connectors.opensearch import OpenSearchWriterClient, OpenSearchWriter logger = logging.getLogger(__name__) @@ -45,6 +46,7 @@ def opensearch( index_settings: dict, insert_settings: Optional[dict] = None, execute: bool = True, + reliability_rewriter: bool = False, **kwargs, ) -> Optional["DocSet"]: """Writes the content of the DocSet into the specified OpenSearch index. @@ -101,7 +103,6 @@ def opensearch( """ from sycamore.connectors.opensearch import ( - OpenSearchWriter, OpenSearchWriterClientParams, OpenSearchWriterTargetParams, ) @@ -136,7 +137,22 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]: client_params = OpenSearchWriterClientParams(**os_client_args) target_params: OpenSearchWriterTargetParams - target_params_dict: dict[str, Any] = {"index_name": index_name} + target_params_dict: dict[str, Any] = { + "index_name": index_name, + "reliability_rewriter": {"num_chunks": 0, "enabled": reliability_rewriter}, + } + if reliability_rewriter: + from sycamore.materialize import Materialize + + assert ( + type(self.plan) == Materialize + ), "The first node must be a materialize node for reliability rewriter to work" + logger.info(f"Reliability rewriter enabled {self.plan.children}") + assert not self.plan.children[ + 0 + ], "Pipeline should only have read materialize and write nodes for reliability rewriter to work" + target_params_dict["reliability_rewriter"]["num_chunks"] = DocSet(self.context, self.plan).count() + if insert_settings: target_params_dict["insert_settings"] = insert_settings if index_settings: @@ -146,12 +162,18 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]: os = OpenSearchWriter( self.plan, client_params=client_params, target_params=target_params, name="OsrchWrite", **kwargs ) + client = None + if reliability_rewriter: + client = os.Client.from_client_params(client_params) + if client._client.indices.exists(index=index_name): + logger.info(f"\nWARNING: Deleting existing index {index_name}\n") + client._client.indices.delete(index=index_name) # We will probably want to break this at some point so that write # doesn't execute automatically, and instead you need to say something # like docset.write.opensearch().execute(), allowing sensible writes # to multiple locations and post-write operations. - return self._maybe_execute(os, execute) + return self._maybe_execute(os, execute, client) @requires_modules(["weaviate", "weaviate.collections.classes.config"], extra="weaviate") def weaviate( @@ -800,10 +822,17 @@ def json( self._maybe_execute(node, True) - def _maybe_execute(self, node: Node, execute: bool) -> Optional[DocSet]: + def _maybe_execute( + self, node: Node, execute: bool, client: Optional[OpenSearchWriterClient] = None + ) -> Optional[DocSet]: ds = DocSet(self.context, node) if not execute: return ds ds.execute() + if client is not None: + assert ( + type(node) == OpenSearchWriter + ), "The first node must be an opensearch writer node for reliability rewriter to work" + client.reliability_assertor(node._target_params) return None From fe27600984d567ec47823415486907f4ec3043b2 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman Date: Mon, 27 Jan 2025 14:21:49 -0800 Subject: [PATCH 2/6] Add integration test, simplify API --- .../sycamore/connectors/base_writer.py | 4 -- .../opensearch/opensearch_writer.py | 18 ++++---- .../opensearch/test_opensearch_read.py | 41 +++++++++++++++++++ lib/sycamore/sycamore/writer.py | 16 ++++---- 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/base_writer.py b/lib/sycamore/sycamore/connectors/base_writer.py index 54d751ef4..ba205fdf1 100644 --- a/lib/sycamore/sycamore/connectors/base_writer.py +++ b/lib/sycamore/sycamore/connectors/base_writer.py @@ -33,10 +33,6 @@ def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"): def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams") -> "BaseDBWriter.TargetParams": pass - @abstractmethod - def reliability_assertor(self, target_params: "BaseDBWriter.TargetParams"): - pass - def close(self): pass diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py index 3a7e93e2c..d2f634da5 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py @@ -42,7 +42,7 @@ class OpenSearchWriterClientParams(BaseDBWriter.ClientParams): @dataclass class OpenSearchWriterTargetParams(BaseDBWriter.TargetParams): index_name: str - reliability_rewriter: dict[str, Any] = field(default_factory=lambda: {"enabled": False, "num_chunks": 0}) + reliability_rewriter_chunks_count: int = 0 settings: dict[str, Any] = field(default_factory=lambda: {"index.knn": True}) mappings: dict[str, Any] = field( default_factory=lambda: { @@ -197,18 +197,20 @@ def _string_values_to_python_types(obj: Any): assert isinstance(mappings, dict) settings = _string_values_to_python_types(response.get(index_name, {}).get("settings", {})) assert isinstance(settings, dict) - reliability_rewriter = target_params.reliability_rewriter - assert isinstance(reliability_rewriter, dict) + reliability_rewriter_chunks_count = target_params.reliability_rewriter_chunks_count + assert isinstance(reliability_rewriter_chunks_count, int) return OpenSearchWriterTargetParams( - index_name=index_name, mappings=mappings, settings=settings, reliability_rewriter=reliability_rewriter + index_name=index_name, + mappings=mappings, + settings=settings, + reliability_rewriter_chunks_count=reliability_rewriter_chunks_count, ) + # TODO: Implement this as an abstract method in the base class and remove the NotImplementedError in writer.py def reliability_assertor(self, target_params: BaseDBWriter.TargetParams): assert isinstance( target_params, OpenSearchWriterTargetParams ), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}" - if not target_params.reliability_rewriter["enabled"]: - return log.info("Flushing index...") self._client.indices.flush(index=target_params.index_name) log.info("Done flushing index.") @@ -216,8 +218,8 @@ def reliability_assertor(self, target_params: BaseDBWriter.TargetParams): assert len(indices) == 1, f"Expected 1 index, found {len(indices)}" num_docs = int(indices[0]["docs.count"]) assert ( - num_docs == target_params.reliability_rewriter["num_chunks"] - ), f"Expected {target_params.reliability_rewriter['num_chunks']} docs, found {num_docs}" + num_docs == target_params.reliability_rewriter_chunks_count + ), f"Expected {target_params.reliability_rewriter_chunks_count} docs, found {num_docs}" log.info(f"{num_docs} chunks written in index {target_params.index_name}") diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index fc20585c3..5048a5d3c 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -161,6 +161,47 @@ def test_ingest_and_read(self, setup_index, os_client, exec_mode): os_client.indices.delete(setup_index, ignore_unavailable=True) + def test_write_with_reliability(self, setup_index, os_client, exec_mode): + """ + Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. + """ + with tempfile.TemporaryDirectory() as tmpdir1: + path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") + context = sycamore.init(exec_mode=exec_mode) + + # 2 docs for ray execution + ( + context.read.binary([path, path], binary_format="pdf") + .partition(partitioner=UnstructuredPdfPartitioner()) + .explode() + .materialize(path=tmpdir1) + .execute() + ) + + ( + context.read.materialize(tmpdir1).write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + reliability_rewriter=True, + ) + ) + count = get_doc_count(os_client, setup_index) + + # Delete and recreate the index + ( + context.read.materialize(tmpdir1).write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + reliability_rewriter=True, + ) + ) + re_count = get_doc_count(os_client, setup_index) + + assert count == re_count, f"Expected {count} documents, found {re_count}" + os_client.indices.delete(setup_index, ignore_unavailable=True) + def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir): """ Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index b1aac664b..e58275e45 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -139,19 +139,19 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]: target_params: OpenSearchWriterTargetParams target_params_dict: dict[str, Any] = { "index_name": index_name, - "reliability_rewriter": {"num_chunks": 0, "enabled": reliability_rewriter}, + "reliability_rewriter_chunks_count": 0, } if reliability_rewriter: from sycamore.materialize import Materialize + assert execute, "Reliability rewriter requires execute to be True" assert ( type(self.plan) == Materialize ), "The first node must be a materialize node for reliability rewriter to work" - logger.info(f"Reliability rewriter enabled {self.plan.children}") assert not self.plan.children[ 0 ], "Pipeline should only have read materialize and write nodes for reliability rewriter to work" - target_params_dict["reliability_rewriter"]["num_chunks"] = DocSet(self.context, self.plan).count() + target_params_dict["reliability_rewriter_chunks_count"] = DocSet(self.context, self.plan).count() if insert_settings: target_params_dict["insert_settings"] = insert_settings @@ -166,7 +166,7 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]: if reliability_rewriter: client = os.Client.from_client_params(client_params) if client._client.indices.exists(index=index_name): - logger.info(f"\nWARNING: Deleting existing index {index_name}\n") + logger.info(f"\n\nWARNING WARNING WARNING: Deleting existing index {index_name}\n\n") client._client.indices.delete(index=index_name) # We will probably want to break this at some point so that write @@ -831,8 +831,8 @@ def _maybe_execute( ds.execute() if client is not None: - assert ( - type(node) == OpenSearchWriter - ), "The first node must be an opensearch writer node for reliability rewriter to work" - client.reliability_assertor(node._target_params) + if type(node) == OpenSearchWriter: + client.reliability_assertor(node._target_params) + else: + raise NotImplementedError return None From 382ba824142242390ca6f0d2f9246b1cee507524 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman Date: Wed, 5 Feb 2025 02:00:23 -0800 Subject: [PATCH 3/6] Address comments --- .../sycamore/connectors/base_writer.py | 12 +++ .../opensearch/opensearch_writer.py | 75 ++++++++++++++++--- .../opensearch/test_opensearch_read.py | 2 +- lib/sycamore/sycamore/writer.py | 38 +++------- 4 files changed, 90 insertions(+), 37 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/base_writer.py b/lib/sycamore/sycamore/connectors/base_writer.py index ba205fdf1..641e82711 100644 --- a/lib/sycamore/sycamore/connectors/base_writer.py +++ b/lib/sycamore/sycamore/connectors/base_writer.py @@ -33,6 +33,18 @@ def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"): def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams") -> "BaseDBWriter.TargetParams": pass + def reliability_assertor(self, target_params: "BaseDBWriter.TargetParams"): + """ + Method to verify that all documents were successfully written when reliability mode is enabled. + + Args: + target_params: Parameters describing the target being written to + + Raises: + NotImplementedError: If the implementing class doesn't support reliability checks + """ + raise NotImplementedError("This writer does not support reliability checks") + def close(self): pass diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py index d2f634da5..b260d1896 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py @@ -16,6 +16,9 @@ DEFAULT_RECORD_PROPERTIES, ) from sycamore.utils.import_utils import requires_modules +from sycamore.plan_nodes import Node +from sycamore.docset import DocSet +from sycamore.context import Context if typing.TYPE_CHECKING: from opensearchpy import OpenSearch @@ -42,7 +45,7 @@ class OpenSearchWriterClientParams(BaseDBWriter.ClientParams): @dataclass class OpenSearchWriterTargetParams(BaseDBWriter.TargetParams): index_name: str - reliability_rewriter_chunks_count: int = 0 + _doc_count: int = 0 settings: dict[str, Any] = field(default_factory=lambda: {"index.knn": True}) mappings: dict[str, Any] = field( default_factory=lambda: { @@ -93,6 +96,61 @@ def compatible_with(self, other: "BaseDBWriter.TargetParams") -> bool: other_flat_mappings = dict(flatten_data(other.mappings)) return check_dictionary_compatibility(my_flat_mappings, other_flat_mappings) + @classmethod + def from_write_args( + cls, + index_name: str, + plan: Node, + context: Context, + reliability_rewriter: bool, + execute: bool, + insert_settings: Optional[dict] = None, + index_settings: Optional[dict] = None, + ) -> "OpenSearchWriterTargetParams": + """ + Build OpenSearchWriterTargetParams from write operation arguments. + + Args: + index_name: Name of the OpenSearch index + plan: The execution plan Node + context: The execution Context + reliability_rewriter: Whether to enable reliability rewriter mode + execute: Whether to execute the pipeline immediately + insert_settings: Optional settings for data insertion + index_settings: Optional index configuration settings + + Returns: + OpenSearchWriterTargetParams configured with the provided settings + + Raises: + AssertionError: If reliability_rewriter conditions are not met + """ + target_params_dict: dict[str, Any] = { + "index_name": index_name, + "_doc_count": 0, + } + + if reliability_rewriter: + from sycamore.materialize import Materialize + + assert execute, "Reliability rewriter requires execute to be True" + assert ( + type(plan) == Materialize + ), "The first node must be a materialize node for reliability rewriter to work" + assert not plan.children[ + 0 + ], "Pipeline should only have read materialize and write nodes for reliability rewriter to work" + target_params_dict["_doc_count"] = DocSet(context, plan).count() + + if insert_settings: + target_params_dict["insert_settings"] = insert_settings + + if index_settings: + target_params_dict["settings"] = index_settings.get("body", {}).get("settings", {}) + target_params_dict["mappings"] = index_settings.get("body", {}).get("mappings", {}) + + return cls(**target_params_dict) + class OpenSearchWriterClient(BaseDBWriter.Client): def __init__(self, os_client: "OpenSearch"): @@ -188,6 +246,8 @@ def _string_values_to_python_types(obj: Any): return obj return obj + # TODO: Convert OpenSearchWriterTargetParams to pydantic model + assert isinstance( target_params, OpenSearchWriterTargetParams ), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}" @@ -197,30 +257,27 @@ def _string_values_to_python_types(obj: Any): assert isinstance(mappings, dict) settings = _string_values_to_python_types(response.get(index_name, {}).get("settings", {})) assert isinstance(settings, dict) - reliability_rewriter_chunks_count = target_params.reliability_rewriter_chunks_count - assert isinstance(reliability_rewriter_chunks_count, int) + _doc_count = target_params._doc_count + assert isinstance(_doc_count, int) return OpenSearchWriterTargetParams( index_name=index_name, mappings=mappings, settings=settings, - reliability_rewriter_chunks_count=reliability_rewriter_chunks_count, + _doc_count=_doc_count, ) - # TODO: Implement this as an abstract method in the base class and remove the NotImplementedError in writer.py def reliability_assertor(self, target_params: BaseDBWriter.TargetParams): assert isinstance( target_params, OpenSearchWriterTargetParams ), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}" log.info("Flushing index...") - self._client.indices.flush(index=target_params.index_name) + self._client.indices.flush(index=target_params.index_name, params={"timeout": 300}) log.info("Done flushing index.") indices = self._client.cat.indices(index=target_params.index_name, format="json") assert len(indices) == 1, f"Expected 1 index, found {len(indices)}" num_docs = int(indices[0]["docs.count"]) - assert ( - num_docs == target_params.reliability_rewriter_chunks_count - ), f"Expected {target_params.reliability_rewriter_chunks_count} docs, found {num_docs}" log.info(f"{num_docs} chunks written in index {target_params.index_name}") + assert num_docs == target_params._doc_count, f"Expected {target_params._doc_count} docs, found {num_docs}" @dataclass diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index 5048a5d3c..47d82716c 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -200,7 +200,7 @@ def test_write_with_reliability(self, setup_index, os_client, exec_mode): re_count = get_doc_count(os_client, setup_index) assert count == re_count, f"Expected {count} documents, found {re_count}" - os_client.indices.delete(setup_index, ignore_unavailable=True) + os_client.indices.delete(setup_index) def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir): """ diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index e58275e45..4a6e79aac 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -136,29 +136,15 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]: os_client_args["hosts"] = _convert_to_host_port_list(hosts) client_params = OpenSearchWriterClientParams(**os_client_args) - target_params: OpenSearchWriterTargetParams - target_params_dict: dict[str, Any] = { - "index_name": index_name, - "reliability_rewriter_chunks_count": 0, - } - if reliability_rewriter: - from sycamore.materialize import Materialize - - assert execute, "Reliability rewriter requires execute to be True" - assert ( - type(self.plan) == Materialize - ), "The first node must be a materialize node for reliability rewriter to work" - assert not self.plan.children[ - 0 - ], "Pipeline should only have read materialize and write nodes for reliability rewriter to work" - target_params_dict["reliability_rewriter_chunks_count"] = DocSet(self.context, self.plan).count() - - if insert_settings: - target_params_dict["insert_settings"] = insert_settings - if index_settings: - target_params_dict["settings"] = index_settings.get("body", {}).get("settings", {}) - target_params_dict["mappings"] = index_settings.get("body", {}).get("mappings", {}) - target_params = OpenSearchWriterTargetParams(**target_params_dict) + target_params = OpenSearchWriterTargetParams.from_write_args( + index_name=index_name, + plan=self.plan, + context=self.context, + reliability_rewriter=reliability_rewriter, + execute=execute, + insert_settings=insert_settings, + index_settings=index_settings, + ) os = OpenSearchWriter( self.plan, client_params=client_params, target_params=target_params, name="OsrchWrite", **kwargs ) @@ -831,8 +817,6 @@ def _maybe_execute( ds.execute() if client is not None: - if type(node) == OpenSearchWriter: - client.reliability_assertor(node._target_params) - else: - raise NotImplementedError + assert type(node) == OpenSearchWriter + client.reliability_assertor(node._target_params) return None From 0b50bb714587d53a37c871ed0cf004496ce6c205 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman Date: Wed, 5 Feb 2025 12:25:53 -0800 Subject: [PATCH 4/6] Fix test bug, address comment --- lib/sycamore/sycamore/connectors/base_writer.py | 3 --- .../connectors/opensearch/test_opensearch_read.py | 15 ++++++++++++--- lib/sycamore/sycamore/writer.py | 5 +++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/base_writer.py b/lib/sycamore/sycamore/connectors/base_writer.py index 641e82711..56c026000 100644 --- a/lib/sycamore/sycamore/connectors/base_writer.py +++ b/lib/sycamore/sycamore/connectors/base_writer.py @@ -37,9 +37,6 @@ def reliability_assertor(self, target_params: "BaseDBWriter.TargetParams"): """ Method to verify that all documents were successfully written when reliability mode is enabled. - Args: - target_params: Parameters describing the target being written to - Raises: NotImplementedError: If the implementing class doesn't support reliability checks """ diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index 4372f0ce0..15845ba73 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -250,7 +250,8 @@ def remove_reconstruct_doc_property(doc: Document): def test_write_with_reliability(self, setup_index, os_client, exec_mode): """ - Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. + Validates that when materialized pickle outputs are deleted, the index is rewritten + with the correct (reduced) number of chunks. """ with tempfile.TemporaryDirectory() as tmpdir1: path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") @@ -273,9 +274,15 @@ def test_write_with_reliability(self, setup_index, os_client, exec_mode): reliability_rewriter=True, ) ) + os_client.indices.refresh(setup_index) count = get_doc_count(os_client, setup_index) - # Delete and recreate the index + # Delete 1 pickle file to make sure reliability rewriter works + pickle_files = [f for f in os.listdir(tmpdir1) if f.endswith(".pickle")] + assert pickle_files, "No pickle files found in materialized directory" + os.remove(os.path.join(tmpdir1, pickle_files[0])) + + # Delete and recreate the index - should have fewer chunks ( context.read.materialize(tmpdir1).write.opensearch( os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, @@ -284,9 +291,11 @@ def test_write_with_reliability(self, setup_index, os_client, exec_mode): reliability_rewriter=True, ) ) + os_client.indices.refresh(setup_index) re_count = get_doc_count(os_client, setup_index) - assert count == re_count, f"Expected {count} documents, found {re_count}" + # Verify document count is reduced + assert count - 1 == re_count, f"Expected {count} documents, found {re_count}" os_client.indices.delete(setup_index) def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir): diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index 8626f03e8..b099e9114 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -865,7 +865,8 @@ def _maybe_execute( return ds ds.execute() + if client is not None: - assert type(node) == OpenSearchWriter - client.reliability_assertor(node._target_params) + if type(node) == OpenSearchWriter: + client.reliability_assertor(node._target_params) return None From 139bcbd22936d2e3e2adb29c7e27591dc4c4d351 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman Date: Wed, 5 Feb 2025 12:38:48 -0800 Subject: [PATCH 5/6] Change node check to BaseDBWriter --- lib/sycamore/sycamore/writer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index b099e9114..9f87a0ff9 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -22,7 +22,8 @@ from neo4j import Auth from neo4j.auth_management import AuthManager -from sycamore.connectors.opensearch import OpenSearchWriterClient, OpenSearchWriter +from sycamore.connectors.opensearch import OpenSearchWriterClient +from sycamore.connectors.base_writer import BaseDBWriter logger = logging.getLogger(__name__) @@ -107,6 +108,7 @@ def opensearch( from sycamore.connectors.opensearch import ( OpenSearchWriterClientParams, OpenSearchWriterTargetParams, + OpenSearchWriter, ) from typing import Any import copy @@ -867,6 +869,6 @@ def _maybe_execute( ds.execute() if client is not None: - if type(node) == OpenSearchWriter: + if type(node) == BaseDBWriter: client.reliability_assertor(node._target_params) return None From 7578d57c996c4f84f9842547e25f1b26d096b0a9 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman Date: Wed, 5 Feb 2025 16:33:31 -0800 Subject: [PATCH 6/6] BaseDBWriter client signature --- lib/sycamore/sycamore/writer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index 9f87a0ff9..39a7a2e6d 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -22,7 +22,6 @@ from neo4j import Auth from neo4j.auth_management import AuthManager -from sycamore.connectors.opensearch import OpenSearchWriterClient from sycamore.connectors.base_writer import BaseDBWriter logger = logging.getLogger(__name__) @@ -860,7 +859,7 @@ def aryn( return self._maybe_execute(ds, True) def _maybe_execute( - self, node: Node, execute: bool, client: Optional[OpenSearchWriterClient] = None + self, node: Node, execute: bool, client: Optional[BaseDBWriter.Client] = None ) -> Optional[DocSet]: ds = DocSet(self.context, node) if not execute: