From e089efee1a2b5f38d5ff55dff459fa86f15fc4f6 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 15 Oct 2025 09:02:49 +0200 Subject: [PATCH 1/2] Make noise selection a rewriting rule --- src/tmlt/analytics/_query_expr.py | 93 +++++--- .../_base_measurement_visitor.py | 137 +----------- .../_query_expr_compiler/_compiler.py | 13 +- src/tmlt/analytics/query_builder.py | 16 +- test/system/session/rows/conftest.py | 62 ++--- test/system/session/rows/test_add_max_rows.py | 4 +- .../rows/test_add_max_rows_in_max_groups.py | 4 +- .../test_measurement_visitor.py | 98 ++++---- .../test_output_schema_visitor.py | 191 ++++++++-------- .../transformation_visitor/test_add_keys.py | 113 +++++----- .../transformation_visitor/test_add_rows.py | 44 ++-- .../test_constraints.py | 20 +- test/unit/test_query_expr_compiler.py | 211 +++++++++--------- test/unit/test_query_expression.py | 115 +++++----- test/unit/test_query_expression_visitor.py | 76 ++++--- 15 files changed, 578 insertions(+), 619 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 46085b23..ab8e0319 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -15,10 +15,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - from pyspark.sql import DataFrame from typeguard import check_type +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from tmlt.core.measurements.aggregations import NoiseMechanism +from tmlt.core.measures import ApproxDP, PureDP, RhoZCDP from tmlt.analytics import AnalyticsInternalError from tmlt.analytics._coerce_spark_schema import coerce_spark_schema_or_fail @@ -155,7 +157,18 @@ class StdevMechanism(Enum): Not compatible with pure DP. """ +@dataclass(frozen=True, kw_only=True) +class CompilationInfo: + """Contextual information added to the QueryExpr during compilation.""" + + output_measure: Union[PureDP, ApproxDP, RhoZCDP] + """The output measure used by this query.""" + catalog: Union[PureDP, ApproxDP, RhoZCDP] + """The Catalog of the Session this query is executed on.""" + + +@dataclass(frozen=True, kw_only=True) class QueryExpr(ABC): """A query expression, base class for relational operators. @@ -169,13 +182,16 @@ class QueryExpr(ABC): returns a relation. """ + compilation_info: CompilationInfo = None + """Compilation info needed for rewrite rules.""" + @abstractmethod def accept(self, visitor: "QueryExprVisitor") -> Any: """Dispatch methods on a visitor based on the QueryExpr type.""" raise NotImplementedError() -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class PrivateSource(QueryExpr): """Loads the private source.""" @@ -198,7 +214,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_private_source(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GetGroups(QueryExpr): """Returns groups based on the geometric partition selection for these columns.""" @@ -222,7 +238,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_get_groups(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GetBounds(QueryExpr): """Returns approximate upper and lower bounds of a column.""" @@ -252,7 +268,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_get_bounds(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Rename(QueryExpr): """Returns the dataframe with columns renamed.""" @@ -284,7 +300,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_rename(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Filter(QueryExpr): """Returns the subset of the rows that satisfy the condition.""" @@ -307,7 +323,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_filter(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Select(QueryExpr): """Returns a subset of the columns.""" @@ -328,7 +344,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_select(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Map(QueryExpr): """Applies a map function to each row of a relation.""" @@ -375,7 +391,7 @@ def __eq__(self, other: object) -> bool: ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class FlatMap(QueryExpr): """Applies a flat map function to each row of a relation.""" @@ -442,7 +458,7 @@ def __eq__(self, other: object) -> bool: ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class FlatMapByID(QueryExpr): """Applies a flat map function to each group of rows with a common ID.""" @@ -486,7 +502,7 @@ def __eq__(self, other: object) -> bool: ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class JoinPrivate(QueryExpr): """Returns the join of two private tables. @@ -532,7 +548,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_join_private(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class JoinPublic(QueryExpr): """Returns the join of a private and public table.""" @@ -641,7 +657,7 @@ class AnalyticsDefault: """ -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ReplaceNullAndNan(QueryExpr): """Returns data with null and NaN expressions replaced by a default. @@ -676,7 +692,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_replace_null_and_nan(self) -@dataclass(frozen=True, init=False, eq=True) +@dataclass(frozen=True, kw_only=True) class ReplaceInfinity(QueryExpr): """Returns data with +inf and -inf expressions replaced by defaults.""" @@ -693,29 +709,26 @@ class ReplaceInfinity(QueryExpr): :class:`~.AnalyticsDefault` class variables). """ - def __init__( - self, child: QueryExpr, replace_with: FrozenDict = FrozenDict.from_dict({}) - ) -> None: + def __post_init__( self) -> None: """Checks arguments to constructor.""" - check_type(child, QueryExpr) - check_type(replace_with, FrozenDict) - check_type(dict(replace_with), Dict[str, Tuple[float, float]]) + check_type(self.child, QueryExpr) + check_type(self.replace_with, FrozenDict) + check_type(dict(self.replace_with), Dict[str, Tuple[float, float]]) # Ensure that the values in replace_with are tuples of floats updated_dict = {} - for col, val in replace_with.items(): + for col, val in self.replace_with.items(): updated_dict[col] = (float(val[0]), float(val[1])) # Subverting the frozen dataclass to update the replace_with attribute object.__setattr__(self, "replace_with", FrozenDict.from_dict(updated_dict)) - object.__setattr__(self, "child", child) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_replace_infinity(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class DropNullAndNan(QueryExpr): """Returns data with rows that contain null or NaN value dropped. @@ -745,7 +758,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_drop_null_and_nan(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class DropInfinity(QueryExpr): """Returns data with rows that contain +inf/-inf dropped.""" @@ -768,7 +781,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_drop_infinity(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class EnforceConstraint(QueryExpr): """Enforces a constraint on the data.""" @@ -787,7 +800,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_enforce_constraint(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByCount(QueryExpr): """Returns the count of each combination of the groupby domains.""" @@ -803,6 +816,8 @@ class GroupByCount(QueryExpr): By DEFAULT, the framework automatically selects an appropriate mechanism. """ + core_mechanism: Optional[NoiseMechanism] = None + """The Core mechanism used for this aggregation. Specified during compilation.""" def __post_init__(self): """Checks arguments to constructor.""" @@ -818,7 +833,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_count(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByCountDistinct(QueryExpr): """Returns the count of distinct rows in each groupby domain value.""" @@ -838,6 +853,8 @@ class GroupByCountDistinct(QueryExpr): By DEFAULT, the framework automatically selects an appropriate mechanism. """ + core_mechanism: Optional[NoiseMechanism]= None + """The Core mechanism used for this aggregation. Specified during compilation.""" def __post_init__(self): """Checks arguments to constructor.""" @@ -854,7 +871,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_count_distinct(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByQuantile(QueryExpr): """Returns the quantile of a column for each combination of the groupby domains. @@ -915,7 +932,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_quantile(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByBoundedSum(QueryExpr): """Returns the bounded sum of a column for each combination of groupby domains. @@ -947,6 +964,8 @@ class GroupByBoundedSum(QueryExpr): By DEFAULT, the framework automatically selects an appropriate mechanism. """ + core_mechanism: Optional[NoiseMechanism]= None + """The Core mechanism used for this aggregation. Specified during compilation.""" def __post_init__(self): """Checks arguments to constructor.""" @@ -976,7 +995,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_bounded_sum(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByBoundedAverage(QueryExpr): """Returns bounded average of a column for each combination of groupby domains. @@ -1008,6 +1027,8 @@ class GroupByBoundedAverage(QueryExpr): By DEFAULT, the framework automatically selects an appropriate mechanism. """ + core_mechanism: Optional[NoiseMechanism]= None + """The Core mechanism used for this aggregation. Specified during compilation.""" def __post_init__(self): """Checks arguments to constructor.""" @@ -1037,7 +1058,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_bounded_average(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByBoundedVariance(QueryExpr): """Returns bounded variance of a column for each combination of groupby domains. @@ -1069,6 +1090,8 @@ class GroupByBoundedVariance(QueryExpr): By DEFAULT, the framework automatically selects an appropriate mechanism. """ + core_mechanism: Optional[NoiseMechanism]= None + """The Core mechanism used for this aggregation. Specified during compilation.""" def __post_init__(self): """Checks arguments to constructor.""" @@ -1098,7 +1121,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_bounded_variance(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class GroupByBoundedSTDEV(QueryExpr): """Returns bounded stdev of a column for each combination of groupby domains. @@ -1130,6 +1153,8 @@ class GroupByBoundedSTDEV(QueryExpr): By DEFAULT, the framework automatically selects an appropriate mechanism. """ + core_mechanism: Optional[NoiseMechanism] = None + """The Core mechanism used for this aggregation. Specified during compilation.""" def __post_init__(self): """Checks arguments to constructor.""" @@ -1160,7 +1185,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_groupby_bounded_stdev(self) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SuppressAggregates(QueryExpr): """Remove all counts that are less than the threshold.""" diff --git a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py index 19dee7d3..1edf5ed1 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py +++ b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py @@ -658,129 +658,6 @@ def _validate_approxDP_and_adjust_budget( else: raise AnalyticsInternalError(f"Unknown mechanism {mechanism}.") - def _pick_noise_for_count( - self, query: Union[GroupByCount, GroupByCountDistinct] - ) -> NoiseMechanism: - """Pick the noise mechanism to use for a count or count-distinct query.""" - requested_mechanism: NoiseMechanism - if query.mechanism in (CountMechanism.DEFAULT, CountDistinctMechanism.DEFAULT): - if isinstance(self.output_measure, (PureDP, ApproxDP)): - requested_mechanism = NoiseMechanism.LAPLACE - else: # output measure is RhoZCDP - requested_mechanism = NoiseMechanism.DISCRETE_GAUSSIAN - elif query.mechanism in ( - CountMechanism.LAPLACE, - CountDistinctMechanism.LAPLACE, - ): - requested_mechanism = NoiseMechanism.LAPLACE - elif query.mechanism in ( - CountMechanism.GAUSSIAN, - CountDistinctMechanism.GAUSSIAN, - ): - requested_mechanism = NoiseMechanism.DISCRETE_GAUSSIAN - else: - raise ValueError( - f"Did not recognize the mechanism name {query.mechanism}." - " Supported mechanisms are DEFAULT, LAPLACE, and GAUSSIAN." - ) - - if requested_mechanism == NoiseMechanism.LAPLACE: - return NoiseMechanism.GEOMETRIC - elif requested_mechanism == NoiseMechanism.DISCRETE_GAUSSIAN: - return NoiseMechanism.DISCRETE_GAUSSIAN - else: - # This should never happen - raise AnalyticsInternalError( - f"Did not recognize the requested mechanism {requested_mechanism}." - ) - - def _pick_noise_for_non_count( - self, - query: Union[ - GroupByBoundedAverage, - GroupByBoundedSTDEV, - GroupByBoundedSum, - GroupByBoundedVariance, - ], - ) -> NoiseMechanism: - """Pick the noise mechanism for non-count queries. - - GroupByQuantile and GetBounds only supports one noise mechanism, so it is not - included here. - """ - measure_column_type = query.child.accept(OutputSchemaVisitor(self.catalog))[ - query.measure_column - ].column_type - requested_mechanism: NoiseMechanism - if query.mechanism in ( - SumMechanism.DEFAULT, - AverageMechanism.DEFAULT, - VarianceMechanism.DEFAULT, - StdevMechanism.DEFAULT, - ): - requested_mechanism = ( - NoiseMechanism.LAPLACE - if isinstance(self.output_measure, (PureDP, ApproxDP)) - else NoiseMechanism.GAUSSIAN - ) - elif query.mechanism in ( - SumMechanism.LAPLACE, - AverageMechanism.LAPLACE, - VarianceMechanism.LAPLACE, - StdevMechanism.LAPLACE, - ): - requested_mechanism = NoiseMechanism.LAPLACE - elif query.mechanism in ( - SumMechanism.GAUSSIAN, - AverageMechanism.GAUSSIAN, - VarianceMechanism.GAUSSIAN, - StdevMechanism.GAUSSIAN, - ): - requested_mechanism = NoiseMechanism.GAUSSIAN - else: - raise ValueError( - f"Did not recognize requested mechanism {query.mechanism}." - " Supported mechanisms are DEFAULT, LAPLACE, and GAUSSIAN." - ) - - # If the query requested a Laplace measure ... - if requested_mechanism == NoiseMechanism.LAPLACE: - if measure_column_type == ColumnType.INTEGER: - return NoiseMechanism.GEOMETRIC - elif measure_column_type == ColumnType.DECIMAL: - return NoiseMechanism.LAPLACE - else: - raise AssertionError( - "Query's measure column should be numeric. This should" - " not happen and is probably a bug; please let us know" - " so we can fix it!" - ) - - # If the query requested a Gaussian measure... - elif requested_mechanism == NoiseMechanism.GAUSSIAN: - if isinstance(self.output_measure, PureDP): - raise ValueError( - "Gaussian noise is not supported under PureDP. " - "Please use RhoZCDP or another measure." - ) - if measure_column_type == ColumnType.DECIMAL: - return NoiseMechanism.GAUSSIAN - elif measure_column_type == ColumnType.INTEGER: - return NoiseMechanism.DISCRETE_GAUSSIAN - else: - raise AssertionError( - "Query's measure column should be numeric. This should" - " not happen and is probably a bug; please let us know" - " so we can fix it!" - ) - - # The requested_mechanism should be either LAPLACE or - # GAUSSIAN, so something has gone awry - else: - raise AnalyticsInternalError( - f"Did not recognize requested mechanism {requested_mechanism}." - ) - def _add_special_value_handling_to_query( self, query: Union[ @@ -1058,7 +935,7 @@ def visit_groupby_count(self, expr: GroupByCount) -> Tuple[Measurement, NoiseInf self.adjusted_budget ) - mechanism = self._pick_noise_for_count(expr) + mechanism = expr.core_mechanism child_transformation, child_ref = self._truncate_table( *self._visit_child_transformation(expr.child, mechanism), grouping_columns=groupby_cols, @@ -1142,7 +1019,7 @@ def visit_groupby_count_distinct( self.adjusted_budget ) - mechanism = self._pick_noise_for_count(expr) + mechanism = expr.core_mechanism ( child_transformation, child_ref, @@ -1359,8 +1236,8 @@ def visit_groupby_bounded_sum( self.adjusted_budget ) - mechanism = self._pick_noise_for_non_count(expr) lower, upper = _get_query_bounds(expr) + mechanism = expr.core_mechanism child_transformation, child_ref = self._truncate_table( *self._visit_child_transformation(expr.child, mechanism), @@ -1456,10 +1333,10 @@ def visit_groupby_bounded_average( ) lower, upper = _get_query_bounds(expr) - mechanism = self._pick_noise_for_non_count(expr) + mechanism = expr.core_mechanism child_transformation, child_ref = self._truncate_table( - *self._visit_child_transformation(expr.child, self.default_mechanism), + *self._visit_child_transformation(expr.child, mechanism), grouping_columns=groupby_cols, ) transformation = get_table_from_ref(child_transformation, child_ref) @@ -1552,7 +1429,7 @@ def visit_groupby_bounded_variance( ) lower, upper = _get_query_bounds(expr) - mechanism = self._pick_noise_for_non_count(expr) + mechanism = expr.core_mechanism child_transformation, child_ref = self._truncate_table( *self._visit_child_transformation(expr.child, mechanism), @@ -1648,7 +1525,7 @@ def visit_groupby_bounded_stdev( ) lower, upper = _get_query_bounds(expr) - mechanism = self._pick_noise_for_non_count(expr) + mechanism = expr.core_mechanism child_transformation, child_ref = self._truncate_table( *self._visit_child_transformation(expr.child, mechanism), diff --git a/src/tmlt/analytics/_query_expr_compiler/_compiler.py b/src/tmlt/analytics/_query_expr_compiler/_compiler.py index 01d0fffe..8efdcd68 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_compiler.py +++ b/src/tmlt/analytics/_query_expr_compiler/_compiler.py @@ -20,11 +20,12 @@ from tmlt.analytics import AnalyticsInternalError from tmlt.analytics._catalog import Catalog from tmlt.analytics._noise_info import NoiseInfo -from tmlt.analytics._query_expr import QueryExpr +from tmlt.analytics._query_expr import CompilationInfo, QueryExpr from tmlt.analytics._query_expr_compiler._measurement_visitor import MeasurementVisitor from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( OutputSchemaVisitor, ) +from tmlt.analytics._query_expr_compiler._rewrite_rules import rewrite from tmlt.analytics._query_expr_compiler._transformation_visitor import ( TransformationVisitor, ) @@ -139,6 +140,14 @@ def __call__( catalog: The catalog, used only for query validation. table_constraints: A mapping of tables to the existing constraints on them. """ + # First, apply rewrite rules. + compilation_info = CompilationInfo( + output_measure=self._output_measure, + catalog=catalog, + ) + query = rewrite(compilation_info, query) + + # Then, visit the query. visitor = MeasurementVisitor( privacy_budget=privacy_budget, stability=stability, @@ -150,8 +159,8 @@ def __call__( catalog=catalog, table_constraints=table_constraints, ) - measurement, noise_info = query.accept(visitor) + if not isinstance(measurement, Measurement): raise AnalyticsInternalError("This query did not create a measurement.") diff --git a/src/tmlt/analytics/query_builder.py b/src/tmlt/analytics/query_builder.py index 8eb802c8..a3f8d1bf 100644 --- a/src/tmlt/analytics/query_builder.py +++ b/src/tmlt/analytics/query_builder.py @@ -288,7 +288,7 @@ def __init__(self, source_id: str): source_id: The source id used in the query_expr. """ self._source_id: str = source_id - self._query_expr: QueryExpr = PrivateSource(source_id) + self._query_expr: QueryExpr = PrivateSource(source_id=source_id) def clone(self) -> QueryBuilder: # noqa: D102 # Returns a new QueryBuilder with the same partial query as the current one. @@ -586,11 +586,11 @@ def join_private( if isinstance(right_operand, str): right_operand = QueryBuilder(right_operand) self._query_expr = JoinPrivate( - self._query_expr, - right_operand._query_expr, - truncation_strategy_left, - truncation_strategy_right, - tuple(join_columns) if join_columns is not None else None, + child=self._query_expr, + right_operand_expr=right_operand._query_expr, + truncation_strategy_left=truncation_strategy_left, + truncation_strategy_right=truncation_strategy_right, + join_columns=tuple(join_columns) if join_columns is not None else None, ) return self @@ -1735,7 +1735,9 @@ def enforce(self, constraint: Constraint) -> "QueryBuilder": constraint: The constraint to enforce. """ self._query_expr = EnforceConstraint( - self._query_expr, constraint, options=FrozenDict.from_dict({}) + child=self._query_expr, + constraint=constraint, + options=FrozenDict.from_dict({}) ) return self diff --git a/test/system/session/rows/conftest.py b/test/system/session/rows/conftest.py index 92ffa308..4460a9b2 100644 --- a/test/system/session/rows/conftest.py +++ b/test/system/session/rows/conftest.py @@ -65,7 +65,7 @@ # (Geometric noise gets applied if PureDP; Gaussian noise gets applied if ZCDP) QueryBuilder("private").count(name="total"), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", ), @@ -75,7 +75,7 @@ # (Geometric noise gets applied if PureDP; Gaussian noise gets applied if ZCDP) QueryBuilder("private").count_distinct(name="total"), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", ), @@ -84,7 +84,7 @@ ( # Total with LAPLACE (Geometric noise gets applied) QueryBuilder("private").count(name="total", mechanism=CountMechanism.LAPLACE), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", mechanism=CountMechanism.LAPLACE, @@ -96,7 +96,7 @@ name="total", mechanism=CountDistinctMechanism.LAPLACE ), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", mechanism=CountDistinctMechanism.LAPLACE, @@ -108,7 +108,7 @@ .groupby(KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]})) .count(), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]}), ), pd.DataFrame( @@ -120,7 +120,7 @@ .groupby(KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]})) .count_distinct(), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]}), ), pd.DataFrame( @@ -136,7 +136,7 @@ .groupby(KeySet.from_dataframe(GET_GROUPBY_TWO_COLUMNS())) .count(), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dataframe(GET_GROUPBY_TWO_COLUMNS()), ), pd.DataFrame({"A": ["0", "0", "1"], "B": [0, 1, 1], "count": [2, 1, 0]}), @@ -146,7 +146,7 @@ .groupby(KeySet.from_dataframe(GET_GROUPBY_TWO_COLUMNS())) .count_distinct(), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dataframe(GET_GROUPBY_TWO_COLUMNS()), ), pd.DataFrame( @@ -158,7 +158,7 @@ .groupby(KeySet.from_dict(GROUPBY_ONE_COLUMN_DICT)) .count(), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict(GROUPBY_ONE_COLUMN_DICT), ), pd.DataFrame({"A": ["0", "1", "2"], "count": [3, 1, 0]}), @@ -168,7 +168,7 @@ .groupby(KeySet.from_dict(GROUPBY_ONE_COLUMN_DICT)) .count_distinct(), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict(GROUPBY_ONE_COLUMN_DICT), ), pd.DataFrame({"A": ["0", "1", "2"], "count_distinct": [3, 1, 0]}), @@ -178,7 +178,7 @@ .groupby(KeySet.from_dataframe(GET_GROUPBY_WITH_DUPLICATES())) .count(), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dataframe(GET_GROUPBY_WITH_DUPLICATES()), ), pd.DataFrame({"A": ["0", "1", "2"], "count": [3, 1, 0]}), @@ -188,7 +188,7 @@ .groupby(KeySet.from_dataframe(GET_GROUPBY_WITH_DUPLICATES())) .count_distinct(), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dataframe(GET_GROUPBY_WITH_DUPLICATES()), ), pd.DataFrame({"A": ["0", "1", "2"], "count_distinct": [3, 1, 0]}), @@ -196,7 +196,7 @@ ( # empty public source QueryBuilder("private").groupby(KeySet.from_dict({})).count(), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count": [4]}), @@ -204,7 +204,7 @@ ( # empty public source QueryBuilder("private").groupby(KeySet.from_dict({})).count_distinct(), GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count_distinct": [4]}), @@ -214,7 +214,7 @@ .groupby(KeySet.from_dict({"A": ["0", "1"]})) .sum(column="X", low=0, high=1, name="sum"), GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0, @@ -232,7 +232,7 @@ child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda _: [{}, {}], max_rows=2, schema_new_columns=Schema({}), @@ -269,7 +269,7 @@ replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], max_rows=1, schema_new_columns=Schema({"Repeat": "INTEGER"}), @@ -314,7 +314,7 @@ replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], max_rows=1, schema_new_columns=Schema( @@ -359,7 +359,7 @@ replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], max_rows=1, schema_new_columns=Schema( @@ -431,7 +431,7 @@ ( # GroupByCount Filter QueryBuilder("private").filter("A == '0'").count(), GroupByCount( - child=Filter(child=PrivateSource("private"), condition="A == '0'"), + child=Filter(child=PrivateSource(source_id="private"), condition="A == '0'"), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count": [3]}), @@ -439,7 +439,7 @@ ( # GroupByCountDistinct Filter QueryBuilder("private").filter("A == '0'").count_distinct(), GroupByCountDistinct( - child=Filter(child=PrivateSource("private"), condition="A == '0'"), + child=Filter(child=PrivateSource(source_id="private"), condition="A == '0'"), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count_distinct": [3]}), @@ -447,7 +447,7 @@ ( # GroupByCount Select QueryBuilder("private").select(["A"]).count(), GroupByCount( - child=Select(child=PrivateSource("private"), columns=tuple(["A"])), + child=Select(child=PrivateSource(source_id="private"), columns=tuple(["A"])), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count": [4]}), @@ -455,7 +455,7 @@ ( # GroupByCountDistinct Select QueryBuilder("private").select(["A"]).count_distinct(), GroupByCountDistinct( - child=Select(child=PrivateSource("private"), columns=tuple(["A"])), + child=Select(child=PrivateSource(source_id="private"), columns=tuple(["A"])), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count_distinct": [2]}), @@ -474,7 +474,7 @@ child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"C": 2 * str(row["B"])}, schema_new_columns=Schema({"C": "VARCHAR"}), augment=True, @@ -501,7 +501,7 @@ child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"C": 2 * str(row["B"])}, schema_new_columns=Schema({"C": "VARCHAR"}), augment=True, @@ -520,7 +520,7 @@ .groupby(KeySet.from_dict({"A+B": [0, 1, 2]})) .count(), GroupByCount( - child=JoinPublic(child=PrivateSource("private"), public_table="public"), + child=JoinPublic(child=PrivateSource(source_id="private"), public_table="public"), groupby_keys=KeySet.from_dict({"A+B": [0, 1, 2]}), ), pd.DataFrame({"A+B": [0, 1, 2], "count": [3, 4, 1]}), @@ -532,7 +532,7 @@ .count(), GroupByCount( child=JoinPublic( - child=PrivateSource("private"), public_table="public", how="left" + child=PrivateSource(source_id="private"), public_table="public", how="left" ), groupby_keys=KeySet.from_dict({"A+B": [0, 1, 2]}), ), @@ -544,7 +544,7 @@ .groupby(KeySet.from_dict({"A+B": [0, 1, 2]})) .count_distinct(), GroupByCountDistinct( - child=JoinPublic(child=PrivateSource("private"), public_table="public"), + child=JoinPublic(child=PrivateSource(source_id="private"), public_table="public"), groupby_keys=KeySet.from_dict({"A+B": [0, 1, 2]}), ), pd.DataFrame({"A+B": [0, 1, 2], "count_distinct": [3, 4, 1]}), @@ -556,7 +556,7 @@ .count(), GroupByCount( child=JoinPublic( - child=PrivateSource("private"), public_table="join_dtypes" + child=PrivateSource(source_id="private"), public_table="join_dtypes" ), groupby_keys=KeySet.from_dict({"DATE": [_DATE1, _DATE2]}), ), @@ -568,7 +568,7 @@ .count_distinct(columns=["DATE"]), GroupByCountDistinct( child=JoinPublic( - child=PrivateSource("private"), public_table="join_dtypes" + child=PrivateSource(source_id="private"), public_table="join_dtypes" ), columns_to_count=tuple(["DATE"]), output_column="count_distinct(DATE)", @@ -631,7 +631,7 @@ .suppress(1), SuppressAggregates( child=GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]}), ), column="count", diff --git a/test/system/session/rows/test_add_max_rows.py b/test/system/session/rows/test_add_max_rows.py index c606bbd8..8d4fbd86 100644 --- a/test/system/session/rows/test_add_max_rows.py +++ b/test/system/session/rows/test_add_max_rows.py @@ -121,7 +121,7 @@ def test_queries_privacy_budget_infinity_puredp( name="total", mechanism=CountMechanism.GAUSSIAN ), GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", mechanism=CountMechanism.GAUSSIAN, @@ -133,7 +133,7 @@ def test_queries_privacy_budget_infinity_puredp( .groupby(KeySet.from_dict({"A": ["0", "1"]})) .stdev(column="B", low=0, high=1, mechanism=StdevMechanism.GAUSSIAN), GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, diff --git a/test/system/session/rows/test_add_max_rows_in_max_groups.py b/test/system/session/rows/test_add_max_rows_in_max_groups.py index 9a67b71e..83f860d1 100644 --- a/test/system/session/rows/test_add_max_rows_in_max_groups.py +++ b/test/system/session/rows/test_add_max_rows_in_max_groups.py @@ -94,7 +94,7 @@ def test_max_rows_per_group_stability(self, spark) -> None: [ ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), mechanism=CountMechanism.LAPLACE, ), @@ -109,7 +109,7 @@ def test_max_rows_per_group_stability(self, spark) -> None: ), ( GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), low=-111, high=234, diff --git a/test/unit/query_expr_compiler/test_measurement_visitor.py b/test/unit/query_expr_compiler/test_measurement_visitor.py index 321271fc..e4e9f669 100644 --- a/test/unit/query_expr_compiler/test_measurement_visitor.py +++ b/test/unit/query_expr_compiler/test_measurement_visitor.py @@ -115,7 +115,7 @@ def chain_to_list(t: ChainTT) -> List[Transformation]: def test_average(lower: float, upper: float) -> None: """Test _get_query_bounds on Average query expr, with lower!=upper.""" average = GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), measure_column="", low=lower, @@ -130,7 +130,7 @@ def test_average(lower: float, upper: float) -> None: def test_stdev(lower: float, upper: float) -> None: """Test _get_query_bounds on STDEV query expr, with lower!=upper.""" stdev = GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), measure_column="", low=lower, @@ -145,7 +145,7 @@ def test_stdev(lower: float, upper: float) -> None: def test_sum(lower: float, upper: float) -> None: """Test _get_query_bounds on Sum query expr, with lower!=upper.""" sum_query = GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), measure_column="", low=lower, @@ -160,7 +160,7 @@ def test_sum(lower: float, upper: float) -> None: def test_variance(lower: float, upper: float) -> None: """Test _get_query_bounds on Variance query expr, with lower!=upper.""" variance = GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), measure_column="", low=lower, @@ -175,7 +175,7 @@ def test_variance(lower: float, upper: float) -> None: def test_quantile(lower: float, upper: float) -> None: """Test _get_query_bounds on Quantile query expr, with lower!=upper.""" quantile = GroupByQuantile( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), measure_column="", low=lower, @@ -974,7 +974,7 @@ def _check_measurement(self, measurement: Measurement): [ ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), mechanism=CountMechanism.DEFAULT, ), @@ -990,7 +990,7 @@ def _check_measurement(self, measurement: Measurement): ), ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), mechanism=CountMechanism.LAPLACE, output_column="count", @@ -1007,7 +1007,7 @@ def _check_measurement(self, measurement: Measurement): ), ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["zero"]}), mechanism=CountMechanism.GAUSSIAN, output_column="custom_count_column", @@ -1024,7 +1024,7 @@ def _check_measurement(self, measurement: Measurement): ), ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), mechanism=CountMechanism.DEFAULT, ), @@ -1040,7 +1040,7 @@ def _check_measurement(self, measurement: Measurement): ), ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), mechanism=CountMechanism.LAPLACE, ), @@ -1071,7 +1071,7 @@ def test_visit_groupby_count( [ ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), mechanism=CountDistinctMechanism.DEFAULT, ), @@ -1087,7 +1087,7 @@ def test_visit_groupby_count( ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), mechanism=CountDistinctMechanism.LAPLACE, output_column="count", @@ -1104,7 +1104,7 @@ def test_visit_groupby_count( ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), columns_to_count=tuple(["A"]), ), @@ -1120,7 +1120,7 @@ def test_visit_groupby_count( ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["zero"]}), mechanism=CountDistinctMechanism.GAUSSIAN, output_column="custom_count_column", @@ -1137,7 +1137,7 @@ def test_visit_groupby_count( ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), mechanism=CountDistinctMechanism.DEFAULT, ), @@ -1153,7 +1153,7 @@ def test_visit_groupby_count( ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), mechanism=CountDistinctMechanism.LAPLACE, ), @@ -1319,7 +1319,7 @@ def test_visit_groupby_quantile( [ ( GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), low=-100, high=100, @@ -1343,7 +1343,7 @@ def test_visit_groupby_quantile( ), ( GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=AverageMechanism.DEFAULT, @@ -1367,7 +1367,7 @@ def test_visit_groupby_quantile( ), ( GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=AverageMechanism.LAPLACE, @@ -1391,7 +1391,7 @@ def test_visit_groupby_quantile( ), ( GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["zero"]}), mechanism=AverageMechanism.DEFAULT, measure_column="B", @@ -1429,7 +1429,7 @@ def test_visit_groupby_bounded_average( [ ( GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), low=-100, high=100, @@ -1449,7 +1449,7 @@ def test_visit_groupby_bounded_average( ), ( GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=SumMechanism.DEFAULT, @@ -1469,7 +1469,7 @@ def test_visit_groupby_bounded_average( ), ( GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=SumMechanism.LAPLACE, @@ -1489,7 +1489,7 @@ def test_visit_groupby_bounded_average( ), ( GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["zero"]}), mechanism=SumMechanism.DEFAULT, measure_column="B", @@ -1523,7 +1523,7 @@ def test_visit_groupby_bounded_sum( [ ( GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), low=-100, high=100, @@ -1551,7 +1551,7 @@ def test_visit_groupby_bounded_sum( ), ( GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=VarianceMechanism.DEFAULT, @@ -1579,7 +1579,7 @@ def test_visit_groupby_bounded_sum( ), ( GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=VarianceMechanism.LAPLACE, @@ -1607,7 +1607,7 @@ def test_visit_groupby_bounded_sum( ), ( GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["zero"]}), mechanism=VarianceMechanism.DEFAULT, measure_column="B", @@ -1649,7 +1649,7 @@ def test_visit_groupby_bounded_variance( [ ( GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), low=-100, high=100, @@ -1677,7 +1677,7 @@ def test_visit_groupby_bounded_variance( ), ( GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=StdevMechanism.DEFAULT, @@ -1705,7 +1705,7 @@ def test_visit_groupby_bounded_variance( ), ( GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), measure_column="X", mechanism=StdevMechanism.LAPLACE, @@ -1733,7 +1733,7 @@ def test_visit_groupby_bounded_variance( ), ( GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["zero"]}), mechanism=StdevMechanism.DEFAULT, measure_column="B", @@ -1773,18 +1773,18 @@ def test_visit_groupby_bounded_stdev( @pytest.mark.parametrize( "query", [ - (PrivateSource("private")), + (PrivateSource(source_id="private")), ( Rename( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict({"A": "A2"}), ) ), - (Filter(child=PrivateSource("private"), condition="B > 2")), - (SelectExpr(child=PrivateSource("private"), columns=tuple(["A"]))), + (Filter(child=PrivateSource(source_id="private"), condition="B > 2")), + (SelectExpr(child=PrivateSource(source_id="private"), columns=tuple(["A"]))), ( Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"C": "c" + str(row["B"])}, schema_new_columns=Schema({"C": "VARCHAR"}), augment=True, @@ -1792,7 +1792,7 @@ def test_visit_groupby_bounded_stdev( ), ( FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": n for n in range(row["B"] + 1)}], schema_new_columns=Schema({"i": "DECIMAL"}), augment=False, @@ -1801,17 +1801,17 @@ def test_visit_groupby_bounded_stdev( ), ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropExcess(3), ) ), - (JoinPublic(child=PrivateSource("private"), public_table="public")), - (ReplaceNullAndNan(child=PrivateSource("private"))), - (ReplaceInfExpr(child=PrivateSource("private"))), - (DropNullAndNan(child=PrivateSource("private"))), - (DropInfExpr(child=PrivateSource("private"))), + (JoinPublic(child=PrivateSource(source_id="private"), public_table="public")), + (ReplaceNullAndNan(child=PrivateSource(source_id="private"))), + (ReplaceInfExpr(child=PrivateSource(source_id="private"))), + (DropNullAndNan(child=PrivateSource(source_id="private"))), + (DropInfExpr(child=PrivateSource(source_id="private"))), ], ) def test_visit_transformations(self, query: QueryExpr): @@ -1825,7 +1825,7 @@ def test_visit_transformations(self, query: QueryExpr): SuppressAggregates( child=GroupByCount( groupby_keys=KeySet.from_dict({}), - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), output_column="count", ), column="count", @@ -1834,7 +1834,7 @@ def test_visit_transformations(self, query: QueryExpr): SuppressAggregates( child=GroupByCount( groupby_keys=KeySet.from_dict({"B": [0, 1]}), - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), output_column="count", ), column="count", @@ -1856,7 +1856,7 @@ def test_visit_suppress_aggregates(self, query: SuppressAggregates) -> None: ( SuppressAggregates( child=GroupByCount( - child=PrivateSource("private_2"), + child=PrivateSource(source_id="private_2"), groupby_keys=KeySet.from_dict( {"A": ["a0", "a1", "a2", "a3"]} ), @@ -1878,7 +1878,7 @@ def test_visit_suppress_aggregates(self, query: SuppressAggregates) -> None: ( SuppressAggregates( child=GroupByCount( - child=PrivateSource("private_2"), + child=PrivateSource(source_id="private_2"), groupby_keys=KeySet.from_dict( {"A": ["a0", "a1", "a2", "a3"]}, ), diff --git a/test/unit/query_expr_compiler/test_output_schema_visitor.py b/test/unit/query_expr_compiler/test_output_schema_visitor.py index 80a47200..6c1c9fe1 100644 --- a/test/unit/query_expr_compiler/test_output_schema_visitor.py +++ b/test/unit/query_expr_compiler/test_output_schema_visitor.py @@ -75,22 +75,22 @@ OUTPUT_SCHEMA_INVALID_QUERY_TESTS = [ ( # Query references public source instead of private source - PrivateSource("public"), + PrivateSource(source_id="public"), "Attempted query on table 'public', which is not a private table", ), ( # JoinPublic has invalid public_id - JoinPublic(child=PrivateSource("private"), public_table="private"), + JoinPublic(child=PrivateSource(source_id="private"), public_table="private"), "Attempted public join on table 'private', which is not a public table", ), ( # JoinPublic references invalid private source JoinPublic( - child=PrivateSource("private_source_not_in_catalog"), public_table="public" + child=PrivateSource(source_id="private_source_not_in_catalog"), public_table="public" ), "Query references nonexistent table 'private_source_not_in_catalog'", ), ( # JoinPublic on columns not common to both tables JoinPublic( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), public_table="public", join_columns=tuple(["B"]), ), @@ -98,36 +98,37 @@ ), ( # JoinPrivate on columns not common to both tables JoinPrivate( - PrivateSource("private"), - Rename(PrivateSource("private"), FrozenDict.from_dict({"B": "Q"})), - TruncationStrategy.DropExcess(1), - TruncationStrategy.DropExcess(1), + child=PrivateSource(source_id="private"), + right_operand_expr=Rename(child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict({"B": "Q"})), + truncation_strategy_left=TruncationStrategy.DropExcess(1), + truncation_strategy_right=TruncationStrategy.DropExcess(1), join_columns=tuple(["B"]), ), "Join columns must be common to both tables", ), ( # JoinPublic on tables with no common columns JoinPublic( - child=Rename(PrivateSource("private"), FrozenDict.from_dict({"A": "Q"})), + child=Rename(child=PrivateSource(source_id="private"), + column_mapper=FrozenDict.from_dict({"A": "Q"})), public_table="public", ), "Tables have no common columns to join on", ), ( # JoinPrivate on tables with no common columns JoinPrivate( - PrivateSource("private"), - Rename( - Select(PrivateSource("private"), tuple(["A"])), - FrozenDict.from_dict({"A": "Z"}), + child=PrivateSource(source_id="private"), + right_operand_expr=Rename( + child=Select(child=PrivateSource(source_id="private"), columns=tuple(["A"])), + column_mapper=FrozenDict.from_dict({"A": "Z"}), ), - TruncationStrategy.DropExcess(1), - TruncationStrategy.DropExcess(1), + truncation_strategy_left=TruncationStrategy.DropExcess(1), + truncation_strategy_right=TruncationStrategy.DropExcess(1), ), "Tables have no common columns to join on", ), ( # JoinPublic on column with mismatched types JoinPublic( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), public_table="public", join_columns=tuple(["A"]), ), @@ -138,13 +139,14 @@ ), ( # JoinPrivate on column with mismatched types JoinPrivate( - PrivateSource("private"), - Rename( - Rename(PrivateSource("private"), FrozenDict.from_dict({"A": "Q"})), - FrozenDict.from_dict({"B": "A"}), + child=PrivateSource(source_id="private"), + right_operand_expr=Rename( + child=Rename(child=PrivateSource(source_id="private"), + column_mapper=FrozenDict.from_dict({"A": "Q"})), + column_mapper=FrozenDict.from_dict({"B": "A"}), ), - TruncationStrategy.DropExcess(1), - TruncationStrategy.DropExcess(1), + truncation_strategy_left=TruncationStrategy.DropExcess(1), + truncation_strategy_right=TruncationStrategy.DropExcess(1), join_columns=tuple(["A"]), ), ( @@ -153,31 +155,31 @@ ), ), ( # Filter on invalid column - Filter(child=PrivateSource("private"), condition="NONEXISTENT>1"), + Filter(child=PrivateSource(source_id="private"), condition="NONEXISTENT>1"), "Invalid filter condition 'NONEXISTENT>1'.*", ), ( # Rename on non-existent column Rename( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict({"NONEXISTENT": "Z"}), ), "Nonexistent columns in rename query: {'NONEXISTENT'}", ), ( # Rename when column exists Rename( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict({"A": "B"}), ), "Cannot rename 'A' to 'B': column 'B' already exists", ), ( # Select non-existent column - Select(child=PrivateSource("private"), columns=tuple(["NONEXISTENT"])), + Select(child=PrivateSource(source_id="private"), columns=tuple(["NONEXISTENT"])), "Nonexistent columns in select query: {'NONEXISTENT'}", ), ( # Nested grouping FlatMap FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": row["X"]} for i in range(row["Repeat"])], schema_new_columns=Schema({"i": "INTEGER"}, grouping_column="i"), augment=True, @@ -196,7 +198,7 @@ ( # FlatMap with inner grouping FlatMap but outer augment=False FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": row["X"]} for i in range(row["Repeat"])], schema_new_columns=Schema({"i": "INTEGER"}, grouping_column="i"), augment=True, @@ -212,7 +214,7 @@ ( # Map with inner grouping FlatMap but outer augment=False Map( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": row["X"]} for i in range(row["Repeat"])], schema_new_columns=Schema({"i": "INTEGER"}, grouping_column="i"), augment=True, @@ -226,7 +228,7 @@ ), ( # ReplaceNullAndNan with a column that doesn't exist ReplaceNullAndNan( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict({"bad": "new_string"}), ), r"Column 'bad' does not exist in this table, available columns are \[.*\]", @@ -234,7 +236,7 @@ ( # ReplaceNullAndNan with bad replacement type ReplaceNullAndNan( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict({"B": "not_an_int"}), ), "Column 'B' cannot have nulls replaced with 'not_an_int', as .* type INTEGER", @@ -242,7 +244,7 @@ ( # ReplaceInfinity with nonexistent column ReplaceInfinity( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict({"wrong": (-1, 1)}), ), r"Column 'wrong' does not exist in this table, available columns are \[.*\]", @@ -250,7 +252,7 @@ ( # ReplaceInfinity with non-decimal column ReplaceInfinity( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict({"A": (-1, 1)}), ), r"Column 'A' has a replacement value provided.*of type VARCHAR \(not DECIMAL\) " @@ -258,17 +260,17 @@ ), ( # DropNullAndNan with column that doesn't exist - DropNullAndNan(child=PrivateSource("private"), columns=tuple(["bad"])), + DropNullAndNan(child=PrivateSource(source_id="private"), columns=tuple(["bad"])), r"Column 'bad' does not exist in this table, available columns are \[.*\]", ), ( # DropInfinity with column that doesn't exist - DropInfinity(child=PrivateSource("private"), columns=tuple(["bad"])), + DropInfinity(child=PrivateSource(source_id="private"), columns=tuple(["bad"])), r"Column 'bad' does not exist in this table, available columns are \[.*\]", ), ( # Type mismatch for the measure column of GroupByQuantile GroupByQuantile( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1, 2]}), measure_column="A", quantile=0.5, @@ -283,7 +285,7 @@ ), ( # Type mismatch for the measure column of GroupByBoundedAverage GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), measure_column="A", low=0.0, @@ -297,7 +299,7 @@ ( # Grouping column is set in a FlatMap but not used in a later GroupBy GroupByCount( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": row["B"]} for i in range(row["Repeat"])], schema_new_columns=Schema({"i": "INTEGER"}, grouping_column="i"), augment=True, @@ -310,7 +312,7 @@ ( # Grouping column is set but not used in a later groupby_public_source GroupByCount( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": row["B"]} for i in range(row["Repeat"])], schema_new_columns=Schema({"i": "INTEGER"}, grouping_column="i"), augment=True, @@ -428,7 +430,7 @@ def test_invalid_group_by_count( ) -> None: """Test invalid measurement QueryExpr.""" with pytest.raises(exception_type, match=expected_error_msg): - GroupByCount(PrivateSource("private"), groupby_keys).accept(self.visitor) + GroupByCount(child=PrivateSource(source_id="private"), groupby_keys=groupby_keys).accept(self.visitor) @pytest.mark.parametrize( "groupby_keys,exception_type,expected_error_msg", @@ -485,20 +487,24 @@ def test_invalid_group_by_aggregations( GroupByBoundedVariance, ]: with pytest.raises(exception_type, match=expected_error_msg): - DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).accept( + DataClass(child=PrivateSource(source_id="private"), + groupby_keys=groupby_keys, measure_column="B", + low=1.0, high=5.0).accept( self.visitor ) with pytest.raises(exception_type, match=expected_error_msg): GroupByQuantile( - PrivateSource("private"), groupby_keys, "B", 0.5, 1.0, 5.0 + child=PrivateSource(source_id="private"), + groupby_keys=groupby_keys, measure_column="B", quantile=0.5, + low=1.0, high=5.0 ).accept(self.visitor) with pytest.raises(exception_type, match=expected_error_msg): GetBounds( - PrivateSource("private"), - groupby_keys, - "B", - "lower_bound", - "upper_bound", + child=PrivateSource(source_id="private"), + groupby_keys=groupby_keys, + measure_column="B", + lower_bound_column="lower_bound", + upper_bound_column="upper_bound", ).accept(self.visitor) @@ -602,7 +608,7 @@ def test_invalid_group_by_count_null( ) -> None: """Test invalid measurement QueryExpr.""" with pytest.raises(exception_type, match=expected_error_msg): - GroupByCount(PrivateSource("private"), groupby_keys).accept(self.visitor) + GroupByCount(child=PrivateSource(source_id="private"), groupby_keys=groupby_keys).accept(self.visitor) @pytest.mark.parametrize( "groupby_keys,exception_type,expected_error_msg", @@ -659,25 +665,28 @@ def test_invalid_group_by_aggregations_null( GroupByBoundedVariance, ]: with pytest.raises(exception_type, match=expected_error_msg): - DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).accept( + DataClass(child=PrivateSource(source_id="private"), + groupby_keys=groupby_keys, measure_column="B", + low=1.0, high=5.0).accept( self.visitor ) with pytest.raises(exception_type, match=expected_error_msg): GroupByQuantile( - PrivateSource("private"), groupby_keys, "B", 0.5, 1.0, 5.0 + child=PrivateSource(source_id="private"), groupby_keys=groupby_keys, + measure_column="B", quantile=0.5, low=1.0, high=5.0 ).accept(self.visitor) with pytest.raises(exception_type, match=expected_error_msg): GetBounds( - PrivateSource("private"), - groupby_keys, - "B", - "lower_bound", - "upper_bound", + child=PrivateSource(source_id="private"), + groupby_keys=groupby_keys, + measure_column="B", + lower_bound_column="lower_bound", + upper_bound_column="upper_bound", ).accept(self.visitor) def test_visit_private_source(self) -> None: """Test visit_private_source.""" - query = PrivateSource("private") + query = PrivateSource(source_id="private") schema = self.visitor.visit_private_source(query) assert ( schema @@ -758,7 +767,7 @@ def test_visit_rename( ) -> None: """Test visit_rename.""" query = Rename( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict(column_mapper), ) schema = self.visitor.visit_rename(query) @@ -767,7 +776,7 @@ def test_visit_rename( @pytest.mark.parametrize("condition", ["B > X", "X < 500", "NOTNULL < 30"]) def test_visit_filter(self, condition: str) -> None: """Test visit_filter.""" - query = Filter(child=PrivateSource("private"), condition=condition) + query = Filter(child=PrivateSource(source_id="private"), condition=condition) schema = self.visitor.visit_filter(query) assert ( schema @@ -798,7 +807,7 @@ def test_visit_filter(self, condition: str) -> None: ) def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None: """Test visit_select.""" - query = Select(child=PrivateSource("private"), columns=tuple(columns)) + query = Select(child=PrivateSource(source_id="private"), columns=tuple(columns)) schema = self.visitor.visit_select(query) assert schema == expected_schema @@ -807,7 +816,7 @@ def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None [ ( Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"NEW": 1234}, schema_new_columns=Schema( {"NEW": ColumnDescriptor(ColumnType.INTEGER, allow_null=False)} @@ -835,7 +844,7 @@ def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None ), ( Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: dict.fromkeys( list(row.keys()) + ["NEW"], list(row.values()) + [1234] ), @@ -865,7 +874,7 @@ def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None ), ( Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {f"_{key}_": val for key, val in row.items()}, schema_new_columns=Schema( { @@ -912,7 +921,7 @@ def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None ), ( Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"ABC": "abc"}, schema_new_columns=Schema( {"ABC": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True)} @@ -923,7 +932,7 @@ def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None ), ( Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"ABC": "abc"}, schema_new_columns=Schema( # This has allow_null=False, but the output schema @@ -946,7 +955,7 @@ def test_visit_map(self, query: Map, expected_schema: Schema) -> None: [ ( FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": i} for i in range(len(row["A"] + 1))], schema_new_columns=Schema( {"i": ColumnDescriptor(ColumnType.INTEGER, allow_null=False)} @@ -975,7 +984,7 @@ def test_visit_map(self, query: Map, expected_schema: Schema) -> None: ), ( FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"i": i} for i in range(len(row["A"] + 1))], schema_new_columns=Schema( {"i": ColumnDescriptor(ColumnType.INTEGER, allow_null=False)} @@ -997,8 +1006,8 @@ def test_visit_flat_map(self, query: FlatMap, expected_schema: Schema) -> None: [ ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("groupby_one_column_private"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="groupby_one_column_private"), truncation_strategy_left=TruncationStrategy.DropExcess(10), truncation_strategy_right=TruncationStrategy.DropExcess(10), ), @@ -1034,7 +1043,7 @@ def test_visit_join_private( [ ( JoinPublic( - child=PrivateSource("private"), public_table="groupby_column_a" + child=PrivateSource(source_id="private"), public_table="groupby_column_a" ), Schema( { @@ -1056,7 +1065,7 @@ def test_visit_join_private( ), ( JoinPublic( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), public_table="groupby_column_a", how="left", ), @@ -1141,8 +1150,8 @@ def test_visit_join_private_nulls(self, left_schema, right_schema, expected_sche catalog.add_private_table("right", right_schema) visitor = OutputSchemaVisitor(catalog) query = JoinPrivate( - child=PrivateSource("left"), - right_operand_expr=PrivateSource("right"), + child=PrivateSource(source_id="left"), + right_operand_expr=PrivateSource(source_id="right"), truncation_strategy_left=TruncationStrategy.DropExcess(1), truncation_strategy_right=TruncationStrategy.DropExcess(1), ) @@ -1203,7 +1212,7 @@ def test_visit_join_public_nulls( catalog.add_private_table("private", private_schema) catalog.add_public_table("public", public_schema) visitor = OutputSchemaVisitor(catalog) - query = JoinPublic(child=PrivateSource("private"), public_table="public") + query = JoinPublic(child=PrivateSource(source_id="private"), public_table="public") result_schema = visitor.visit_join_public(query) assert result_schema == expected_schema @@ -1212,7 +1221,7 @@ def test_visit_join_public_nulls( [ ( ReplaceNullAndNan( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict({}), ), Schema( @@ -1235,7 +1244,7 @@ def test_visit_join_public_nulls( ), ( ReplaceNullAndNan( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict({"A": "", "B": 0}), ), Schema( @@ -1258,7 +1267,7 @@ def test_visit_join_public_nulls( ), ( ReplaceNullAndNan( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), replace_with=FrozenDict.from_dict( { "A": "this_was_null", @@ -1300,7 +1309,7 @@ def test_visit_replace_null_and_nan( "query,expected_schema", [ ( - DropNullAndNan(child=PrivateSource("private"), columns=tuple()), + DropNullAndNan(child=PrivateSource(source_id="private"), columns=tuple()), Schema( { "A": ColumnDescriptor( @@ -1344,7 +1353,7 @@ def test_visit_replace_null_and_nan( ), ( DropNullAndNan( - child=PrivateSource("private"), columns=tuple(["A", "X", "T"]) + child=PrivateSource(source_id="private"), columns=tuple(["A", "X", "T"]) ), Schema( { @@ -1387,7 +1396,7 @@ def test_visit_drop_null_and_nan( "query,expected_schema", [ ( - DropInfinity(child=PrivateSource("private"), columns=tuple()), + DropInfinity(child=PrivateSource(source_id="private"), columns=tuple()), Schema( { "A": ColumnDescriptor( @@ -1430,7 +1439,7 @@ def test_visit_drop_null_and_nan( ), ), ( - DropInfinity(child=PrivateSource("private"), columns=tuple(["X"])), + DropInfinity(child=PrivateSource(source_id="private"), columns=tuple(["X"])), Schema( { "A": ColumnDescriptor( @@ -1473,7 +1482,7 @@ def test_visit_drop_infinity( [ ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["a1", "a2"]}), output_column="count", ), @@ -1486,7 +1495,7 @@ def test_visit_drop_infinity( ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict( {"A": ["a1", "a2"], "NOTNULL": [1, 2]} ), @@ -1507,7 +1516,7 @@ def test_visit_drop_infinity( ), ( GroupByQuantile( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict( {"D": [datetime.date(1980, 1, 1), datetime.date(2000, 1, 1)]} ), @@ -1528,7 +1537,7 @@ def test_visit_drop_infinity( ), ( GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict( {"D": [datetime.date(1980, 1, 1), datetime.date(2000, 1, 1)]} ), @@ -1546,7 +1555,7 @@ def test_visit_drop_infinity( ), ( GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [1, 2, 3]}), measure_column="NOTNULL", low=-100, @@ -1562,7 +1571,7 @@ def test_visit_drop_infinity( ), ( GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["a1", "a2"], "B": [1, 2, 3]}), measure_column="NOTNULL", low=-100, @@ -1581,7 +1590,7 @@ def test_visit_drop_infinity( ), ( GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict( {"A": ["a1", "a2"], "D": [datetime.date(1999, 12, 31)]} ), @@ -1600,7 +1609,7 @@ def test_visit_drop_infinity( ), ( GetBounds( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict( {"A": ["a1", "a2"], "D": [datetime.date(1999, 12, 31)]} ), @@ -1641,7 +1650,7 @@ def test_visit_groupby_get_bounds_partition_selection(self) -> None: ) with config.features.auto_partition_selection.enabled(): query = GetBounds( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=tuple("A"), measure_column="NOTNULL", lower_bound_column="lower_bound", @@ -1655,7 +1664,7 @@ def test_visit_groupby_get_bounds_partition_selection(self) -> None: [ SuppressAggregates( child=GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="count", ), @@ -1664,7 +1673,7 @@ def test_visit_groupby_get_bounds_partition_selection(self) -> None: ), SuppressAggregates( child=GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict( { "A": ["a1", "a2"], diff --git a/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py b/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py index 796043cc..93b6851f 100644 --- a/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py +++ b/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py @@ -95,7 +95,7 @@ def _validate_transform_basics( @pytest.mark.parametrize("source_id", ["ids1", "ids2"]) def test_visit_private_source(self, source_id: str) -> None: """Test generating transformations from a PrivateSource.""" - query = PrivateSource(source_id) + query = PrivateSource(source_id=source_id) transformation, reference, constraints = query.accept(self.visitor) assert reference.path == [TableCollection("ids"), NamedTable(source_id)] assert isinstance(transformation, IdentityTransformation) @@ -103,7 +103,7 @@ def test_visit_private_source(self, source_id: str) -> None: def test_invalid_private_source(self) -> None: """Test that invalid PrivateSource expressions are handled.""" - query = PrivateSource("nonexistent") + query = PrivateSource(source_id="nonexistent") with pytest.raises(ValueError, match="Table 'nonexistent' does not exist"): query.accept(self.visitor) @@ -140,7 +140,7 @@ def test_visit_rename( self, mapper: Dict[str, str], expected_df: DataFrame, grouping_column: str ) -> None: """Test generating transformations from a Rename.""" - query = Rename(PrivateSource("ids1"), FrozenDict.from_dict(mapper)) + query = Rename(child=PrivateSource(source_id="ids1"), column_mapper=FrozenDict.from_dict(mapper)) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics( transformation, reference, query, grouping_column @@ -163,7 +163,7 @@ def test_visit_rename( ) def test_visit_filter(self, filter_expr: str, expected_df: DataFrame) -> None: """Test visit_filter.""" - query = Filter(PrivateSource(source_id="ids1"), filter_expr) + query = Filter(child=PrivateSource(source_id="ids1"), condition=filter_expr) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) self._validate_result(transformation, reference, expected_df) @@ -181,7 +181,7 @@ def test_visit_filter(self, filter_expr: str, expected_df: DataFrame) -> None: ) def test_visit_select(self, columns: List[str], expected_df: DataFrame) -> None: """Test generating transformations from a Select.""" - query = Select(PrivateSource(source_id="ids1"), tuple(columns)) + query = Select(child=PrivateSource(source_id="ids1"), columns=tuple(columns)) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) self._validate_result(transformation, reference, expected_df) @@ -192,9 +192,9 @@ def test_visit_select(self, columns: List[str], expected_df: DataFrame) -> None: [ ( Map( - PrivateSource("ids1"), - lambda row: {"X": 2 * str(row["S"])}, - Schema({"X": "VARCHAR"}), + child=PrivateSource(source_id="ids1"), + f=lambda row: {"X": 2 * str(row["S"])}, + schema_new_columns=Schema({"X": "VARCHAR"}), augment=True, ), pd.DataFrame( @@ -204,9 +204,9 @@ def test_visit_select(self, columns: List[str], expected_df: DataFrame) -> None: ), ( Map( - PrivateSource("ids1"), - lambda row: {"X": 2 * str(row["S"]), "Y": row["I"] + 2 * row["F"]}, - Schema({"X": "VARCHAR", "Y": "DECIMAL"}), + child=PrivateSource(source_id="ids1"), + f=lambda row: {"X": 2 * str(row["S"]), "Y": row["I"] + 2 * row["F"]}, + schema_new_columns=Schema({"X": "VARCHAR", "Y": "DECIMAL"}), augment=True, ), pd.DataFrame( @@ -225,7 +225,8 @@ def test_visit_map(self, query: Map, expected_df: DataFrame) -> None: def test_visit_map_invalid(self) -> None: """Test that invalid Map expressions are handled.""" - query = Map(PrivateSource("ids1"), lambda row: {}, Schema({}), augment=False) + query = Map(child=PrivateSource(source_id="ids1"), f=lambda row: {}, + schema_new_columns=Schema({}), augment=False) with pytest.raises(ValueError, match="Maps on tables.*must be augmenting"): query.accept(self.visitor) @@ -234,7 +235,7 @@ def test_visit_map_invalid(self) -> None: [ ( FlatMap( - child=PrivateSource("ids1"), + child=PrivateSource(source_id="ids1"), f=lambda row: [{"S_is_zero": 1 if row["S"] == "0" else 2}], schema_new_columns=Schema({"S_is_zero": "INTEGER"}), augment=True, @@ -247,7 +248,7 @@ def test_visit_map_invalid(self) -> None: ), ( FlatMap( - child=PrivateSource("ids1"), + child=PrivateSource(source_id="ids1"), f=lambda row: [{"X": n} for n in range(row["I"] + 4)], schema_new_columns=Schema({"X": "INTEGER"}), augment=True, @@ -275,9 +276,9 @@ def test_visit_flat_map(self, query: FlatMap, expected_df: DataFrame) -> None: def test_visit_flatmap_invalid(self) -> None: """Test that invalid FlatMap expressions are handled.""" query = FlatMap( - PrivateSource("ids1"), - lambda row: [{}], - Schema({}), + child=PrivateSource(source_id="ids1"), + f=lambda row: [{}], + schema_new_columns=Schema({}), augment=False, max_rows=1, ) @@ -285,9 +286,9 @@ def test_visit_flatmap_invalid(self) -> None: query.accept(self.visitor) query = FlatMap( - PrivateSource("ids1"), - lambda row: [{"X": row["I"]}], - Schema({"X": "INTEGER"}, "X"), + child=PrivateSource(source_id="ids1"), + f=lambda row: [{"X": row["I"]}], + schema_new_columns=Schema({"X": "INTEGER"}, "X"), augment=True, max_rows=1, ) @@ -295,9 +296,9 @@ def test_visit_flatmap_invalid(self) -> None: query.accept(self.visitor) query = FlatMap( - PrivateSource("ids1"), - lambda row: [{"X": row["I"]}], - Schema({}), + child=PrivateSource(source_id="ids1"), + f=lambda row: [{"X": row["I"]}], + schema_new_columns=Schema({}), augment=True, max_rows=1, ) @@ -309,7 +310,10 @@ def test_visit_flatmap_invalid(self) -> None: [ ( JoinPrivate( - PrivateSource("ids1"), PrivateSource("ids2"), None, None, None + child=PrivateSource(source_id="ids1"), + right_operand_expr=PrivateSource(source_id="ids2"), + truncation_strategy_left=None, + truncation_strategy_right=None, join_columns=None ), pd.DataFrame( [[1, "0", 0, 0.1, DATE1, TIMESTAMP1, "a"]], @@ -318,11 +322,11 @@ def test_visit_flatmap_invalid(self) -> None: ), ( JoinPrivate( - PrivateSource("ids1"), - PrivateSource("ids2"), - None, - None, - tuple(["id"]), + child=PrivateSource(source_id="ids1"), + right_operand_expr=PrivateSource(source_id="ids2"), + truncation_strategy_left=None, + truncation_strategy_right=None, + join_columns=tuple(["id"]), ), pd.DataFrame( [[1, "0", 0, 0, 0.1, DATE1, TIMESTAMP1, "a"]], @@ -331,8 +335,8 @@ def test_visit_flatmap_invalid(self) -> None: ), ( JoinPrivate( - PrivateSource("ids1"), - PrivateSource("ids2"), + child=PrivateSource(source_id="ids1"), + right_operand_expr=PrivateSource(source_id="ids2"), join_columns=tuple(["id"]), ), pd.DataFrame( @@ -355,25 +359,25 @@ def test_visit_join_private( "query", [ JoinPrivate( - PrivateSource("ids1"), - PrivateSource("ids2"), - TruncationStrategy.DropExcess(1), - TruncationStrategy.DropExcess(1), - tuple(["id"]), + child=PrivateSource(source_id="ids1"), + right_operand_expr=PrivateSource(source_id="ids2"), + truncation_strategy_right=TruncationStrategy.DropExcess(1), + truncation_strategy_left=TruncationStrategy.DropExcess(1), + join_columns=tuple(["id"]), ), JoinPrivate( - PrivateSource("ids1"), - PrivateSource("ids2"), - TruncationStrategy.DropExcess(1), - None, - tuple(["id"]), + child=PrivateSource(source_id="ids1"), + right_operand_expr=PrivateSource(source_id="ids2"), + truncation_strategy_right=TruncationStrategy.DropExcess(1), + truncation_strategy_left=None, + join_columns=tuple(["id"]), ), JoinPrivate( - PrivateSource("ids1"), - PrivateSource("ids2"), - None, - TruncationStrategy.DropExcess(1), - tuple(["id"]), + child=PrivateSource(source_id="ids1"), + right_operand_expr=PrivateSource(source_id="ids2"), + truncation_strategy_right=None, + truncation_strategy_left=TruncationStrategy.DropExcess(1), + join_columns=tuple(["id"]), ), ], ) @@ -392,14 +396,16 @@ def test_visit_join_private_raises_warning(self, query) -> None: "query,expected_df", [ ( - JoinPublic(PrivateSource("ids1"), "public", None), + JoinPublic(child=PrivateSource(source_id="ids1"), + public_table="public", join_columns=None), pd.DataFrame( [[1, "0", 0, 0.1, DATE1, TIMESTAMP1, "x"]], columns=["id", "S", "I", "F", "D", "T", "public"], ), ), ( - JoinPublic(PrivateSource("ids1"), "public", tuple(["S"])), + JoinPublic(child=PrivateSource(source_id="ids1"), public_table="public", + join_columns=tuple(["S"])), pd.DataFrame( [[1, "0", 0, 0, 0.1, DATE1, TIMESTAMP1, "x"]], columns=["id", "S", "I_left", "I_right", "F", "D", "T", "public"], @@ -419,7 +425,8 @@ def test_visit_join_public_str( def test_visit_join_public_df(self) -> None: """Test generating transformations from a JoinPublic using a dataframe.""" query = JoinPublic( - PrivateSource("ids1"), self.visitor.public_sources["public"], None + child=PrivateSource(source_id="ids1"), + public_table=self.visitor.public_sources["public"], join_columns=None ) expected_df = pd.DataFrame( [[1, "0", 0, 0.1, DATE1, TIMESTAMP1, "x"]], @@ -478,7 +485,7 @@ def test_visit_replace_null_and_nan( ): """Test generating transformations from a ReplaceNullAndNan.""" query = ReplaceNullAndNan( - PrivateSource("ids_infs_nans"), FrozenDict.from_dict(replace_with) + child=PrivateSource(source_id="ids_infs_nans"), replace_with=FrozenDict.from_dict(replace_with) ) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) @@ -521,7 +528,7 @@ def test_visit_replace_infinity( ): """Test generating transformations from a ReplaceInfinity.""" query = ReplaceInfinity( - PrivateSource("ids_infs_nans"), FrozenDict.from_dict(replace_with) + child=PrivateSource(source_id="ids_infs_nans"), replace_with=FrozenDict.from_dict(replace_with) ) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) @@ -601,7 +608,7 @@ def test_visit_drop_null_and_nan( expected_nan_cols: List[str], ) -> None: """Test generating transformations from a DropNullAndNan.""" - query = DropNullAndNan(PrivateSource("ids"), tuple(query_columns)) + query = DropNullAndNan(child=PrivateSource(source_id="ids"), columns=tuple(query_columns)) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) assert constraints == [] @@ -629,7 +636,7 @@ def test_visit_drop_infinity( self, query_columns: List[str], expected_inf_cols: List[str] ) -> None: """Test generating transformations from a DropInfinity.""" - query = DropInfinity(PrivateSource("ids"), tuple(query_columns)) + query = DropInfinity(child=PrivateSource(source_id="ids"), columns=tuple(query_columns)) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) assert constraints == [] diff --git a/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py b/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py index 31efba60..0c5abf8e 100644 --- a/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py +++ b/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py @@ -227,7 +227,7 @@ def test_visit_invalid_select(self) -> None: [ ( Map( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: {"X": 2 * str(row["S"])}, schema_new_columns=Schema({"X": "VARCHAR"}), augment=True, @@ -239,7 +239,7 @@ def test_visit_invalid_select(self) -> None: ), ( Map( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: {"X": 2 * str(row["S"])}, schema_new_columns=Schema({"X": "VARCHAR"}), augment=False, @@ -262,7 +262,7 @@ def test_visit_map(self, query: Map, expected_df: DataFrame) -> None: [ ( FlatMap( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: [{"S_is_zero": 1 if row["S"] == "0" else 2}], schema_new_columns=Schema({"S_is_zero": "INTEGER"}), augment=True, @@ -275,7 +275,7 @@ def test_visit_map(self, query: Map, expected_df: DataFrame) -> None: ), ( FlatMap( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: [{"i": n for n in range(row["I"] + 1)}], schema_new_columns=Schema({"i": "INTEGER"}), augment=False, @@ -285,7 +285,7 @@ def test_visit_map(self, query: Map, expected_df: DataFrame) -> None: ), ( FlatMap( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: [{"i": n} for n in range(row["I"] + 10)], schema_new_columns=Schema({"i": "INTEGER"}), augment=False, @@ -311,7 +311,7 @@ def test_visit_flat_map_without_grouping( [ ( FlatMap( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: [{"group": 0 if row["F"] == 0 else 17}], schema_new_columns=Schema( {"group": ColumnDescriptor(ColumnType.INTEGER)}, @@ -341,7 +341,7 @@ def test_visit_flat_map_with_grouping( def test_visit_flat_map_invalid(self) -> None: """Test visit_flat_map with invalid query.""" query = FlatMap( - child=PrivateSource("rows1"), + child=PrivateSource(source_id="rows1"), f=lambda row: [{"group": 0 if row["F"] == 0 else 17}], schema_new_columns=Schema( {"group": ColumnDescriptor(ColumnType.INTEGER)}, grouping_column="group" @@ -363,8 +363,8 @@ def test_visit_flat_map_invalid(self) -> None: [ ( JoinPrivate( - child=PrivateSource("rows1"), - right_operand_expr=PrivateSource("rows2"), + child=PrivateSource(source_id="rows1"), + right_operand_expr=PrivateSource(source_id="rows2"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropExcess(10), ), @@ -375,8 +375,8 @@ def test_visit_flat_map_invalid(self) -> None: ), ( JoinPrivate( - child=PrivateSource("rows2"), - right_operand_expr=PrivateSource("rows1"), + child=PrivateSource(source_id="rows2"), + right_operand_expr=PrivateSource(source_id="rows1"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropNonUnique(), join_columns=tuple(["I"]), @@ -428,8 +428,8 @@ class InvalidStrategy(TruncationStrategy.Type): """An invalid truncation strategy.""" query1 = JoinPrivate( - child=PrivateSource("rows1"), - right_operand_expr=PrivateSource("rows2"), + child=PrivateSource(source_id="rows1"), + right_operand_expr=PrivateSource(source_id="rows2"), truncation_strategy_left=InvalidStrategy(), truncation_strategy_right=TruncationStrategy.DropExcess(3), ) @@ -440,8 +440,8 @@ class InvalidStrategy(TruncationStrategy.Type): query1.accept(self.visitor) query2 = JoinPrivate( - child=PrivateSource("rows1"), - right_operand_expr=PrivateSource("rows2"), + child=PrivateSource(source_id="rows1"), + right_operand_expr=PrivateSource(source_id="rows2"), truncation_strategy_left=TruncationStrategy.DropExcess(2), truncation_strategy_right=InvalidStrategy(), ) @@ -449,8 +449,8 @@ class InvalidStrategy(TruncationStrategy.Type): query2.accept(self.visitor) query3 = JoinPrivate( - child=PrivateSource("rows1"), - right_operand_expr=PrivateSource("rows2"), + child=PrivateSource(source_id="rows1"), + right_operand_expr=PrivateSource(source_id="rows2"), truncation_strategy_left=None, truncation_strategy_right=None, ) @@ -602,7 +602,7 @@ def test_visit_replace_null_and_nan( def test_visit_replace_null_and_nan_with_grouping_column(self) -> None: """Test behavior of visit_replace_null_and_nan with IfGroupedBy metric.""" flatmap_query = FlatMap( - child=PrivateSource("rows_infs_nans"), + child=PrivateSource(source_id="rows_infs_nans"), f=lambda row: [{"group": 0 if row["inf"] < 0 else 17}], schema_new_columns=Schema( {"group": ColumnDescriptor(ColumnType.INTEGER, allow_null=True)}, @@ -689,7 +689,7 @@ def test_visit_replace_infinity( def test_visit_drop_null_and_nan_with_grouping_column(self) -> None: """Test behavior of visit_drop_null_and_nan with IfGroupedBy metric.""" flatmap_query = FlatMap( - child=PrivateSource("rows_infs_nans"), + child=PrivateSource(source_id="rows_infs_nans"), f=lambda row: [{"group": 0 if row["inf"] < 0 else 17}], schema_new_columns=Schema( {"group": ColumnDescriptor(ColumnType.INTEGER, allow_null=True)}, @@ -726,7 +726,7 @@ def test_visit_drop_null_and_nan_with_grouping_column(self) -> None: def test_visit_drop_infinity_with_grouping_column(self) -> None: """Test behavior of visit_drop_infinity with IfGroupedBy metric.""" flatmap_query = FlatMap( - child=PrivateSource("rows_infs_nans"), + child=PrivateSource(source_id="rows_infs_nans"), f=lambda row: [{"group": 0 if row["inf"] < 0 else 17}], schema_new_columns=Schema( {"group": ColumnDescriptor(ColumnType.INTEGER, allow_null=True)}, @@ -894,7 +894,7 @@ def test_visit_drop_null_and_nan( expected_nan_cols: List[str], ) -> None: """Test generating transformations from a DropNullAndNan.""" - query = DropNullAndNan(PrivateSource("rows"), tuple(query_columns)) + query = DropNullAndNan(child=PrivateSource(source_id="rows"), columns=tuple(query_columns)) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) assert constraints == [] @@ -921,7 +921,7 @@ def test_visit_drop_infinity( self, query_columns: List[str], expected_inf_cols: List[str] ) -> None: """Test generating transformations from a DropInfinity.""" - query = DropInfExpr(child=PrivateSource("rows"), columns=tuple(query_columns)) + query = DropInfExpr(child=PrivateSource(source_id="rows"), columns=tuple(query_columns)) transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) assert constraints == [] diff --git a/test/unit/query_expr_compiler/transformation_visitor/test_constraints.py b/test/unit/query_expr_compiler/transformation_visitor/test_constraints.py index 37d7118b..3592ff68 100644 --- a/test/unit/query_expr_compiler/transformation_visitor/test_constraints.py +++ b/test/unit/query_expr_compiler/transformation_visitor/test_constraints.py @@ -47,7 +47,8 @@ def _test_is_subset(input_df: pd.DataFrame, result_df: pd.DataFrame): def test_max_rows_per_id(self, constraint_max: int): """Test truncation with MaxRowsPerID.""" constraint = MaxRowsPerID(constraint_max) - query = EnforceConstraint(PrivateSource("ids_duplicates"), constraint) + query = EnforceConstraint(child=PrivateSource(source_id="ids_duplicates"), + constraint=constraint) transformation, ref, constraints = query.accept(self.visitor) assert len(constraints) == 1 assert constraints[0] == constraint @@ -69,7 +70,7 @@ def test_max_rows_per_id(self, constraint_max: int): def test_max_groups_per_id(self, grouping_col: str, constraint_max: int): """Test truncation with MaxGroupsPerID.""" constraint = MaxGroupsPerID(grouping_col, constraint_max) - query = EnforceConstraint(PrivateSource("ids_duplicates"), constraint) + query = EnforceConstraint(child=PrivateSource(source_id="ids_duplicates"), constraint=constraint) transformation, ref, constraints = query.accept(self.visitor) assert len(constraints) == 1 assert constraints[0] == constraint @@ -92,7 +93,8 @@ def test_max_groups_per_id(self, grouping_col: str, constraint_max: int): def test_max_rows_per_group_per_id(self, constraint_max: int, grouping_col: str): """Test truncation with MaxRowsPerGroupPerID.""" constraint = MaxRowsPerGroupPerID(grouping_col, constraint_max) - query = EnforceConstraint(PrivateSource("ids_duplicates"), constraint) + query = EnforceConstraint(child=PrivateSource(source_id="ids_duplicates"), + constraint=constraint) transformation, ref, constraints = query.accept(self.visitor) assert len(constraints) == 1 assert constraints[0] == constraint @@ -116,8 +118,8 @@ def test_l1_update_metric(self, constraint_max: int): """Test L1 truncation with updating metric.""" constraint = MaxRowsPerID(constraint_max) query = EnforceConstraint( - PrivateSource("ids_duplicates"), - constraint, + child=PrivateSource(source_id="ids_duplicates"), + constraint=constraint, options=FrozenDict.from_dict({"update_metric": True}), ) transformation, ref, constraints = query.accept(self.visitor) @@ -147,12 +149,12 @@ def test_l0_linf_update_metric( ): """Test L0 + L-inf truncation with updating metric.""" query = EnforceConstraint( - EnforceConstraint( - PrivateSource("ids_duplicates"), - MaxGroupsPerID(grouping_col, group_max), + child=EnforceConstraint( + child=PrivateSource(source_id="ids_duplicates"), + constraint=MaxGroupsPerID(grouping_col, group_max), options=FrozenDict.from_dict({"update_metric": True}), ), - MaxRowsPerGroupPerID(grouping_col, row_max), + constraint=MaxRowsPerGroupPerID(grouping_col, row_max), options=FrozenDict.from_dict({"update_metric": True}), ) transformation, ref, constraints = query.accept(self.visitor) diff --git a/test/unit/test_query_expr_compiler.py b/test/unit/test_query_expr_compiler.py index 9d2d5d8b..27bfabf0 100644 --- a/test/unit/test_query_expr_compiler.py +++ b/test/unit/test_query_expr_compiler.py @@ -85,7 +85,7 @@ QUERY_EXPR_COMPILER_TESTS = [ ( # Total GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", ), @@ -93,7 +93,7 @@ ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", ), @@ -101,7 +101,7 @@ ), ( # Full marginal from domain description GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]}), ), pd.DataFrame( @@ -110,21 +110,21 @@ ), ( # Incomplete two-column marginal with a dataframe GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dataframe(GET_GROUPBY_TWO()), ), pd.DataFrame({"A": ["0", "0", "1"], "B": [0, 1, 1], "count": [2, 1, 0]}), ), ( # One-column marginal with additional value GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict(GROUPBY_ONE_DICT), ), pd.DataFrame({"A": ["0", "1", "2"], "count": [3, 1, 0]}), ), ( # BoundedAverage GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -134,7 +134,7 @@ ), ( # BoundedSTDEV GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -144,7 +144,7 @@ ), ( # BoundedVariance GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -155,7 +155,7 @@ ), ( # BoundedSum GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -166,14 +166,14 @@ ), ( # Marginal over A GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame({"A": ["0", "1"], "count": [3, 1]}), ), ( # Marginal over B GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"B": [0, 1]}), ), pd.DataFrame({"B": [0, 1], "count": [3, 1]}), @@ -183,7 +183,7 @@ child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{}, {}], schema_new_columns=Schema({}), augment=True, @@ -203,7 +203,7 @@ replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], schema_new_columns=Schema({"Repeat": "INTEGER"}), augment=True, @@ -228,7 +228,7 @@ replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], schema_new_columns=Schema( {"Repeat": "INTEGER"}, grouping_column="Repeat" @@ -251,7 +251,7 @@ ), ( # Filter GroupByCount( - child=Filter(child=PrivateSource("private"), condition="A == '0'"), + child=Filter(child=PrivateSource(source_id="private"), condition="A == '0'"), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count": [3]}), @@ -259,7 +259,7 @@ ( # Rename GroupByCount( child=Rename( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict({"A": "Z"}), ), groupby_keys=KeySet.from_dict({"Z": ["0", "1"]}), @@ -268,7 +268,7 @@ ), ( # Select GroupByCount( - child=Select(child=PrivateSource("private"), columns=tuple(["A"])), + child=Select(child=PrivateSource(source_id="private"), columns=tuple(["A"])), groupby_keys=KeySet.from_dict({}), ), pd.DataFrame({"count": [4]}), @@ -278,7 +278,7 @@ child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=Map( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: {"C": 2 * str(row["B"])}, schema_new_columns=Schema({"C": "VARCHAR"}), augment=True, @@ -293,7 +293,7 @@ ), ( # JoinPublic GroupByCount( - child=JoinPublic(child=PrivateSource("private"), public_table="public"), + child=JoinPublic(child=PrivateSource(source_id="private"), public_table="public"), groupby_keys=KeySet.from_dict({"A+B": [0, 1, 2]}), ), pd.DataFrame({"A+B": [0, 1, 2], "count": [2, 2, 0]}), @@ -301,7 +301,7 @@ ( # JoinPublic with One Join Column GroupByCount( child=JoinPublic( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), public_table="public", join_columns=tuple(["A"]), ), @@ -312,84 +312,89 @@ # Tests on less-common data types ( GroupByCount( - JoinPublic(PrivateSource("private"), "dtypes"), - KeySet.from_dict({"A": ["0", "1"]}), + child=JoinPublic(child=PrivateSource(source_id="private"), public_table="dtypes"), + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame({"A": ["0", "1"], "count": [3, 1]}), ), ( GroupByCount( - Filter( - JoinPublic(PrivateSource("private"), "dtypes"), - "date < '2022-01-02'", + child=Filter( + child=JoinPublic(child=PrivateSource(source_id="private"), + public_table="dtypes"), + condition="date < '2022-01-02'", ), - KeySet.from_dict({"A": ["0", "1"]}), + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame({"A": ["0", "1"], "count": [3, 0]}), ), ( GroupByCount( - Filter( - JoinPublic(PrivateSource("private"), "dtypes"), - "date = '2022-01-02'", + child=Filter( + child=JoinPublic(child=PrivateSource(source_id="private"), + public_table="dtypes"), + condition="date = '2022-01-02'", ), - KeySet.from_dict({"A": ["0", "1"]}), + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame({"A": ["0", "1"], "count": [0, 1]}), ), ( GroupByCount( - Filter( - JoinPublic(PrivateSource("private"), "dtypes"), - "timestamp < '2022-01-01T12:40:00'", + child=Filter( + child=JoinPublic(child=PrivateSource(source_id="private"), + public_table="dtypes"), + condition="timestamp < '2022-01-01T12:40:00'", ), - KeySet.from_dict({"A": ["0", "1"]}), + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame({"A": ["0", "1"], "count": [3, 0]}), ), ( GroupByCount( - Filter( - JoinPublic(PrivateSource("private"), "dtypes"), - "timestamp >= '2022-01-01T12:45:00'", + child=Filter( + child=JoinPublic(child=PrivateSource(source_id="private"), public_table="dtypes"), + condition="timestamp >= '2022-01-01T12:45:00'", ), - KeySet.from_dict({"A": ["0", "1"]}), + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame({"A": ["0", "1"], "count": [0, 1]}), ), ( GroupByBoundedSum( - ReplaceNullAndNan( + child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=Map( - JoinPublic(PrivateSource("private"), "dtypes"), - lambda row: {"day": row["date"].day}, - Schema({"day": ColumnDescriptor(ColumnType.INTEGER)}), + child=JoinPublic(child=PrivateSource(source_id="private"), + public_table="dtypes"), + f=lambda row: {"day": row["date"].day}, + schema_new_columns=Schema({"day": ColumnDescriptor(ColumnType.INTEGER)}), augment=True, ), ), - KeySet.from_dict({"A": ["0", "1"]}), - "day", - 0, - 2, + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), + measure_column="day", + low=0, + high=2, ), pd.DataFrame({"A": ["0", "1"], "sum": [3, 2]}), ), ( GroupByBoundedSum( - ReplaceNullAndNan( + child=ReplaceNullAndNan( replace_with=FrozenDict.from_dict({}), child=Map( - JoinPublic(PrivateSource("private"), "dtypes"), - lambda row: {"minute": row["timestamp"].minute}, - Schema({"minute": ColumnDescriptor(ColumnType.INTEGER)}), + child=JoinPublic(child=PrivateSource(source_id="private"), + public_table="dtypes"), + f=lambda row: {"minute": row["timestamp"].minute}, + schema_new_columns=Schema({"minute": ColumnDescriptor(ColumnType.INTEGER)}), augment=True, ), ), - KeySet.from_dict({"A": ["0", "1"]}), - "minute", - 0, - 59, + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), + measure_column="minute", + low=0, + high=59, ), pd.DataFrame({"A": ["0", "1"], "sum": [90, 45]}), ), @@ -575,7 +580,7 @@ class TestQueryExprCompiler: [ ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", ), @@ -583,7 +588,7 @@ class TestQueryExprCompiler: ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="distinct", columns_to_count=tuple(["B"]), @@ -592,14 +597,14 @@ class TestQueryExprCompiler: ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), ), pd.DataFrame([["0", 3], ["1", 1]], columns=["A", "count_distinct"]), ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), columns_to_count=tuple(["B"]), ), @@ -607,14 +612,14 @@ class TestQueryExprCompiler: ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict(GROUPBY_ONE_DICT), ), pd.DataFrame({"A": ["0", "1", "2"], "count_distinct": [3, 1, 0]}), ), ( GroupByCountDistinct( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict(GROUPBY_ONE_DICT), columns_to_count=tuple(["B"]), ), @@ -684,7 +689,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): [ ( # Total with LAPLACE (Geometric noise gets applied) GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", mechanism=CountMechanism.LAPLACE, @@ -694,7 +699,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # Total with GAUSSIAN GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), output_column="total", mechanism=CountMechanism.GAUSSIAN, @@ -704,7 +709,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedAverage on floating-point valued measure column with LAPLACE GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -716,7 +721,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedAverage with integer valued measure column with LAPLACE GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -728,7 +733,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedAverage with integer valued measure column with GAUSSIAN GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -740,7 +745,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedSTDEV on floating-point valued measure column with LAPLACE GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -752,7 +757,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedSTDEV on integer valued measure column with LAPLACE GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -764,7 +769,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedSTDEV on integer valued measure column with GAUSSIAN GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -776,7 +781,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedVariance on floating-point valued measure column with LAPLACE GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -789,7 +794,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedVariance on integer valued measure column with LAPLACE GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -802,7 +807,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedVariance on integer valued measure column with GAUSSIAN GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -815,7 +820,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedSum on floating-point valued measure column with LAPLACE GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -828,7 +833,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedSum on integer valued measure column with LAPLACE GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -841,7 +846,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ), ( # BoundedSum on integer valued measure column with GAUSSIAN GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="B", low=0, @@ -858,7 +863,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], schema_new_columns=Schema( {"Repeat": "INTEGER"}, grouping_column="Repeat" @@ -886,7 +891,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ( # BoundedAverage with floating-point valued measure column with GAUSSIAN [ GroupByBoundedAverage( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -900,7 +905,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ( # BoundedSTDEV on floating-point valued measure column with GAUSSIAN [ GroupByBoundedSTDEV( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -914,7 +919,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ( # BoundedVariance on floating-point valued measure column with GAUSSIAN [ GroupByBoundedVariance( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -929,7 +934,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): ( # BoundedSum on floating-point valued measure column with GAUSSIAN [ GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -948,7 +953,7 @@ def test_queries(self, query_expr: QueryExpr, expected: pd.DataFrame): replace_with=FrozenDict.from_dict({}), child=FlatMap( child=FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [ {"Repeat": 1 if row["A"] == "0" else 2} ], @@ -1015,7 +1020,7 @@ def test_join_public_dataframe(self, spark): ).fillna(0) transformation, reference, _constraints = self.compiler.build_transformation( - JoinPublic(PrivateSource("private"), public_sdf), + JoinPublic(child=PrivateSource(source_id="private"), public_table=public_sdf), input_domain=self.input_domain, input_metric=self.input_metric, public_sources={}, @@ -1048,8 +1053,8 @@ def test_join_private(self, spark): ) transformation, reference, _constraints = self.compiler.build_transformation( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropExcess(3), ), @@ -1087,8 +1092,8 @@ def test_join_private(self, spark): [ ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropExcess(3), ), @@ -1096,8 +1101,8 @@ def test_join_private(self, spark): ), ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropExcess(1), ), @@ -1105,8 +1110,8 @@ def test_join_private(self, spark): ), ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropExcess(1), truncation_strategy_right=TruncationStrategy.DropExcess(1), ), @@ -1114,8 +1119,8 @@ def test_join_private(self, spark): ), ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropExcess(3), truncation_strategy_right=TruncationStrategy.DropNonUnique(), ), @@ -1123,8 +1128,8 @@ def test_join_private(self, spark): ), ( JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=TruncationStrategy.DropNonUnique(), truncation_strategy_right=TruncationStrategy.DropNonUnique(), ), @@ -1156,8 +1161,8 @@ class Strategy(TruncationStrategy.Type): """An invalid truncation strategy.""" query = JoinPrivate( - child=PrivateSource("private"), - right_operand_expr=PrivateSource("private_2"), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private_2"), truncation_strategy_left=Strategy(), truncation_strategy_right=Strategy(), ) @@ -1179,7 +1184,7 @@ class Strategy(TruncationStrategy.Type): [ ( FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda _: [{"G": "a"}, {"G": "b"}], schema_new_columns=Schema({"G": "VARCHAR"}, grouping_column="G"), augment=True, @@ -1190,7 +1195,7 @@ class Strategy(TruncationStrategy.Type): ), ( FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda _: [{"G": "a"}, {"G": "b"}], schema_new_columns=Schema({"G": "VARCHAR"}), augment=True, @@ -1229,7 +1234,7 @@ def test_float_groupby_sum(self, spark): ) ) query_expr = GroupByBoundedSum( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), measure_column="X", low=0.0, @@ -1257,7 +1262,7 @@ def test_float_groupby_sum(self, spark): ( # Top-level query needs to be instance of measurement QueryExpr FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{}, {}], schema_new_columns=Schema({}), augment=True, @@ -1267,7 +1272,7 @@ def test_float_groupby_sum(self, spark): ( # Query's child has to be transformation QueryExpr GroupByBoundedSum( child=GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"A": ["0", "1"], "B": [0, 1]}), ), groupby_keys=KeySet.from_dict({}), @@ -1297,13 +1302,13 @@ def test_invalid_queries(self, query_expr: QueryExpr): [ ( GroupByCount( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({}), ) ), ( GroupByCount( - child=PrivateSource("doubled"), + child=PrivateSource(source_id="doubled"), groupby_keys=KeySet.from_dict({}), ) ), @@ -1406,7 +1411,7 @@ def test_compile_groupby_quantile( ) query_expr = GroupByQuantile( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), groupby_keys=KeySet.from_dict({"Gender": ["M", "F"]}), measure_column="Age", quantile=0.5, diff --git a/test/unit/test_query_expression.py b/test/unit/test_query_expression.py index 09fcca4b..b8ac62e6 100644 --- a/test/unit/test_query_expression.py +++ b/test/unit/test_query_expression.py @@ -74,7 +74,7 @@ def test_invalid_private_source( ): """Tests that invalid private source errors on post-init.""" with pytest.raises(exception_type, match=expected_error_msg): - PrivateSource(invalid_source_id) + PrivateSource(source_id=invalid_source_id) @pytest.mark.parametrize( @@ -87,7 +87,7 @@ def test_invalid_private_source( def test_invalid_rename(column_mapper: Dict[str, str]): """Tests that invalid Rename errors on post-init.""" with pytest.raises(TypeCheckError): - Rename(PrivateSource("private"), FrozenDict.from_dict(column_mapper)) + Rename(child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict(column_mapper)) def test_invalid_rename_empty_string(): @@ -99,13 +99,13 @@ def test_invalid_rename_empty_string(): " are not allowed" ), ): - Rename(PrivateSource("private"), FrozenDict.from_dict({"A": ""})) + Rename(child=PrivateSource(source_id="private"), column_mapper=FrozenDict.from_dict({"A": ""})) def test_invalid_filter(): """Tests that invalid Filter errors on post-init.""" with pytest.raises(TypeCheckError): - Filter(PrivateSource("private"), 0) # type: ignore + Filter(child=PrivateSource(source_id="private"), condition=0) # type: ignore @pytest.mark.parametrize( @@ -121,7 +121,7 @@ def test_invalid_select( ): """Tests that invalid Rename errors on post-init.""" with pytest.raises((ValueError, TypeCheckError)): - Select(PrivateSource("private"), columns) + Select(child=PrivateSource(source_id="private"), columns=columns) @pytest.mark.parametrize( @@ -152,14 +152,15 @@ def test_invalid_map( ): """Tests that invalid Map errors on post-init.""" with pytest.raises((TypeCheckError, ValueError), match=expected_error_msg): - Map(PrivateSource("private"), func, schema_new_columns, augment) + Map(child=PrivateSource(source_id="private"), f=func, schema_new_columns=schema_new_columns, + augment=augment) @pytest.mark.parametrize( "child,func,max_rows,schema_new_columns,augment,expected_error_msg", [ ( # Invalid max_rows - PrivateSource("private"), + PrivateSource(source_id="private"), lambda row: [{"i": row["X"]} for i in range(row["Repeat"])], -1, Schema({"i": "INTEGER"}), @@ -168,7 +169,7 @@ def test_invalid_map( ), ( # Invalid augment FlatMap( - child=PrivateSource("private"), + child=PrivateSource(source_id="private"), f=lambda row: [{"Repeat": 1 if row["A"] == "0" else 2}], schema_new_columns=Schema({"Repeat": "INTEGER"}), augment=True, @@ -181,7 +182,7 @@ def test_invalid_map( None, ), ( # Invalid grouping result - PrivateSource("private"), + PrivateSource(source_id="private"), lambda row: [{"i": row["X"]} for i in range(row["Repeat"])], 2, Schema({"i": "INTEGER", "j": "INTEGER"}, grouping_column="i"), @@ -203,7 +204,8 @@ def test_invalid_flatmap( ): """Tests that invalid FlatMap errors on post-init.""" with pytest.raises((TypeCheckError, ValueError), match=expected_error_msg): - FlatMap(child, func, schema_new_columns, augment, max_rows) + FlatMap(child=child, f=func, schema_new_columns=schema_new_columns, + augment=augment, max_rows=max_rows) @pytest.mark.parametrize( @@ -219,8 +221,8 @@ def test_invalid_flat_map_by_id( """FlatMapByID raises an exception when given invalid parameters.""" with pytest.raises(expected_exc): FlatMapByID( - PrivateSource("private"), - lambda rows: rows, + child=PrivateSource(source_id="private"), + f=lambda rows: rows, schema_new_columns=schema_new_columns, ) @@ -236,14 +238,15 @@ def test_invalid_join_columns(join_columns: List[str], expected_error_msg: str): """Tests that JoinPrivate, JoinPublic error with invalid join columns.""" with pytest.raises(ValueError, match=expected_error_msg): JoinPrivate( - PrivateSource("private"), - PrivateSource("private2"), - TruncationStrategy.DropExcess(1), - TruncationStrategy.DropExcess(1), - tuple(join_columns), + child=PrivateSource(source_id="private"), + right_operand_expr=PrivateSource(source_id="private2"), + truncation_strategy_left=TruncationStrategy.DropExcess(1), + truncation_strategy_right=TruncationStrategy.DropExcess(1), + join_columns=tuple(join_columns), ) with pytest.raises(ValueError, match=expected_error_msg): - JoinPublic(PrivateSource("private"), "public", tuple(join_columns)) + JoinPublic(child=PrivateSource(source_id="private"), + public_table="public", join_columns=tuple(join_columns)) def test_invalid_how(): @@ -251,7 +254,8 @@ def test_invalid_how(): with pytest.raises( ValueError, match="Invalid join type 'invalid': must be 'inner' or 'left'" ): - JoinPublic(PrivateSource("private"), "public", tuple("A"), how="invalid") + JoinPublic(child=PrivateSource(source_id="private"), public_table="public", + join_columns=tuple("A"), how="invalid") @pytest.mark.parametrize( @@ -269,7 +273,8 @@ def test_invalid_how(): def test_invalid_replace_infinity(replace_with: Any, expected_error_msg: str) -> None: """Test ReplaceInfinity with invalid arguments.""" with pytest.raises((TypeCheckError), match=expected_error_msg): - QueryBuilder("private").replace_infinity(replace_with) + ReplaceInfinity(child=PrivateSource(source_id="private"), + replace_with=FrozenDict.from_dict(replace_with)), @pytest.mark.parametrize( @@ -283,7 +288,7 @@ def test_invalid_replace_infinity(replace_with: Any, expected_error_msg: str) -> def test_invalid_drop_null_and_nan(columns: Any) -> None: """Test DropNullAndNan with invalid arguments.""" with pytest.raises(TypeCheckError): - DropNullAndNan(PrivateSource("private"), columns) + DropNullAndNan(child=PrivateSource(source_id="private"), columns=columns) @pytest.mark.parametrize( @@ -297,14 +302,14 @@ def test_invalid_drop_null_and_nan(columns: Any) -> None: def test_invalid_drop_infinity(columns: Any) -> None: """Test DropInfinity with invalid arguments.""" with pytest.raises(TypeCheckError): - DropInfinity(PrivateSource("private"), columns) + DropInfinity(child=PrivateSource(source_id="private"), columns=columns) @pytest.mark.parametrize( "child,keys,output_column", [ ( - PrivateSource("private"), + PrivateSource(source_id="private"), KeySet.from_dict({}), 123, ) @@ -313,7 +318,7 @@ def test_invalid_drop_infinity(columns: Any) -> None: def test_invalid_groupbycount(child: QueryExpr, keys: KeySet, output_column: str): """Tests that invalid GroupByCount errors on post-init.""" with pytest.raises(TypeCheckError): - GroupByCount(child, keys, output_column) + GroupByCount(child=child, groupby_keys=keys, output_column=output_column) @pytest.mark.parametrize( @@ -353,7 +358,8 @@ def test_invalid_groupbyagg( GroupByBoundedSTDEV, ]: with pytest.raises((TypeCheckError, ValueError), match=expected_error_msg): - DataClass(PrivateSource("private"), keys, measure_column, low, high) + DataClass(child=PrivateSource(source_id="private"), groupby_keys=keys, + measure_column=measure_column, low=low, high=high) @pytest.mark.parametrize( @@ -420,7 +426,8 @@ def test_invalid_groupbyquantile( """Test invalid GroupByQuantile.""" with pytest.raises((TypeCheckError, ValueError), match=expected_error_msg): GroupByQuantile( - PrivateSource("private"), keys, measure_column, quantile, low, high + child=PrivateSource(source_id="private"), groupby_keys=keys, + measure_column=measure_column, quantile=quantile, low=low, high=high ) @@ -430,7 +437,7 @@ def test_invalid_groupbyquantile( @pytest.mark.parametrize("source_id", ["private_source", "_Private", "no_space2"]) def test_valid_private_source(source_id: str): """Tests valid private source does not error.""" - PrivateSource(source_id) + PrivateSource(source_id=source_id) @pytest.mark.parametrize("low,high", [(8.0, 10.0), (1, 10), (1.0, 10)]) @@ -443,11 +450,11 @@ def test_clamping_bounds_casting(low: float, high: float): GroupByBoundedSTDEV, ]: query = DataClass( - PrivateSource("private"), - KeySet.from_dict({"A": ["0", "1"]}), - "B", - low, - high, + child=PrivateSource(source_id="private"), + groupby_keys=KeySet.from_dict({"A": ["0", "1"]}), + measure_column="B", + low=low, + high=high, ) assert isinstance( query, @@ -464,10 +471,10 @@ def test_clamping_bounds_casting(low: float, high: float): @pytest.mark.parametrize( "child,replace_with", [ - (PrivateSource("private"), {"col": "value", "col2": "value2"}), - (PrivateSource("private"), {}), + (PrivateSource(source_id="private"), {"col": "value", "col2": "value2"}), + (PrivateSource(source_id="private"), {}), ( - PrivateSource("private"), + PrivateSource(source_id="private"), { "A": 1, "B": 2.0, @@ -485,22 +492,22 @@ def test_valid_replace_null_and_nan( ], ): """Test ReplaceNullAndNan creation with valid values.""" - ReplaceNullAndNan(child, FrozenDict.from_dict(replace_with)) + ReplaceNullAndNan(child=child, replace_with=FrozenDict.from_dict(replace_with)) @pytest.mark.parametrize( "child,replace_with", [ - (PrivateSource("private"), {}), - (PrivateSource("private"), {"A": (-100.0, 100.0)}), - (PrivateSource("private"), {"A": (-999.9, 999.9), "B": (123.45, 678.90)}), + (PrivateSource(source_id="private"), {}), + (PrivateSource(source_id="private"), {"A": (-100.0, 100.0)}), + (PrivateSource(source_id="private"), {"A": (-999.9, 999.9), "B": (123.45, 678.90)}), ], ) def test_valid_replace_infinity( child: QueryExpr, replace_with: Dict[str, Tuple[float, float]] ) -> None: """Test ReplaceInfinity with valid values.""" - query = ReplaceInfinity(child, FrozenDict.from_dict(replace_with)) + query = ReplaceInfinity(child=child, replace_with=FrozenDict.from_dict(replace_with)) for v in query.replace_with.values(): # Check that values got converted to floats assert len(v) == 2 @@ -511,27 +518,27 @@ def test_valid_replace_infinity( @pytest.mark.parametrize( "child,columns", [ - (PrivateSource("private"), []), - (PrivateSource("private"), ["A"]), - (PrivateSource("different_private_source"), ["A", "B"]), + (PrivateSource(source_id="private"), []), + (PrivateSource(source_id="private"), ["A"]), + (PrivateSource(source_id="different_private_source"), ["A", "B"]), ], ) def test_valid_drop_null_and_nan(child: QueryExpr, columns: List[str]) -> None: """Test DropNullAndNan with valid values.""" - DropInfinity(child, tuple(columns)) + DropInfinity(child=child, columns=tuple(columns)) @pytest.mark.parametrize( "child,columns", [ - (PrivateSource("private"), []), - (PrivateSource("private"), ["A"]), - (PrivateSource("different_private_source"), ["A", "B"]), + (PrivateSource(source_id="private"), []), + (PrivateSource(source_id="private"), ["A"]), + (PrivateSource(source_id="different_private_source"), ["A", "B"]), ], ) def test_valid_drop_infinity(child: QueryExpr, columns: List[str]) -> None: """Test DropInfinity with valid values.""" - DropInfinity(child, tuple(columns)) + DropInfinity(child=child, columns=tuple(columns)) """Tests for JoinPublic with a Spark DataFrame as the public table.""" @@ -540,7 +547,7 @@ def test_valid_drop_infinity(child: QueryExpr, columns: List[str]) -> None: def test_join_public_string_nan(spark): """Test that the string "NaN" is allowed in string-valued columns.""" df = spark.createDataFrame(pd.DataFrame({"col": ["nan", "NaN", "NAN", "Nan"]})) - query_expr = JoinPublic(PrivateSource("a"), df) + query_expr = JoinPublic(child=PrivateSource(source_id="a"), public_table=df) assert isinstance(query_expr.public_table, DataFrame) assert_frame_equal_with_sort(query_expr.public_table.toPandas(), df.toPandas()) @@ -552,26 +559,26 @@ def test_join_public_dataframe_validation_column_type(spark): df = spark.createDataFrame(data, schema) with pytest.raises(ValueError, match="^Unsupported Spark data type.*"): - JoinPublic(PrivateSource("a"), df) + JoinPublic(child=PrivateSource(source_id="a"), public_table=df) @pytest.mark.parametrize( "child,column,threshold,expected_error_msg", [ ( - PrivateSource("P"), + PrivateSource(source_id="P"), "count", 0, "SuppressAggregates is only supported on aggregates that are GroupByCounts", ), ( - GroupByCount(PrivateSource("P"), KeySet.from_dict({})), + GroupByCount(child=PrivateSource(source_id="P"), groupby_keys=KeySet.from_dict({})), -17, 0, None, ), ( - GroupByCount(PrivateSource("P"), KeySet.from_dict({})), + GroupByCount(child=PrivateSource(source_id="P"), groupby_keys=KeySet.from_dict({})), "count", "not an int", None, @@ -587,7 +594,7 @@ def test_invalid_suppress_aggregates( ) -> None: """Test that SuppressAggregates rejects invalid arguments.""" with pytest.raises((TypeError, TypeCheckError), match=expected_error_msg): - SuppressAggregates(child, column, threshold) + SuppressAggregates(child=child, column=column, threshold=threshold) @pytest.mark.parametrize( diff --git a/test/unit/test_query_expression_visitor.py b/test/unit/test_query_expression_visitor.py index b48ee30f..6a7d2bbc 100644 --- a/test/unit/test_query_expression_visitor.py +++ b/test/unit/test_query_expression_visitor.py @@ -112,90 +112,106 @@ def visit_groupby_bounded_stdev(self, expr): def visit_suppress_aggregates(self, expr): return "SuppressAggregates" - @pytest.mark.parametrize( "expr,expected", [ - (PrivateSource("P"), "PrivateSource"), - (Rename(PrivateSource("P"), FrozenDict.from_dict({"A": "B"})), "Rename"), - (Filter(PrivateSource("P"), "A Date: Wed, 15 Oct 2025 09:08:23 +0200 Subject: [PATCH 2/2] add new rule rewriting file --- .../_query_expr_compiler/_rewrite_rules.py | 201 ++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py diff --git a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py new file mode 100644 index 00000000..2f377897 --- /dev/null +++ b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py @@ -0,0 +1,201 @@ +"""Rules for rewriting QueryExprs.""" + +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import replace +from functools import wraps +from typing import Callable + +from tmlt.core.measurements.aggregations import NoiseMechanism +from tmlt.core.measures import ApproxDP, PureDP, RhoZCDP + +from tmlt.analytics import AnalyticsInternalError +from tmlt.analytics._query_expr import ( + AverageMechanism, + CompilationInfo, + CountDistinctMechanism, + CountMechanism, + DropInfinity, + DropNullAndNan, + EnforceConstraint, + Filter, + FlatMap, + FlatMapByID, + GetBounds, + GetGroups, + GroupByBoundedAverage, + GroupByBoundedSTDEV, + GroupByBoundedSum, + GroupByBoundedVariance, + GroupByCount, + GroupByCountDistinct, + GroupByQuantile, + JoinPrivate, + JoinPublic, + Map, + PrivateSource, + QueryExpr, + Rename, + ReplaceInfinity, + ReplaceNullAndNan, + Select, + StdevMechanism, + SumMechanism, + SuppressAggregates, + VarianceMechanism, +) + +EXPRS_WITH_ONE_CHILD = ( + DropInfinity, + DropNullAndNan, + EnforceConstraint, + Filter, + FlatMap, + FlatMapByID, + GetBounds, + GetGroups, + GroupByBoundedAverage, + GroupByBoundedSTDEV, + GroupByBoundedSum, + GroupByBoundedVariance, + GroupByCount, + GroupByCountDistinct, + GroupByQuantile, + JoinPublic, + Map, + Rename, + ReplaceInfinity, + ReplaceNullAndNan, + Select, + SuppressAggregates, +) +from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( + OutputSchemaVisitor, +) +from tmlt.analytics._schema import ColumnType + +def depth_first(func: Callable[[QueryExpr], QueryExpr]) -> Callable[[QueryExpr], QueryExpr]: + """Recursively applies the given method to a QueryExpr, depth-first.""" + + @wraps(func) + def wrapped(expr: QueryExpr) -> QueryExpr: + if isinstance(expr, PrivateSource): + return func(expr) + if isinstance(expr, EXPRS_WITH_ONE_CHILD): + child=wrapped(expr.child) + return func(replace(expr, child=child)) + elif isinstance(expr, JoinPrivate): + left = wrapped(expr.child) + right = wrapped(expr.right_operand_expr) + return func(replace(expr, child=left, right_operand_expr=right)) + else: + raise AnalyticsInternalError( + f"Unrecognized QueryExpr subtype {type(expr).__qualname__}." + ) + + return wrapped + + +def add_compilation_info(info: CompilationInfo) -> Callable[QueryExpr, QueryExpr]: + """Adds the compilation info to each node of the QueryExpr.""" + + @depth_first + def add_info(expr: QueryExpr) -> QueryExpr: + return replace(expr, compilation_info=info) + + return add_info + + +def select_noise_mechanism(expr: QueryExpr) -> QueryExpr: + """Changes the default noise type into a concrete noise type for aggregations. + + This requires the QueryExpr to have been annotated with compilation info.""" + output_measure = expr.compilation_info.output_measure + + if isinstance(expr, SuppressAggregates): + return replace(expr, child=select_noise_mechanism(expr.child)) + + if isinstance(expr, (GroupByCount, GroupByCountDistinct)): + if expr.mechanism in (CountMechanism.DEFAULT, CountDistinctMechanism.DEFAULT): + core_mechanism = ( + NoiseMechanism.GEOMETRIC + if isinstance(output_measure, (PureDP, ApproxDP)) + else NoiseMechanism.DISCRETE_GAUSSIAN + + ) + elif expr.mechanism in (CountMechanism.LAPLACE, CountDistinctMechanism.LAPLACE): + core_mechanism = NoiseMechanism.GEOMETRIC + elif expr.mechanism in (CountMechanism.GAUSSIAN, CountDistinctMechanism.GAUSSIAN): + core_mechanism = NoiseMechanism.DISCRETE_GAUSSIAN + else: + raise ValueError( + f"Did not recognize the mechanism name {expr.mechanism}." + " Supported mechanisms are DEFAULT, LAPLACE, and GAUSSIAN." + ) + return replace(expr, core_mechanism=core_mechanism) + + if isinstance(expr, ( + GroupByBoundedAverage, + GroupByBoundedSTDEV, + GroupByBoundedSum, + GroupByBoundedVariance, + )): + # Distinguish between Laplace/Geometric or (Discrete) Gaussian. + # Assume floating-point output column type at first + if expr.mechanism in ( + SumMechanism.DEFAULT, + AverageMechanism.DEFAULT, + VarianceMechanism.DEFAULT, + StdevMechanism.DEFAULT, + ): + core_mechanism = ( + NoiseMechanism.LAPLACE + if isinstance(output_measure, (PureDP, ApproxDP)) + else NoiseMechanism.GAUSSIAN + ) + elif expr.mechanism in ( + SumMechanism.LAPLACE, + AverageMechanism.LAPLACE, + VarianceMechanism.LAPLACE, + StdevMechanism.LAPLACE, + ): + core_mechanism = NoiseMechanism.LAPLACE + elif expr.mechanism in ( + SumMechanism.GAUSSIAN, + AverageMechanism.GAUSSIAN, + VarianceMechanism.GAUSSIAN, + StdevMechanism.GAUSSIAN, + ): + core_mechanism = NoiseMechanism.GAUSSIAN + else: + raise ValueError( + f"Did not recognize requested mechanism {expr.mechanism}." + " Supported mechanisms are DEFAULT, LAPLACE, and GAUSSIAN." + ) + + # If the output column type is integer, use integer noise distributions instead + catalog = expr.compilation_info.catalog + schema = expr.child.accept(OutputSchemaVisitor(catalog)) + measure_column_type = schema[expr.measure_column].column_type + if measure_column_type == ColumnType.INTEGER: + core_mechanism = ( + NoiseMechanism.GEOMETRIC + if core_mechanism == NoiseMechanism.LAPLACE + else NoiseMechanism.DISCRETE_GAUSSIAN + ) + + return replace(expr, core_mechanism=core_mechanism) + + # All other aggregations don't use Core's NoiseMechanism, so they stay unchanged. + return expr + + +def rewrite(info: CompilationInfo, expr: QueryExpr) -> QueryExpr: + """Rewrites the given QueryExpr into a QueryExpr that can be compiled.""" + rewrite_rules = [ + add_compilation_info(info), + select_noise_mechanism, + ] + for rule in rewrite_rules: + expr = rule(expr) + return expr