+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions lib/aryn-sdk/aryn_sdk/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ def partition_file(
default: English
extract_table_structure: extract tables and their structural content.
default: False
table_extraction_options: Specify options for table extraction, currently only supports boolean
'include_additional_text': if table extraction is enabled, attempt to enhance the table
structure by merging in tokens from text extraction. This can be useful for tables with missing
or misaligned text, and is False by default.
default: {}
table_extraction_options: Specify options for table extraction. Only enabled if table extraction
is enabled. Default is {}. Options:
- 'include_additional_text': Attempt to enhance the table structure by merging in tokens from
text extraction. This can be useful for tables with missing or misaligned text. Default: False
- 'model_selection': expression to instruct DocParse how to choose which model to use for table
extraction. See https://docs.aryn.ai/docparse/processing_options for more details. Default:
"pixels > 500 -> deformable_detr; table_transformer"
extract_images: extract image contents in ppm format, base64 encoded.
default: False
selected_pages: list of individual pages (1-indexed) from the pdf to partition
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from PIL import Image
from sycamore.data.bbox import BoundingBox
from sycamore.data.element import TableElement
Expand All @@ -7,6 +8,8 @@
DeformableTableStructureExtractor,
)

HTSE = HybridTableStructureExtractor


class TestTableExtractors:

Expand All @@ -20,25 +23,111 @@ def mock_doc_image(mocker, width, height):
def mock_table_element(mocker, width, height):
elt = mocker.Mock(spec=TableElement)
elt.bbox = BoundingBox(0, 0, width, height)
elt.tokens = [{"text": "alrngea;rjgnekl"}]
return elt

def test_hybrid_routing_both_gt500(self, mocker):
im = TestTableExtractors.mock_doc_image(mocker, 1000, 1000)
elt = TestTableExtractors.mock_table_element(mocker, 0.7, 0.7)
extractor = HybridTableStructureExtractor(deformable_model="dont initialize me")
chosen = extractor._pick_model(elt, im)
chosen = extractor._pick_model(elt, im, model_selection="pixels>500->deformable_detr;table_transformer")
assert type(chosen) == DeformableTableStructureExtractor

def test_hybrid_routing_one_gt500(self, mocker):
im = TestTableExtractors.mock_doc_image(mocker, 1000, 1000)
elt = TestTableExtractors.mock_table_element(mocker, 0.7, 0.2)
extractor = HybridTableStructureExtractor(deformable_model="dont initialize me")
chosen = extractor._pick_model(elt, im)
chosen = extractor._pick_model(elt, im, model_selection="pixels>500->deformable_detr;table_transformer")
assert type(chosen) == DeformableTableStructureExtractor

def test_hybrid_routing_neither_gt500(self, mocker):
im = TestTableExtractors.mock_doc_image(mocker, 1000, 1000)
elt = TestTableExtractors.mock_table_element(mocker, 0.2, 0.2)
extractor = HybridTableStructureExtractor(deformable_model="dont initialize me")
chosen = extractor._pick_model(elt, im)
chosen = extractor._pick_model(elt, im, model_selection="pixels>500->deformable_detr;table_transformer")
assert type(chosen) == TableTransformerStructureExtractor


class TestHybridSelectionStatements:
params = [(1000, 25), (25, 1000), (25, 25), (1000, 1000)]

def test_static(self):
f = HybridTableStructureExtractor.parse_model_selection("table_transformer")
for p in self.params:
assert f(*p) == "table_transformer"

f = HybridTableStructureExtractor.parse_model_selection("deformable_detr ")
for p in self.params:
assert f(*p) == "deformable_detr"

f = HybridTableStructureExtractor.parse_model_selection("deformable_detr; this is a comment")
for p in self.params:
assert f(*p) == "deformable_detr"

def test_pixelmetric(self):
f = HybridTableStructureExtractor.parse_model_selection("pixels > 500 -> deformable_detr; table_transformer")
selections = [f(*p) for p in self.params]
assert selections == ["deformable_detr", "table_transformer", "table_transformer", "deformable_detr"]

def test_charmetric(self):
f = HybridTableStructureExtractor.parse_model_selection("chars > 500 -> deformable_detr; table_transformer")
selections = [f(*p) for p in self.params]
assert selections == ["table_transformer", "deformable_detr", "table_transformer", "deformable_detr"]

def test_bad_modelname(self):
with pytest.raises(ValueError, match=r"Invalid statement.* model_name was not in.*"):
HTSE.parse_model_selection("tatr")

with pytest.raises(ValueError, match=r"Invalid statement.* Result model .* was not in.*"):
HTSE.parse_model_selection("pixels>500 -> nonmodel")

with pytest.raises(ValueError, match=r"Invalid statement.* model_name was not in.*"):
HTSE.parse_model_selection("pixels>500 -> deformable_detr; yo_mama")

def test_multiple_arrows(self):
with pytest.raises(ValueError, match=r"Invalid statement.* Found more than 2 instances of '->'"):
HTSE.parse_model_selection("pixels>500 -> vrooooom -> vrooooooooooom")

