这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Changed
- Aggregation mechanisms can now be specified as strings instead of enums, e.g. ``"laplace"`` instead of ``CountMechanism.LAPLACE`` or ``SumMechanism.LAPLACE``.
- Removed previously deprecated argument ``max_num_rows`` to ``flat_map``. Use ``max_rows`` instead.
- Removed previously deprecated argument ``cols`` to ``count_distinct``. Use ``columns`` instead.
- Infinity values are now automatically dropped before a floating-point column is passed to `get_bounds`. (The documentation previously claimed that this was done, but this was not the case.)
- Fixed the documentation of the behavior of some numeric aggregations (`sum`, `average`, `stdev`, `variance`, `quantile`) to match the actual behavior: infinity values are clamped using the specified bounds before being passed to the aggregation function, not dropped.


.. _v0.20.2:
Expand Down
55 changes: 13 additions & 42 deletions src/tmlt/analytics/_query_expr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Building blocks of the Tumult Analytics query language. Not for direct use.

Defines the :class:`QueryExpr` class, which represents expressions in the
Tumult Analytics query language. QueryExpr and its subclasses should not be
directly constructed or deconstructed by most users; interfaces such as
:class:`tmlt.analytics.QueryBuilder` to create them and
:class:`tmlt.analytics.Session` to consume them provide more
user-friendly features.
Defines the :class:`QueryExpr` class, which represents expressions in the Tumult
Analytics query language. QueryExpr and its subclasses should not be directly
constructed; but instead built using a :class:`tmlt.analytics.QueryBuilder`. The
documentation of the :class:`tmlt.analytics.QueryBuilder` provides more information
about the intended semantics of :class:`QueryExpr` objects.
"""

# SPDX-License-Identifier: Apache-2.0
Expand Down Expand Up @@ -175,14 +174,10 @@ class StdevMechanism(Enum):
class QueryExpr(ABC):
"""A query expression, base class for relational operators.

In most cases, QueryExpr should not be manipulated directly, but rather
created using :class:`tmlt.analytics.QueryBuilder` and then
consumed by :class:`tmlt.analytics.Session`. While they can be
created and modified directly, this is an advanced usage and is not
recommended for typical users.

QueryExpr are organized in a tree, where each node is an operator which
returns a relation.
QueryExpr are organized in a tree, where each node is an operator that returns a
table. They are built using the :class:`tmlt.analytics.QueryBuilder`, then rewritten
during the compilation process. They should not be created directly, except in
tests.
"""

@abstractmethod
Expand Down Expand Up @@ -1775,13 +1770,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:

@dataclass(frozen=True)
class GroupByBoundedSum(SingleChildQueryExpr):
"""Returns the bounded sum of a column for each combination of groupby domains.

If the column to be measured contains null, NaN, or positive or negative infinity,
those values will be dropped (as if dropped explicitly via
:class:`DropNullAndNan` and :class:`DropInfinity`) before the sum is
calculated.
"""
"""Returns the bounded sum of a column for each combination of groupby domains."""

groupby_keys: Union[KeySet, Tuple[str, ...]]
"""The keys, or columns list to collect keys from, to be grouped on."""
Expand Down Expand Up @@ -1842,13 +1831,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:

@dataclass(frozen=True)
class GroupByBoundedAverage(SingleChildQueryExpr):
"""Returns bounded average of a column for each combination of groupby domains.

If the column to be measured contains null, NaN, or positive or negative infinity,
those values will be dropped (as if dropped explicitly via
:class:`DropNullAndNan` and :class:`DropInfinity`) before the average is
calculated.
"""
"""Returns bounded average of a column for each combination of groupby domains."""

groupby_keys: Union[KeySet, Tuple[str, ...]]
"""The keys, or columns list to collect keys from, to be grouped on."""
Expand Down Expand Up @@ -1909,13 +1892,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:

@dataclass(frozen=True)
class GroupByBoundedVariance(SingleChildQueryExpr):
"""Returns bounded variance of a column for each combination of groupby domains.

If the column to be measured contains null, NaN, or positive or negative infinity,
those values will be dropped (as if dropped explicitly via
:class:`DropNullAndNan` and :class:`DropInfinity`) before the variance is
calculated.
"""
"""Returns bounded variance of a column for each combination of groupby domains."""

groupby_keys: Union[KeySet, Tuple[str, ...]]
"""The keys, or columns list to collect keys from, to be grouped on."""
Expand Down Expand Up @@ -1976,13 +1953,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any:

@dataclass(frozen=True)
class GroupByBoundedStdev(SingleChildQueryExpr):
"""Returns bounded stdev of a column for each combination of groupby domains.

