+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from sycamore.plan_nodes import Node, Transform
from sycamore.transforms.augment_text import TextAugmentor
from sycamore.transforms.clustering import KMeans
from sycamore.transforms.embed import Embedder
from sycamore.transforms import DocumentStructure, Sort
from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor
Expand All @@ -36,6 +37,7 @@

if TYPE_CHECKING:
from sycamore.writer import DocSetWriter
from sycamore.grouped_data import GroupedData

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -920,6 +922,40 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int = 20, init_mode: str = "random", epsilon: float = 1e-4):
"""
Apply kmeans over embedding field

Args:
K: the count of centroids
iterations: the max iteration runs before converge
init_mode: how the initial centroids are select
epsilon: the condition for determining if it's converged
Return a list of max K centroids
"""

def init_embedding(row):
doc = Document.from_row(row)
return {"vector": doc.embedding, "cluster": -1}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a "assert self.context.exec_mode == ExecMode.RAY" in here?

embeddings = self.plan.execute().map(init_embedding).materialize()

initial_centroids = KMeans.init(embeddings, K, init_mode)
centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon)
return centroids

def clustering(self, centroids, cluster_field_name, **resource_args) -> "DocSet":
def cluster(doc: Document) -> Document:
idx = KMeans.closest(doc.embedding, centroids)
doc[cluster_field_name] = idx
return doc

from sycamore.transforms import Map

resource_args["enable_auto_metadata"] = False
mapping = Map(self.plan, f=cluster, **resource_args)
return DocSet(self.context, mapping)

