diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e5c59355..0d5ccec5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,8 @@ Changed - Aggregation mechanisms can now be specified as strings instead of enums, e.g. ``"laplace"`` instead of ``CountMechanism.LAPLACE`` or ``SumMechanism.LAPLACE``. - Removed previously deprecated argument ``max_num_rows`` to ``flat_map``. Use ``max_rows`` instead. - Removed previously deprecated argument ``cols`` to ``count_distinct``. Use ``columns`` instead. +- Infinity values are now automatically dropped before a floating-point column is passed to `get_bounds`. (The documentation previously claimed that this was done, but this was not the case.) +- Fixed the documentation of the behavior of some numeric aggregations (`sum`, `average`, `stdev`, `variance`, `quantile`) to match the actual behavior: infinity values are clamped using the specified bounds before being passed to the aggregation function, not dropped. .. _v0.20.2: diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 67bff179..54a8874d 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -1,11 +1,10 @@ """Building blocks of the Tumult Analytics query language. Not for direct use. -Defines the :class:`QueryExpr` class, which represents expressions in the -Tumult Analytics query language. QueryExpr and its subclasses should not be -directly constructed or deconstructed by most users; interfaces such as -:class:`tmlt.analytics.QueryBuilder` to create them and -:class:`tmlt.analytics.Session` to consume them provide more -user-friendly features. +Defines the :class:`QueryExpr` class, which represents expressions in the Tumult +Analytics query language. QueryExpr and its subclasses should not be directly +constructed; but instead built using a :class:`tmlt.analytics.QueryBuilder`. The +documentation of the :class:`tmlt.analytics.QueryBuilder` provides more information +about the intended semantics of :class:`QueryExpr` objects. """ # SPDX-License-Identifier: Apache-2.0 @@ -175,14 +174,10 @@ class StdevMechanism(Enum): class QueryExpr(ABC): """A query expression, base class for relational operators. - In most cases, QueryExpr should not be manipulated directly, but rather - created using :class:`tmlt.analytics.QueryBuilder` and then - consumed by :class:`tmlt.analytics.Session`. While they can be - created and modified directly, this is an advanced usage and is not - recommended for typical users. - - QueryExpr are organized in a tree, where each node is an operator which - returns a relation. + QueryExpr are organized in a tree, where each node is an operator that returns a + table. They are built using the :class:`tmlt.analytics.QueryBuilder`, then rewritten + during the compilation process. They should not be created directly, except in + tests. """ @abstractmethod @@ -1775,13 +1770,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedSum(SingleChildQueryExpr): - """Returns the bounded sum of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the sum is - calculated. - """ + """Returns the bounded sum of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" @@ -1842,13 +1831,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedAverage(SingleChildQueryExpr): - """Returns bounded average of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the average is - calculated. - """ + """Returns bounded average of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" @@ -1909,13 +1892,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedVariance(SingleChildQueryExpr): - """Returns bounded variance of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the variance is - calculated. - """ + """Returns bounded variance of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" @@ -1976,13 +1953,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedStdev(SingleChildQueryExpr): - """Returns bounded stdev of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the - standard deviation is calculated. - """ + """Returns bounded stdev of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" diff --git a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py index 19b7c52c..554c1e26 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py +++ b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py @@ -1,5 +1,4 @@ """Defines a base class for building measurement visitors.""" -import dataclasses import math import warnings from abc import abstractmethod @@ -101,7 +100,7 @@ SuppressAggregates, VarianceMechanism, ) -from tmlt.analytics._schema import ColumnType, FrozenDict, Schema +from tmlt.analytics._schema import Schema from tmlt.analytics._table_identifier import Identifier from tmlt.analytics._table_reference import TableReference from tmlt.analytics._transformation_utils import get_table_from_ref @@ -675,65 +674,6 @@ def _validate_approxDP_and_adjust_budget( else: raise AnalyticsInternalError(f"Unknown mechanism {mechanism}.") - def _add_special_value_handling_to_query( - self, - query: Union[ - GroupByBoundedAverage, - GroupByBoundedStdev, - GroupByBoundedSum, - GroupByBoundedVariance, - GroupByQuantile, - GetBounds, - ], - ): - """Returns a new query that handles nulls, NaNs and infinite values. - - If the measure column allows nulls or NaNs, the new query - will drop those values. - - If the measure column allows infinite values, the new query will replace those - values with the low and high values specified in the query. - - These changes are added immediately before the groupby aggregation in the query. - """ - expected_schema = query.child.schema(self.catalog) - - # You can't perform these queries on nulls, NaNs, or infinite values - # so check for those - try: - measure_desc = expected_schema[query.measure_column] - except KeyError as e: - raise KeyError( - f"Measure column {query.measure_column} is not in the input schema." - ) from e - - new_child: QueryExpr - # If null or NaN values are allowed ... - if measure_desc.allow_null or ( - measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_nan - ): - # then drop those values - # (but don't mutate the original query) - new_child = DropNullAndNan( - child=query.child, columns=tuple([query.measure_column]) - ) - query = dataclasses.replace(query, child=new_child) - if not isinstance(query, GetBounds): - # If infinite values are allowed... - if ( - measure_desc.column_type == ColumnType.DECIMAL - and measure_desc.allow_inf - ): - # then clamp them (to low/high values) - new_child = ReplaceInfinity( - child=query.child, - replace_with=FrozenDict.from_dict( - {query.measure_column: (query.low, query.high)} - ), - ) - query = dataclasses.replace(query, child=new_child) - return query - def _validate_measurement(self, measurement: Measurement, mid_stability: sp.Expr): """Validate a measurement.""" if isinstance(self.adjusted_budget.value, tuple): @@ -1146,7 +1086,6 @@ def visit_groupby_quantile( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1241,7 +1180,6 @@ def visit_groupby_bounded_sum( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1337,7 +1275,6 @@ def visit_groupby_bounded_average( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1433,7 +1370,6 @@ def visit_groupby_bounded_variance( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1529,7 +1465,6 @@ def visit_groupby_bounded_stdev( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1622,7 +1557,6 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]: # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) keyset_budget = self._get_zero_budget() diff --git a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py index 7707a2af..7d14173b 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py +++ b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py @@ -20,15 +20,21 @@ AverageMechanism, CountDistinctMechanism, CountMechanism, + DropInfinity, + DropNullAndNan, + FrozenDict, + GetBounds, GroupByBoundedAverage, GroupByBoundedStdev, GroupByBoundedSum, GroupByBoundedVariance, GroupByCount, GroupByCountDistinct, + GroupByQuantile, JoinPrivate, PrivateSource, QueryExpr, + ReplaceInfinity, SingleChildQueryExpr, StdevMechanism, SumMechanism, @@ -193,9 +199,68 @@ def select_noise(expr: QueryExpr) -> QueryExpr: return select_noise +def add_special_value_handling( + info: CompilationInfo, +) -> Callable[[QueryExpr], QueryExpr]: + """Rewrites the query to handle nulls, NaNs and infinite values. + + If the measure column allows nulls or NaNs, the rewritten query will drop those + values. If the measure column allows infinite values, the new query will replace + those values with the clamping bounds specified in the query, or drop these values + for :meth:`~tmlt.analytics.QueryBuilder.get_bounds`. + """ + + @depth_first + def handle_special_values(expr: QueryExpr) -> QueryExpr: + if not isinstance( + expr, + ( + GroupByBoundedAverage, + GroupByBoundedStdev, + GroupByBoundedSum, + GroupByBoundedVariance, + GroupByQuantile, + GetBounds, + ), + ): + return expr + schema = expr.child.schema(info.catalog) + measure_desc = schema[expr.measure_column] + # Remove nulls/NaN if necessary + if measure_desc.allow_null or ( + measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_nan + ): + expr = replace( + expr, + child=DropNullAndNan(child=expr.child, columns=(expr.measure_column,)), + ) + # Remove infinities if necessary + if measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_inf: + if isinstance(expr, GetBounds): + return replace( + expr, + child=DropInfinity( + child=expr.child, columns=(expr.measure_column,) + ), + ) + return replace( + expr, + child=ReplaceInfinity( + child=expr.child, + replace_with=FrozenDict.from_dict( + {expr.measure_column: (expr.low, expr.high)} + ), + ), + ) + return expr + + return handle_special_values + + def rewrite(info: CompilationInfo, expr: QueryExpr) -> QueryExpr: """Rewrites the given QueryExpr into a QueryExpr that can be compiled.""" rewrite_rules = [ + add_special_value_handling(info), select_noise_mechanism(info), ] for rule in rewrite_rules: diff --git a/src/tmlt/analytics/query_builder.py b/src/tmlt/analytics/query_builder.py index a427220f..478e9f33 100644 --- a/src/tmlt/analytics/query_builder.py +++ b/src/tmlt/analytics/query_builder.py @@ -764,7 +764,7 @@ def replace_infinity( ) return self - def drop_null_and_nan(self, columns: Optional[List[str]]) -> "QueryBuilder": + def drop_null_and_nan(self, columns: Optional[List[str]] = None) -> "QueryBuilder": """Removes rows containing null or NaN values. .. note:: @@ -869,7 +869,7 @@ def drop_null_and_nan(self, columns: Optional[List[str]]) -> "QueryBuilder": ) return self - def drop_infinity(self, columns: Optional[List[str]]) -> "QueryBuilder": + def drop_infinity(self, columns: Optional[List[str]] = None) -> "QueryBuilder": """Remove rows containing infinite values. .. @@ -2453,12 +2453,6 @@ def sum( ) -> Query: """Returns a sum query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2472,6 +2466,12 @@ def sum( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2541,12 +2541,6 @@ def average( ) -> Query: """Returns an average query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2560,6 +2554,12 @@ def average( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2629,12 +2629,6 @@ def variance( ) -> Query: """Returns a variance query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2648,6 +2642,12 @@ def variance( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2717,12 +2717,6 @@ def stdev( ) -> Query: """Returns a standard deviation query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2736,6 +2730,12 @@ def stdev( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2995,10 +2995,9 @@ def quantile( .. note:: If the column being measured contains NaN or null values, a - :meth:`~QueryBuilder.drop_null_and_nan` query will be performed - first. If the column being measured contains infinite values, a - :meth:`~QueryBuilder.drop_infinity` query will be performed first. - + :meth:`~QueryBuilder.drop_null_and_nan` query will be performed first. If + the column being measured contains infinite values, these values will be + clamped between ``low`` and ``high``. .. >>> from tmlt.analytics import ( ... AddOneRow, diff --git a/test/system/session/rows/conftest.py b/test/system/session/rows/conftest.py index 265f6e7c..10368575 100644 --- a/test/system/session/rows/conftest.py +++ b/test/system/session/rows/conftest.py @@ -38,7 +38,6 @@ FrozenDict, Schema, analytics_to_spark_columns_descriptor, - analytics_to_spark_schema, ) # Shorthands for some values used in tests @@ -709,56 +708,3 @@ def sess_data(spark, request): analytics_to_spark_columns_descriptor(Schema(sdf_col_types)) ) request.cls.sdf_input_domain = sdf_input_domain - - -###DATA FOR SESSIONS WITH NULLS### -@pytest.fixture(name="null_session_data", scope="class") -def null_setup(spark, request): - """Set up test data for sessions with nulls.""" - # Since Spark gives back timestamps with microsecond accuracy, this - # dataframe needs to make that the default precision for column T. - pdf = pd.DataFrame( - [ - ["a0", 0, 0.0, datetime.date(2000, 1, 1), datetime.datetime(2020, 1, 1)], - [None, 1, 1.0, datetime.date(2001, 1, 1), datetime.datetime(2021, 1, 1)], - ["a2", None, 2.0, datetime.date(2002, 1, 1), datetime.datetime(2022, 1, 1)], - ["a3", 3, None, datetime.date(2003, 1, 1), datetime.datetime(2023, 1, 1)], - ["a4", 4, 4.0, None, datetime.datetime(2024, 1, 1)], - ["a5", 5, 5.0, datetime.date(2005, 1, 1), None], - ], - columns=["A", "I", "X", "D", "T"], - ).astype({"T": "datetime64[us]"}) - - request.cls.pdf = pdf - - sdf_col_types = { - "A": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), - "I": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), - "X": ColumnDescriptor(ColumnType.DECIMAL, allow_null=True), - "D": ColumnDescriptor(ColumnType.DATE, allow_null=True), - "T": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), - } - - sdf = spark.createDataFrame( - pdf, schema=analytics_to_spark_schema(Schema(sdf_col_types)) - ) - request.cls.sdf = sdf - - -###DATA FOR SESSIONS WITH INF VALUES### -@pytest.fixture(name="infs_test_data", scope="class") -def infs_setup(spark, request): - """Set up tests.""" - pdf = pd.DataFrame( - {"A": ["a0", "a0", "a1", "a1"], "B": [float("-inf"), 2.0, 5.0, float("inf")]} - ) - request.cls.pdf = pdf - - sdf_col_types = { - "A": ColumnDescriptor(ColumnType.VARCHAR), - "B": ColumnDescriptor(ColumnType.DECIMAL, allow_inf=True), - } - sdf = spark.createDataFrame( - pdf, schema=analytics_to_spark_schema(Schema(sdf_col_types)) - ) - request.cls.sdf = sdf diff --git a/test/system/session/rows/test_add_max_rows_infs_nulls.py b/test/system/session/rows/test_add_max_rows_infs_nulls.py deleted file mode 100644 index 20738d87..00000000 --- a/test/system/session/rows/test_add_max_rows_infs_nulls.py +++ /dev/null @@ -1,512 +0,0 @@ -"""System tests for Sessions with Nulls and Infs.""" - -# SPDX-License-Identifier: Apache-2.0 -# Copyright Tumult Labs 2025 - -import datetime -from typing import Any, Dict, List, Mapping, Tuple, Union - -import pandas as pd -import pytest -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import StringType, StructField, StructType -from tmlt.core.measurements.interactive_measurements import SequentialQueryable -from tmlt.core.utils.testing import Case, parametrize - -from tmlt.analytics import ( - AddOneRow, - AnalyticsDefault, - ColumnDescriptor, - ColumnType, - KeySet, - PureDPBudget, - QueryBuilder, - Session, - TruncationStrategy, -) -from tmlt.analytics._table_identifier import NamedTable - -from ....conftest import assert_frame_equal_with_sort - - -@pytest.mark.usefixtures("null_session_data") -class TestSessionWithNulls: - """Tests for sessions with Nulls.""" - - pdf: pd.DataFrame - sdf: DataFrame - - def _expected_replace(self, d: Mapping[str, Any]) -> pd.DataFrame: - """The expected value if you replace None with default values in d.""" - new_cols: List[pd.DataFrame] = [] - for col in list(self.pdf.columns): - if col in dict(d): - # make sure I becomes an integer here - if col == "I": - new_cols.append(self.pdf[col].fillna(dict(d)[col]).astype(int)) - else: - new_cols.append(self.pdf[col].fillna(dict(d)[col])) - else: - new_cols.append(self.pdf[col]) - # `axis=1` means that you want to "concatenate" by columns - # i.e., you want your new table to look like this: - # df1 | df2 | df3 | ... - # df1 | df2 | df3 | ... - return pd.concat(new_cols, axis=1) - - def test_expected_replace(self) -> None: - """Test the test method _expected_replace.""" - d = { - "A": "a999", - "I": -999, - "X": 99.9, - "D": datetime.date(1999, 1, 1), - "T": datetime.datetime(2019, 1, 1), - } - expected = pd.DataFrame( - [ - [ - "a0", - 0, - 0.0, - datetime.date(2000, 1, 1), - datetime.datetime(2020, 1, 1), - ], - [ - "a999", - 1, - 1.0, - datetime.date(2001, 1, 1), - datetime.datetime(2021, 1, 1), - ], - [ - "a2", - -999, - 2.0, - datetime.date(2002, 1, 1), - datetime.datetime(2022, 1, 1), - ], - [ - "a3", - 3, - 99.9, - datetime.date(2003, 1, 1), - datetime.datetime(2023, 1, 1), - ], - [ - "a4", - 4, - 4.0, - datetime.date(1999, 1, 1), - datetime.datetime(2024, 1, 1), - ], - [ - "a5", - 5, - 5.0, - datetime.date(2005, 1, 1), - datetime.datetime(2019, 1, 1), - ], - ], - columns=["A", "I", "X", "D", "T"], - ) - assert_frame_equal_with_sort(self.pdf, self._expected_replace({})) - assert_frame_equal_with_sort( - expected, - self._expected_replace(d), - ) - - @pytest.mark.parametrize( - "cols_to_defaults", - [ - ({"A": "aaaaaaa"}), - ({"I": 999}), - ( - { - "A": "aaa", - "I": 999, - "X": -99.9, - "D": datetime.date.fromtimestamp(0), - "T": datetime.datetime.fromtimestamp(0), - } - ), - ], - ) - def test_replace_null_and_nan( - self, - cols_to_defaults: Mapping[ - str, Union[int, float, str, datetime.date, datetime.datetime] - ], - ) -> None: - """Test Session.replace_null_and_nan.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - session.create_view( - QueryBuilder("private").replace_null_and_nan(cols_to_defaults), - "replaced", - cache=False, - ) - # pylint: disable=protected-access - queryable = session._accountant._queryable - assert isinstance(queryable, SequentialQueryable) - data = queryable._data - assert isinstance(data, dict) - assert isinstance(data[NamedTable("replaced")], DataFrame) - # pylint: enable=protected-access - assert_frame_equal_with_sort( - data[NamedTable("replaced")].toPandas(), - self._expected_replace(cols_to_defaults), - ) - - @pytest.mark.parametrize( - "public_df,keyset,expected", - [ - ( - pd.DataFrame( - [[None, 0], [None, 1], ["a2", 1], ["a2", 2]], - columns=["A", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 1, 2]}), - pd.DataFrame([[0, 1], [1, 2], [2, 1]], columns=["new_column", "count"]), - ), - ( - pd.DataFrame( - [["a0", 0, 0], [None, 1, 17], ["a5", 5, 17], ["a5", 5, 400]], - columns=["A", "I", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 17, 400]}), - pd.DataFrame( - [[0, 1], [17, 2], [400, 1]], columns=["new_column", "count"] - ), - ), - ( - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), "2000"], - [datetime.date(2001, 1, 1), "2001"], - [None, "none"], - [None, "also none"], - ], - columns=["D", "year"], - ), - KeySet.from_dict( - {"D": [datetime.date(2000, 1, 1), datetime.date(2001, 1, 1), None]} - ), - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), 1], - [datetime.date(2001, 1, 1), 1], - [None, 2], - ], - columns=["D", "count"], - ), - ), - ], - ) - def test_join_public( - self, spark, public_df: pd.DataFrame, keyset: KeySet, expected: pd.DataFrame - ) -> None: - """Test that join_public creates the correct results. - - The query used to evaluate this is a GroupByCount on the new dataframe, - using the keyset provided. - """ - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - session.add_public_dataframe("public", spark.createDataFrame(public_df)) - result = session.evaluate( - QueryBuilder("private").join_public("public").groupby(keyset).count(), - privacy_budget=PureDPBudget(float("inf")), - ) - assert_frame_equal_with_sort(result.toPandas(), expected) - - @pytest.mark.parametrize( - "private_df,keyset,expected", - [ - ( - pd.DataFrame( - [[None, 0], [None, 1], ["a2", 1], ["a2", 2]], - columns=["A", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 1, 2]}), - pd.DataFrame([[0, 1], [1, 2], [2, 1]], columns=["new_column", "count"]), - ), - ( - pd.DataFrame( - [["a0", 0, 0], [None, 1, 17], ["a5", 5, 17], ["a5", 5, 400]], - columns=["A", "I", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 17, 400]}), - pd.DataFrame( - [[0, 1], [17, 2], [400, 1]], columns=["new_column", "count"] - ), - ), - ( - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), "2000"], - [datetime.date(2001, 1, 1), "2001"], - [None, "none"], - [None, "also none"], - ], - columns=["D", "year"], - ), - KeySet.from_dict( - {"D": [datetime.date(2000, 1, 1), datetime.date(2001, 1, 1), None]} - ), - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), 1], - [datetime.date(2001, 1, 1), 1], - [None, 2], - ], - columns=["D", "count"], - ), - ), - ], - ) - def test_join_private( - self, spark, private_df: pd.DataFrame, keyset: KeySet, expected: pd.DataFrame - ) -> None: - """Test that join_private creates the correct results. - - The query used to evaluate this is a GroupByCount on the joined dataframe, - using the keyset provided. - """ - session = ( - Session.Builder() - .with_privacy_budget(PureDPBudget(float("inf"))) - .with_private_dataframe("private", self.sdf, AddOneRow()) - .with_private_dataframe( - "private2", spark.createDataFrame(private_df), AddOneRow() - ) - .build() - ) - result = session.evaluate( - QueryBuilder("private") - .join_private( - QueryBuilder("private2"), - TruncationStrategy.DropExcess(100), - TruncationStrategy.DropExcess(100), - ) - .groupby(keyset) - .count(), - PureDPBudget(float("inf")), - ) - assert_frame_equal_with_sort(result.toPandas(), expected) - - @parametrize( - Case("both_allow_nulls")( - public_schema=StructType([StructField("foo", StringType(), True)]), - private_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True) - }, - ), - Case("none_allow_nulls")( - public_schema=StructType([StructField("foo", StringType(), False)]), - private_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("public_only_nulls")( - public_schema=StructType([StructField("foo", StringType(), True)]), - private_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("private_only_nulls")( - public_schema=StructType([StructField("foo", StringType(), False)]), - private_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - ) - def test_public_join_schema_null_propagation( - self, - public_schema: StructType, - private_schema: StructType, - expected_schema: StructType, - spark: SparkSession, - ): - """Tests that join_public correctly handles schemas that allow null values.""" - public_df = spark.createDataFrame([], public_schema) - private_df = spark.createDataFrame([], private_schema) - sess = ( - Session.Builder() - .with_privacy_budget(PureDPBudget(float("inf"))) - .with_private_dataframe("private", private_df, protected_change=AddOneRow()) - .with_public_dataframe("public", public_df) - .build() - ) - sess.create_view( - QueryBuilder("private").join_public("public"), source_id="join", cache=False - ) - assert sess.get_schema("join") == expected_schema - - @parametrize( - Case("both_allow_nulls")( - left_schema=StructType([StructField("foo", StringType(), True)]), - right_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True) - }, - ), - Case("none_allow_nulls")( - left_schema=StructType([StructField("foo", StringType(), False)]), - right_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("public_only_nulls")( - left_schema=StructType([StructField("foo", StringType(), True)]), - right_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("private_only_nulls")( - left_schema=StructType([StructField("foo", StringType(), False)]), - right_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - ) - def test_private_join_schema_null_propagation( - self, - left_schema: StructType, - right_schema: StructType, - expected_schema: StructType, - spark: SparkSession, - ): - """Tests that join_private correctly handles schemas that allow null values.""" - left_df = spark.createDataFrame([], left_schema) - right_df = spark.createDataFrame([], right_schema) - sess = ( - Session.Builder() - .with_privacy_budget(PureDPBudget(float("inf"))) - .with_private_dataframe("left", left_df, protected_change=AddOneRow()) - .with_private_dataframe("right", right_df, protected_change=AddOneRow()) - .build() - ) - sess.create_view( - QueryBuilder("left").join_private( - "right", - truncation_strategy_left=TruncationStrategy.DropExcess(1), - truncation_strategy_right=TruncationStrategy.DropExcess(1), - ), - source_id="join", - cache=False, - ) - assert sess.get_schema("join") == expected_schema - - -@pytest.mark.usefixtures("infs_test_data") -class TestSessionWithInfs: - """Tests for Sessions with Infs.""" - - pdf: pd.DataFrame - sdf: DataFrame - - @pytest.mark.parametrize( - "replace_with,", - [ - ({}), - ({"B": (-100.0, 100.0)}), - ({"B": (123.45, 678.90)}), - ({"B": (999.9, 111.1)}), - ], - ) - def test_replace_infinity( - self, replace_with: Dict[str, Tuple[float, float]] - ) -> None: - """Test replace_infinity query.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - session.create_view( - QueryBuilder("private").replace_infinity(replace_with), - "replaced", - cache=False, - ) - # pylint: disable=protected-access - queryable = session._accountant._queryable - assert isinstance(queryable, SequentialQueryable) - data = queryable._data - assert isinstance(data, dict) - assert isinstance(data[NamedTable("replaced")], DataFrame) - # pylint: enable=protected-access - (replace_negative, replace_positive) = replace_with.get( - "B", (AnalyticsDefault.DECIMAL, AnalyticsDefault.DECIMAL) - ) - expected = self.pdf.replace(float("-inf"), replace_negative).replace( - float("inf"), replace_positive - ) - assert_frame_equal_with_sort(data[NamedTable("replaced")].toPandas(), expected) - - @pytest.mark.parametrize( - "replace_with,expected", - [ - ({}, pd.DataFrame([["a0", 2.0], ["a1", 5.0]], columns=["A", "sum"])), - ( - {"B": (-100.0, 100.0)}, - pd.DataFrame([["a0", -98.0], ["a1", 105.0]], columns=["A", "sum"]), - ), - ( - {"B": (500.0, 100.0)}, - pd.DataFrame([["a0", 502.0], ["a1", 105.0]], columns=["A", "sum"]), - ), - ], - ) - def test_sum( - self, replace_with: Dict[str, Tuple[float, float]], expected: pd.DataFrame - ) -> None: - """Test GroupByBoundedSum after replacing infinite values.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - result = session.evaluate( - QueryBuilder("private") - .replace_infinity(replace_with) - .groupby(KeySet.from_dict({"A": ["a0", "a1"]})) - .sum("B", low=-1000, high=1000, name="sum"), - PureDPBudget(float("inf")), - ) - assert_frame_equal_with_sort(result.toPandas(), expected) - - def test_drop_infinity(self): - """Test GroupByBoundedSum after dropping infinite values.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - result = session.evaluate( - QueryBuilder("private") - .drop_infinity(columns=["B"]) - .groupby(KeySet.from_dict({"A": ["a0", "a1"]})) - .sum("B", low=-1000, high=1000, name="sum"), - PureDPBudget(float("inf")), - ) - expected = pd.DataFrame([["a0", 2.0], ["a1", 5.0]], columns=["A", "sum"]) - assert_frame_equal_with_sort(result.toPandas(), expected) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py new file mode 100644 index 00000000..846af57f --- /dev/null +++ b/test/system/session/test_special_values.py @@ -0,0 +1,953 @@ +"""System tests for tables with special values (nulls, nans, infinities).""" + +# SPDX-License-Identifier: Apache-2.0 +# Copyright Tumult Labs 2025 + +import datetime +from typing import Dict, List, Optional, Tuple, Union + +import pandas as pd +import pytest +from numpy import sqrt +from pyspark.sql import DataFrame +from tmlt.core.utils.testing import Case, parametrize + +from tmlt.analytics import ( + AddOneRow, + AddRowsWithID, + ColumnDescriptor, + ColumnType, + KeySet, + MaxRowsPerID, + ProtectedChange, + PureDPBudget, + Query, + QueryBuilder, + Session, + TruncationStrategy, +) +from tmlt.analytics._schema import Schema, analytics_to_spark_schema + +from ...conftest import assert_frame_equal_with_sort + + +@pytest.fixture(name="sdf_special_values", scope="module") +def special_values_dataframe(spark): + """Set up test data for sessions with special values.""" + sdf_col_types = { + "string_nulls": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), + "int_no_null": ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + "int_nulls": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), + "float_no_special": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=False, + allow_inf=False, + ), + "float_nulls": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=True, + allow_nan=False, + allow_inf=False, + ), + "float_nans": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=True, + allow_inf=False, + ), + "float_infs": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=False, + allow_inf=True, + ), + "float_all_special": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=True, + allow_nan=True, + allow_inf=True, + ), + "date_nulls": ColumnDescriptor(ColumnType.DATE, allow_null=True), + "time_nulls": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), + } + date = datetime.date(2000, 1, 1) + time = datetime.datetime(2020, 1, 1) + sdf = spark.createDataFrame( + [(f"normal_{i}", 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, date, time) for i in range(20)] + + [ + # Rows with nulls + (None, 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, date, time), + ("u2", 1, None, 1.0, 1.0, 1.0, 1.0, 1.0, date, time), + ("u3", 1, 1, 1.0, None, 1.0, 1.0, None, date, time), + ("u4", 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, None, time), + ("u5", 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, date, None), + # Rows with nans + ("a6", 1, 1, 1.0, 1.0, float("nan"), 1.0, float("nan"), date, time), + # Rows with infinities + ("i7", 1, 1, 1.0, 1.0, 1.0, float("inf"), float("inf"), date, time), + ("i8", 1, 1, 1.0, 1.0, 1.0, -float("inf"), -float("inf"), date, time), + ("i9", 1, 1, 1.0, 1.0, 1.0, float("inf"), 1.0, date, time), + ("i10", 1, 1, 1.0, 1.0, 1.0, -float("inf"), 1.0, date, time), + ], + schema=analytics_to_spark_schema(Schema(sdf_col_types)), + ) + return sdf + + +@parametrize( + [ + Case("int_sum")( + # There are 29 1s in the "int_nulls" column and one null, which should be + # dropped by default. + query=QueryBuilder("private").sum("int_nulls", 0, 1), + expected_df=pd.DataFrame( + [[29]], + columns=["int_nulls_sum"], + ), + ), + Case("count_distinct")( + # Nulls, nans, and infinities count as distinct values in a count_distinct + # query. All other values in the "float_all_special" are 1s. + query=QueryBuilder("private").count_distinct(["float_all_special"]), + expected_df=pd.DataFrame( + [[5]], + columns=["count_distinct(float_all_special)"], + ), + ), + Case("count_distinct_deduplicates")( + # The "float_infs" column contains 1s, two positive infinities, and two + # negative infinities. + query=QueryBuilder("private").count_distinct(["float_infs"]), + expected_df=pd.DataFrame( + [[3]], + columns=["count_distinct(float_infs)"], + ), + ), + Case("float_average")( + # In the "float_all_special" column, there are 26 1s, one null (dropped), + # one NaN (dropped), one negative infinity (clamped to 50), and one positive + # infinity (clamed to 100). + query=QueryBuilder("private").sum("float_all_special", -100, 300), + expected_df=pd.DataFrame( + [[226]], # 26-100+300 + columns=["float_all_special_sum"], + ), + ), + Case("group_by_null")( + # Nulls can be used as group-by + query=( + QueryBuilder("private") + .groupby( + KeySet.from_dict({"date_nulls": [datetime.date(2000, 1, 1), None]}) + ) + .count() + ), + expected_df=pd.DataFrame( + [[datetime.date(2000, 1, 1), 29], [None, 1]], + columns=["date_nulls", "count"], + ), + ), + ] +) +def test_default_behavior( + sdf_special_values: DataFrame, query: Query, expected_df: pd.DataFrame +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + result = sess.evaluate(query, inf_budget) + print(expected_df) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("int_noop")( + # Column "int_no_null" has only non-null values, all equal to 1 + replace_with={"int_no_null": 42}, + column="int_no_null", + low=0, + high=1, + expected_average=1, + ), + Case("int_replace_null")( + # Column "int_nulls" has one null value and 29 1s. + replace_with={"int_nulls": 31}, + column="int_nulls", + low=0, + high=100, + expected_average=2.0, # (29+31)/30 + ), + Case("float_replace_null")( + # Column "float_nulls" has one null value and 29 1s. + replace_with={"float_nulls": 61}, + column="float_nulls", + low=0, + high=100, + expected_average=3.0, # (29+61)/30 + ), + Case("float_replace_nan")( + # Column "float_nulls" has one null value and 29 1s. + replace_with={"float_nans": 91}, + column="float_nans", + low=0, + high=100, + expected_average=4.0, # (29+91)/30 + ), + Case("float_replace_both")( + # Column "float_all_special" has 26 1s, one null value, one nan-value, one + # negative infinity (clamped to 0), one positive infinity (clamped to 34). + replace_with={"float_all_special": 15}, + column="float_all_special", + low=0, + high=34, + expected_average=3.0, # (26+15+15+34)/30 + ), + Case("replace_all_with_none")( + # When called with no argument, replace_null_and_nan should replace all null + # values by analytics defaults, e.g. 0. + replace_with=None, + column="float_nulls", + low=0, + high=1, + expected_average=29.0 / 30, + ), + Case("replace_all_with_empty_dict")( + # Same thing with an empty dict and with nan values. + replace_with={}, + column="float_nans", + low=0, + high=1, + expected_average=29.0 / 30, + ), + ] +) +def test_replace_null_and_nan( + sdf_special_values: DataFrame, + replace_with: Optional[Dict[str, Union[int, float]]], + column: str, + low: Union[int, float], + high: Union[int, float], + expected_average: Union[int, float], +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + base_query = QueryBuilder("private") + query = base_query.replace_null_and_nan(replace_with).average(column, low, high) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected_average]], columns=[column + "_average"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + # All columns have 30 rows, all non-special values are equal to 1. + Case("int_noop")( + # Column "int_no_null" has only regular values. + affected_columns=["int_no_null"], + measure_column="int_no_null", + low=0, + high=1, + expected_sum=30, + ), + Case("int_drop_nulls")( + # Column "int_nulls" has one null value and 29 1s. + affected_columns=["int_nulls"], + measure_column="int_nulls", + low=0, + high=100, + expected_sum=29, + ), + Case("float_drop_nulls")( + # Column "float_nulls" has one null value and 29 1s. + affected_columns=["float_nulls"], + measure_column="float_nulls", + low=0, + high=100, + expected_sum=29, + ), + Case("float_drop_nan")( + # Column "float_nans" has one nan value and 29 1s. + affected_columns=["float_nans"], + measure_column="float_nans", + low=0, + high=100, + expected_sum=29, + ), + Case("float_drop_both")( + # Column "float_all_special" has 26 1s, one null, one nan, one negative + # infinity (clamped to 0), one positive infinity (clamped to 100). + affected_columns=["float_all_special"], + measure_column="float_all_special", + low=0, + high=100, + expected_sum=126, + ), + Case("drop_other_columns")( + # Column "float_infs" has 26 1s, two negative infinities (clamped to 0) and + # two positive infinities (clamped to 100). But dropping rows from columns + # "string_nulls", "float_nulls", "float_nans", "date_nulls" and "time_nulls" + # should remove five rows, leaving just 21 normal values. + affected_columns=[ + "string_nulls", + "float_nulls", + "float_nans", + "date_nulls", + "time_nulls", + ], + measure_column="float_infs", + low=0, + high=100, + expected_sum=221, + ), + Case("drop_all_with_none")( + # When called with no argument, replace_null_and_nan should drop all rows + # that have null/nan values anywhere, which leaves 24 1s even if we're + # summing a column without nulls. + affected_columns=None, + measure_column="int_no_null", + low=0, + high=1, + expected_sum=24, + ), + Case("drop_all_with_empty_list")( + # Same thing with an empty list. + affected_columns=[], + measure_column="float_nulls", + low=0, + high=1, + expected_sum=24.0, + ), + ] +) +def test_drop_null_and_nan( + sdf_special_values: DataFrame, + affected_columns: Optional[List[str]], + measure_column: str, + low: Union[int, float], + high: Union[int, float], + expected_sum: Union[int, float], +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + base_query = QueryBuilder("private") + query = base_query.drop_null_and_nan(affected_columns).sum( + measure_column, low, high + ) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected_sum]], columns=[measure_column + "_sum"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + # All these tests compute the average of the "float_infs" column of the input table, + # in which there are: + # - 26 non-infinity values, all equal to 1 + # - two negative infinity + # - two positive infinity + # We test this using average and not sum to distinguish between infinities being + # removed from infinities being changed to 0. + [ + Case("replace_no_clamp")( + replace_with={"float_infs": (0, 17)}, + low=-100, + high=100, + # 26+0+0+17+17 = 60, divided by 30 is 2 + expected_average=2.0, + ), + Case("replace_clamp")( + replace_with={"float_infs": (-4217, 300)}, + low=-5, + high=22, + # 26-5+5+22+22 = 60, divided by 30 is 2 + expected_average=2.0, + ), + Case("replace_unrelated_column")( + # If we don't explicitly replace infinity in the measure column, then + # infinities should be clamped to the bounds. + replace_with={"float_all_special": (-4217, 300)}, + low=-10, + high=27, + # 26-10+10+27+27 = 60, divided by 30 is 2 + expected_average=2.0, + ), + Case("replace_with_none")( + # If used without any argument, replace_infinity transforms all infinity + # values in all columns of the table to 0. + replace_with=None, + low=-10, + high=10, + expected_average=26.0 / 30.0, + ), + Case("replace_with_empty_dict")( + # Same with an empty dict. + replace_with={}, + low=-10, + high=10, + expected_average=26.0 / 30.0, + ), + ] +) +def test_replace_infinity_average( + sdf_special_values: DataFrame, + replace_with: Optional[Dict[str, Tuple[float, float]]], + low: float, + high: float, + expected_average: float, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + base_query = QueryBuilder("private") + query = base_query.replace_infinity(replace_with).average("float_infs", low, high) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected_average]], columns=["float_infs_average"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("all_ones")( + replace_with={"float_infs": (1, 1)}, + expected_sum=30.0, + expected_stdev=0, + expected_variance=0, + ), + Case("one_zero_one_one")( + # If we don't replace infinities in the measure column, then infinity values + # should be clamped to the bounds, namely 0 and 1. + replace_with={"float_all_special": (1, 1)}, + expected_sum=28.0, + expected_stdev=sqrt((2 * (28.0 / 30) ** 2 + 28 * (2.0 / 30) ** 2) / 29), + expected_variance=(2 * (28.0 / 30) ** 2 + 28 * (2.0 / 30) ** 2) / 29, + ), + Case("all_zeroes")( + # Without argument, all infinities are replaced by 0 + replace_with=None, + expected_sum=26.0, + expected_stdev=sqrt((4 * (26.0 / 30) ** 2 + 26 * (4.0 / 30) ** 2) / 29), + expected_variance=(4 * (26.0 / 30) ** 2 + 26 * (4.0 / 30) ** 2) / 29, + ), + ] +) +def test_replace_infinity_other_aggregations( + sdf_special_values: DataFrame, + replace_with: Optional[Dict[str, Tuple[float, float]]], + expected_sum: float, + expected_stdev: float, + expected_variance: float, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + protected_change=AddOneRow(), + ) + + query_sum = ( + QueryBuilder("private").replace_infinity(replace_with).sum("float_infs", 0, 1) + ) + result_sum = sess.evaluate(query_sum, inf_budget) + expected_df = pd.DataFrame([[expected_sum]], columns=["float_infs_sum"]) + assert_frame_equal_with_sort(result_sum.toPandas(), expected_df) + + query_stdev = ( + QueryBuilder("private").replace_infinity(replace_with).stdev("float_infs", 0, 1) + ) + result_stdev = sess.evaluate(query_stdev, inf_budget) + expected_df = pd.DataFrame([[expected_stdev]], columns=["float_infs_stdev"]) + assert_frame_equal_with_sort(result_stdev.toPandas(), expected_df) + + query_variance = ( + QueryBuilder("private") + .replace_infinity(replace_with) + .variance("float_infs", 0, 1) + ) + result_variance = sess.evaluate(query_variance, inf_budget) + expected_df = pd.DataFrame([[expected_variance]], columns=["float_infs_variance"]) + assert_frame_equal_with_sort(result_variance.toPandas(), expected_df) + + +@parametrize( + # All these tests compute the sum of the "float_infs" column of the input table. + [ + Case("drop_rows_in_column")( + # There are 26 non-infinity values in the "float_infs" column. + columns=["float_infs"], + expected_sum=26.0, + ), + Case("drop_no_rows")( + # The call to drop_infinity is a no-op. In the "float_infs" column, there + # are two rows with positive infinities (clamped to 1), and two with + # negative infinities (clamped to 0). + columns=["float_no_special"], + expected_sum=28.0, + ), + Case("drop_some_rows_due_to_other_columns")( + # Two rows with infinite values in the "float_infs" column also have + # infinite values in the "float_all_special" column. We end up with one + # positive infinity value, clamped to 1, and one negative, clamped to 0. + columns=["float_all_special"], + expected_sum=27.0, + ), + Case("drop_rows_in_all_columns")( + # If used without any argument, drop_infinity removes all infinity values in + # all columns of the table. + columns=None, + expected_sum=26.0, + ), + ] +) +def test_drop_infinity( + sdf_special_values: DataFrame, + columns: Optional[List[str]], + expected_sum: float, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + base_query = QueryBuilder("private") + query = base_query.drop_infinity(columns).sum("float_infs", 0, 1) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected_sum]], columns=["float_infs_sum"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("works_with_nulls")( + # get_bounds doesn't explode when called on a null column + query=QueryBuilder("private").get_bounds("int_nulls"), + expected_df=pd.DataFrame( + [[-1, 1]], + columns=["int_nulls_lower_bound", "int_nulls_upper_bound"], + ), + ), + Case("works_with_nan")( + # Same with nans + query=QueryBuilder("private").get_bounds("float_nans"), + expected_df=pd.DataFrame( + [[-1, 1]], + columns=["float_nans_lower_bound", "float_nans_upper_bound"], + ), + ), + Case("works_with_infinity")( + # Same with infinities + query=QueryBuilder("private").get_bounds("float_infs"), + expected_df=pd.DataFrame( + [[-1, 1]], + columns=["float_infs_lower_bound", "float_infs_upper_bound"], + ), + ), + Case("drop_and_replace")( + # Dropping nulls & nans removes 6/30 values, replacing 4 infinity values by + # (-3,3) guarantees ensures get_bounds should get the interval corresponding + # to next power of 2, namely 4 + query=( + QueryBuilder("private") + .drop_null_and_nan() + .replace_infinity({"float_infs": (-3, 3)}) + .get_bounds("float_infs") + ), + expected_df=pd.DataFrame( + [[-4, 4]], + columns=["float_infs_lower_bound", "float_infs_upper_bound"], + ), + ), + ] +) +def test_get_bounds( + sdf_special_values: DataFrame, query: Query, expected_df: pd.DataFrame +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + result = sess.evaluate(query, inf_budget) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("normal_case_explicit")( + # Column "float_all_special" has 26 1s, one null (replaced by 100), one nan + # (replaced by 100), and two infinities (dropped). + query=( + QueryBuilder("private") + .enforce(MaxRowsPerID(1)) + .replace_null_and_nan({"float_all_special": 100.0}) + .drop_infinity(["float_all_special"]) + .sum("float_all_special", 0, 100) + ), + expected_df=pd.DataFrame( + [[226]], # 26+100+100 + columns=["float_all_special_sum"], + ), + ), + Case("normal_case_implicit")( + # Column "float_all_special" has 26 1s, one null, one nan, one negative + # infinity (clamped to -50), one positive infinity (clamped to 100). + query=( + QueryBuilder("private") + .enforce(MaxRowsPerID(1)) + .sum("float_all_special", -50, 100) + ), + expected_df=pd.DataFrame( + [[76]], # 26-50+100 + columns=["float_all_special_sum"], + ), + ), + Case("nulls_are_not_dropped_in_id_column")( + # When called with no argument, replace_null_and_nan should drop all rows + # that have null/nan values anywhere, except in the privacy ID column. This + # should leave 25 1s even if we're summing a column without nulls. + query=( + QueryBuilder("private") + .drop_null_and_nan() + .enforce(MaxRowsPerID(1)) + .sum("int_no_null", 0, 1) + ), + expected_df=pd.DataFrame([[25]], columns=["int_no_null_sum"]), + ), + ] +) +def test_privacy_ids( + sdf_special_values: DataFrame, query: Query, expected_df: pd.DataFrame +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddRowsWithID("string_nulls"), + ) + result = sess.evaluate(query, inf_budget) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@pytest.fixture(name="sdf_for_joins", scope="module") +def dataframe_for_join(spark): + """Set up test data for sessions with special values. + + This data is then joined with the ``sdf_special_values`` dataframe used previously + in this test suite. + """ + sdf_col_types = { + "string_nulls": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), + "int_nulls": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), + "float_all_special": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=True, + allow_nan=True, + allow_inf=True, + ), + "date_nulls": ColumnDescriptor(ColumnType.DATE, allow_null=True), + "time_nulls": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), + "new_int": ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + } + date = datetime.date(2000, 1, 1) + time = datetime.datetime(2020, 1, 1) + sdf = spark.createDataFrame( + [ + # Normal row + ("normal_0", 1, 1.0, date, time, 1), + # Rows with nulls: some whose values appear in `sdf_special_values`… + (None, 1, 1.0, date, time, 1), + ("u2", None, 1.0, date, time, 1), + ("u3", 1, None, date, time, 1), + # … and two identical rows, where the combination of nulls does not appear + # in `sdf_special_values`. + ("u4", 1, 1.0, None, None, 1), + ("u5", 1, 1.0, None, None, 1), + # Row with nans + ("a6", 1, float("nan"), date, time, 1), + # Rows with infinities + ("i7", 1, float("inf"), date, time, 1), + ("i8", 1, -float("inf"), date, time, 1), + ], + schema=analytics_to_spark_schema(Schema(sdf_col_types)), + ) + return sdf + + +@parametrize( + [ + Case("public_join_inner_all_match")( + # Joining with the first three columns, all columns of the right table + # should match exactly one row, without duplicates. This checks that tables + # are joined on all three kinds of special values. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public( + "public", + ["string_nulls", "int_nulls", "float_all_special"], + "inner", + ) + .sum("new_int", 0, 1) + ), + expected_df=pd.DataFrame( + [[9]], + columns=["new_int_sum"], + ), + ), + Case("public_join_inner_duplicates")( + # Joining with the date and time columns only creates matches for the rows + # where both are specified: 28 in the left table and 7 in the right table. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public("public", ["date_nulls", "time_nulls"], "inner") + .sum("new_int", 0, 1) + ), + expected_df=pd.DataFrame( + [[28 * 7]], + columns=["new_int_sum"], + ), + ), + Case("public_join_left_duplicates")( + # Same as before, except we do a left join, so 2 rows in the original table + # are preserved in the join. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public("public", ["date_nulls", "time_nulls"], "left") + .count() + ), + expected_df=pd.DataFrame( + [[28 * 7 + 2]], + columns=["count"], + ), + ), + Case("private_join_add_rows")( + # Private joins without duplicates should work the same way as the inner + # public join above, leaving the 9 rows in common between the two tables. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_private( + "private_2", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + truncation_strategy_left=TruncationStrategy.DropNonUnique(), + truncation_strategy_right=TruncationStrategy.DropNonUnique(), + ) + .count() + ), + expected_df=pd.DataFrame([[9]], columns=["count"]), + ), + Case("private_join_ids")( + # Same with a privacy ID column. + protected_change=AddRowsWithID("string_nulls"), + query=( + QueryBuilder("private") + .join_private( + "private_2", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + ) + .enforce(MaxRowsPerID(1)) + .count() + ), + expected_df=pd.DataFrame([[9]], columns=["count"]), + ), + Case("private_join_preserves_special_values")( + # After the join, "float_all_special" should have the same data as in the + # table used for the join: 5 1s, one null (replaced by 100), one nan + # (replaced by 100), and two infinities (dropped). + protected_change=AddRowsWithID("string_nulls"), + query=( + QueryBuilder("private") + .join_private( + "private_2", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + ) + .enforce(MaxRowsPerID(1)) + .drop_infinity(["float_all_special"]) + .replace_null_and_nan({"float_all_special": 100.0}) + .sum("float_all_special", 0, 200) + ), + expected_df=pd.DataFrame( + [[5 + 100 + 100]], columns=["float_all_special_sum"] + ), + ), + Case("public_join_preserves_special_values")( + # Same with a public join. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public( + "public", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + ) + .drop_infinity(["float_all_special"]) + .replace_null_and_nan({"float_all_special": 100.0}) + .sum("float_all_special", 0, 200) + ), + expected_df=pd.DataFrame( + [[5 + 100 + 100]], columns=["float_all_special_sum"] + ), + ), + ] +) +def test_joins( + sdf_special_values: DataFrame, + sdf_for_joins: DataFrame, + protected_change: ProtectedChange, + query: Query, + expected_df: pd.DataFrame, +): + inf_budget = PureDPBudget.inf() + sess = ( + Session.Builder() + .with_id_space("default_id_space") + .with_private_dataframe("private", sdf_special_values, protected_change) + .with_private_dataframe("private_2", sdf_for_joins, protected_change) + .with_public_dataframe("public", sdf_for_joins) + .with_privacy_budget(inf_budget) + .build() + ) + result = sess.evaluate(query, inf_budget) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("private_int_remove_nulls")( + # Null joined with no nulls = no nulls + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .rename({"int_no_null": "int_joined"}) + .join_private( + QueryBuilder("private_2").rename({"int_nulls": "int_joined"}), + join_columns=["int_joined"], + truncation_strategy_left=TruncationStrategy.DropExcess(30), + truncation_strategy_right=TruncationStrategy.DropExcess(30), + ) + ), + expected_col=( + "int_joined", + ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + ), + ), + Case("private_float_remove_both")( + # All special joined with only nulls & nan = only nulls & nan + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .drop_null_and_nan(["float_all_special"]) + .join_private( + QueryBuilder("private").drop_infinity(["float_all_special"]), + join_columns=["float_all_special"], + truncation_strategy_left=TruncationStrategy.DropExcess(30), + truncation_strategy_right=TruncationStrategy.DropExcess(30), + ) + ), + expected_col=( + "float_all_special", + ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=False, + allow_inf=False, + ), + ), + ), + Case("public_int_remove_nulls_from_right")( + # No nulls joined with nulls = no nulls (public version) + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .select(["int_no_null"]) + .rename({"int_no_null": "int_nulls"}) + .join_public( + "public", + join_columns=["int_nulls"], + ) + ), + expected_col=( + "int_nulls", + ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + ), + ), + Case("public_int_remove_nulls_from_left")( + # Nulls joined with no nulls = no nulls (reverse) + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .rename({"int_nulls": "new_int"}) + .join_public( + "public", + join_columns=["new_int"], + ) + ), + expected_col=( + "new_int", + ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + ), + ), + Case("public_int_keep_null_on_left_join")( + # Nulls *left* joined with no nulls = nulls + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .rename({"int_nulls": "new_int"}) + .join_public( + "public", + join_columns=["new_int"], + how="left", + ) + ), + expected_col=( + "int_nulls", + ColumnDescriptor(ColumnType.INTEGER, allow_null=True), + ), + ), + ] +) +def test_join_schema( + sdf_special_values: DataFrame, + sdf_for_joins: DataFrame, + protected_change: ProtectedChange, + query: Query, + expected_col: Dict["str", ColumnDescriptor], +): + inf_budget = PureDPBudget.inf() + sess = ( + Session.Builder() + .with_id_space("default_id_space") + .with_private_dataframe("private", sdf_special_values, protected_change) + .with_private_dataframe("private_2", sdf_for_joins, protected_change) + .with_public_dataframe("public", sdf_for_joins) + .with_privacy_budget(inf_budget) + .build() + ) + sess.create_view(query, "view", cache=False) + schema = sess.get_schema("view") + assert expected_col in schema.items() diff --git a/test/unit/query_expr_compiler/test_measurement_visitor.py b/test/unit/query_expr_compiler/test_measurement_visitor.py index e8791f98..7be38fa6 100644 --- a/test/unit/query_expr_compiler/test_measurement_visitor.py +++ b/test/unit/query_expr_compiler/test_measurement_visitor.py @@ -635,24 +635,6 @@ def test_visit_groupby_count_distinct( ] ), ), - ( - QueryBuilder("private").quantile( - low=-100, - high=100, - name="custom_output_column", - column="null_and_nan", - quantile=0.1, - ), - PureDP(), - NoiseInfo( - [ - { - "noise_mechanism": _NoiseMechanism.EXPONENTIAL, - "noise_parameter": 3.3333333333333326, - } - ] - ), - ), ( QueryBuilder("private") .groupby(KeySet.from_dict({"B": [0, 1]})) @@ -673,26 +655,6 @@ def test_visit_groupby_count_distinct( ] ), ), - ( - QueryBuilder("private") - .groupby(KeySet.from_dict({"B": [0, 1]})) - .quantile( - column="null_and_inf", - name="quantile", - low=123.345, - high=987.65, - quantile=0.25, - ), - PureDP(), - NoiseInfo( - [ - { - "noise_mechanism": _NoiseMechanism.EXPONENTIAL, - "noise_parameter": 3.3333333333333326, - } - ] - ), - ), ( QueryBuilder("private") .groupby(KeySet.from_dict({"A": ["zero"]})) @@ -707,20 +669,6 @@ def test_visit_groupby_count_distinct( ] ), ), - ( - QueryBuilder("private") - .groupby(KeySet.from_dict({"A": ["zero"]})) - .quantile(quantile=0.5, low=0, high=1, column="nan_and_inf"), - RhoZCDP(), - NoiseInfo( - [ - { - "noise_mechanism": _NoiseMechanism.EXPONENTIAL, - "noise_parameter": 2.9814239699997196, - } - ] - ), - ), ( QueryBuilder("private") .groupby(KeySet.from_dict({"A": ["zero"]})) diff --git a/test/unit/query_expr_compiler/test_rewrite_rules.py b/test/unit/query_expr_compiler/test_rewrite_rules.py index 8e253e77..c5b4a249 100644 --- a/test/unit/query_expr_compiler/test_rewrite_rules.py +++ b/test/unit/query_expr_compiler/test_rewrite_rules.py @@ -13,6 +13,9 @@ AverageMechanism, CountDistinctMechanism, CountMechanism, + DropInfinity, + DropNullAndNan, + GetBounds, GroupByBoundedAverage, GroupByBoundedStdev, GroupByBoundedSum, @@ -22,6 +25,7 @@ PrivateSource, QueryExpr, QueryExprVisitor, + ReplaceInfinity, SingleChildQueryExpr, StdevMechanism, SumMechanism, @@ -30,9 +34,10 @@ ) from tmlt.analytics._query_expr_compiler._rewrite_rules import ( CompilationInfo, + add_special_value_handling, select_noise_mechanism, ) -from tmlt.analytics._schema import ColumnDescriptor, ColumnType, Schema +from tmlt.analytics._schema import ColumnDescriptor, ColumnType, FrozenDict, Schema # SPDX-License-Identifier: Apache-2.0 # Copyright Tumult Labs 2025 @@ -391,3 +396,241 @@ def test_recursive_noise_selection(catalog: Catalog) -> None: info = CompilationInfo(output_measure=ApproxDP(), catalog=catalog) got_expr = select_noise_mechanism(info)(expr) assert got_expr == expected_expr + + +@parametrize( + [ + Case()(agg="count"), + Case()(agg="count_distinct"), + ] +) +@parametrize( + [ + Case()(col_desc=ColumnDescriptor(ColumnType.INTEGER, allow_null=False)), + Case()(col_desc=ColumnDescriptor(ColumnType.INTEGER, allow_null=True)), + Case()( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=False + ) + ), + Case()( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=True + ) + ), + ] +) +def test_special_value_handling_count_unaffected( + agg: str, + col_desc: ColumnDescriptor, +) -> None: + (AggExpr, AggMech) = AGG_CLASSES[agg] + expr = AggExpr( + child=BASE_EXPR, + groupby_keys=KeySet.from_dict({}), + mechanism=AggMech["DEFAULT"], + ) + catalog = Catalog() + catalog.add_private_table("private", {"col": col_desc}) + info = CompilationInfo(output_measure=PureDP(), catalog=catalog) + got_expr = add_special_value_handling(info)(expr) + assert got_expr == expr + + +@parametrize( + [ + # Columns with no special values should be unaffected + Case(f"no_op_null_{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=False, allow_inf=False + ), + new_child=BASE_EXPR, + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs and infinities do not matter for non-floats + Case(f"no_op_nan_inf_{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=True, allow_inf=True + ), + new_child=BASE_EXPR, + ) + for col_type in [ColumnType.INTEGER, ColumnType.DATE, ColumnType.TIMESTAMP] + ] + + [ + # Nulls must be dropped if needed + Case(f"drop_null_{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=True, allow_nan=False, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs must also be dropped if needed + Case("drop_nan")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=True, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ), + # Only one pass is enough to drop both nulls and NaNs + Case("drop_both")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ), + # If not handled, infinities must be clamped to the clamping bounds + Case("clamp_inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=True + ), + new_child=ReplaceInfinity( + child=BASE_EXPR, replace_with=FrozenDict.from_dict({"col": (0, 1)}) + ), + ), + # Handling both kinds of special values at once. This would fail if the two + # value handling exprs are in the wrong order; this is not ideal, but ah well. + Case("drop_nan_clamp_inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=True + ), + new_child=ReplaceInfinity( + child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + replace_with=FrozenDict.from_dict({"col": (0, 1)}), + ), + ), + ] +) +@parametrize( + [ + Case()(agg="sum"), + Case()(agg="average"), + Case()(agg="stdev"), + Case()(agg="variance"), + ] +) +def test_special_value_handling_numeric_aggregations( + agg: str, + col_desc: ColumnDescriptor, + new_child: QueryExpr, +) -> None: + (AggExpr, AggMech) = AGG_CLASSES[agg] + expr = AggExpr( + child=BASE_EXPR, + measure_column="col", + low=0, + high=1, + groupby_keys=KeySet.from_dict({}), + mechanism=AggMech["DEFAULT"], + ) + catalog = Catalog() + catalog.add_private_table("private", {"col": col_desc}) + info = CompilationInfo(output_measure=PureDP(), catalog=catalog) + got_expr = add_special_value_handling(info)(expr) + assert got_expr == replace( + expr, + child=new_child, + ) + + +@parametrize( + [ + # Columns with no special values should be unaffected + Case(f"no-op-{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=False, allow_inf=False + ), + new_child=BASE_EXPR, + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs and infinities do not matter for non-floats + Case(f"no-op-nan-inf-{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=True, allow_inf=True + ), + new_child=BASE_EXPR, + ) + for col_type in [ColumnType.INTEGER, ColumnType.DATE, ColumnType.TIMESTAMP] + ] + + [ + # Nulls must be dropped if needed + Case(f"drop-nulls-{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=True, allow_nan=False, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs must also be dropped if needed + Case("drop-nan")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=True, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ), + # Same for infinities (contrary to other aggregations which use clamping) + Case("drop-inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=True + ), + new_child=DropInfinity(child=BASE_EXPR, columns=("col",)), + ), + # And both kinds of special values must be handled + Case("drop-nan-and-inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=True + ), + new_child=DropInfinity( + child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + columns=("col",), + ), + ), + ] +) +def test_special_value_handling_get_bounds( + col_desc: ColumnDescriptor, + new_child: QueryExpr, +) -> None: + expr = GetBounds( + child=BASE_EXPR, + measure_column="col", + groupby_keys=KeySet.from_dict({}), + lower_bound_column="lower", + upper_bound_column="upper", + ) + catalog = Catalog() + catalog.add_private_table("private", {"col": col_desc}) + info = CompilationInfo(output_measure=PureDP(), catalog=catalog) + got_expr = add_special_value_handling(info)(expr) + assert got_expr == replace( + expr, + child=new_child, + )