def test_no_comparison(self):
with pytest.raises(ValueError, match=r"Invalid statement.* Did not find a comparison operator .*"):
HTSE.parse_model_selection("pixels=3->deformable_detr")

with pytest.raises(ValueError, match=r"Invalid statement.* Did not find a comparison operator .*"):
HTSE.parse_model_selection("chars->deformable_detr")

def test_multiple_comparisons(self):
with pytest.raises(ValueError, match=r"Invalid comparison.* Comparison statements must take the form .*"):
HTSE.parse_model_selection("1000 > pixels > 300 -> deformable_detr")

def test_backwards_comparison(self):
with pytest.raises(ValueError, match=r"Invalid comparison.* Allowed metrics are.*"):
HTSE.parse_model_selection("1000 > pixels -> table_transformer")

def test_bad_metric(self):
with pytest.raises(ValueError, match=r"Invalid comparison.* Allowed metrics are.*"):
HTSE.parse_model_selection("pickles > 1000 -> table_transformer")

with pytest.raises(ValueError, match=r"Invalid comparison.* Allowed metrics are.*"):
HTSE.parse_model_selection("charm < 5 -> deformable_detr")

def test_bad_threshold(self):
with pytest.raises(ValueError, match=r"Invalid comparison.* Threshold .* must be numeric"):
HTSE.parse_model_selection("pixels > chars -> table_transformer")

def test_complicated(self):
f = HTSE.parse_model_selection(
"pixels>5->table_transformer; chars<30->deformable_detr;chars>35->table_transformer;"
"pixels>2->deformable_detr;table_transformer;comment"
)
assert f(10, 14) == "table_transformer"
assert f(5, 15) == "deformable_detr"
assert f(5, 42) == "table_transformer"
assert f(5, 32) == "deformable_detr"
assert f(0, 32) == "table_transformer"

def test_excess_semicolons_ok(self):
f = HTSE.parse_model_selection("chars>0->table_transformer;")
assert f(10, 10) == "table_transformer"

f = HTSE.parse_model_selection(";;;chars>0->table_transformer;")
assert f(10, 10) == "table_transformer"
182 changes: 173 additions & 9 deletions lib/sycamore/sycamore/transforms/table_structure/extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Union
from typing import Any, Union, Optional, Callable

from PIL import Image
import pdf2image
Expand All @@ -14,6 +14,8 @@
from sycamore.utils import choose_device
from sycamore.utils.import_utils import requires_modules

Num = Union[float, int]


class TableStructureExtractor:
"""Interface for extracting table structure."""
Expand Down Expand Up @@ -235,6 +237,17 @@ class HybridTableStructureExtractor(TableStructureExtractor):
"""A TableStructureExtractor implementation that conditionally uses either Deformable or TATR
depending on the size of the table"""

_model_names = ("table_transformer", "deformable_detr")
_metrics = ("pixels", "chars")
_comparisons = (
"==",
"<=",
">=",
"!=",
"<",
">",
)

