-
Notifications
You must be signed in to change notification settings - Fork 65
Fix bugs in Unflattening Data #930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5a7487b
54d2888
cce4500
35a7cca
850ffd5
7a06f24
b2c8f3d
3a43bbb
6661802
e11b0b1
eb16d42
d4f8041
eb5ca12
08bd1e0
5d360ed
b7e22c8
8f1ec4e
aabaab3
ee690a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
from dataclasses import dataclass | ||
from typing import Callable, Iterator, Union, Iterable, Tuple, Any, Dict | ||
from sycamore.data import Document | ||
import json | ||
import string | ||
import random | ||
import math | ||
import numpy as np | ||
|
||
|
||
@dataclass | ||
|
@@ -29,25 +32,42 @@ def generate_random_string(length=8): | |
return "".join(random.choice(characters) for _ in range(length)) | ||
|
||
|
||
def filter_doc(obj, include): | ||
return {k: v for k, v in obj.__dict__.items() if k in include} | ||
def filter_doc(doc: Document, include): | ||
return {k: v for k, v in doc.items() if k in include} | ||
|
||
|
||
def check_dictionary_compatibility(dict1: dict[Any, Any], dict2: dict[Any, Any], ignore: list[str] = []): | ||
for k in dict1: | ||
if ignore and any(val in k for val in ignore): | ||
if not dict1.get(k) or (ignore and any(val in k for val in ignore)): | ||
continue | ||
if k not in dict2: | ||
return False | ||
if dict1[k] != dict2[k]: | ||
if dict1[k] != dict2[k] and (dict1[k] or dict2[k]): | ||
return False | ||
return True | ||
|
||
|
||
def compare_docs(doc1, doc2): | ||
def compare_docs(doc1: Document, doc2: Document): | ||
filtered_doc1 = filter_doc(doc1, DEFAULT_RECORD_PROPERTIES.keys()) | ||
filtered_doc2 = filter_doc(doc2, DEFAULT_RECORD_PROPERTIES.keys()) | ||
return filtered_doc1 == filtered_doc2 | ||
for key in filtered_doc1: | ||
if isinstance(filtered_doc1[key], (list, np.ndarray)) or isinstance(filtered_doc2.get(key), (list, np.ndarray)): | ||
assert len(filtered_doc1[key]) == len(filtered_doc2[key]) | ||
for item1, item2 in zip(filtered_doc1[key], filtered_doc2[key]): | ||
try: | ||
# Convert items to float for numerical comparison | ||
num1 = float(item1) | ||
num2 = float(item2) | ||
# Check if numbers are close within tolerance | ||
assert math.isclose(num1, num2, rel_tol=1e-5, abs_tol=1e-5) | ||
except (ValueError, TypeError): | ||
# If conversion to float fails, do direct comparison | ||
assert item1 == item2 | ||
elif isinstance(filtered_doc1[key], dict) and isinstance(filtered_doc2.get(key), dict): | ||
assert check_dictionary_compatibility(filtered_doc1[key], filtered_doc2.get(key)) | ||
else: | ||
assert filtered_doc1[key] == filtered_doc2.get(key) | ||
return True | ||
|
||
|
||
def _add_key_to_prefix(prefix, key, separator="."): | ||
|
@@ -88,49 +108,86 @@ def flatten_data( | |
return items | ||
|
||
|
||
def unflatten_data(data: dict[str, Any], separator: str = ".") -> dict[Any, Any]: | ||
result: dict[Any, Any] = {} | ||
def unflatten_data(data: dict[Any, Any], separator: str = ".") -> dict[Any, Any]: | ||
""" | ||
Unflattens a dictionary with keys that contain separators into a nested dictionary. The separator can be escaped, | ||
and if there are integer keys in the path, the result will be a list instead of a dictionary. | ||
""" | ||
|
||
def parse_key(key: str) -> list: | ||
# Handle escaped separator | ||
def split_key(key: str, separator: str = ".") -> list[str]: | ||
karanataryn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Splits the key by separator (which can be multiple characters), respecting escaped separators. | ||
""" | ||
parts = [] | ||
current = "" | ||
escape = False | ||
for char in key: | ||
if escape: | ||
if char == separator: | ||
current += separator | ||
i = 0 | ||
while i < len(key): | ||
if key[i] == "\\": | ||
# Escape character | ||
if i + 1 < len(key): | ||
current += key[i + 1] | ||
i += 2 | ||
else: | ||
current += "\\" + char | ||
escape = False | ||
elif char == "\\": | ||
escape = True | ||
elif char == separator: | ||
# Trailing backslash, treat it as literal backslash | ||
current += "\\" | ||
i += 1 | ||
elif key[i : i + len(separator)] == separator: | ||
# Found separator | ||
parts.append(current) | ||
current = "" | ||
i += len(separator) | ||
else: | ||
current += char | ||
current += key[i] | ||
i += 1 | ||
parts.append(current) | ||
return parts | ||
|
||
for key, value in data.items(): | ||
parts = parse_key(key) | ||
result: dict[Any, Any] = {} | ||
for flat_key, value in data.items(): | ||
parts = split_key(flat_key, separator) | ||
current = result | ||
for i, part in enumerate(parts): | ||
part_key: Union[str, int] = int(part) if part.isdigit() else part | ||
# Determine whether the key part is an integer (for list indices) | ||
key: Union[str, int] | ||
try: | ||
key = int(part) | ||
except ValueError: | ||
key = part | ||
Comment on lines
+152
to
+155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could prob do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Part will also be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isnumeric is a string method to tell you if all the characters are digits. so you could do if part.isnumeric():
key = int(part)
else:
key = part I just like to avoid extraneous try/catches where possible. maybe that's silly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah that makes sense. I can do that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait isdigit does the same thing. Why did we change this from the prev implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because both |
||
|
||
is_last = i == len(parts) - 1 | ||
|
||
if is_last: | ||
current[part_key] = value | ||
else: | ||
next_part_is_digit = parts[i + 1].isdigit() if i + 1 < len(parts) else False | ||
if part_key not in current: | ||
current[part_key] = [] if next_part_is_digit else {} | ||
current = current[part_key] | ||
# If current is a list and the next part is a digit, ensure proper length | ||
# Set the value at the deepest level | ||
if isinstance(current, list): | ||
if next_part_is_digit and len(current) <= int(parts[i + 1]): | ||
current.extend("" for _ in range(int(parts[i + 1]) - len(current) + 1)) | ||
# Ensure the list is big enough | ||
while len(current) <= key: | ||
current.append("") | ||
current[key] = value | ||
else: | ||
current[key] = value | ||
else: | ||
# Determine the type of the next part | ||
next_part = parts[i + 1] | ||
|
||
# Check if the next part is an index (integer) | ||
try: | ||
int(next_part) | ||
next_is_index = True | ||
except ValueError: | ||
next_is_index = False | ||
|
||
# Initialize containers as needed | ||
if isinstance(current, list): | ||
# Ensure the list is big enough | ||
while len(current) <= key: | ||
current.append("") | ||
if current[key] == "" or current[key] is None: | ||
current[key] = [] if next_is_index else {} | ||
current = current[key] | ||
else: | ||
if key not in current: | ||
current[key] = [] if next_is_index else {} | ||
current = current[key] | ||
return result | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.