-
Notifications
You must be signed in to change notification settings - Fork 65
Add a groupby operator #1123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a groupby operator #1123
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from ray.data.aggregate import AggregateFn | ||
|
||
from sycamore import DocSet | ||
from sycamore.data import Document | ||
|
||
|
||
class GroupedData: | ||
def __init__(self, docset: DocSet, key): | ||
self._docset = docset | ||
self._key = key | ||
|
||
def aggregate(self, f: "AggregateFn") -> DocSet: | ||
dataset = self._docset.plan.execute() | ||
grouped = dataset.map(Document.from_row).groupby(self._key) | ||
aggregated = grouped.aggregate(f) | ||
|
||
def to_doc(row: dict): | ||
count = row.pop("count()") | ||
doc = Document(row) | ||
properties = doc.properties | ||
properties["count"] = count | ||
doc.properties = properties | ||
return doc.to_row() | ||
|
||
serialized = aggregated.map(to_doc) | ||
from sycamore.transforms import DatasetScan | ||
|
||
return DocSet(self._docset.context, DatasetScan(serialized)) | ||
|
||
def count(self) -> DocSet: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import count here |
||
from ray.data._internal.aggregate import Count | ||
|
||
return self.aggregate(Count()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pytest | ||
|
||
import sycamore | ||
from sycamore import DocSet | ||
from sycamore.data import Document | ||
|
||
|
||
class TestGroup: | ||
@pytest.fixture | ||
def fruits_docset(self) -> DocSet: | ||
doc_list = [ | ||
Document(text_representation="apple", parent_id=8), | ||
Document(text_representation="banana", parent_id=7), | ||
Document(text_representation="apple", parent_id=8), | ||
Document(text_representation="banana", parent_id=7), | ||
Document(text_representation="cherry", parent_id=6), | ||
Document(text_representation="apple", parent_id=9), | ||
] | ||
context = sycamore.init() | ||
return context.read.document(doc_list) | ||
|
||
def test_groupby_count(self, fruits_docset): | ||
aggregated = fruits_docset.groupby("text_representation").count() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do the Documents in aggregated look like at this point? |
||
assert aggregated.count() == 3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
import ray.data | ||
|
||
import sycamore | ||
from sycamore.data import Document | ||
from sycamore.transforms.clustering import KMeans | ||
|
||
|
||
class TestKMeans: | ||
|
||
def test_kmeans(self): | ||
points = np.random.uniform(0, 40, (20, 4)) | ||
docs = [ | ||
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i}) | ||
for i, point in enumerate(points) | ||
] | ||
context = sycamore.init() | ||
docset = context.read.document(docs) | ||
centroids = docset.kmeans(3, 4) | ||
assert len(centroids) == 3 | ||
|
||
def test_closest(self): | ||
row = [[0, 0, 0, 0]] | ||
centroids = [ | ||
[1, 1, 1, 1], | ||
[2, 2, 2, 2], | ||
[-1, -1, -1, -1], | ||
] | ||
assert KMeans.closest(row, centroids) == 0 | ||
|
||
def test_random(self): | ||
points = np.random.uniform(0, 40, (20, 4)) | ||
embeddings = [{"vector": list(point), "cluster": -1} for point in points] | ||
embeddings = ray.data.from_items(embeddings) | ||
centroids = KMeans.random_init(embeddings, 10) | ||
assert len(centroids) == 10 | ||
|
||
def test_converged(self): | ||
last_ones = [[1.0, 1.0], [10.0, 10.0]] | ||
next_ones = [[2.0, 2.0], [12.0, 12.0]] | ||
assert KMeans.converged(last_ones, next_ones, 10).item() is True | ||
assert KMeans.converged(last_ones, next_ones, 1).item() is False | ||
|
||
def test_converge(self): | ||
points = np.random.uniform(0, 10, (20, 4)) | ||
embeddings = [{"vector": list(point), "cluster": -1} for point in points] | ||
embeddings = ray.data.from_items(embeddings) | ||
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]] | ||
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4) | ||
assert len(new_centroids) == 2 | ||
|
||
def test_clustering(self): | ||
np.random.seed(2024) | ||
points = np.random.uniform(0, 40, (20, 4)) | ||
docs = [ | ||
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i}) | ||
for i, point in enumerate(points) | ||
] | ||
context = sycamore.init() | ||
docset = context.read.document(docs) | ||
centroids = docset.kmeans(3, 4) | ||
|
||
clustered_docs = docset.clustering(centroids, "cluster").take_all() | ||
ids = [doc["cluster"] for doc in clustered_docs] | ||
assert all(0 <= idx < 3 for idx in ids) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import random | ||
|
||
|
||
class KMeans: | ||
|
||
@staticmethod | ||
def closest(row, centroids): | ||
import torch | ||
|
||
row = torch.Tensor([row]) | ||
centroids = torch.Tensor(centroids) | ||
distance = torch.cdist(row, centroids) | ||
idx = torch.argmin(distance) | ||
return idx | ||
|
||
@staticmethod | ||
def converged(last_ones, next_ones, epsilon): | ||
import torch | ||
|
||
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones)) | ||
return len(last_ones) == torch.sum(distance < epsilon) | ||
|
||
@staticmethod | ||
def random_init(embeddings, K): | ||
count = embeddings.count() | ||
assert count > 0 and K < count | ||
fraction = min(2 * K / count, 1.0) | ||
|
||
candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()] | ||
candidates.sort() | ||
from itertools import groupby | ||
|
||
uniques = [key for key, _ in groupby(candidates)] | ||
assert len(uniques) >= K | ||
|
||
centroids = random.sample(uniques, K) | ||
return centroids | ||
|
||
@staticmethod | ||
def init(embeddings, K, init_mode): | ||
if init_mode == "random": | ||
Comment on lines
+40
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. supernit: init_mode could be an Enum but str is fine too. I guess would be nice to have the list of known init_modes in the exception? |
||
return KMeans.random_init(embeddings, K) | ||
else: | ||
raise Exception("Unknown init mode") | ||
|
||
@staticmethod | ||
def update(embeddings, centroids, iterations, epsilon): | ||
i = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import AggregateFn here |
||
d = len(centroids[0]) | ||
|
||
from ray.data.aggregate import AggregateFn | ||
|
||
update_centroids = AggregateFn( | ||
init=lambda v: ([0] * d, 0), | ||
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1), | ||
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]), | ||
name="centroids", | ||
) | ||
|
||
while i < iterations: | ||
|
||
def _find_cluster(row): | ||
idx = KMeans.closest(row["vector"], centroids) | ||
return {"vector": row["vector"], "cluster": idx} | ||
|
||
aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take() | ||
import numpy as np | ||
|
||
new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated] | ||
|
||
if KMeans.converged(centroids, new_centroids, epsilon): | ||
return new_centroids | ||
else: | ||
i += 1 | ||
centroids = new_centroids | ||
|
||
return centroids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a "assert self.context.exec_mode == ExecMode.RAY" in here?