这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 59 additions & 34 deletions src/tmlt/analytics/_query_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from pyspark.sql import DataFrame
from typeguard import check_type
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tmlt.core.measurements.aggregations import NoiseMechanism
from tmlt.core.measures import ApproxDP, PureDP, RhoZCDP

from tmlt.analytics import AnalyticsInternalError
from tmlt.analytics._coerce_spark_schema import coerce_spark_schema_or_fail
Expand Down Expand Up @@ -155,7 +157,18 @@ class StdevMechanism(Enum):
Not compatible with pure DP.
"""

@dataclass(frozen=True, kw_only=True)
class CompilationInfo:
"""Contextual information added to the QueryExpr during compilation."""

output_measure: Union[PureDP, ApproxDP, RhoZCDP]
"""The output measure used by this query."""

catalog: Union[PureDP, ApproxDP, RhoZCDP]
"""The Catalog of the Session this query is executed on."""


@dataclass(frozen=True, kw_only=True)
class QueryExpr(ABC):
"""A query expression, base class for relational operators.

Expand All @@ -169,13 +182,16 @@ class QueryExpr(ABC):
returns a relation.
"""

compilation_info: CompilationInfo = None
"""Compilation info needed for rewrite rules."""

@abstractmethod
def accept(self, visitor: "QueryExprVisitor") -> Any:
"""Dispatch methods on a visitor based on the QueryExpr type."""
raise NotImplementedError()


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class PrivateSource(QueryExpr):
"""Loads the private source."""

Expand All @@ -198,7 +214,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_private_source(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GetGroups(QueryExpr):
"""Returns groups based on the geometric partition selection for these columns."""

Expand All @@ -222,7 +238,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_get_groups(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GetBounds(QueryExpr):
"""Returns approximate upper and lower bounds of a column."""

Expand Down Expand Up @@ -252,7 +268,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_get_bounds(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class Rename(QueryExpr):
"""Returns the dataframe with columns renamed."""

Expand Down Expand Up @@ -284,7 +300,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_rename(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class Filter(QueryExpr):
"""Returns the subset of the rows that satisfy the condition."""

Expand All @@ -307,7 +323,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_filter(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class Select(QueryExpr):
"""Returns a subset of the columns."""

Expand All @@ -328,7 +344,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_select(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class Map(QueryExpr):
"""Applies a map function to each row of a relation."""

Expand Down Expand Up @@ -375,7 +391,7 @@ def __eq__(self, other: object) -> bool:
)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class FlatMap(QueryExpr):
"""Applies a flat map function to each row of a relation."""

Expand Down Expand Up @@ -442,7 +458,7 @@ def __eq__(self, other: object) -> bool:
)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class FlatMapByID(QueryExpr):
"""Applies a flat map function to each group of rows with a common ID."""

Expand Down Expand Up @@ -486,7 +502,7 @@ def __eq__(self, other: object) -> bool:
)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class JoinPrivate(QueryExpr):
"""Returns the join of two private tables.

Expand Down Expand Up @@ -532,7 +548,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_join_private(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class JoinPublic(QueryExpr):
"""Returns the join of a private and public table."""

Expand Down Expand Up @@ -641,7 +657,7 @@ class AnalyticsDefault:
"""


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ReplaceNullAndNan(QueryExpr):
"""Returns data with null and NaN expressions replaced by a default.

Expand Down Expand Up @@ -676,7 +692,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_replace_null_and_nan(self)


@dataclass(frozen=True, init=False, eq=True)
@dataclass(frozen=True, kw_only=True)
class ReplaceInfinity(QueryExpr):
"""Returns data with +inf and -inf expressions replaced by defaults."""

Expand All @@ -693,29 +709,26 @@ class ReplaceInfinity(QueryExpr):
:class:`~.AnalyticsDefault` class variables).
"""

def __init__(
self, child: QueryExpr, replace_with: FrozenDict = FrozenDict.from_dict({})
) -> None:
def __post_init__( self) -> None:
"""Checks arguments to constructor."""
check_type(child, QueryExpr)
check_type(replace_with, FrozenDict)
check_type(dict(replace_with), Dict[str, Tuple[float, float]])
check_type(self.child, QueryExpr)
check_type(self.replace_with, FrozenDict)
check_type(dict(self.replace_with), Dict[str, Tuple[float, float]])

# Ensure that the values in replace_with are tuples of floats
updated_dict = {}
for col, val in replace_with.items():
for col, val in self.replace_with.items():
updated_dict[col] = (float(val[0]), float(val[1]))

# Subverting the frozen dataclass to update the replace_with attribute
object.__setattr__(self, "replace_with", FrozenDict.from_dict(updated_dict))
object.__setattr__(self, "child", child)

def accept(self, visitor: "QueryExprVisitor") -> Any:
"""Visit this QueryExpr with visitor."""
return visitor.visit_replace_infinity(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class DropNullAndNan(QueryExpr):
"""Returns data with rows that contain null or NaN value dropped.

