+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions lib/sycamore/sycamore/connectors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import random
import math
import numpy as np
import pyarrow as pa
import re


@dataclass
Expand Down Expand Up @@ -36,9 +38,14 @@ def filter_doc(doc: Document, include):
return {k: v for k, v in doc.items() if k in include}


def check_dictionary_compatibility(dict1: dict[Any, Any], dict2: dict[Any, Any], ignore: list[str] = []):
def check_dictionary_compatibility(dict1: dict[Any, Any], dict2: dict[Any, Any], ignore_list: list[str] = []):
for k in dict1:
if not dict1.get(k) or (ignore and any(val in k for val in ignore)):
if not dict1.get(k) or (
ignore_list
and any(
(ignore_value in k and any(k in dict2_k for dict2_k in dict2.keys())) for ignore_value in ignore_list
)
): # skip if ignored key and if it exists in dict2
continue
if k not in dict2:
return False
Expand Down Expand Up @@ -298,3 +305,30 @@ def _type_filter(x):
return not isinstance(x, tuple(types))

return _type_filter


def _get_pyarrow_type(key: str, dtype: str) -> pa.DataType:
if dtype == ("VARCHAR"):
return pa.string()
elif dtype == ("DOUBLE"):
return pa.float64()
elif dtype == ("BIGINT"):
return pa.int64()
elif dtype.startswith("MAP"):
match = re.match(r"MAP\((.+),\s*(.+)\)", dtype)
if not match:
raise ValueError(f"Invalid MAP type format: {dtype}")
key_type, value_type = match.groups()
pa_key_type = _get_pyarrow_type(key, key_type)
pa_value_type = _get_pyarrow_type(key, value_type)
return pa.map_(pa_key_type, pa_value_type)
elif dtype == "VARCHAR[]":
return pa.list_(pa.string())
elif dtype == "DOUBLE[]" or key == "embedding": # embedding is a list of floats with a fixed dimension
return pa.list_(pa.float64())
elif dtype == "BIGINT[]":
return pa.list_(pa.int64())
elif dtype == "FLOAT":
return pa.float32()
else:
raise ValueError(f"Unsupported pyarrow datatype: {dtype}")
88 changes: 38 additions & 50 deletions lib/sycamore/sycamore/connectors/duckdb/duckdb_writer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from dataclasses import dataclass, asdict, field
from typing import Optional, Any, Dict
from typing import Optional, Any
from typing_extensions import TypeGuard

import pyarrow as pa
import logging
import duckdb

from sycamore.data.document import Document
from sycamore.connectors.base_writer import BaseDBWriter
from sycamore.connectors.common import convert_to_str_dict
from sycamore.utils.import_utils import requires_modules

import pyarrow as pa
import os
from sycamore.connectors.common import convert_to_str_dict, _get_pyarrow_type


