+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/sycamore/sycamore/connectors/base_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"):
def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams") -> "BaseDBWriter.TargetParams":
pass

def reliability_assertor(self, target_params: "BaseDBWriter.TargetParams"):
"""
Method to verify that all documents were successfully written when reliability mode is enabled.

Raises:
NotImplementedError: If the implementing class doesn't support reliability checks
"""
raise NotImplementedError("This writer does not support reliability checks")

def close(self):
pass

Expand Down
83 changes: 82 additions & 1 deletion lib/sycamore/sycamore/connectors/opensearch/opensearch_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
DEFAULT_RECORD_PROPERTIES,
)
from sycamore.utils.import_utils import requires_modules
from sycamore.plan_nodes import Node
from sycamore.docset import DocSet
from sycamore.context import Context

if typing.TYPE_CHECKING:
from opensearchpy import OpenSearch
Expand All @@ -42,6 +45,7 @@ class OpenSearchWriterClientParams(BaseDBWriter.ClientParams):
@dataclass
class OpenSearchWriterTargetParams(BaseDBWriter.TargetParams):
index_name: str
_doc_count: int = 0
settings: dict[str, Any] = field(default_factory=lambda: {"index.knn": True})
mappings: dict[str, Any] = field(
default_factory=lambda: {
Expand Down Expand Up @@ -92,6 +96,61 @@ def compatible_with(self, other: "BaseDBWriter.TargetParams") -> bool:
other_flat_mappings = dict(flatten_data(other.mappings))
return check_dictionary_compatibility(my_flat_mappings, other_flat_mappings)

@classmethod
def from_write_args(
cls,
index_name: str,
plan: Node,
context: Context,
reliability_rewriter: bool,
execute: bool,
insert_settings: Optional[dict] = None,
index_settings: Optional[dict] = None,
) -> "OpenSearchWriterTargetParams":
"""
Build OpenSearchWriterTargetParams from write operation arguments.

Args:
index_name: Name of the OpenSearch index
plan: The execution plan Node
context: The execution Context
reliability_rewriter: Whether to enable reliability rewriter mode
execute: Whether to execute the pipeline immediately
insert_settings: Optional settings for data insertion
index_settings: Optional index configuration settings

Returns:
OpenSearchWriterTargetParams configured with the provided settings

Raises:
AssertionError: If reliability_rewriter conditions are not met
"""
target_params_dict: dict[str, Any] = {
"index_name": index_name,
"_doc_count": 0,
}

if reliability_rewriter:
from sycamore.materialize import Materialize

assert execute, "Reliability rewriter requires execute to be True"
assert (
type(plan) == Materialize
), "The first node must be a materialize node for reliability rewriter to work"
assert not plan.children[
0
], "Pipeline should only have read materialize and write nodes for reliability rewriter to work"
target_params_dict["_doc_count"] = DocSet(context, plan).count()

if insert_settings:
target_params_dict["insert_settings"] = insert_settings

if index_settings:
target_params_dict["settings"] = index_settings.get("body", {}).get("settings", {})
target_params_dict["mappings"] = index_settings.get("body", {}).get("mappings", {})

return cls(**target_params_dict)


class OpenSearchWriterClient(BaseDBWriter.Client):
def __init__(self, os_client: "OpenSearch"):
Expand Down Expand Up @@ -187,6 +246,8 @@ def _string_values_to_python_types(obj: Any):
return obj
return obj

# TODO: Convert OpenSearchWriterTargetParams to pydantic model

assert isinstance(
target_params, OpenSearchWriterTargetParams
), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}"
Expand All @@ -196,7 +257,27 @@ def _string_values_to_python_types(obj: Any):
assert isinstance(mappings, dict)
settings = _string_values_to_python_types(response.get(index_name, {}).get("settings", {}))
assert isinstance(settings, dict)
return OpenSearchWriterTargetParams(index_name=index_name, mappings=mappings, settings=settings)
_doc_count = target_params._doc_count
assert isinstance(_doc_count, int)
return OpenSearchWriterTargetParams(
index_name=index_name,
mappings=mappings,
settings=settings,
_doc_count=_doc_count,
)

def reliability_assertor(self, target_params: BaseDBWriter.TargetParams):
assert isinstance(
target_params, OpenSearchWriterTargetParams
), f"Provided target_params was not of type OpenSearchWriterTargetParams:\n{target_params}"
log.info("Flushing index...")
self._client.indices.flush(index=target_params.index_name, params={"timeout": 300})
log.info("Done flushing index.")
indices = self._client.cat.indices(index=target_params.index_name, format="json")
assert len(indices) == 1, f"Expected 1 index, found {len(indices)}"
num_docs = int(indices[0]["docs.count"])
log.info(f"{num_docs} chunks written in index {target_params.index_name}")
assert num_docs == target_params._doc_count, f"Expected {target_params._doc_count} docs, found {num_docs}"


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,56 @@ def remove_reconstruct_doc_property(doc: Document):
# Clean up
os_client.indices.delete(setup_index, ignore_unavailable=True)

def test_write_with_reliability(self, setup_index, os_client, exec_mode):
"""
Validates that when materialized pickle outputs are deleted, the index is rewritten
with the correct (reduced) number of chunks.
"""
with tempfile.TemporaryDirectory() as tmpdir1:
path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf")
context = sycamore.init(exec_mode=exec_mode)

# 2 docs for ray execution
(
context.read.binary([path, path], binary_format="pdf")
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY))
.explode()
.materialize(path=tmpdir1)
.execute()
)

