+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 90 additions & 33 deletions lib/sycamore/sycamore/connectors/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from dataclasses import dataclass
from typing import Callable, Iterator, Union, Iterable, Tuple, Any, Dict
from sycamore.data import Document
import json
import string
import random
import math
import numpy as np


@dataclass
Expand All @@ -29,25 +32,42 @@ def generate_random_string(length=8):
return "".join(random.choice(characters) for _ in range(length))


def filter_doc(obj, include):
return {k: v for k, v in obj.__dict__.items() if k in include}
def filter_doc(doc: Document, include):
return {k: v for k, v in doc.items() if k in include}


def check_dictionary_compatibility(dict1: dict[Any, Any], dict2: dict[Any, Any], ignore: list[str] = []):
for k in dict1:
if ignore and any(val in k for val in ignore):
if not dict1.get(k) or (ignore and any(val in k for val in ignore)):
continue
if k not in dict2:
return False
if dict1[k] != dict2[k]:
if dict1[k] != dict2[k] and (dict1[k] or dict2[k]):
return False
return True


def compare_docs(doc1, doc2):
def compare_docs(doc1: Document, doc2: Document):
filtered_doc1 = filter_doc(doc1, DEFAULT_RECORD_PROPERTIES.keys())
filtered_doc2 = filter_doc(doc2, DEFAULT_RECORD_PROPERTIES.keys())
return filtered_doc1 == filtered_doc2
for key in filtered_doc1:
if isinstance(filtered_doc1[key], (list, np.ndarray)) or isinstance(filtered_doc2.get(key), (list, np.ndarray)):
assert len(filtered_doc1[key]) == len(filtered_doc2[key])
for item1, item2 in zip(filtered_doc1[key], filtered_doc2[key]):
try:
# Convert items to float for numerical comparison
num1 = float(item1)
num2 = float(item2)
# Check if numbers are close within tolerance
assert math.isclose(num1, num2, rel_tol=1e-5, abs_tol=1e-5)
except (ValueError, TypeError):
# If conversion to float fails, do direct comparison
assert item1 == item2
elif isinstance(filtered_doc1[key], dict) and isinstance(filtered_doc2.get(key), dict):
assert check_dictionary_compatibility(filtered_doc1[key], filtered_doc2.get(key))
else:
assert filtered_doc1[key] == filtered_doc2.get(key)
return True