@dataclass
Expand All @@ -22,7 +23,7 @@ class DuckDBWriterTargetParams(BaseDBWriter.TargetParams):
db_url: Optional[str] = "tmp.db"
table_name: Optional[str] = "default_table"
batch_size: int = 1000
schema: Dict[str, str] = field(
schema: dict[str, str] = field(
default_factory=lambda: {
"doc_id": "VARCHAR",
"parent_id": "VARCHAR",
Expand All @@ -44,8 +45,6 @@ def compatible_with(self, other: BaseDBWriter.TargetParams) -> bool:
return False
if self.table_name != other.table_name:
return False
if self.batch_size != other.batch_size:
return False
if other.schema and self.schema:
if (
"embedding" in other.schema
Expand Down Expand Up @@ -74,23 +73,21 @@ def write_many_records(self, records: list[BaseDBWriter.Record], target_params:
), f"Wrong kind of target parameters found: {target_params}"
dict_params = asdict(target_params)
N = target_params.batch_size * 1024 # Around 1 MB
headers = ["doc_id", "parent_id", "embedding", "properties", "text_representation", "bbox", "shingles", "type"]
schema = pa.schema(
[
("doc_id", pa.string()),
("parent_id", pa.string()),
("embedding", pa.list_(pa.float32())),
("properties", pa.map_(pa.string(), pa.string())),
("text_representation", pa.string()),
("bbox", pa.list_(pa.float32())),
("shingles", pa.list_(pa.int64())),
("type", pa.string()),
]
)

def write_batch(batch_data: dict):
import duckdb
# Validate schema and create pyarrow schema
headers = []
pa_fields = []
for key, dtype in target_params.schema.items():
headers.append(key)
try:
pa_dtype = _get_pyarrow_type(key, dtype)
pa_fields.append((key, pa_dtype))
except Exception as e:
raise ValueError(f"Invalid schema attribute or datatype for {key}: {e}")

schema = pa.schema(pa_fields)

def write_batch(batch_data: dict):
pa_table = pa.Table.from_pydict(batch_data, schema=schema) # noqa
client = duckdb.connect(str(dict_params.get("db_url")))
client.sql(f"INSERT INTO {dict_params.get('table_name')} SELECT * FROM pa_table")
Expand All @@ -100,63 +97,54 @@ def write_batch(batch_data: dict):
batch_data: dict[str, list[Any]] = {key: [] for key in headers}

for r in records:
# Append the new data to the batch
batch_data["doc_id"].append(r.doc_id)
batch_data["parent_id"].append(r.parent_id)
batch_data["embedding"].append(r.embedding)
batch_data["properties"].append(convert_to_str_dict(r.properties) if r.properties else {})
batch_data["text_representation"].append(r.text_representation)
batch_data["bbox"].append(r.bbox)
batch_data["shingles"].append(r.shingles)
batch_data["type"].append(r.type)
for key in headers:
value = getattr(r, key, None)
if isinstance(value, dict) and value:
value = convert_to_str_dict(value)
batch_data[key].append(value)

# If we've reached the batch size, write to the database
if batch_data.__sizeof__() >= N:
write_batch(batch_data)
# Write any remaining records
if len(batch_data["doc_id"]) > 0:
if len(batch_data[headers[0]]) > 0:
write_batch(batch_data)

def create_target_idempotent(self, target_params: BaseDBWriter.TargetParams):
import duckdb

assert isinstance(target_params, DuckDBWriterTargetParams)
dict_params = asdict(target_params)
schema = dict_params.get("schema")
if not target_params.db_url:
raise ValueError(f"Must provide valid disk location. Location Specified: {target_params.db_url}")
client = duckdb.connect(str(dict_params.get("db_url")))
try:
if schema:
embedding_size = schema.get("embedding") + "[" + str(dict_params.get("dimensions")) + "]"
client.sql(
f"""CREATE TABLE {dict_params.get('table_name')} (doc_id {schema.get('doc_id')},
parent_id {schema.get('parent_id')},
embedding {embedding_size}, properties {schema.get('properties')},
text_representation {schema.get('text_representation')}, bbox {schema.get('bbox')},
shingles {schema.get('shingles')}, type {schema.get('type')})"""
)
columns = []
for key, dtype in schema.items():
if key == "embedding":
dtype += f"[{dict_params.get('dimensions')}]"
columns.append(f"{key} {dtype}")
columns_str = ", ".join(columns)
client.sql(f"CREATE TABLE {dict_params.get('table_name')} ({columns_str})")
else:
print(
logging.warning(
f"""Error creating table {dict_params.get('table_name')}
in database {dict_params.get('db_url')}: no schema provided"""
)
except Exception:
return

def get_existing_target_params(self, target_params: BaseDBWriter.TargetParams) -> "DuckDBWriterTargetParams":
import duckdb

assert isinstance(target_params, DuckDBWriterTargetParams)
dict_params = asdict(target_params)
schema = target_params.schema
if not target_params.db_url or not os.path.exists(target_params.db_url):
raise ValueError(f"Must provide valid disk location. Location Specified: {target_params.db_url}")
if target_params.db_url and target_params.table_name:
client = duckdb.connect(str(dict_params.get("db_url")))
try:
table = client.sql(f"SELECT * FROM {target_params.table_name}")
schema = dict(zip(table.columns, [repr(i) for i in table.dtypes]))
schema = dict(zip(table.columns, [str(i) for i in table.dtypes]))
except Exception as e:
print(
logging.warning(
f"""Table {dict_params.get('table_name')}
does not exist in database {dict_params.get('table_name')}: {e}"""
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def write_many_records(self, records: list[BaseDBWriter.Record], target_params:

assert isinstance(target_params, ElasticsearchWriterTargetParams)
assert _narrow_list_of_doc_records(records), f"Found a bad record in {records}"
if not records:
return
with self._client:

def bulk_action_generator():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
from unittest import mock
from sycamore.data import Document
from sycamore.data.document import DocumentPropertyTypes, DocumentSource

from sycamore.connectors.duckdb.duckdb_reader import (
DuckDBReaderClient,
DuckDBReaderClientParams,
DuckDBReaderQueryParams,
DuckDBReaderQueryResponse,
DuckDBReader,
)


@pytest.fixture
def mock_duckdb_connection():
with mock.patch("duckdb.connect") as mock_connect:
mock_conn = mock.Mock()
mock_connect.return_value = mock_conn
yield mock_conn


def test_duckdb_reader_client_from_client_params(mock_duckdb_connection):
params = DuckDBReaderClientParams(db_url="test_db")
with mock.patch("duckdb.connect", return_value=mock_duckdb_connection) as mock_connect:
client = DuckDBReaderClient.from_client_params(params)
assert isinstance(client, DuckDBReaderClient)
mock_connect.assert_called_once_with(database="test_db", read_only=True)


def test_duckdb_reader_client_read_records(mock_duckdb_connection):
client = DuckDBReaderClient(mock_duckdb_connection)
query_params = DuckDBReaderQueryParams(table_name="test_table", query=None, create_hnsw_table=None)
mock_duckdb_connection.execute.return_value = mock.Mock()
response = client.read_records(query_params)
assert isinstance(response, DuckDBReaderQueryResponse)
mock_duckdb_connection.execute.assert_called_once_with("SELECT * from test_table")


def test_duckdb_reader_client_read_records_with_query(mock_duckdb_connection):
client = DuckDBReaderClient(mock_duckdb_connection)
query_params = DuckDBReaderQueryParams(
table_name="test_table", query="SELECT * FROM test_table", create_hnsw_table=None
)
mock_duckdb_connection.execute.return_value = mock.Mock()
response = client.read_records(query_params)
assert isinstance(response, DuckDBReaderQueryResponse)
mock_duckdb_connection.execute.assert_called_once_with("SELECT * FROM test_table")


def test_duckdb_reader_client_read_records_with_create_hnsw_table(mock_duckdb_connection):
client = DuckDBReaderClient(mock_duckdb_connection)
query_params = DuckDBReaderQueryParams(
table_name="test_table", query=None, create_hnsw_table="CREATE TABLE hnsw AS SELECT * FROM test_table"
)
mock_duckdb_connection.execute.return_value = mock.Mock()
response = client.read_records(query_params)
assert isinstance(response, DuckDBReaderQueryResponse)
mock_duckdb_connection.execute.assert_any_call("CREATE TABLE hnsw AS SELECT * FROM test_table")
mock_duckdb_connection.execute.assert_any_call("SELECT * from test_table")


def test_duckdb_reader_client_check_target_presence(mock_duckdb_connection):
client = DuckDBReaderClient(mock_duckdb_connection)
query_params = DuckDBReaderQueryParams(table_name="test_table", query=None, create_hnsw_table=None)
mock_duckdb_connection.sql.return_value = mock.Mock()
assert client.check_target_presence(query_params) is True
mock_duckdb_connection.sql.assert_called_once_with("SELECT * FROM test_table")


def test_duckdb_reader_client_check_target_presence_not_found(mock_duckdb_connection):
client = DuckDBReaderClient(mock_duckdb_connection)
query_params = DuckDBReaderQueryParams(table_name="non_existent_table", query=None, create_hnsw_table=None)
mock_duckdb_connection.sql.side_effect = Exception("Table not found")
assert client.check_target_presence(query_params) is False
mock_duckdb_connection.sql.assert_called_once_with("SELECT * FROM non_existent_table")


def test_duckdb_reader_query_response_to_docs():
mock_output = mock.Mock()
mock_output.df.return_value.to_dict.return_value = [{"properties": {"key": "value"}, "embedding": 0.0}]
response = DuckDBReaderQueryResponse(output=mock_output)
query_params = DuckDBReaderQueryParams(table_name="test_table", query=None, create_hnsw_table=None)
docs = response.to_docs(query_params)
assert len(docs) == 1
assert isinstance(docs[0], Document)
assert docs[0].properties[DocumentPropertyTypes.SOURCE] == DocumentSource.DB_QUERY
assert docs[0].properties["key"] == "value"
assert docs[0].embedding == []


def test_duckdb_reader_query_response_to_docs_empty():
mock_output = mock.Mock()
mock_output.df.return_value.to_dict.return_value = []
response = DuckDBReaderQueryResponse(output=mock_output)
query_params = DuckDBReaderQueryParams(table_name="test_table", query=None, create_hnsw_table=None)
docs = response.to_docs(query_params)
assert len(docs) == 0


def test_duckdb_reader():
assert DuckDBReader.Client == DuckDBReaderClient
assert DuckDBReader.Record == DuckDBReaderQueryResponse
assert DuckDBReader.ClientParams == DuckDBReaderClientParams
assert DuckDBReader.QueryParams == DuckDBReaderQueryParams
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载