+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/sycamore/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 79 additions & 0 deletions lib/sycamore/sycamore/evaluation/ocr/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from argparse import ArgumentParser

from ray.data import ActorPoolStrategy
import sycamore
from sycamore.evaluation.ocr.models import (
PaddleOCR,
EasyOCR,
Tesseract,
Textract,
LLMOCR,
DocTR,
RapidOCR,
ExtractOCRFromImage,
)
from sycamore.evaluation.ocr.metrics import (
CharacterErrorRate,
WordErrorRate,
MatchErrorRate,
WordInformationLost,
apply_metric,
)
import time
import json
from sycamore.evaluation.ocr.data import BaseOCREvalScan, InvoiceOCREvalScan, HandwritingOCREvalScan

DATASETS = {"base": BaseOCREvalScan, "handwriting": HandwritingOCREvalScan, "invoice": InvoiceOCREvalScan}

MODELS = {
"easy": EasyOCR,
"tesseract": Tesseract,
"textract": Textract,
"paddle": PaddleOCR,
"llm": LLMOCR,
"doctr": DocTR,
"rapid": RapidOCR,
}

METRICS = [CharacterErrorRate(), MatchErrorRate(), WordErrorRate(), WordInformationLost()]

model_actorpool = ActorPoolStrategy(size=2)
model_kwargs = {"device": "mps"}

parser = ArgumentParser()
parser.add_argument("dataset", nargs="?", choices=list(DATASETS.keys()), help="dataset to evaluate")
parser.add_argument("model", nargs="?", choices=list(MODELS.keys()), help="OCR Model to use")
parser.add_argument("--debug", required=False, action="store_true")
parser.add_argument("--limit", type=int, default=1000, help="A limit on the number of values to run")
args = parser.parse_args()
dataset = DATASETS.get(args.dataset, BaseOCREvalScan) if args.dataset else BaseOCREvalScan
model = MODELS.get(args.model, EasyOCR) if args.model else EasyOCR
# debug = args.debug if args.debug else False
limit = args.limit if not args.debug else args.debug

all_results: dict[str, dict[str, dict[str, float]]] = {}
for dataset_name, dataset_class in DATASETS.items():
ctx = sycamore.init()

pipeline = dataset_class().to_docset(ctx) # type: ignore
if dataset_name == "base":
pipeline = pipeline.filter(lambda doc: "index" not in doc.data) # type: ignore
all_results[dataset_name] = {}
for model_name, model_class in MODELS.items():
print(f"Running evaluation for dataset '{dataset_name}' and model '{model_name}'")

curr_time = time.time()
pipeline = pipeline.limit(limit)
pipeline = pipeline.map_batch(ExtractOCRFromImage(model_class()), compute=model_actorpool)
for m in METRICS:
pipeline = pipeline.map(apply_metric(m))
aggs = pipeline.plan.execute().aggregate(*[m.to_aggregate_fn() for m in METRICS])
aggs["latency"] = time.time() - curr_time

print("=" * 80)
print(aggs)
all_results[dataset_name][model_name] = aggs

# Write the results to a JSON file
with open("ocr_evaluation_results.json", "w") as f:
json.dump(all_results, f, indent=4)
131 changes: 131 additions & 0 deletions lib/sycamore/sycamore/evaluation/ocr/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from sycamore.data import Document
from PIL import Image
from datasets import load_dataset
from datasets.load import IterableDataset
from ray.data import from_huggingface, Dataset
from typing import Optional
import io
from sycamore.docset import DocSet
from sycamore.context import Context
from sycamore.plan_nodes import Scan
import json
import math
import ast


class OCREvalDocument(Document):

def __init__(self, document=None, /, **kwargs):
super().__init__(document, **kwargs)
if not document:
document = []

@property
def gt_text(self):
"""
Returns the Ground Truth text object
"""
return self.data.get("gt_text")

@gt_text.setter
def gt_text(self, gt_text: str):
"""
Sets the Ground Truth text object
"""
self.data["gt_text"] = gt_text

@property
def pred_text(self):
"""
Returns the Predicted text object
"""
return self.data.get("pred_text")

@pred_text.setter
def pred_text(self, pred_text: str):
"""
Sets the Predicted text object
"""
self.data["pred_text"] = pred_text

@property
def metrics(self) -> dict:
"""Dictionary of evaluation metrics"""
if "metrics" not in self.data:
self.data["metrics"] = {}
return self.data["metrics"]

@property
def image(self) -> Optional[Image.Image]:
"""Bytes of image of table"""
if "image" in self.data:
imbytes = self.data["image"]
return Image.open(io.BytesIO(imbytes))
return None

@image.setter
def image(self, im: Image.Image):
"""Set the image of the table"""
buf = io.BytesIO()
im.save(buf, format="png")
self.data["image"] = buf.getvalue()


class OCREvalScan(Scan):
def format(self): # type: ignore
return "OCREvalScan"

def to_docset(self, context: Context) -> DocSet:
return DocSet(context, self)


class InvoiceOCREvalScan(OCREvalScan):

@staticmethod
def _ray_row_to_document(row) -> dict[str, bytes]:
img = Image.open(io.BytesIO(row["image"]["bytes"])).convert("RGB")
eval_doc = OCREvalDocument()
eval_doc.image = img
eval_doc.gt_text = " ".join(ast.literal_eval(json.loads(row["raw_data"])["ocr_words"]))
return {"doc": eval_doc.serialize()}

def execute(self, **kwargs) -> Dataset:
hf_ds = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split="train", streaming=True)
assert isinstance(hf_ds, IterableDataset)
ray_ds = from_huggingface(hf_ds)
return ray_ds.map(InvoiceOCREvalScan._ray_row_to_document)


class HandwritingOCREvalScan(OCREvalScan):

@staticmethod
def _ray_row_to_document(row) -> dict[str, bytes]:
img = Image.open(io.BytesIO(row["image"]["bytes"])).convert("RGB")
eval_doc = OCREvalDocument()
eval_doc.image = img
eval_doc.gt_text = row["text"]
return {"doc": eval_doc.serialize()}

def execute(self, **kwargs) -> Dataset:
hf_ds = load_dataset("corto-ai/handwritten-text", split="train", streaming=True)
assert isinstance(hf_ds, IterableDataset)
ray_ds = from_huggingface(hf_ds)
return ray_ds.map(HandwritingOCREvalScan._ray_row_to_document)


class BaseOCREvalScan(OCREvalScan):
@staticmethod
def _ray_row_to_document(row) -> dict[str, bytes]:
img = Image.open(io.BytesIO(row["cropped_image"]["bytes"])).convert("RGB")
eval_doc = OCREvalDocument()
eval_doc.image = img
eval_doc.gt_text = "".join(row["answer"]) if isinstance(row["answer"], list) else row["answer"]
if row["__index_level_0__"] and not math.isnan(row["__index_level_0__"]):
eval_doc.data["index"] = row["__index_level_0__"]
return {"doc": eval_doc.serialize()}

def execute(self, **kwargs) -> Dataset:
hf_ds = load_dataset("deoxykev/short_ocr_sentences", split="train", streaming=True)
assert isinstance(hf_ds, IterableDataset)
ray_ds = from_huggingface(hf_ds)
return ray_ds.map(BaseOCREvalScan._ray_row_to_document)
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载