Expand Down Expand Up @@ -745,7 +758,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_drop_null_and_nan(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class DropInfinity(QueryExpr):
"""Returns data with rows that contain +inf/-inf dropped."""

Expand All @@ -768,7 +781,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_drop_infinity(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class EnforceConstraint(QueryExpr):
"""Enforces a constraint on the data."""

Expand All @@ -787,7 +800,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_enforce_constraint(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByCount(QueryExpr):
"""Returns the count of each combination of the groupby domains."""

Expand All @@ -803,6 +816,8 @@ class GroupByCount(QueryExpr):
By DEFAULT, the framework automatically selects an
appropriate mechanism.
"""
core_mechanism: Optional[NoiseMechanism] = None
"""The Core mechanism used for this aggregation. Specified during compilation."""

def __post_init__(self):
"""Checks arguments to constructor."""
Expand All @@ -818,7 +833,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_count(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByCountDistinct(QueryExpr):
"""Returns the count of distinct rows in each groupby domain value."""

Expand All @@ -838,6 +853,8 @@ class GroupByCountDistinct(QueryExpr):

By DEFAULT, the framework automatically selects an appropriate mechanism.
"""
core_mechanism: Optional[NoiseMechanism]= None
"""The Core mechanism used for this aggregation. Specified during compilation."""

def __post_init__(self):
"""Checks arguments to constructor."""
Expand All @@ -854,7 +871,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_count_distinct(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByQuantile(QueryExpr):
"""Returns the quantile of a column for each combination of the groupby domains.

Expand Down Expand Up @@ -915,7 +932,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_quantile(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByBoundedSum(QueryExpr):
"""Returns the bounded sum of a column for each combination of groupby domains.

Expand Down Expand Up @@ -947,6 +964,8 @@ class GroupByBoundedSum(QueryExpr):
By DEFAULT, the framework automatically selects an
appropriate mechanism.
"""
core_mechanism: Optional[NoiseMechanism]= None
"""The Core mechanism used for this aggregation. Specified during compilation."""

def __post_init__(self):
"""Checks arguments to constructor."""
Expand Down Expand Up @@ -976,7 +995,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_bounded_sum(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByBoundedAverage(QueryExpr):
"""Returns bounded average of a column for each combination of groupby domains.

Expand Down Expand Up @@ -1008,6 +1027,8 @@ class GroupByBoundedAverage(QueryExpr):
By DEFAULT, the framework automatically selects an
appropriate mechanism.
"""
core_mechanism: Optional[NoiseMechanism]= None
"""The Core mechanism used for this aggregation. Specified during compilation."""

def __post_init__(self):
"""Checks arguments to constructor."""
Expand Down Expand Up @@ -1037,7 +1058,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_bounded_average(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByBoundedVariance(QueryExpr):
"""Returns bounded variance of a column for each combination of groupby domains.

Expand Down Expand Up @@ -1069,6 +1090,8 @@ class GroupByBoundedVariance(QueryExpr):
By DEFAULT, the framework automatically selects an
appropriate mechanism.
"""
core_mechanism: Optional[NoiseMechanism]= None
"""The Core mechanism used for this aggregation. Specified during compilation."""

def __post_init__(self):
"""Checks arguments to constructor."""
Expand Down Expand Up @@ -1098,7 +1121,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_bounded_variance(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class GroupByBoundedSTDEV(QueryExpr):
"""Returns bounded stdev of a column for each combination of groupby domains.

Expand Down Expand Up @@ -1130,6 +1153,8 @@ class GroupByBoundedSTDEV(QueryExpr):
By DEFAULT, the framework automatically selects an
appropriate mechanism.
"""
core_mechanism: Optional[NoiseMechanism] = None
"""The Core mechanism used for this aggregation. Specified during compilation."""

def __post_init__(self):
"""Checks arguments to constructor."""
Expand Down Expand Up @@ -1160,7 +1185,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:
return visitor.visit_groupby_bounded_stdev(self)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SuppressAggregates(QueryExpr):
"""Remove all counts that are less than the threshold."""

Expand Down
Loading