def _add_key_to_prefix(prefix, key, separator="."):
Expand Down Expand Up @@ -88,49 +108,86 @@ def flatten_data(
return items


def unflatten_data(data: dict[str, Any], separator: str = ".") -> dict[Any, Any]:
result: dict[Any, Any] = {}
def unflatten_data(data: dict[Any, Any], separator: str = ".") -> dict[Any, Any]:
"""
Unflattens a dictionary with keys that contain separators into a nested dictionary. The separator can be escaped,
and if there are integer keys in the path, the result will be a list instead of a dictionary.
"""

def parse_key(key: str) -> list:
# Handle escaped separator
def split_key(key: str, separator: str = ".") -> list[str]:
"""
Splits the key by separator (which can be multiple characters), respecting escaped separators.
"""
parts = []
current = ""
escape = False
for char in key:
if escape:
if char == separator:
current += separator
i = 0
while i < len(key):
if key[i] == "\\":
# Escape character
if i + 1 < len(key):
current += key[i + 1]
i += 2
else:
current += "\\" + char
escape = False
elif char == "\\":
escape = True
elif char == separator:
# Trailing backslash, treat it as literal backslash
current += "\\"
i += 1
elif key[i : i + len(separator)] == separator:
# Found separator
parts.append(current)
current = ""
i += len(separator)
else:
current += char
current += key[i]
i += 1
parts.append(current)
return parts

for key, value in data.items():
parts = parse_key(key)
result: dict[Any, Any] = {}
for flat_key, value in data.items():
parts = split_key(flat_key, separator)
current = result
for i, part in enumerate(parts):
part_key: Union[str, int] = int(part) if part.isdigit() else part
# Determine whether the key part is an integer (for list indices)
key: Union[str, int]
try:
key = int(part)
except ValueError:
key = part
Comment on lines +152 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could prob do if key.isnumeric()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part will also be a str though? So I would have to cast it to make it work I think.

Copy link
Collaborator

@HenryL27 HenryL27 Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnumeric is a string method to tell you if all the characters are digits. so you could do

if part.isnumeric():
    key = int(part)
else:
    key = part

I just like to avoid extraneous try/catches where possible. maybe that's silly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that makes sense. I can do that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait isdigit does the same thing. Why did we change this from the prev implementation?

Copy link
Contributor Author

@karanataryn karanataryn Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because both isdigit and (I just realized) isnumeric don't handle decimals or negative numbers. I'll stick with the current implementation for now.


is_last = i == len(parts) - 1

if is_last:
current[part_key] = value
else:
next_part_is_digit = parts[i + 1].isdigit() if i + 1 < len(parts) else False
if part_key not in current:
current[part_key] = [] if next_part_is_digit else {}
current = current[part_key]
# If current is a list and the next part is a digit, ensure proper length
# Set the value at the deepest level
if isinstance(current, list):
if next_part_is_digit and len(current) <= int(parts[i + 1]):
current.extend("" for _ in range(int(parts[i + 1]) - len(current) + 1))
# Ensure the list is big enough
while len(current) <= key:
current.append("")
current[key] = value
else:
current[key] = value
else:
# Determine the type of the next part
next_part = parts[i + 1]

# Check if the next part is an index (integer)
try:
int(next_part)
next_is_index = True
except ValueError:
next_is_index = False

# Initialize containers as needed
if isinstance(current, list):
# Ensure the list is big enough
while len(current) <= key:
current.append("")
if current[key] == "" or current[key] is None:
current[key] = [] if next_is_index else {}
current = current[key]
else:
if key not in current:
current[key] = [] if next_is_index else {}
current = current[key]
return result


Expand Down
5 changes: 4 additions & 1 deletion lib/sycamore/sycamore/connectors/duckdb/duckdb_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sycamore.data import Document
from sycamore.data.document import DocumentPropertyTypes, DocumentSource

from dataclasses import dataclass
import typing
Expand Down Expand Up @@ -73,7 +74,9 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
object["properties"] = convert_from_str_dict(val)
if isinstance(object["embedding"], float):
object["embedding"] = []
result.append(Document(object))
doc = Document(object)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
result.append(doc)
return result


Expand Down
8 changes: 7 additions & 1 deletion lib/sycamore/sycamore/connectors/duckdb/duckdb_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class DuckDBWriterTargetParams(BaseDBWriter.TargetParams):
schema: Dict[str, str] = field(
default_factory=lambda: {
"doc_id": "VARCHAR",
"parent_id": "VARCHAR",
"embedding": "FLOAT",
"properties": "MAP(VARCHAR, VARCHAR)",
"text_representation": "VARCHAR",
Expand Down Expand Up @@ -73,10 +74,11 @@ def write_many_records(self, records: list[BaseDBWriter.Record], target_params:
), f"Wrong kind of target parameters found: {target_params}"
dict_params = asdict(target_params)
N = target_params.batch_size * 1024 # Around 1 MB
headers = ["doc_id", "embedding", "properties", "text_representation", "bbox", "shingles", "type"]
headers = ["doc_id", "parent_id", "embedding", "properties", "text_representation", "bbox", "shingles", "type"]
schema = pa.schema(
[
("doc_id", pa.string()),
("parent_id", pa.string()),
("embedding", pa.list_(pa.float32())),
("properties", pa.map_(pa.string(), pa.string())),
("text_representation", pa.string()),
Expand All @@ -100,6 +102,7 @@ def write_batch(batch_data: dict):
for r in records:
# Append the new data to the batch
batch_data["doc_id"].append(r.doc_id)
batch_data["parent_id"].append(r.parent_id)
batch_data["embedding"].append(r.embedding)
batch_data["properties"].append(convert_to_str_dict(r.properties) if r.properties else {})
batch_data["text_representation"].append(r.text_representation)
Expand All @@ -126,6 +129,7 @@ def create_target_idempotent(self, target_params: BaseDBWriter.TargetParams):
embedding_size = schema.get("embedding") + "[" + str(dict_params.get("dimensions")) + "]"
client.sql(
f"""CREATE TABLE {dict_params.get('table_name')} (doc_id {schema.get('doc_id')},
parent_id {schema.get('parent_id')},
embedding {embedding_size}, properties {schema.get('properties')},
text_representation {schema.get('text_representation')}, bbox {schema.get('bbox')},
shingles {schema.get('shingles')}, type {schema.get('type')})"""
Expand Down Expand Up @@ -168,6 +172,7 @@ def get_existing_target_params(self, target_params: BaseDBWriter.TargetParams) -
@dataclass
class DuckDBDocumentRecord(BaseDBWriter.Record):
doc_id: str
parent_id: Optional[str] = None
embedding: Optional[list[float]] = None
properties: Optional[dict[str, Any]] = None
text_representation: Optional[str] = None
Expand All @@ -183,6 +188,7 @@ def from_doc(cls, document: Document, target_params: BaseDBWriter.TargetParams)
raise ValueError(f"Cannot write documents without a doc_id. Found {document}")
return DuckDBDocumentRecord(
doc_id=doc_id,
parent_id=document.parent_id,
properties=document.properties,
type=document.type,
text_representation=document.text_representation,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sycamore.data import Document
from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data.document import DocumentPropertyTypes, DocumentSource
from sycamore.utils.import_utils import requires_modules
from dataclasses import dataclass, field
import typing
Expand Down Expand Up @@ -81,8 +82,14 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
for data in self.output:
doc_id = data["_id"]
doc = Document(
{"doc_id": doc_id, "embedding": data["_source"].get("embeddings"), **data["_source"].get("properties")}
{
"doc_id": doc_id,
"parent_id": data["_source"].get("parent_id"),
"embedding": data["_source"].get("embedding"),
**data["_source"].get("properties"),
}
)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
result.append(doc)
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ class ElasticsearchWriterTargetParams(BaseDBWriter.TargetParams):
mappings: dict[str, Any] = field(
default_factory=lambda: {
"properties": {
"embeddings": {
"embedding": {
"type": "dense_vector",
"dims": 384,
"index": True,
"similarity": "cosine",
},
"properties": {"type": "object"},
"parent_id": {"type": "text"},
}
}
)
Expand Down Expand Up @@ -88,7 +89,8 @@ def bulk_action_generator():
"_index": target_params.index_name,
"_id": r.doc_id,
"properties": r.properties,
"embeddings": r.embeddings,
"embedding": r.embedding,
"parent_id": r.parent_id,
}

for success, info in parallel_bulk(
Expand Down Expand Up @@ -129,15 +131,15 @@ def get_existing_target_params(self, target_params: BaseDBWriter.TargetParams) -
class ElasticsearchWriterDocumentRecord(BaseDBWriter.Record):
doc_id: str
properties: dict
embeddings: Optional[list[float]]
parent_id: Optional[str]
embedding: Optional[list[float]]

@classmethod
def from_doc(
cls, document: Document, target_params: BaseDBWriter.TargetParams
) -> "ElasticsearchWriterDocumentRecord":
assert isinstance(target_params, ElasticsearchWriterTargetParams)
doc_id = document.doc_id
embedding = document.embedding
if doc_id is None:
raise ValueError(f"Cannot write documents without a doc_id. Found {document}")
properties = {
Expand All @@ -147,7 +149,9 @@ def from_doc(
"bbox": document.bbox.coordinates if document.bbox else None,
"shingles": document.shingles,
}
return ElasticsearchWriterDocumentRecord(doc_id=doc_id, properties=properties, embeddings=embedding)
return ElasticsearchWriterDocumentRecord(
doc_id=doc_id, parent_id=document.parent_id, properties=properties, embedding=document.embedding
)


def _narrow_list_of_doc_records(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
**data.get("_source", {}),
}
)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
result.append(doc)
else:
assert (
Expand Down
13 changes: 11 additions & 2 deletions lib/sycamore/sycamore/connectors/pinecone/pinecone_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,21 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
assert isinstance(self, PineconeReaderQueryResponse)
result = []
for data in self.output:
doc_id = data.id.split("#")[1] if len(data.id.split("#")) > 1 else data.id
if len(id := data.id.split("#")) > 1:
parent_id = id[0]
doc_id = id[1]
else:
parent_id = None
doc_id = data.id
if data.sparse_vector:
term_frequency = dict(zip(data.sparse_vector.indices, data.sparse_vector.values))
data.metadata["properties.term_frequency"] = term_frequency
metadata = data.metadata if data.metadata else {}
doc = Document({"doc_id": doc_id, "embedding": data.values} | unflatten_data(metadata))
doc_dict = {"doc_id": doc_id, "embedding": data.values, "parent_id": parent_id} | unflatten_data(metadata)
doc_dict["bbox"] = (
[bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]] if (bbox := doc_dict.get("bbox")) else []
)
doc = Document(doc_dict)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
result.append(doc)
return result
Expand Down
4 changes: 2 additions & 2 deletions lib/sycamore/sycamore/connectors/pinecone/pinecone_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, asdict
import typing
from typing import Optional, TypedDict, Union
from typing import Optional, TypedDict, Union, Any
from sycamore.utils import batched
from typing_extensions import TypeGuard

Expand Down Expand Up @@ -119,7 +119,7 @@ def from_doc(cls, document: Document, target_params: "BaseDBWriter.TargetParams"
else:
id = f"{document.parent_id}#{document.doc_id}"
values = document.embedding
metadata = {
metadata: dict[str, Any] = {
"type": document.type,
"text_representation": document.text_representation,
"bbox": document.bbox.to_dict() if document.bbox else None,
Expand Down
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载