(
context.read.materialize(tmpdir1).write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index,
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
reliability_rewriter=True,
)
)
os_client.indices.refresh(setup_index)
count = get_doc_count(os_client, setup_index)

# Delete 1 pickle file to make sure reliability rewriter works
pickle_files = [f for f in os.listdir(tmpdir1) if f.endswith(".pickle")]
assert pickle_files, "No pickle files found in materialized directory"
os.remove(os.path.join(tmpdir1, pickle_files[0]))

# Delete and recreate the index - should have fewer chunks
(
context.read.materialize(tmpdir1).write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index,
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
reliability_rewriter=True,
)
)
os_client.indices.refresh(setup_index)
re_count = get_doc_count(os_client, setup_index)

# Verify document count is reduced
assert count - 1 == re_count, f"Expected {count} documents, found {re_count}"
os_client.indices.delete(setup_index)

def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
Expand Down
37 changes: 26 additions & 11 deletions lib/sycamore/sycamore/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from neo4j import Auth
from neo4j.auth_management import AuthManager

from sycamore.connectors.base_writer import BaseDBWriter

logger = logging.getLogger(__name__)

Expand All @@ -47,6 +48,7 @@ def opensearch(
index_settings: dict,
insert_settings: Optional[dict] = None,
execute: bool = True,
reliability_rewriter: bool = False,
**kwargs,
) -> Optional["DocSet"]:
"""Writes the content of the DocSet into the specified OpenSearch index.
Expand Down Expand Up @@ -103,9 +105,9 @@ def opensearch(
"""

from sycamore.connectors.opensearch import (
OpenSearchWriter,
OpenSearchWriterClientParams,
OpenSearchWriterTargetParams,
OpenSearchWriter,
)
from typing import Any
import copy
Expand Down Expand Up @@ -137,23 +139,30 @@ def _convert_to_host_port_list(hostlist: Any) -> list[HostAndPort]:
os_client_args["hosts"] = _convert_to_host_port_list(hosts)
client_params = OpenSearchWriterClientParams(**os_client_args)

target_params: OpenSearchWriterTargetParams
target_params_dict: dict[str, Any] = {"index_name": index_name}
if insert_settings:
target_params_dict["insert_settings"] = insert_settings
if index_settings:
target_params_dict["settings"] = index_settings.get("body", {}).get("settings", {})
target_params_dict["mappings"] = index_settings.get("body", {}).get("mappings", {})
target_params = OpenSearchWriterTargetParams(**target_params_dict)
target_params = OpenSearchWriterTargetParams.from_write_args(
index_name=index_name,
plan=self.plan,
context=self.context,
reliability_rewriter=reliability_rewriter,
execute=execute,
insert_settings=insert_settings,
index_settings=index_settings,
)
os = OpenSearchWriter(
self.plan, client_params=client_params, target_params=target_params, name="OsrchWrite", **kwargs
)
client = None
if reliability_rewriter:
client = os.Client.from_client_params(client_params)
if client._client.indices.exists(index=index_name):
logger.info(f"\n\nWARNING WARNING WARNING: Deleting existing index {index_name}\n\n")
client._client.indices.delete(index=index_name)

# We will probably want to break this at some point so that write
# doesn't execute automatically, and instead you need to say something
# like docset.write.opensearch().execute(), allowing sensible writes
# to multiple locations and post-write operations.
return self._maybe_execute(os, execute)
return self._maybe_execute(os, execute, client)

@requires_modules(["weaviate", "weaviate.collections.classes.config"], extra="weaviate")
def weaviate(
Expand Down Expand Up @@ -849,10 +858,16 @@ def aryn(

return self._maybe_execute(ds, True)

def _maybe_execute(self, node: Node, execute: bool) -> Optional[DocSet]:
def _maybe_execute(
self, node: Node, execute: bool, client: Optional[BaseDBWriter.Client] = None
) -> Optional[DocSet]:
ds = DocSet(self.context, node)
if not execute:
return ds

ds.execute()

if client is not None:
if type(node) == BaseDBWriter:
client.reliability_assertor(node._target_params)
return None
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载