+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/sycamore/sycamore/materialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(

self._maybe_anonymous()

def _maybe_anonymous(self):
def _maybe_anonymous(self) -> None:
if self._root is None:
return
from pyarrow.fs import S3FileSystem
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from pydantic import BaseModel
import sycamore
from sycamore.data.document import Document
from sycamore.data.document import Document, HierarchicalDocument
from sycamore.data.element import Element
from sycamore.llms.llms import LLM
from sycamore.reader import DocSetReader
Expand Down Expand Up @@ -74,7 +74,7 @@ async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict
}
"""

def test_entity_extractor(self):
def test_entity_extractor(self) -> None:
context = sycamore.init()
reader = DocSetReader(context)
ds = reader.document(self.docs)
Expand All @@ -83,10 +83,10 @@ class Company(BaseModel):
name: str

llm = self.MockLLM()
ds = ds.extract_document_structure(structure=StructureBySection).extract_graph_entities(
ds = ds.extract_document_structure(structure=StructureBySection()).extract_graph_entities(
[EntityExtractor(llm=llm, entities=[Company])]
)
docs = ds.take_all()
docs = [HierarchicalDocument(doc.data) for doc in ds.take_all()]

for doc in docs:
for section in doc.children:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from pydantic import BaseModel
import sycamore
from sycamore.data.document import Document
from sycamore.data.document import Document, HierarchicalDocument
from sycamore.data.element import Element
from sycamore.llms.llms import LLM
from sycamore.reader import DocSetReader
Expand Down Expand Up @@ -80,8 +80,8 @@ class MockRelationshipLLM(LLM):
def __init__(self):
super().__init__(model_name="mock_model")

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None):
pass
def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
return ""

def is_chat_mode(self):
return True
Expand All @@ -94,7 +94,7 @@ async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict
}
"""

def test_relationship_extractor(self):
def test_relationship_extractor(self) -> None:
context = sycamore.init()
reader = DocSetReader(context)
ds = reader.document(self.docs)
Expand All @@ -107,11 +107,11 @@ class Competes(BaseModel):
end: Company

ds = (
ds.extract_document_structure(structure=StructureBySection)
ds.extract_document_structure(structure=StructureBySection())
.extract_graph_entities([EntityExtractor(self.MockEntityLLM(), [Company])])
.extract_graph_relationships([RelationshipExtractor(self.MockRelationshipLLM(), [Competes])])
)
docs = ds.take_all()
docs = [HierarchicalDocument(doc.data) for doc in ds.take_all()]

for doc in docs:
for section in doc.children:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from pydantic import BaseModel
import sycamore
from sycamore.data.document import Document
from sycamore.data.document import Document, HierarchicalDocument
from sycamore.data.element import Element
from sycamore.llms.llms import LLM
from sycamore.reader import DocSetReader
Expand Down Expand Up @@ -93,7 +93,7 @@ async def generate_async(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict
}
"""

def test_resolve_entities(self):
def test_resolve_entities(self) -> None:
context = sycamore.init()
reader = DocSetReader(context)
ds = reader.document(self.docs)
Expand All @@ -106,12 +106,12 @@ class Competes(BaseModel):
end: Company

ds = (
ds.extract_document_structure(structure=StructureBySection)
ds.extract_document_structure(structure=StructureBySection())
.extract_graph_entities([EntityExtractor(self.MockEntityLLM(), [Company])])
.extract_graph_relationships([RelationshipExtractor(self.MockRelationshipLLM(), [Competes])])
.resolve_graph_entities(resolvers=[], resolve_duplicates=True)
)
docs = ds.take_all()
docs = [HierarchicalDocument(doc.data) for doc in ds.take_all()]

for doc in docs:
if doc.data.get("EXTRACTED_NODES", False) is True:
Expand Down
14 changes: 4 additions & 10 deletions lib/sycamore/sycamore/tests/unit/utils/test_import_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import TYPE_CHECKING
from mypy import api
from typing_extensions import assert_type
from sycamore.utils.import_utils import requires_modules


Expand All @@ -9,13 +8,8 @@ def require_fn() -> int:
return 42


res = require_fn()
if TYPE_CHECKING:
reveal_type(res) # noqa: F821


# This test fails prior to adding generic (ParamSpec and TypeVar) type annotations to the
# requires_modules decorator, as the revealed type is "Any".
def test_mypy_type():
mypy_res = api.run([__file__])
assert 'Revealed type is "builtins.int"' in mypy_res[0]
def test_mypy_type() -> None:
res = require_fn()
assert_type(res, int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_type is only present in Python 3.11, so I would expect your unit tests to fail on 3.9 and 3.10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out this does work if we import from typing_extensions!

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, Field


def test_pydantic_picklng():
def test_pydantic_picklng() -> None:
class BoardMember(BaseModel):
name: str
votes_for: Optional[int]
Expand Down
4 changes: 2 additions & 2 deletions lib/sycamore/sycamore/transforms/extract_graph_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
import hashlib
import io
from typing import Dict, Any, List
from typing import Dict, Any, List, Type
from sycamore.plan_nodes import Node
from sycamore.transforms.map import Map
from sycamore.data import HierarchicalDocument
Expand Down Expand Up @@ -43,7 +43,7 @@ class EntityExtractor(GraphEntityExtractor):
def __init__(
self,
llm: LLM,
entities: list[BaseModel],
entities: List[Type[BaseModel]],
prompt: str = GraphEntityExtractorPrompt.user,
split_calls: bool = False,
):
Expand Down
17 changes: 9 additions & 8 deletions lib/sycamore/sycamore/transforms/extract_graph_relationships.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from abc import ABC, abstractmethod
import asyncio
import base64
from collections import defaultdict
from enum import Enum
import hashlib
import json
import io
import logging
from typing import Dict, Any, List, Type
import uuid

from PIL import Image
from typing import Dict, Any, List
from pydantic import BaseModel, create_model

from sycamore.plan_nodes import Node
from sycamore.transforms.map import Map
from sycamore.data import HierarchicalDocument
from sycamore.llms import LLM
from sycamore.llms.prompts import GraphRelationshipExtractorPrompt
from pydantic import BaseModel, create_model
import asyncio

import json
import uuid
import logging

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,7 +46,7 @@ class RelationshipExtractor(GraphRelationshipExtractor):
def __init__(
self,
llm: LLM,
relationships: list[BaseModel],
relationships: List[Type[BaseModel]],
prompt: str = GraphRelationshipExtractorPrompt.user,
split_calls: bool = False,
):
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载