+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lib/sycamore/sycamore/tests/unit/transforms/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
import pytest
import ray.data

from sycamore.data import Document
from sycamore.plan_nodes import Node
from sycamore.transforms import Embed
from sycamore.transforms.embed import SentenceTransformerEmbedder
from sycamore.transforms.embed import OpenAIEmbedder, SentenceTransformerEmbedder


class TestEmbedding:
Expand Down Expand Up @@ -77,3 +78,10 @@ def test_sentence_transformer_embedding(self, mocker):
input_dataset.show()
output_dataset = embedding.execute()
output_dataset.show()

def test_openai_embedder_pickle(self):
obj = OpenAIEmbedder()
obj._client = obj.client_wrapper.get_client()

pickle.dumps(obj)
assert True
8 changes: 8 additions & 0 deletions lib/sycamore/sycamore/transforms/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ def __init__(
self._client: Optional[OpenAIClient] = None
self.model_name = model_name

def __getstate__(self):
state = self.__dict__.copy()
state["_client"] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)

def generate_embeddings(self, doc_batch: list[Document]) -> list[Document]:
# TODO: Add some input validation here.
# The OpenAI docs are quite vague on acceptable values for model_batch_size.
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载