def __init__(
self,
deformable_model: str,
Expand All @@ -245,25 +258,47 @@ def __init__(
self._tatr = TableTransformerStructureExtractor(tatr_model, device)

def _pick_model(
self, element: TableElement, doc_image: Image.Image
self,
element: TableElement,
doc_image: Image.Image,
model_selection: str,
) -> Union[TableTransformerStructureExtractor, DeformableTableStructureExtractor]:
"""If the absolute size of the table is > 500 pixels in any dimension, use deformable.
Otherwise, use TATR"""
"""Use the model_selection expression to choose the model to use for table extraction.
If the expression returns None, use table transformer."""
if element.bbox is None:
return self._tatr

select_fn = self.parse_model_selection(model_selection)

width, height = doc_image.size
bb = element.bbox.to_absolute(width, height)
padding = 10
max_dim = max(bb.width, bb.height) + 2 * padding
if max_dim > 500:

nchars = sum(len(tok["text"]) for tok in element.tokens or [{"text": ""}])

selection = select_fn(max_dim, nchars)
print("=" * 80)
print(selection)
if selection == "table_transformer":
return self._tatr
elif selection == "deformable_detr":
return self._deformable
return self._tatr
elif selection is None:
return self._tatr
raise ValueError(f"Somehow we got an invalid selection: {selection}. This should be unreachable.")

def _init_structure_model(self):
self._deformable._init_structure_model()
self._tatr._init_structure_model()

def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=False) -> TableElement:
def extract(
self,
element: TableElement,
doc_image: Image.Image,
union_tokens=False,
model_selection: str = "pixels > 500 -> deformable_detr; table_transformer",
) -> TableElement:
"""Extracts the table structure from the specified element using a either a DeformableDETR or
TATR model, depending on the size of the table.

Expand All @@ -276,9 +311,138 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
Used for bounding box calculations.
union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table.
apply_thresholds: Apply class thresholds to the objects output by the model.
model_selection: Control which model gets selected. See ``parse_model_selection`` for
expression syntax. Default is "pixels > 500 -> deformable_detr; table_transformer".
If no statements are matched, defaults to table transformer.
"""
m = self._pick_model(element, doc_image, model_selection)
return m.extract(element, doc_image, union_tokens)

@classmethod
def parse_model_selection(cls, selection: str) -> Callable[[float, int], Optional[str]]:
"""
model = self._pick_model(element, doc_image)
return model.extract(element, doc_image, union_tokens)
Parse a model selection expression. Model selection expressions are of the form:
"metric cmp threshold -> model; metric cmp threshold -> model; model;"
That is, any number of conditional expression selections followed by up to one unconditional
selection expression, separated by semicolons. Expressions are processed from left to right.
Anything after the unconditional expression is not processed.

- Supported metrics are "pixels" - the number of pixels in the larger dimension of the table (we
find this to be easier to reason about than the total number of pixels which depends on two numbers),
and "chars" - the number of characters in the table, as detected by the partitioner's text_extractor.
- Supported comparisons are the usual set - <, >, <=, >=, ==, !=.
- The threshold must be numeric (and int or a float)
- The model must be either "deformable_detr" or "table_transformer"

Args:
selection: the selection string.

Returns:
a function that can be used to select a model given the pixels and chars metrics.

Examples:
- `"table_transformer"` => always use table transformer
- `"pixels > 500 -> deformable_detr; table_transformer"` => if the biggest dimension of
the table is greater than 500 pixels use deformable detr. Otherwise use table_transformer.
- `"pixels>50->table_transformer; chars<30->deformable_detr;chars>35->table_transformer;pixels>2->deformable_detr;table_transformer;comment"`
=> if the biggest dimension is more than 50 pixels use table transformer. Else if the total number of chars in the table is less than
30 use deformable_detr. Else if there are mode than 35 chars use table transformer. Else if there are more than 2 pixels in the biggest
dimension use deformable detr. Otherwise use table transformer. comment is not processed.
""" # noqa: E501 # line too long. long line is a long example. I want it that way.
statements = selection.split(sep=";")
checks = []
for statement in statements:
statement = statement.strip()
if statement == "":
continue
if "->" not in statement:
if statement not in cls._model_names:
raise ValueError(
f"Invalid statement: '{statement}'. Did not find '->', so this is assumed"
f" to be a static statement, but the model_name was not in {cls._model_names}"
)
checks.append(lambda pixels, chars: statement)
break
pieces = statement.split(sep="->")
if len(pieces) > 2:
raise ValueError(f"Invalid statement: '{statement}'. Found more than 2 instances of '->'")
result = pieces[1].strip()
if result not in cls._model_names:
raise ValueError(
f"Invalid statement: '{statement}'. Result model ({result}) was not in {cls._model_names}"
)
if all(c not in pieces[0] for c in cls._comparisons):
raise ValueError(
f"Invalid statement: '{statement}'. Did not find a comparison operator {cls._comparisons}"
)
metric, cmp, threshold = cls.parse_comparison(pieces[0])

def make_check(metric, compare, threshold, result):
# otherwise captrued variables change their values
def check(pixels: float, chars: int) -> Optional[str]:
if metric == "pixels":
cmpval = pixels
else:
cmpval = chars
if compare(cmpval, threshold):
return result
return None

return check

checks.append(make_check(metric, cmp, threshold, result))

def select_fn(pixels: float, chars: int) -> Optional[str]:
for c in checks:
if (rv := c(pixels, chars)) is not None:
return rv
return None

return select_fn

@staticmethod
def get_cmp_fn(opstring: str) -> Callable[[Num, Num], bool]:
ops = {
"!=": lambda a, b: a != b,
">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b,
"==": lambda a, b: a == b,
"<": lambda a, b: a < b,
">": lambda a, b: a > b,
}
if opstring in ops:
return ops[opstring]
raise ValueError(f"Invalid comparison: Unsupported operator '{opstring}'")

@classmethod
def parse_comparison(cls, comparison: str) -> tuple[str, Callable[[Num, Num], bool], Union[int, float]]:
cmp_pieces = []
cmp = None
for opstring in sorted(cls._comparisons, key=lambda c: len(c), reverse=True):
if opstring in comparison:
cmp_pieces = comparison.split(sep=opstring)
cmp = cls.get_cmp_fn(opstring)
break

if len(cmp_pieces) == 0 or cmp is None:
raise ValueError(
f"Invalid comparison: '{comparison}'. Did not find comparison operator {cls._comparisons}."
)
if len(cmp_pieces) != 2:
raise ValueError(
f"Invalid comparison: '{comparison}'. Comparison statements must take the form 'METRIC CMP THRESHOLD'."
)
metric, threshold = cmp_pieces[0].strip(), cmp_pieces[1].strip()
if metric not in cls._metrics:
raise ValueError(f"Invalid comparison: '{comparison}'. Allowed metrics are: '{cls._metrics}'")
try:
threshold_num: Num = int(threshold)
except ValueError:
try:
threshold_num = float(threshold)
except ValueError:
raise ValueError(f"Invalid comparison: '{comparison}'. Threshold ({threshold}) must be numeric")
return metric, cmp, threshold_num


DEFAULT_TABLE_STRUCTURE_EXTRACTOR = TableTransformerStructureExtractor
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载