+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
" and Lincoln ended up running the store by himself.[60] Although the economy was booming, the business"
" struggled and went into debt, causing Lincoln to sell his share."
),
(""),
]


Expand All @@ -41,8 +42,9 @@ def check_embedder(embedder: Embedder, expected_dim: int):
assert len(new_docs) == len(docs)

for doc in new_docs:
assert doc.embedding is not None
assert len(doc.embedding) == expected_dim
if doc.text_representation != "":
assert doc.embedding is not None
assert len(doc.embedding) == expected_dim


def test_openai_embedding():
Expand Down
10 changes: 7 additions & 3 deletions lib/sycamore/sycamore/transforms/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def _pre_process_document(document: Document) -> str:
return document.text_representation if document.text_representation is not None else ""


def _text_representation_is_empty(doc: Document) -> bool:
return doc.text_representation is None or doc.text_representation.strip() == ""


class Embedder(ABC):
def __init__(
self,
Expand Down Expand Up @@ -188,14 +192,14 @@ def generate_embeddings(self, doc_batch: list[Document]) -> list[Document]:
text_to_embed = [
self.pre_process_document(doc).replace("\n", " ")
for doc in batch
if doc.text_representation is not None
if not _text_representation_is_empty(doc)
]

embeddings = self._client.embeddings.create(model=self.model_name, input=text_to_embed).data

i = 0
for doc in batch:
if doc.text_representation is not None:
if not _text_representation_is_empty(doc):
doc.embedding = embeddings[i].embedding
i += 1

Expand Down Expand Up @@ -272,7 +276,7 @@ def generate_embeddings(self, doc_batch: list[Document]) -> list[Document]:
client = boto3.client("bedrock-runtime")

for doc in doc_batch:
if doc.text_representation is not None:
if not _text_representation_is_empty(doc):
doc.embedding = self._generate_embedding(client, self.pre_process_document(doc))
return doc_batch

Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载