def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet":
"""
Applies the FlatMap transformation on the Docset.
Expand Down Expand Up @@ -1315,6 +1351,11 @@ def llm_query(self, query_agent: LLMTextQueryAgent, **kwargs) -> "DocSet":
queries = LLMQuery(self.plan, query_agent=query_agent, **kwargs)
return DocSet(self.context, queries)

def groupby(self, key: Union[str, list[str]]) -> "GroupedData":
from sycamore.grouped_data import GroupedData

return GroupedData(self, key)

@context_params(OperationTypes.INFORMATION_EXTRACTOR)
def top_k(
self,
Expand Down
36 changes: 36 additions & 0 deletions lib/sycamore/sycamore/grouped_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ray.data.aggregate import AggregateFn

from sycamore import DocSet
from sycamore.data import Document


class GroupedData:
def __init__(self, docset: DocSet, key):
self._docset = docset
self._key = key

def aggregate(self, f: "AggregateFn") -> DocSet:
dataset = self._docset.plan.execute()
grouped = dataset.map(Document.from_row).groupby(self._key)
aggregated = grouped.aggregate(f)

def to_doc(row: dict):
count = row.pop("count()")
doc = Document(row)
properties = doc.properties
properties["count"] = count
doc.properties = properties
return doc.to_row()

serialized = aggregated.map(to_doc)
from sycamore.transforms import DatasetScan

return DocSet(self._docset.context, DatasetScan(serialized))

def count(self) -> DocSet:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import count here

from ray.data._internal.aggregate import Count

return self.aggregate(Count())
23 changes: 23 additions & 0 deletions lib/sycamore/sycamore/tests/unit/test_docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,29 @@ def test_top_k_discrete(self, fruits_docset):
assert top_k_list[1].properties["key"] == "banana"
assert top_k_list[1].properties["count"] == 2

@pytest.fixture
def fruits_embedding_docset(self) -> DocSet:
doc_list = [
Document(text_representation="apple", parent_id=8, embedding=[1]),
Document(text_representation="banana", parent_id=7, embedding=[2]),
Document(text_representation="apple", parent_id=8, embedding=[1]),
Document(text_representation="banana", parent_id=7, embedding=[2]),
Document(text_representation="cherry", parent_id=6, embedding=[3]),
Document(text_representation="apple", parent_id=9, embedding=[1]),
]
context = sycamore.init()
return context.read.document(doc_list)

def test_top_k_with_clustering_groupby(self, fruits_embedding_docset):
centroids = fruits_embedding_docset.kmeans(K=3)
clustered = fruits_embedding_docset.clustering(centroids, "centroids")
aggregated = clustered.groupby("centroids").count()
top_k_docset = aggregated.sort(True, "properties.count", 0).limit(2)

top_k_list = top_k_docset.take()
assert top_k_list[0].properties["count"] == 3
assert top_k_list[1].properties["count"] == 2

def test_top_k_unique_field(self, fruits_docset):

top_k_docset = fruits_docset.top_k(
Expand Down
24 changes: 24 additions & 0 deletions lib/sycamore/sycamore/tests/unit/test_grouped_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

import sycamore
from sycamore import DocSet
from sycamore.data import Document


class TestGroup:
@pytest.fixture
def fruits_docset(self) -> DocSet:
doc_list = [
Document(text_representation="apple", parent_id=8),
Document(text_representation="banana", parent_id=7),
Document(text_representation="apple", parent_id=8),
Document(text_representation="banana", parent_id=7),
Document(text_representation="cherry", parent_id=6),
Document(text_representation="apple", parent_id=9),
]
context = sycamore.init()
return context.read.document(doc_list)

def test_groupby_count(self, fruits_docset):
aggregated = fruits_docset.groupby("text_representation").count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do the Documents in aggregated look like at this point?

assert aggregated.count() == 3
65 changes: 65 additions & 0 deletions lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import ray.data

import sycamore
from sycamore.data import Document
from sycamore.transforms.clustering import KMeans


class TestKMeans:

def test_kmeans(self):
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)
assert len(centroids) == 3

def test_closest(self):
row = [[0, 0, 0, 0]]
centroids = [
[1, 1, 1, 1],
[2, 2, 2, 2],
[-1, -1, -1, -1],
]
assert KMeans.closest(row, centroids) == 0

def test_random(self):
points = np.random.uniform(0, 40, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = KMeans.random_init(embeddings, 10)
assert len(centroids) == 10

def test_converged(self):
last_ones = [[1.0, 1.0], [10.0, 10.0]]
next_ones = [[2.0, 2.0], [12.0, 12.0]]
assert KMeans.converged(last_ones, next_ones, 10).item() is True
assert KMeans.converged(last_ones, next_ones, 1).item() is False

def test_converge(self):
points = np.random.uniform(0, 10, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]]
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4)
assert len(new_centroids) == 2

def test_clustering(self):
np.random.seed(2024)
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)

clustered_docs = docset.clustering(centroids, "cluster").take_all()
ids = [doc["cluster"] for doc in clustered_docs]
assert all(0 <= idx < 3 for idx in ids)
77 changes: 77 additions & 0 deletions lib/sycamore/sycamore/transforms/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import random


class KMeans:

@staticmethod
def closest(row, centroids):
import torch

row = torch.Tensor([row])
centroids = torch.Tensor(centroids)
distance = torch.cdist(row, centroids)
idx = torch.argmin(distance)
return idx

@staticmethod
def converged(last_ones, next_ones, epsilon):
import torch

distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones))
return len(last_ones) == torch.sum(distance < epsilon)

@staticmethod
def random_init(embeddings, K):
count = embeddings.count()
assert count > 0 and K < count
fraction = min(2 * K / count, 1.0)

candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()]
candidates.sort()
from itertools import groupby

uniques = [key for key, _ in groupby(candidates)]
assert len(uniques) >= K

centroids = random.sample(uniques, K)
return centroids

@staticmethod
def init(embeddings, K, init_mode):
if init_mode == "random":
Comment on lines +40 to +41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supernit: init_mode could be an Enum but str is fine too. I guess would be nice to have the list of known init_modes in the exception?

return KMeans.random_init(embeddings, K)
else:
raise Exception("Unknown init mode")

@staticmethod
def update(embeddings, centroids, iterations, epsilon):
i = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import AggregateFn here

d = len(centroids[0])

from ray.data.aggregate import AggregateFn

update_centroids = AggregateFn(
init=lambda v: ([0] * d, 0),
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1),
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]),
name="centroids",
)

while i < iterations:

def _find_cluster(row):
idx = KMeans.closest(row["vector"], centroids)
return {"vector": row["vector"], "cluster": idx}

aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take()
import numpy as np

new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated]

if KMeans.converged(centroids, new_centroids, epsilon):
return new_centroids
else:
i += 1
centroids = new_centroids

return centroids
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载