If the column to be measured contains null, NaN, or positive or negative infinity,
those values will be dropped (as if dropped explicitly via
:class:`DropNullAndNan` and :class:`DropInfinity`) before the
standard deviation is calculated.
"""
"""Returns bounded stdev of a column for each combination of groupby domains."""

groupby_keys: Union[KeySet, Tuple[str, ...]]
"""The keys, or columns list to collect keys from, to be grouped on."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Defines a base class for building measurement visitors."""
import dataclasses
import math
import warnings
from abc import abstractmethod
Expand Down Expand Up @@ -101,7 +100,7 @@
SuppressAggregates,
VarianceMechanism,
)
from tmlt.analytics._schema import ColumnType, FrozenDict, Schema
from tmlt.analytics._schema import Schema
from tmlt.analytics._table_identifier import Identifier
from tmlt.analytics._table_reference import TableReference
from tmlt.analytics._transformation_utils import get_table_from_ref
Expand Down Expand Up @@ -675,65 +674,6 @@ def _validate_approxDP_and_adjust_budget(
else:
raise AnalyticsInternalError(f"Unknown mechanism {mechanism}.")

def _add_special_value_handling_to_query(
self,
query: Union[
GroupByBoundedAverage,
GroupByBoundedStdev,
GroupByBoundedSum,
GroupByBoundedVariance,
GroupByQuantile,
GetBounds,
],
):
"""Returns a new query that handles nulls, NaNs and infinite values.

If the measure column allows nulls or NaNs, the new query
will drop those values.

If the measure column allows infinite values, the new query will replace those
values with the low and high values specified in the query.

These changes are added immediately before the groupby aggregation in the query.
"""
expected_schema = query.child.schema(self.catalog)

# You can't perform these queries on nulls, NaNs, or infinite values
# so check for those
try:
measure_desc = expected_schema[query.measure_column]
except KeyError as e:
raise KeyError(
f"Measure column {query.measure_column} is not in the input schema."
) from e

new_child: QueryExpr
# If null or NaN values are allowed ...
if measure_desc.allow_null or (
measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_nan
):
# then drop those values
# (but don't mutate the original query)
new_child = DropNullAndNan(
child=query.child, columns=tuple([query.measure_column])
)
query = dataclasses.replace(query, child=new_child)
if not isinstance(query, GetBounds):
# If infinite values are allowed...
if (
measure_desc.column_type == ColumnType.DECIMAL
and measure_desc.allow_inf
):
# then clamp them (to low/high values)
new_child = ReplaceInfinity(
child=query.child,
replace_with=FrozenDict.from_dict(
{query.measure_column: (query.low, query.high)}
),
)
query = dataclasses.replace(query, child=new_child)
return query

def _validate_measurement(self, measurement: Measurement, mid_stability: sp.Expr):
"""Validate a measurement."""
if isinstance(self.adjusted_budget.value, tuple):
Expand Down Expand Up @@ -1146,7 +1086,6 @@ def visit_groupby_quantile(

# Peek at the schema, to see if there are errors there
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand Down Expand Up @@ -1241,7 +1180,6 @@ def visit_groupby_bounded_sum(

# Peek at the schema, to see if there are errors there
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand Down Expand Up @@ -1337,7 +1275,6 @@ def visit_groupby_bounded_average(

# Peek at the schema, to see if there are errors there
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand Down Expand Up @@ -1433,7 +1370,6 @@ def visit_groupby_bounded_variance(

# Peek at the schema, to see if there are errors there
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand Down Expand Up @@ -1529,7 +1465,6 @@ def visit_groupby_bounded_stdev(

# Peek at the schema, to see if there are errors there
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand Down Expand Up @@ -1622,7 +1557,6 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]:
# Peek at the schema, to see if there are errors there
expr.schema(self.catalog)

expr = self._add_special_value_handling_to_query(expr)
if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
keyset_budget = self._get_zero_budget()
Expand Down
65 changes: 65 additions & 0 deletions src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@
AverageMechanism,
CountDistinctMechanism,
CountMechanism,
DropInfinity,
DropNullAndNan,
FrozenDict,
GetBounds,
GroupByBoundedAverage,
GroupByBoundedStdev,
GroupByBoundedSum,
GroupByBoundedVariance,
GroupByCount,
GroupByCountDistinct,
GroupByQuantile,
JoinPrivate,
PrivateSource,
QueryExpr,
ReplaceInfinity,
SingleChildQueryExpr,
StdevMechanism,
SumMechanism,
Expand Down Expand Up @@ -193,9 +199,68 @@ def select_noise(expr: QueryExpr) -> QueryExpr:
return select_noise


def add_special_value_handling(
info: CompilationInfo,
) -> Callable[[QueryExpr], QueryExpr]:
"""Rewrites the query to handle nulls, NaNs and infinite values.

If the measure column allows nulls or NaNs, the rewritten query will drop those
values. If the measure column allows infinite values, the new query will replace
those values with the clamping bounds specified in the query, or drop these values
for :meth:`~tmlt.analytics.QueryBuilder.get_bounds`.
"""

@depth_first
def handle_special_values(expr: QueryExpr) -> QueryExpr:
if not isinstance(
expr,
(
GroupByBoundedAverage,
GroupByBoundedStdev,
GroupByBoundedSum,
GroupByBoundedVariance,
GroupByQuantile,
GetBounds,
),
):
return expr
schema = expr.child.schema(info.catalog)
measure_desc = schema[expr.measure_column]
# Remove nulls/NaN if necessary
if measure_desc.allow_null or (
measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_nan
):
expr = replace(
expr,
child=DropNullAndNan(child=expr.child, columns=(expr.measure_column,)),
)
# Remove infinities if necessary
if measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_inf:
if isinstance(expr, GetBounds):
return replace(
expr,
child=DropInfinity(
child=expr.child, columns=(expr.measure_column,)
),
)
return replace(
expr,
child=ReplaceInfinity(
child=expr.child,
replace_with=FrozenDict.from_dict(
{expr.measure_column: (expr.low, expr.high)}
),
),
)
return expr

return handle_special_values


def rewrite(info: CompilationInfo, expr: QueryExpr) -> QueryExpr:
"""Rewrites the given QueryExpr into a QueryExpr that can be compiled."""
rewrite_rules = [
add_special_value_handling(info),
select_noise_mechanism(info),
]
for rule in rewrite_rules:
Expand Down
Loading