From d93b47699bbb43bd1cef5ac622db6e613f4aa2e7 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Fri, 24 Oct 2025 18:14:03 +0200 Subject: [PATCH 01/16] create rewriting logic, fix docstrings --- src/tmlt/analytics/_query_expr.py | 55 ++++----------- .../_query_expr_compiler/_rewrite_rules.py | 67 +++++++++++++++++++ src/tmlt/analytics/query_builder.py | 55 ++++++++------- 3 files changed, 107 insertions(+), 70 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index d5c58323..0c486d8c 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -1,11 +1,10 @@ """Building blocks of the Tumult Analytics query language. Not for direct use. -Defines the :class:`QueryExpr` class, which represents expressions in the -Tumult Analytics query language. QueryExpr and its subclasses should not be -directly constructed or deconstructed by most users; interfaces such as -:class:`tmlt.analytics.QueryBuilder` to create them and -:class:`tmlt.analytics.Session` to consume them provide more -user-friendly features. +Defines the :class:`QueryExpr` class, which represents expressions in the Tumult +Analytics query language. QueryExpr and its subclasses should not be directly +constructed; but instead built using a :class:`tmlt.analytics.QueryBuilder`. The +documentation of the :class:`tmlt.analytics.QueryBuilder` provides more information +about the intended semantics of :class:`QueryExpr` objects. """ # SPDX-License-Identifier: Apache-2.0 @@ -174,14 +173,10 @@ class StdevMechanism(Enum): class QueryExpr(ABC): """A query expression, base class for relational operators. - In most cases, QueryExpr should not be manipulated directly, but rather - created using :class:`tmlt.analytics.QueryBuilder` and then - consumed by :class:`tmlt.analytics.Session`. While they can be - created and modified directly, this is an advanced usage and is not - recommended for typical users. - - QueryExpr are organized in a tree, where each node is an operator which - returns a relation. + QueryExpr are organized in a tree, where each node is an operator that returns a + table. They are built using the :class:`tmlt.analytics.QueryBuilder`, then rewritten + during the compilation process. They should not be created directly, except in + tests. """ @abstractmethod @@ -1784,13 +1779,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedSum(SingleChildQueryExpr): - """Returns the bounded sum of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the sum is - calculated. - """ + """Returns the bounded sum of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" @@ -1851,13 +1840,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedAverage(SingleChildQueryExpr): - """Returns bounded average of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the average is - calculated. - """ + """Returns bounded average of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" @@ -1918,13 +1901,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedVariance(SingleChildQueryExpr): - """Returns bounded variance of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the variance is - calculated. - """ + """Returns bounded variance of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" @@ -1985,13 +1962,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: @dataclass(frozen=True) class GroupByBoundedSTDEV(SingleChildQueryExpr): - """Returns bounded stdev of a column for each combination of groupby domains. - - If the column to be measured contains null, NaN, or positive or negative infinity, - those values will be dropped (as if dropped explicitly via - :class:`DropNullAndNan` and :class:`DropInfinity`) before the - standard deviation is calculated. - """ + """Returns bounded stdev of a column for each combination of groupby domains.""" groupby_keys: Union[KeySet, Tuple[str, ...]] """The keys, or columns list to collect keys from, to be grouped on.""" diff --git a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py index 6a8e9006..bf2924af 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py +++ b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py @@ -20,15 +20,21 @@ AverageMechanism, CountDistinctMechanism, CountMechanism, + DropInfinity, + DropNullAndNan, + FrozenDict, + GetBounds, GroupByBoundedAverage, GroupByBoundedSTDEV, GroupByBoundedSum, GroupByBoundedVariance, GroupByCount, GroupByCountDistinct, + GroupByQuantile, JoinPrivate, PrivateSource, QueryExpr, + ReplaceInfinity, SingleChildQueryExpr, StdevMechanism, SumMechanism, @@ -193,9 +199,70 @@ def select_noise(expr: QueryExpr) -> QueryExpr: return select_noise +def add_special_value_handling( + info: CompilationInfo, +) -> Callable[[QueryExpr], QueryExpr]: + """Rewrites the query to handle nulls, NaNs and infinite values. + + If the measure column allows nulls or NaNs, the rewritten query will drop those + values. If the measure column allows infinite values, the new query will replace + those values with the clamping bounds specified in the query, or drop these values + for :meth:`~tmlt.analytics.QueryBuilder.get_bounds`. + """ + + @depth_first + def handle_special_values(expr: QueryExpr) -> QueryExpr: + if not isinstance( + expr, + ( + GroupByBoundedAverage, + GroupByBoundedSTDEV, + GroupByBoundedSum, + GroupByBoundedVariance, + GroupByQuantile, + GetBounds, + ), + ): + return expr + schema = expr.child.schema(info.catalog) + measure_desc = schema[expr.measure_column] + # Remove nulls/NaN if necessary + if measure_desc.allow_null or ( + measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_nan + ): + expr = replace( + expr, + child=DropNullAndNan( + child=expr.child, columns=tuple([expr.measure_column]) + ), + ) + # Remove infinities if necessary + if measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_inf: + if isinstance(expr, GetBounds): + return replace( + expr, + child=DropInfinity( + child=expr.child, columns=tuple([expr.measure_column]) + ), + ) + return replace( + expr, + child=ReplaceInfinity( + child=expr.child, + replace_with=FrozenDict.from_dict( + {expr.measure_column: (expr.low, expr.high)} + ), + ), + ) + return expr + + return handle_special_values + + def rewrite(info: CompilationInfo, expr: QueryExpr) -> QueryExpr: """Rewrites the given QueryExpr into a QueryExpr that can be compiled.""" rewrite_rules = [ + add_special_value_handling(info), select_noise_mechanism(info), ] for rule in rewrite_rules: diff --git a/src/tmlt/analytics/query_builder.py b/src/tmlt/analytics/query_builder.py index b3c1976a..80e0eb56 100644 --- a/src/tmlt/analytics/query_builder.py +++ b/src/tmlt/analytics/query_builder.py @@ -2449,12 +2449,6 @@ def sum( ) -> Query: """Returns a sum query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2468,6 +2462,12 @@ def sum( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2535,12 +2535,6 @@ def average( ) -> Query: """Returns an average query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2554,6 +2548,12 @@ def average( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2621,12 +2621,6 @@ def variance( ) -> Query: """Returns a variance query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2640,6 +2634,12 @@ def variance( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2707,12 +2707,6 @@ def stdev( ) -> Query: """Returns a standard deviation query ready to be evaluated. - .. note:: - If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the - column being measured contains infinite values, a - :meth:`~drop_infinity` query will be performed first. - .. note:: Regarding the clamping bounds: @@ -2726,6 +2720,12 @@ def stdev( Consult the :ref:`Numerical aggregations ` tutorial for more information. + .. note:: + If the column being measured contains NaN or null values, a + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. + .. >>> from tmlt.analytics import ( ... AddOneRow, @@ -2978,10 +2978,9 @@ def quantile( .. note:: If the column being measured contains NaN or null values, a - :meth:`~QueryBuilder.drop_null_and_nan` query will be performed - first. If the column being measured contains infinite values, a - :meth:`~QueryBuilder.drop_infinity` query will be performed first. - + :meth:`~drop_null_and_nan` query will be performed first. If the column + being measured contains infinite values, these values will be clamped + between ``low`` and ``high``. .. >>> from tmlt.analytics import ( ... AddOneRow, From 9dcedc28c37a66a3c59fc4f37d848d4e9f3206b4 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 27 Oct 2025 14:44:50 +0100 Subject: [PATCH 02/16] rewrite session tests for special values --- .../_base_measurement_visitor.py | 68 +-- test/system/session/rows/conftest.py | 54 -- .../rows/test_add_max_rows_infs_nulls.py | 512 ------------------ test/system/session/test_special_values.py | 486 +++++++++++++++++ 4 files changed, 487 insertions(+), 633 deletions(-) delete mode 100644 test/system/session/rows/test_add_max_rows_infs_nulls.py create mode 100644 test/system/session/test_special_values.py diff --git a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py index 26427d02..b5d94565 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py +++ b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py @@ -1,5 +1,4 @@ """Defines a base class for building measurement visitors.""" -import dataclasses import math import warnings from abc import abstractmethod @@ -101,7 +100,7 @@ SuppressAggregates, VarianceMechanism, ) -from tmlt.analytics._schema import ColumnType, FrozenDict, Schema +from tmlt.analytics._schema import Schema from tmlt.analytics._table_identifier import Identifier from tmlt.analytics._table_reference import TableReference from tmlt.analytics._transformation_utils import get_table_from_ref @@ -675,65 +674,6 @@ def _validate_approxDP_and_adjust_budget( else: raise AnalyticsInternalError(f"Unknown mechanism {mechanism}.") - def _add_special_value_handling_to_query( - self, - query: Union[ - GroupByBoundedAverage, - GroupByBoundedSTDEV, - GroupByBoundedSum, - GroupByBoundedVariance, - GroupByQuantile, - GetBounds, - ], - ): - """Returns a new query that handles nulls, NaNs and infinite values. - - If the measure column allows nulls or NaNs, the new query - will drop those values. - - If the measure column allows infinite values, the new query will replace those - values with the low and high values specified in the query. - - These changes are added immediately before the groupby aggregation in the query. - """ - expected_schema = query.child.schema(self.catalog) - - # You can't perform these queries on nulls, NaNs, or infinite values - # so check for those - try: - measure_desc = expected_schema[query.measure_column] - except KeyError as e: - raise KeyError( - f"Measure column {query.measure_column} is not in the input schema." - ) from e - - new_child: QueryExpr - # If null or NaN values are allowed ... - if measure_desc.allow_null or ( - measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_nan - ): - # then drop those values - # (but don't mutate the original query) - new_child = DropNullAndNan( - child=query.child, columns=tuple([query.measure_column]) - ) - query = dataclasses.replace(query, child=new_child) - if not isinstance(query, GetBounds): - # If infinite values are allowed... - if ( - measure_desc.column_type == ColumnType.DECIMAL - and measure_desc.allow_inf - ): - # then clamp them (to low/high values) - new_child = ReplaceInfinity( - child=query.child, - replace_with=FrozenDict.from_dict( - {query.measure_column: (query.low, query.high)} - ), - ) - query = dataclasses.replace(query, child=new_child) - return query - def _validate_measurement(self, measurement: Measurement, mid_stability: sp.Expr): """Validate a measurement.""" if isinstance(self.adjusted_budget.value, tuple): @@ -1146,7 +1086,6 @@ def visit_groupby_quantile( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1241,7 +1180,6 @@ def visit_groupby_bounded_sum( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1337,7 +1275,6 @@ def visit_groupby_bounded_average( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1433,7 +1370,6 @@ def visit_groupby_bounded_variance( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1529,7 +1465,6 @@ def visit_groupby_bounded_stdev( # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1622,7 +1557,6 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]: # Peek at the schema, to see if there are errors there expr.schema(self.catalog) - expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) keyset_budget = self._get_zero_budget() diff --git a/test/system/session/rows/conftest.py b/test/system/session/rows/conftest.py index 92ffa308..e2cbd04d 100644 --- a/test/system/session/rows/conftest.py +++ b/test/system/session/rows/conftest.py @@ -38,7 +38,6 @@ FrozenDict, Schema, analytics_to_spark_columns_descriptor, - analytics_to_spark_schema, ) # Shorthands for some values used in tests @@ -708,56 +707,3 @@ def sess_data(spark, request): analytics_to_spark_columns_descriptor(Schema(sdf_col_types)) ) request.cls.sdf_input_domain = sdf_input_domain - - -###DATA FOR SESSIONS WITH NULLS### -@pytest.fixture(name="null_session_data", scope="class") -def null_setup(spark, request): - """Set up test data for sessions with nulls.""" - # Since Spark gives back timestamps with microsecond accuracy, this - # dataframe needs to make that the default precision for column T. - pdf = pd.DataFrame( - [ - ["a0", 0, 0.0, datetime.date(2000, 1, 1), datetime.datetime(2020, 1, 1)], - [None, 1, 1.0, datetime.date(2001, 1, 1), datetime.datetime(2021, 1, 1)], - ["a2", None, 2.0, datetime.date(2002, 1, 1), datetime.datetime(2022, 1, 1)], - ["a3", 3, None, datetime.date(2003, 1, 1), datetime.datetime(2023, 1, 1)], - ["a4", 4, 4.0, None, datetime.datetime(2024, 1, 1)], - ["a5", 5, 5.0, datetime.date(2005, 1, 1), None], - ], - columns=["A", "I", "X", "D", "T"], - ).astype({"T": "datetime64[us]"}) - - request.cls.pdf = pdf - - sdf_col_types = { - "A": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), - "I": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), - "X": ColumnDescriptor(ColumnType.DECIMAL, allow_null=True), - "D": ColumnDescriptor(ColumnType.DATE, allow_null=True), - "T": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), - } - - sdf = spark.createDataFrame( - pdf, schema=analytics_to_spark_schema(Schema(sdf_col_types)) - ) - request.cls.sdf = sdf - - -###DATA FOR SESSIONS WITH INF VALUES### -@pytest.fixture(name="infs_test_data", scope="class") -def infs_setup(spark, request): - """Set up tests.""" - pdf = pd.DataFrame( - {"A": ["a0", "a0", "a1", "a1"], "B": [float("-inf"), 2.0, 5.0, float("inf")]} - ) - request.cls.pdf = pdf - - sdf_col_types = { - "A": ColumnDescriptor(ColumnType.VARCHAR), - "B": ColumnDescriptor(ColumnType.DECIMAL, allow_inf=True), - } - sdf = spark.createDataFrame( - pdf, schema=analytics_to_spark_schema(Schema(sdf_col_types)) - ) - request.cls.sdf = sdf diff --git a/test/system/session/rows/test_add_max_rows_infs_nulls.py b/test/system/session/rows/test_add_max_rows_infs_nulls.py deleted file mode 100644 index 20738d87..00000000 --- a/test/system/session/rows/test_add_max_rows_infs_nulls.py +++ /dev/null @@ -1,512 +0,0 @@ -"""System tests for Sessions with Nulls and Infs.""" - -# SPDX-License-Identifier: Apache-2.0 -# Copyright Tumult Labs 2025 - -import datetime -from typing import Any, Dict, List, Mapping, Tuple, Union - -import pandas as pd -import pytest -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import StringType, StructField, StructType -from tmlt.core.measurements.interactive_measurements import SequentialQueryable -from tmlt.core.utils.testing import Case, parametrize - -from tmlt.analytics import ( - AddOneRow, - AnalyticsDefault, - ColumnDescriptor, - ColumnType, - KeySet, - PureDPBudget, - QueryBuilder, - Session, - TruncationStrategy, -) -from tmlt.analytics._table_identifier import NamedTable - -from ....conftest import assert_frame_equal_with_sort - - -@pytest.mark.usefixtures("null_session_data") -class TestSessionWithNulls: - """Tests for sessions with Nulls.""" - - pdf: pd.DataFrame - sdf: DataFrame - - def _expected_replace(self, d: Mapping[str, Any]) -> pd.DataFrame: - """The expected value if you replace None with default values in d.""" - new_cols: List[pd.DataFrame] = [] - for col in list(self.pdf.columns): - if col in dict(d): - # make sure I becomes an integer here - if col == "I": - new_cols.append(self.pdf[col].fillna(dict(d)[col]).astype(int)) - else: - new_cols.append(self.pdf[col].fillna(dict(d)[col])) - else: - new_cols.append(self.pdf[col]) - # `axis=1` means that you want to "concatenate" by columns - # i.e., you want your new table to look like this: - # df1 | df2 | df3 | ... - # df1 | df2 | df3 | ... - return pd.concat(new_cols, axis=1) - - def test_expected_replace(self) -> None: - """Test the test method _expected_replace.""" - d = { - "A": "a999", - "I": -999, - "X": 99.9, - "D": datetime.date(1999, 1, 1), - "T": datetime.datetime(2019, 1, 1), - } - expected = pd.DataFrame( - [ - [ - "a0", - 0, - 0.0, - datetime.date(2000, 1, 1), - datetime.datetime(2020, 1, 1), - ], - [ - "a999", - 1, - 1.0, - datetime.date(2001, 1, 1), - datetime.datetime(2021, 1, 1), - ], - [ - "a2", - -999, - 2.0, - datetime.date(2002, 1, 1), - datetime.datetime(2022, 1, 1), - ], - [ - "a3", - 3, - 99.9, - datetime.date(2003, 1, 1), - datetime.datetime(2023, 1, 1), - ], - [ - "a4", - 4, - 4.0, - datetime.date(1999, 1, 1), - datetime.datetime(2024, 1, 1), - ], - [ - "a5", - 5, - 5.0, - datetime.date(2005, 1, 1), - datetime.datetime(2019, 1, 1), - ], - ], - columns=["A", "I", "X", "D", "T"], - ) - assert_frame_equal_with_sort(self.pdf, self._expected_replace({})) - assert_frame_equal_with_sort( - expected, - self._expected_replace(d), - ) - - @pytest.mark.parametrize( - "cols_to_defaults", - [ - ({"A": "aaaaaaa"}), - ({"I": 999}), - ( - { - "A": "aaa", - "I": 999, - "X": -99.9, - "D": datetime.date.fromtimestamp(0), - "T": datetime.datetime.fromtimestamp(0), - } - ), - ], - ) - def test_replace_null_and_nan( - self, - cols_to_defaults: Mapping[ - str, Union[int, float, str, datetime.date, datetime.datetime] - ], - ) -> None: - """Test Session.replace_null_and_nan.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - session.create_view( - QueryBuilder("private").replace_null_and_nan(cols_to_defaults), - "replaced", - cache=False, - ) - # pylint: disable=protected-access - queryable = session._accountant._queryable - assert isinstance(queryable, SequentialQueryable) - data = queryable._data - assert isinstance(data, dict) - assert isinstance(data[NamedTable("replaced")], DataFrame) - # pylint: enable=protected-access - assert_frame_equal_with_sort( - data[NamedTable("replaced")].toPandas(), - self._expected_replace(cols_to_defaults), - ) - - @pytest.mark.parametrize( - "public_df,keyset,expected", - [ - ( - pd.DataFrame( - [[None, 0], [None, 1], ["a2", 1], ["a2", 2]], - columns=["A", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 1, 2]}), - pd.DataFrame([[0, 1], [1, 2], [2, 1]], columns=["new_column", "count"]), - ), - ( - pd.DataFrame( - [["a0", 0, 0], [None, 1, 17], ["a5", 5, 17], ["a5", 5, 400]], - columns=["A", "I", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 17, 400]}), - pd.DataFrame( - [[0, 1], [17, 2], [400, 1]], columns=["new_column", "count"] - ), - ), - ( - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), "2000"], - [datetime.date(2001, 1, 1), "2001"], - [None, "none"], - [None, "also none"], - ], - columns=["D", "year"], - ), - KeySet.from_dict( - {"D": [datetime.date(2000, 1, 1), datetime.date(2001, 1, 1), None]} - ), - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), 1], - [datetime.date(2001, 1, 1), 1], - [None, 2], - ], - columns=["D", "count"], - ), - ), - ], - ) - def test_join_public( - self, spark, public_df: pd.DataFrame, keyset: KeySet, expected: pd.DataFrame - ) -> None: - """Test that join_public creates the correct results. - - The query used to evaluate this is a GroupByCount on the new dataframe, - using the keyset provided. - """ - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - session.add_public_dataframe("public", spark.createDataFrame(public_df)) - result = session.evaluate( - QueryBuilder("private").join_public("public").groupby(keyset).count(), - privacy_budget=PureDPBudget(float("inf")), - ) - assert_frame_equal_with_sort(result.toPandas(), expected) - - @pytest.mark.parametrize( - "private_df,keyset,expected", - [ - ( - pd.DataFrame( - [[None, 0], [None, 1], ["a2", 1], ["a2", 2]], - columns=["A", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 1, 2]}), - pd.DataFrame([[0, 1], [1, 2], [2, 1]], columns=["new_column", "count"]), - ), - ( - pd.DataFrame( - [["a0", 0, 0], [None, 1, 17], ["a5", 5, 17], ["a5", 5, 400]], - columns=["A", "I", "new_column"], - ), - KeySet.from_dict({"new_column": [0, 17, 400]}), - pd.DataFrame( - [[0, 1], [17, 2], [400, 1]], columns=["new_column", "count"] - ), - ), - ( - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), "2000"], - [datetime.date(2001, 1, 1), "2001"], - [None, "none"], - [None, "also none"], - ], - columns=["D", "year"], - ), - KeySet.from_dict( - {"D": [datetime.date(2000, 1, 1), datetime.date(2001, 1, 1), None]} - ), - pd.DataFrame( - [ - [datetime.date(2000, 1, 1), 1], - [datetime.date(2001, 1, 1), 1], - [None, 2], - ], - columns=["D", "count"], - ), - ), - ], - ) - def test_join_private( - self, spark, private_df: pd.DataFrame, keyset: KeySet, expected: pd.DataFrame - ) -> None: - """Test that join_private creates the correct results. - - The query used to evaluate this is a GroupByCount on the joined dataframe, - using the keyset provided. - """ - session = ( - Session.Builder() - .with_privacy_budget(PureDPBudget(float("inf"))) - .with_private_dataframe("private", self.sdf, AddOneRow()) - .with_private_dataframe( - "private2", spark.createDataFrame(private_df), AddOneRow() - ) - .build() - ) - result = session.evaluate( - QueryBuilder("private") - .join_private( - QueryBuilder("private2"), - TruncationStrategy.DropExcess(100), - TruncationStrategy.DropExcess(100), - ) - .groupby(keyset) - .count(), - PureDPBudget(float("inf")), - ) - assert_frame_equal_with_sort(result.toPandas(), expected) - - @parametrize( - Case("both_allow_nulls")( - public_schema=StructType([StructField("foo", StringType(), True)]), - private_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True) - }, - ), - Case("none_allow_nulls")( - public_schema=StructType([StructField("foo", StringType(), False)]), - private_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("public_only_nulls")( - public_schema=StructType([StructField("foo", StringType(), True)]), - private_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("private_only_nulls")( - public_schema=StructType([StructField("foo", StringType(), False)]), - private_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - ) - def test_public_join_schema_null_propagation( - self, - public_schema: StructType, - private_schema: StructType, - expected_schema: StructType, - spark: SparkSession, - ): - """Tests that join_public correctly handles schemas that allow null values.""" - public_df = spark.createDataFrame([], public_schema) - private_df = spark.createDataFrame([], private_schema) - sess = ( - Session.Builder() - .with_privacy_budget(PureDPBudget(float("inf"))) - .with_private_dataframe("private", private_df, protected_change=AddOneRow()) - .with_public_dataframe("public", public_df) - .build() - ) - sess.create_view( - QueryBuilder("private").join_public("public"), source_id="join", cache=False - ) - assert sess.get_schema("join") == expected_schema - - @parametrize( - Case("both_allow_nulls")( - left_schema=StructType([StructField("foo", StringType(), True)]), - right_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True) - }, - ), - Case("none_allow_nulls")( - left_schema=StructType([StructField("foo", StringType(), False)]), - right_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("public_only_nulls")( - left_schema=StructType([StructField("foo", StringType(), True)]), - right_schema=StructType([StructField("foo", StringType(), False)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - Case("private_only_nulls")( - left_schema=StructType([StructField("foo", StringType(), False)]), - right_schema=StructType([StructField("foo", StringType(), True)]), - expected_schema={ - "foo": ColumnDescriptor(ColumnType.VARCHAR, allow_null=False) - }, - ), - ) - def test_private_join_schema_null_propagation( - self, - left_schema: StructType, - right_schema: StructType, - expected_schema: StructType, - spark: SparkSession, - ): - """Tests that join_private correctly handles schemas that allow null values.""" - left_df = spark.createDataFrame([], left_schema) - right_df = spark.createDataFrame([], right_schema) - sess = ( - Session.Builder() - .with_privacy_budget(PureDPBudget(float("inf"))) - .with_private_dataframe("left", left_df, protected_change=AddOneRow()) - .with_private_dataframe("right", right_df, protected_change=AddOneRow()) - .build() - ) - sess.create_view( - QueryBuilder("left").join_private( - "right", - truncation_strategy_left=TruncationStrategy.DropExcess(1), - truncation_strategy_right=TruncationStrategy.DropExcess(1), - ), - source_id="join", - cache=False, - ) - assert sess.get_schema("join") == expected_schema - - -@pytest.mark.usefixtures("infs_test_data") -class TestSessionWithInfs: - """Tests for Sessions with Infs.""" - - pdf: pd.DataFrame - sdf: DataFrame - - @pytest.mark.parametrize( - "replace_with,", - [ - ({}), - ({"B": (-100.0, 100.0)}), - ({"B": (123.45, 678.90)}), - ({"B": (999.9, 111.1)}), - ], - ) - def test_replace_infinity( - self, replace_with: Dict[str, Tuple[float, float]] - ) -> None: - """Test replace_infinity query.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - session.create_view( - QueryBuilder("private").replace_infinity(replace_with), - "replaced", - cache=False, - ) - # pylint: disable=protected-access - queryable = session._accountant._queryable - assert isinstance(queryable, SequentialQueryable) - data = queryable._data - assert isinstance(data, dict) - assert isinstance(data[NamedTable("replaced")], DataFrame) - # pylint: enable=protected-access - (replace_negative, replace_positive) = replace_with.get( - "B", (AnalyticsDefault.DECIMAL, AnalyticsDefault.DECIMAL) - ) - expected = self.pdf.replace(float("-inf"), replace_negative).replace( - float("inf"), replace_positive - ) - assert_frame_equal_with_sort(data[NamedTable("replaced")].toPandas(), expected) - - @pytest.mark.parametrize( - "replace_with,expected", - [ - ({}, pd.DataFrame([["a0", 2.0], ["a1", 5.0]], columns=["A", "sum"])), - ( - {"B": (-100.0, 100.0)}, - pd.DataFrame([["a0", -98.0], ["a1", 105.0]], columns=["A", "sum"]), - ), - ( - {"B": (500.0, 100.0)}, - pd.DataFrame([["a0", 502.0], ["a1", 105.0]], columns=["A", "sum"]), - ), - ], - ) - def test_sum( - self, replace_with: Dict[str, Tuple[float, float]], expected: pd.DataFrame - ) -> None: - """Test GroupByBoundedSum after replacing infinite values.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - result = session.evaluate( - QueryBuilder("private") - .replace_infinity(replace_with) - .groupby(KeySet.from_dict({"A": ["a0", "a1"]})) - .sum("B", low=-1000, high=1000, name="sum"), - PureDPBudget(float("inf")), - ) - assert_frame_equal_with_sort(result.toPandas(), expected) - - def test_drop_infinity(self): - """Test GroupByBoundedSum after dropping infinite values.""" - session = Session.from_dataframe( - PureDPBudget(float("inf")), - "private", - self.sdf, - protected_change=AddOneRow(), - ) - result = session.evaluate( - QueryBuilder("private") - .drop_infinity(columns=["B"]) - .groupby(KeySet.from_dict({"A": ["a0", "a1"]})) - .sum("B", low=-1000, high=1000, name="sum"), - PureDPBudget(float("inf")), - ) - expected = pd.DataFrame([["a0", 2.0], ["a1", 5.0]], columns=["A", "sum"]) - assert_frame_equal_with_sort(result.toPandas(), expected) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py new file mode 100644 index 00000000..2a3d9b1c --- /dev/null +++ b/test/system/session/test_special_values.py @@ -0,0 +1,486 @@ +"""System tests for tables with special values (nulls, nans, infinities).""" + +# SPDX-License-Identifier: Apache-2.0 +# Copyright Tumult Labs 2025 + +import datetime +from typing import Dict, List, Optional, Tuple, Union + +import pandas as pd +import pytest +from numpy import sqrt +from pyspark.sql import DataFrame +from tmlt.core.utils.testing import Case, parametrize + +from tmlt.analytics import ( + AddOneRow, + AddRowsWithID, + ColumnDescriptor, + ColumnType, + MaxRowsPerID, + ProtectedChange, + PureDPBudget, + QueryBuilder, + Session, +) +from tmlt.analytics._schema import Schema, analytics_to_spark_schema + +from ...conftest import assert_frame_equal_with_sort + + +@pytest.fixture(name="sdf_special_values", scope="module") +def null_setup(spark): + """Set up test data for sessions with special values.""" + sdf_col_types = { + "string": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), + "int_no_null": ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + "int_nulls": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), + "float_no_special": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=False, + allow_inf=False, + ), + "float_nulls": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=True, + allow_nan=False, + allow_inf=False, + ), + "float_nans": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=True, + allow_inf=False, + ), + "float_infs": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=False, + allow_inf=True, + ), + "float_all_special": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=True, + allow_nan=True, + allow_inf=True, + ), + "date": ColumnDescriptor(ColumnType.DATE, allow_null=True), + "time": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), + } + date = datetime.date(2000, 1, 1) + time = datetime.datetime(2020, 1, 1) + sdf = spark.createDataFrame( + [(f"normal_{i}", 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, date, time) for i in range(20)] + + [ + # Rows with nulls + (None, 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, date, time), + ("u2", 1, None, 1.0, 1.0, 1.0, 1.0, 1.0, date, time), + ("u3", 1, 1, 1.0, None, 1.0, 1.0, None, date, time), + ("u4", 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, None, time), + ("u5", 1, 1, 1.0, 1.0, 1.0, 1.0, 1.0, date, None), + # Rows with nans + ("a6", 1, 1, 1.0, 1.0, float("nan"), 1.0, float("nan"), date, time), + # Rows with infinities + ("i7", 1, 1, 1.0, 1.0, 1.0, float("inf"), float("inf"), date, time), + ("i8", 1, 1, 1.0, 1.0, 1.0, -float("inf"), -float("inf"), date, time), + ("i9", 1, 1, 1.0, 1.0, 1.0, float("inf"), 1.0, date, time), + ("i10", 1, 1, 1.0, 1.0, 1.0, -float("inf"), 1.0, date, time), + ], + schema=analytics_to_spark_schema(Schema(sdf_col_types)), + ) + return sdf + + +@parametrize( + [ + Case("int_noop")( + # Column "int_no_null" has only non-null values, all equal to 1 + replace_with={"int_no_null": 42}, + column="int_no_null", + low=0, + high=1, + expected=1, + ), + Case("int_replace_null")( + # Column "int_nulls" has one null value and 29 non-nulls. + replace_with={"int_nulls": 31}, + column="int_nulls", + low=0, + high=100, + expected=2.0, # (29+31)/30 + ), + Case("float_replace_null")( + # Column "float_nulls" has one null value and 29 non-nulls. + replace_with={"float_nulls": 61}, + column="float_nulls", + low=0, + high=100, + expected=3.0, # (29+61)/30 + ), + Case("float_replace_nan")( + # Column "float_nulls" has one null value and 29 non-nulls. + replace_with={"float_nans": 91}, + column="float_nans", + low=0, + high=100, + expected=4.0, # (29+91)/30 + ), + Case("float_replace_both")( + # Column "float_all_special" has 26 regular values, one null value, one + # nan-value, one negative infinity (clamped to 0), and one positive infinity + # (clamped to 34). + replace_with={"float_all_special": 15}, + column="float_all_special", + low=0, + high=34, + expected=3.0, # (26+15+15+34)/30 + ), + Case("replace_all_with_none")( + # When called with no argument, replace_null_and_nan should replace all null + # values by analytics defaults, e.g. 0. + replace_with=None, + column="float_nulls", + low=0, + high=1, + expected=29.0 / 30, + ), + Case("replace_all_with_empty_dict")( + # Same thing with an empty dict and with nan values. + replace_with={}, + column="float_nans", + low=0, + high=1, + expected=29.0 / 30, + ), + ] +) +@parametrize( + [ + Case("one_row")(protected_change=AddOneRow()), + # Case("ids")(protected_change=AddRowsWithID("string")), + ] +) +def test_replace_null_and_nan( + sdf_special_values: DataFrame, + replace_with: Optional[Dict[str, Union[int, float]]], + column: str, + low: Union[int, float], + high: Union[int, float], + expected: Union[int, float], + protected_change: ProtectedChange, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + protected_change, + ) + base_query = QueryBuilder("private") + if isinstance(protected_change, AddRowsWithID): + base_query = base_query.enforce(MaxRowsPerID(1)) + query = base_query.replace_null_and_nan(replace_with).average(column, low, high) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected]], columns=[column + "_average"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + # All columns have 30 rows, all non-special values are equal to 1. + Case("int_noop")( + # Column "int_no_null" has only regular values. + affected_columns=["int_no_null"], + measure_column="int_no_null", + low=0, + high=1, + expected=30, + ), + Case("int_drop_nulls")( + # Column "int_nulls" has one null value. + affected_columns=["int_nulls"], + measure_column="int_nulls", + low=0, + high=100, + expected=29, + ), + Case("float_drop_nulls")( + # Column "float_nulls" has one null value. + affected_columns=["float_nulls"], + measure_column="float_nulls", + low=0, + high=100, + expected=29, + ), + Case("float_drop_nan")( + # Column "float_nans" has one nan value. + affected_columns=["float_nans"], + measure_column="float_nans", + low=0, + high=100, + expected=29, + ), + Case("float_drop_both")( + # Column "float_all_special" has 26 normal values, one null, one nan, one + # negative infinity (clamped to 0), one positive infinity (clamped to 100). + affected_columns=["float_all_special"], + measure_column="float_all_special", + low=0, + high=100, + expected=126, + ), + Case("drop_other_columns")( + # Column "float_infs" has 26 normal values, two negative infinities (clamped + # to 0) and two positive infinities (clamped to 100). But dropping rows from + # columns "string", "float_nulls", "float_nans", "date" and "time" should + # remove five rows, leaving just 21 normal values. + affected_columns=["string", "float_nulls", "float_nans", "date", "time"], + measure_column="float_infs", + low=0, + high=100, + expected=221, + ), + Case("drop_all_with_none")( + # When called with no argument, replace_null_and_nan should drop all rows + # that have null/nan values anywhere, which leaves 24 rows even if we're + # summing a column without nulls. + affected_columns=None, + measure_column="int_no_null", + low=0, + high=1, + expected=24, + ), + Case("drop_all_with_empty_list")( + # Same thing with an empty list. + affected_columns=[], + measure_column="float_nulls", + low=0, + high=1, + expected=24.0, + ), + ] +) +def test_drop_null_and_nan( + sdf_special_values: DataFrame, + affected_columns: Optional[List[str]], + measure_column: str, + low: Union[int, float], + high: Union[int, float], + expected: Union[int, float], +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + base_query = QueryBuilder("private") + query = base_query.drop_null_and_nan(affected_columns).sum( + measure_column, low, high + ) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected]], columns=[measure_column + "_sum"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + # All these tests compute the average of the "float_infs" column of the input table, + # in which there are: + # - 26 non-infinity values, all equal to 1 + # - two negative infinity + # - two positive infinity + # We test this using average and not sum to distinguish between infinities being + # removed from infinities being changed to 0. + [ + Case("replace_no_clamp")( + replace_with={"float_infs": (0, 17)}, + low=-100, + high=100, + # 26+0+17+17 = 60, divided by 30 is 2 + expected=2.0, + ), + Case("replace_clamp")( + replace_with={"float_infs": (-4217, 300)}, + low=-5, + high=22, + # 26-10+44 = 60, divided by 30 is 2 + expected=2.0, + ), + Case("replace_unrelated_column")( + # If we don't explicitly replace infinity in the measure column, then + # infinities should be clamped to the bounds. + replace_with={"float_all_special": (-4217, 300)}, + low=-10, + high=27, + # 26-20+54 = 60, divided by 30 is 2 + expected=2.0, + ), + Case("replace_with_none")( + # If used without any argument, replace_infinity transforms all infinity + # values in all columns of the table to 0. + replace_with=None, + low=-10, + high=10, + expected=26.0 / 30.0, + ), + Case("replace_with_empty_dict")( + # Same with an empty dict. + replace_with=None, + low=-10, + high=10, + expected=26.0 / 30.0, + ), + ] +) +@parametrize( + [ + Case("one_row")(protected_change=AddOneRow()), + # Case("ids")(protected_change=AddRowsWithID("string")), + ] +) +def test_replace_infinity_average( + sdf_special_values: DataFrame, + replace_with: Optional[Dict[str, Tuple[float, float]]], + low: float, + high: float, + expected: float, + protected_change: ProtectedChange, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + protected_change, + ) + base_query = QueryBuilder("private") + if isinstance(protected_change, AddRowsWithID): + base_query = base_query.enforce(MaxRowsPerID(1)) + query = base_query.replace_infinity(replace_with).average("float_infs", low, high) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected]], columns=["float_infs_average"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("all_ones")( + replace_with={"float_infs": (1, 1)}, + expected_sum=30.0, + expected_stdev=0, + expected_variance=0, + ), + Case("one_zero_one_one")( + # If we don't replace infinities in the measure column, then infinity values + # should be clamped to the bounds, namely 0 and 1. + replace_with={"float_all_special": (1, 1)}, + expected_sum=28.0, + expected_stdev=sqrt((2 * (28.0 / 30) ** 2 + 28 * (2.0 / 30) ** 2) / 29), + expected_variance=(2 * (28.0 / 30) ** 2 + 28 * (2.0 / 30) ** 2) / 29, + ), + Case("all_zeroes")( + # Without argument, all infinities are replaced by 0 + replace_with=None, + expected_sum=26.0, + expected_stdev=sqrt((4 * (26.0 / 30) ** 2 + 26 * (4.0 / 30) ** 2) / 29), + expected_variance=(4 * (26.0 / 30) ** 2 + 26 * (4.0 / 30) ** 2) / 29, + ), + ] +) +def test_replace_infinity_other_aggregations( + sdf_special_values: DataFrame, + replace_with: Optional[Dict[str, Tuple[float, float]]], + expected_sum: float, + expected_stdev: float, + expected_variance: float, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + protected_change=AddOneRow(), + ) + + query_sum = ( + QueryBuilder("private").replace_infinity(replace_with).sum("float_infs", 0, 1) + ) + result_sum = sess.evaluate(query_sum, inf_budget) + expected_df = pd.DataFrame([[expected_sum]], columns=["float_infs_sum"]) + assert_frame_equal_with_sort(result_sum.toPandas(), expected_df) + + query_stdev = ( + QueryBuilder("private").replace_infinity(replace_with).stdev("float_infs", 0, 1) + ) + result_stdev = sess.evaluate(query_stdev, inf_budget) + expected_df = pd.DataFrame([[expected_stdev]], columns=["float_infs_stdev"]) + assert_frame_equal_with_sort(result_stdev.toPandas(), expected_df) + + query_variance = ( + QueryBuilder("private") + .replace_infinity(replace_with) + .variance("float_infs", 0, 1) + ) + result_variance = sess.evaluate(query_variance, inf_budget) + expected_df = pd.DataFrame([[expected_variance]], columns=["float_infs_variance"]) + assert_frame_equal_with_sort(result_variance.toPandas(), expected_df) + + +@parametrize( + # All these tests compute the sum of the "float_infs" column of the input table. + [ + Case("drop_rows_in_column")( + # There are 26 non-infinity values in the "float_infs" column. + columns=["float_infs"], + expected=26.0, + ), + Case("drop_no_rows")( + # The call to drop_infinity is a no-op. In the "float_infs" column, there + # are two rows with positive infinities (clamped to 1), and two with + # negative infinities (clamped to 0). + columns=["float_no_special"], + expected=28.0, + ), + Case("drop_some_rows_due_to_other_columns")( + # Two rows with infinite values in the "float_infs" column also have + # infinite values in the "float_all_special" column. We end up with one + # positive infinity value, clamped to 1, and one negative, clamped to 0. + columns=["float_all_special"], + expected=27.0, + ), + Case("drop_rows_in_all_columns")( + # If used without any argument, drop_infinity removes all infinity values in + # all columns of the table. + columns=None, + expected=26.0, + ), + ] +) +@parametrize( + [ + Case("one_row")(protected_change=AddOneRow()), + # Case("ids")(protected_change=AddRowsWithID("string")), + ] +) +def test_drop_infinity( + sdf_special_values: DataFrame, + columns: Optional[List[str]], + expected: float, + protected_change: ProtectedChange, +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + protected_change, + ) + base_query = QueryBuilder("private") + if isinstance(protected_change, AddRowsWithID): + base_query = base_query.enforce(MaxRowsPerID(1)) + query = base_query.drop_infinity(columns).sum("float_infs", 0, 1) + result = sess.evaluate(query, inf_budget) + expected_df = pd.DataFrame([[expected]], columns=["float_infs_sum"]) + assert_frame_equal_with_sort(result.toPandas(), expected_df) From d86149214c10cb13049c261899af23179d552d00 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 27 Oct 2025 17:11:26 +0100 Subject: [PATCH 03/16] add a test for get_bounds --- src/tmlt/analytics/query_builder.py | 4 +- test/system/session/test_special_values.py | 69 ++++++++++++++++++---- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/src/tmlt/analytics/query_builder.py b/src/tmlt/analytics/query_builder.py index 80e0eb56..ad1be8f7 100644 --- a/src/tmlt/analytics/query_builder.py +++ b/src/tmlt/analytics/query_builder.py @@ -764,7 +764,7 @@ def replace_infinity( ) return self - def drop_null_and_nan(self, columns: Optional[List[str]]) -> "QueryBuilder": + def drop_null_and_nan(self, columns: Optional[List[str]] = None) -> "QueryBuilder": """Removes rows containing null or NaN values. .. note:: @@ -869,7 +869,7 @@ def drop_null_and_nan(self, columns: Optional[List[str]]) -> "QueryBuilder": ) return self - def drop_infinity(self, columns: Optional[List[str]]) -> "QueryBuilder": + def drop_infinity(self, columns: Optional[List[str]] = None) -> "QueryBuilder": """Remove rows containing infinite values. .. diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index 2a3d9b1c..8eff7890 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -20,6 +20,7 @@ MaxRowsPerID, ProtectedChange, PureDPBudget, + Query, QueryBuilder, Session, ) @@ -458,29 +459,77 @@ def test_replace_infinity_other_aggregations( ), ] ) -@parametrize( - [ - Case("one_row")(protected_change=AddOneRow()), - # Case("ids")(protected_change=AddRowsWithID("string")), - ] -) def test_drop_infinity( sdf_special_values: DataFrame, columns: Optional[List[str]], expected: float, - protected_change: ProtectedChange, ): inf_budget = PureDPBudget(float("inf")) sess = Session.from_dataframe( inf_budget, "private", sdf_special_values, - protected_change, + AddOneRow(), ) base_query = QueryBuilder("private") - if isinstance(protected_change, AddRowsWithID): - base_query = base_query.enforce(MaxRowsPerID(1)) query = base_query.drop_infinity(columns).sum("float_infs", 0, 1) result = sess.evaluate(query, inf_budget) expected_df = pd.DataFrame([[expected]], columns=["float_infs_sum"]) assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("works_with_nulls")( + # Check that + query=QueryBuilder("private").get_bounds("int_nulls"), + expected_df=pd.DataFrame( + [[-1, 1]], + columns=["int_nulls_lower_bound", "int_nulls_upper_bound"], + ), + ), + Case("works_with_nan")( + # Check that + query=QueryBuilder("private").get_bounds("float_nans"), + expected_df=pd.DataFrame( + [[-1, 1]], + columns=["float_nans_lower_bound", "float_nans_upper_bound"], + ), + ), + Case("works_with_infinity")( + # Check that + query=QueryBuilder("private").get_bounds("float_infs"), + expected_df=pd.DataFrame( + [[-1, 1]], + columns=["float_infs_lower_bound", "float_infs_upper_bound"], + ), + ), + Case("drop_and_replace")( + # Dropping nulls & nans removes 6/30 values, replacing 4 infinity values by + # (-3,3) guarantees ensures get_bounds should get the interval corresponding + # to next power of 2, namely 4 + query=( + QueryBuilder("private") + .drop_null_and_nan() + .replace_infinity({"float_infs": (-3, 3)}) + .get_bounds("float_infs") + ), + expected_df=pd.DataFrame( + [[-4, 4]], + columns=["float_infs_lower_bound", "float_infs_upper_bound"], + ), + ), + ] +) +def test_get_bounds( + sdf_special_values: DataFrame, query: Query, expected_df: pd.DataFrame +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + result = sess.evaluate(query, inf_budget) + assert_frame_equal_with_sort(result.toPandas(), expected_df) From 671864cc2c743158ab58ee7bcdf9510e1f3e0e49 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 27 Oct 2025 18:14:45 +0100 Subject: [PATCH 04/16] unit tests for the new rewrite rule --- .../query_expr_compiler/test_rewrite_rules.py | 243 +++++++++++++++++- 1 file changed, 242 insertions(+), 1 deletion(-) diff --git a/test/unit/query_expr_compiler/test_rewrite_rules.py b/test/unit/query_expr_compiler/test_rewrite_rules.py index 5fdef85d..f7dce901 100644 --- a/test/unit/query_expr_compiler/test_rewrite_rules.py +++ b/test/unit/query_expr_compiler/test_rewrite_rules.py @@ -13,6 +13,9 @@ AverageMechanism, CountDistinctMechanism, CountMechanism, + DropInfinity, + GetBounds, + DropNullAndNan, GroupByBoundedAverage, GroupByBoundedSTDEV, GroupByBoundedSum, @@ -22,6 +25,7 @@ PrivateSource, QueryExpr, QueryExprVisitor, + ReplaceInfinity, SingleChildQueryExpr, StdevMechanism, SumMechanism, @@ -30,9 +34,10 @@ ) from tmlt.analytics._query_expr_compiler._rewrite_rules import ( CompilationInfo, + add_special_value_handling, select_noise_mechanism, ) -from tmlt.analytics._schema import ColumnDescriptor, ColumnType, Schema +from tmlt.analytics._schema import ColumnDescriptor, ColumnType, FrozenDict, Schema # SPDX-License-Identifier: Apache-2.0 # Copyright Tumult Labs 2025 @@ -391,3 +396,239 @@ def test_recursive_noise_selection(catalog: Catalog) -> None: info = CompilationInfo(output_measure=ApproxDP(), catalog=catalog) got_expr = select_noise_mechanism(info)(expr) assert got_expr == expected_expr + + +@parametrize( + [ + Case()(agg="count"), + Case()(agg="count_distinct"), + ] +) +@parametrize( + [ + Case()(col_desc=ColumnDescriptor(ColumnType.INTEGER, allow_null=False)), + Case()(col_desc=ColumnDescriptor(ColumnType.INTEGER, allow_null=True)), + Case()( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=False + ) + ), + Case()( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=True + ) + ), + ] +) +def test_special_value_handling_count_unaffected( + agg: str, + col_desc: ColumnDescriptor, +) -> None: + (AggExpr, AggMech) = AGG_CLASSES[agg] + expr = AggExpr( + child=BASE_EXPR, + groupby_keys=KeySet.from_dict({}), + mechanism=AggMech["DEFAULT"], + ) + catalog = Catalog() + catalog.add_private_table("private", {"col": col_desc}) + info = CompilationInfo(output_measure=PureDP(), catalog=catalog) + got_expr = add_special_value_handling(info)(expr) + assert got_expr == expr + + +@parametrize( + [ + # Columns with no special values should be unaffected + Case(f"no_op_null_{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=False, allow_inf=False + ), + new_child=BASE_EXPR, + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs and infinities do not matter for non-floats + Case(f"no_op_nan_inf_{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=True, allow_inf=True + ), + new_child=BASE_EXPR, + ) + for col_type in [ColumnType.INTEGER, ColumnType.DATE, ColumnType.TIMESTAMP] + ] + + [ + # Nulls must be dropped if needed + Case(f"drop_null_{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=True, allow_nan=False, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs must also be dropped added if needed + Case("drop_nan")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=True, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ), + # Only one pass is enough to drop both nulls and NaNs + Case("drop_both")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ), + # If not handled, infinities must be clamped to the clamping bounds + Case("clamp_inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=True + ), + new_child=ReplaceInfinity( + child=BASE_EXPR, replace_with=FrozenDict.from_dict({"col": (0, 1)}) + ), + ), + # Handling both kinds of special values at once. This would fail if the two + # value handling exprs are in the wrong order; this is not ideal, but ah well. + Case("drop_nan_clamp_inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=True + ), + new_child=ReplaceInfinity( + child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + replace_with=FrozenDict.from_dict({"col": (0, 1)}), + ), + ) + ] +) +@parametrize( + [ + Case()(agg="sum"), + Case()(agg="average"), + Case()(agg="stdev"), + Case()(agg="variance"), + ] +) +def test_special_value_handling_numeric_aggregations( + agg: str, + col_desc: ColumnDescriptor, + new_child: QueryExpr, +) -> None: + (AggExpr, AggMech) = AGG_CLASSES[agg] + expr = AggExpr( + child=BASE_EXPR, + measure_column="col", + low=0, + high=1, + groupby_keys=KeySet.from_dict({}), + mechanism=AggMech["DEFAULT"], + ) + catalog = Catalog() + catalog.add_private_table("private", {"col": col_desc}) + info = CompilationInfo(output_measure=PureDP(), catalog=catalog) + got_expr = add_special_value_handling(info)(expr) + assert got_expr == replace( + expr, + child=new_child, + ) + +@parametrize( + [ + # Columns with no special values should be unaffected + Case(f"no-op-{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=False, allow_inf=False + ), + new_child=BASE_EXPR, + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs and infinities do not matter for non-floats + Case(f"no-op-nan-inf-{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=False, allow_nan=True, allow_inf=True + ), + new_child=BASE_EXPR, + ) + for col_type in [ColumnType.INTEGER, ColumnType.DATE, ColumnType.TIMESTAMP] + ] + + [ + # Nulls must be dropped if needed + Case(f"drop-nulls-{col_type}")( + col_desc=ColumnDescriptor( + col_type, allow_null=True, allow_nan=False, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ) + for col_type in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ColumnType.DATE, + ColumnType.TIMESTAMP, + ] + ] + + [ + # NaNs must also be dropped added if needed + Case("drop-nan")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=True, allow_inf=False + ), + new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + ), + # If not handled, infinities must be clamped to the clamping bounds + Case("drop-inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=True + ), + new_child=DropInfinity( child=BASE_EXPR, columns=("col",)), + ), + # And both kinds of special values must be handled + Case("drop-nan-and-inf")( + col_desc=ColumnDescriptor( + ColumnType.DECIMAL, allow_null=True, allow_nan=True, allow_inf=True + ), + new_child=DropInfinity( + child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), + columns=("col",)), + ) + ] +) +def test_special_value_handling_get_bounds( + col_desc: ColumnDescriptor, + new_child: QueryExpr, +) -> None: + expr = GetBounds( + child=BASE_EXPR, + measure_column="col", + groupby_keys=KeySet.from_dict({}), + lower_bound_column="lower", + upper_bound_column="upper", + ) + catalog = Catalog() + catalog.add_private_table("private", {"col": col_desc}) + info = CompilationInfo(output_measure=PureDP(), catalog=catalog) + got_expr = add_special_value_handling(info)(expr) + assert got_expr == replace( + expr, + child=new_child, + ) From 9915c1a2732f5ad6e9f3da6bcc733d523b780b94 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 27 Oct 2025 18:15:05 +0100 Subject: [PATCH 05/16] lint --- test/unit/query_expr_compiler/test_rewrite_rules.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/unit/query_expr_compiler/test_rewrite_rules.py b/test/unit/query_expr_compiler/test_rewrite_rules.py index f7dce901..8a149928 100644 --- a/test/unit/query_expr_compiler/test_rewrite_rules.py +++ b/test/unit/query_expr_compiler/test_rewrite_rules.py @@ -14,8 +14,8 @@ CountDistinctMechanism, CountMechanism, DropInfinity, - GetBounds, DropNullAndNan, + GetBounds, GroupByBoundedAverage, GroupByBoundedSTDEV, GroupByBoundedSum, @@ -512,7 +512,7 @@ def test_special_value_handling_count_unaffected( child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), replace_with=FrozenDict.from_dict({"col": (0, 1)}), ), - ) + ), ] ) @parametrize( @@ -546,6 +546,7 @@ def test_special_value_handling_numeric_aggregations( child=new_child, ) + @parametrize( [ # Columns with no special values should be unaffected @@ -600,7 +601,7 @@ def test_special_value_handling_numeric_aggregations( col_desc=ColumnDescriptor( ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=True ), - new_child=DropInfinity( child=BASE_EXPR, columns=("col",)), + new_child=DropInfinity(child=BASE_EXPR, columns=("col",)), ), # And both kinds of special values must be handled Case("drop-nan-and-inf")( @@ -609,8 +610,9 @@ def test_special_value_handling_numeric_aggregations( ), new_child=DropInfinity( child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), - columns=("col",)), - ) + columns=("col",), + ), + ), ] ) def test_special_value_handling_get_bounds( From b18bbbddb32d8340ec125bdcb0e34b45f2848257 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Tue, 28 Oct 2025 09:29:06 +0100 Subject: [PATCH 06/16] add todo for changelog --- .../_query_expr_compiler/_rewrite_rules.py | 6 +- test/system/session/test_special_values.py | 146 +++++++++--------- .../query_expr_compiler/test_rewrite_rules.py | 4 +- 3 files changed, 74 insertions(+), 82 deletions(-) diff --git a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py index bf2924af..fb4a3038 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py +++ b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py @@ -232,9 +232,7 @@ def handle_special_values(expr: QueryExpr) -> QueryExpr: ): expr = replace( expr, - child=DropNullAndNan( - child=expr.child, columns=tuple([expr.measure_column]) - ), + child=DropNullAndNan(child=expr.child, columns=(expr.measure_column,)), ) # Remove infinities if necessary if measure_desc.column_type == ColumnType.DECIMAL and measure_desc.allow_inf: @@ -242,7 +240,7 @@ def handle_special_values(expr: QueryExpr) -> QueryExpr: return replace( expr, child=DropInfinity( - child=expr.child, columns=tuple([expr.measure_column]) + child=expr.child, columns=(expr.measure_column,) ), ) return replace( diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index 8eff7890..c556ab42 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -33,7 +33,7 @@ def null_setup(spark): """Set up test data for sessions with special values.""" sdf_col_types = { - "string": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), + "string_nulls": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), "int_no_null": ColumnDescriptor(ColumnType.INTEGER, allow_null=False), "int_nulls": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), "float_no_special": ColumnDescriptor( @@ -66,8 +66,8 @@ def null_setup(spark): allow_nan=True, allow_inf=True, ), - "date": ColumnDescriptor(ColumnType.DATE, allow_null=True), - "time": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), + "date_nulls": ColumnDescriptor(ColumnType.DATE, allow_null=True), + "time_nulls": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), } date = datetime.date(2000, 1, 1) time = datetime.datetime(2020, 1, 1) @@ -101,41 +101,40 @@ def null_setup(spark): column="int_no_null", low=0, high=1, - expected=1, + expected_average=1, ), Case("int_replace_null")( - # Column "int_nulls" has one null value and 29 non-nulls. + # Column "int_nulls" has one null value and 29 1s. replace_with={"int_nulls": 31}, column="int_nulls", low=0, high=100, - expected=2.0, # (29+31)/30 + expected_average=2.0, # (29+31)/30 ), Case("float_replace_null")( - # Column "float_nulls" has one null value and 29 non-nulls. + # Column "float_nulls" has one null value and 29 1s. replace_with={"float_nulls": 61}, column="float_nulls", low=0, high=100, - expected=3.0, # (29+61)/30 + expected_average=3.0, # (29+61)/30 ), Case("float_replace_nan")( - # Column "float_nulls" has one null value and 29 non-nulls. + # Column "float_nulls" has one null value and 29 1s. replace_with={"float_nans": 91}, column="float_nans", low=0, high=100, - expected=4.0, # (29+91)/30 + expected_average=4.0, # (29+91)/30 ), Case("float_replace_both")( - # Column "float_all_special" has 26 regular values, one null value, one - # nan-value, one negative infinity (clamped to 0), and one positive infinity - # (clamped to 34). + # Column "float_all_special" has 26 1s, one null value, one nan-value, one + # negative infinity (clamped to 0), one positive infinity (clamped to 34). replace_with={"float_all_special": 15}, column="float_all_special", low=0, high=34, - expected=3.0, # (26+15+15+34)/30 + expected_average=3.0, # (26+15+15+34)/30 ), Case("replace_all_with_none")( # When called with no argument, replace_null_and_nan should replace all null @@ -144,7 +143,7 @@ def null_setup(spark): column="float_nulls", low=0, high=1, - expected=29.0 / 30, + expected_average=29.0 / 30, ), Case("replace_all_with_empty_dict")( # Same thing with an empty dict and with nan values. @@ -152,38 +151,29 @@ def null_setup(spark): column="float_nans", low=0, high=1, - expected=29.0 / 30, + expected_average=29.0 / 30, ), ] ) -@parametrize( - [ - Case("one_row")(protected_change=AddOneRow()), - # Case("ids")(protected_change=AddRowsWithID("string")), - ] -) def test_replace_null_and_nan( sdf_special_values: DataFrame, replace_with: Optional[Dict[str, Union[int, float]]], column: str, low: Union[int, float], high: Union[int, float], - expected: Union[int, float], - protected_change: ProtectedChange, + expected_average: Union[int, float], ): inf_budget = PureDPBudget(float("inf")) sess = Session.from_dataframe( inf_budget, "private", sdf_special_values, - protected_change, + AddOneRow(), ) base_query = QueryBuilder("private") - if isinstance(protected_change, AddRowsWithID): - base_query = base_query.enforce(MaxRowsPerID(1)) query = base_query.replace_null_and_nan(replace_with).average(column, low, high) result = sess.evaluate(query, inf_budget) - expected_df = pd.DataFrame([[expected]], columns=[column + "_average"]) + expected_df = pd.DataFrame([[expected_average]], columns=[column + "_average"]) assert_frame_equal_with_sort(result.toPandas(), expected_df) @@ -196,61 +186,67 @@ def test_replace_null_and_nan( measure_column="int_no_null", low=0, high=1, - expected=30, + expected_sum=30, ), Case("int_drop_nulls")( - # Column "int_nulls" has one null value. + # Column "int_nulls" has one null value and 29 1s. affected_columns=["int_nulls"], measure_column="int_nulls", low=0, high=100, - expected=29, + expected_sum=29, ), Case("float_drop_nulls")( - # Column "float_nulls" has one null value. + # Column "float_nulls" has one null value and 29 1s. affected_columns=["float_nulls"], measure_column="float_nulls", low=0, high=100, - expected=29, + expected_sum=29, ), Case("float_drop_nan")( - # Column "float_nans" has one nan value. + # Column "float_nans" has one nan value and 29 1s. affected_columns=["float_nans"], measure_column="float_nans", low=0, high=100, - expected=29, + expected_sum=29, ), Case("float_drop_both")( - # Column "float_all_special" has 26 normal values, one null, one nan, one - # negative infinity (clamped to 0), one positive infinity (clamped to 100). + # Column "float_all_special" has 26 1s, one null, one nan, one negative + # infinity (clamped to 0), one positive infinity (clamped to 100). affected_columns=["float_all_special"], measure_column="float_all_special", low=0, high=100, - expected=126, + expected_sum=126, ), Case("drop_other_columns")( - # Column "float_infs" has 26 normal values, two negative infinities (clamped - # to 0) and two positive infinities (clamped to 100). But dropping rows from - # columns "string", "float_nulls", "float_nans", "date" and "time" should - # remove five rows, leaving just 21 normal values. - affected_columns=["string", "float_nulls", "float_nans", "date", "time"], + # Column "float_infs" has 26 1s, two negative infinities (clamped to 0) and + # two positive infinities (clamped to 100). But dropping rows from columns + # "string_nulls", "float_nulls", "float_nans", "date_nulls" and "time_nulls" + # should remove five rows, leaving just 21 normal values. + affected_columns=[ + "string_nulls", + "float_nulls", + "float_nans", + "date_nulls", + "time_nulls", + ], measure_column="float_infs", low=0, high=100, - expected=221, + expected_sum=221, ), Case("drop_all_with_none")( # When called with no argument, replace_null_and_nan should drop all rows - # that have null/nan values anywhere, which leaves 24 rows even if we're + # that have null/nan values anywhere, which leaves 24 1s even if we're # summing a column without nulls. affected_columns=None, measure_column="int_no_null", low=0, high=1, - expected=24, + expected_sum=24, ), Case("drop_all_with_empty_list")( # Same thing with an empty list. @@ -258,7 +254,7 @@ def test_replace_null_and_nan( measure_column="float_nulls", low=0, high=1, - expected=24.0, + expected_sum=24.0, ), ] ) @@ -268,7 +264,7 @@ def test_drop_null_and_nan( measure_column: str, low: Union[int, float], high: Union[int, float], - expected: Union[int, float], + expected_sum: Union[int, float], ): inf_budget = PureDPBudget(float("inf")) sess = Session.from_dataframe( @@ -282,7 +278,7 @@ def test_drop_null_and_nan( measure_column, low, high ) result = sess.evaluate(query, inf_budget) - expected_df = pd.DataFrame([[expected]], columns=[measure_column + "_sum"]) + expected_df = pd.DataFrame([[expected_sum]], columns=[measure_column + "_sum"]) assert_frame_equal_with_sort(result.toPandas(), expected_df) @@ -299,15 +295,15 @@ def test_drop_null_and_nan( replace_with={"float_infs": (0, 17)}, low=-100, high=100, - # 26+0+17+17 = 60, divided by 30 is 2 - expected=2.0, + # 26+0+0+17+17 = 60, divided by 30 is 2 + expected_average=2.0, ), Case("replace_clamp")( replace_with={"float_infs": (-4217, 300)}, low=-5, high=22, - # 26-10+44 = 60, divided by 30 is 2 - expected=2.0, + # 26-5+5+22+22 = 60, divided by 30 is 2 + expected_average=2.0, ), Case("replace_unrelated_column")( # If we don't explicitly replace infinity in the measure column, then @@ -315,8 +311,8 @@ def test_drop_null_and_nan( replace_with={"float_all_special": (-4217, 300)}, low=-10, high=27, - # 26-20+54 = 60, divided by 30 is 2 - expected=2.0, + # 26-10+10+27+27 = 60, divided by 30 is 2 + expected_average=2.0, ), Case("replace_with_none")( # If used without any argument, replace_infinity transforms all infinity @@ -324,44 +320,35 @@ def test_drop_null_and_nan( replace_with=None, low=-10, high=10, - expected=26.0 / 30.0, + expected_average=26.0 / 30.0, ), Case("replace_with_empty_dict")( # Same with an empty dict. - replace_with=None, + replace_with={}, low=-10, high=10, - expected=26.0 / 30.0, + expected_average=26.0 / 30.0, ), ] ) -@parametrize( - [ - Case("one_row")(protected_change=AddOneRow()), - # Case("ids")(protected_change=AddRowsWithID("string")), - ] -) def test_replace_infinity_average( sdf_special_values: DataFrame, replace_with: Optional[Dict[str, Tuple[float, float]]], low: float, high: float, - expected: float, - protected_change: ProtectedChange, + expected_average: float, ): inf_budget = PureDPBudget(float("inf")) sess = Session.from_dataframe( inf_budget, "private", sdf_special_values, - protected_change, + AddOneRow(), ) base_query = QueryBuilder("private") - if isinstance(protected_change, AddRowsWithID): - base_query = base_query.enforce(MaxRowsPerID(1)) query = base_query.replace_infinity(replace_with).average("float_infs", low, high) result = sess.evaluate(query, inf_budget) - expected_df = pd.DataFrame([[expected]], columns=["float_infs_average"]) + expected_df = pd.DataFrame([[expected_average]], columns=["float_infs_average"]) assert_frame_equal_with_sort(result.toPandas(), expected_df) @@ -435,34 +422,34 @@ def test_replace_infinity_other_aggregations( Case("drop_rows_in_column")( # There are 26 non-infinity values in the "float_infs" column. columns=["float_infs"], - expected=26.0, + expected_sum=26.0, ), Case("drop_no_rows")( # The call to drop_infinity is a no-op. In the "float_infs" column, there # are two rows with positive infinities (clamped to 1), and two with # negative infinities (clamped to 0). columns=["float_no_special"], - expected=28.0, + expected_sum=28.0, ), Case("drop_some_rows_due_to_other_columns")( # Two rows with infinite values in the "float_infs" column also have # infinite values in the "float_all_special" column. We end up with one # positive infinity value, clamped to 1, and one negative, clamped to 0. columns=["float_all_special"], - expected=27.0, + expected_sum=27.0, ), Case("drop_rows_in_all_columns")( # If used without any argument, drop_infinity removes all infinity values in # all columns of the table. columns=None, - expected=26.0, + expected_sum=26.0, ), ] ) def test_drop_infinity( sdf_special_values: DataFrame, columns: Optional[List[str]], - expected: float, + expected_sum: float, ): inf_budget = PureDPBudget(float("inf")) sess = Session.from_dataframe( @@ -474,7 +461,7 @@ def test_drop_infinity( base_query = QueryBuilder("private") query = base_query.drop_infinity(columns).sum("float_infs", 0, 1) result = sess.evaluate(query, inf_budget) - expected_df = pd.DataFrame([[expected]], columns=["float_infs_sum"]) + expected_df = pd.DataFrame([[expected_sum]], columns=["float_infs_sum"]) assert_frame_equal_with_sort(result.toPandas(), expected_df) @@ -533,3 +520,10 @@ def test_get_bounds( ) result = sess.evaluate(query, inf_budget) assert_frame_equal_with_sort(result.toPandas(), expected_df) + +# TODO: add tests with AddRowsWithID, checking behavior when removing/replacing nulls in +# the ID column + +# TODO: add tests for public and private joins + +# TODO: add changelog diff --git a/test/unit/query_expr_compiler/test_rewrite_rules.py b/test/unit/query_expr_compiler/test_rewrite_rules.py index 8a149928..072e1e5b 100644 --- a/test/unit/query_expr_compiler/test_rewrite_rules.py +++ b/test/unit/query_expr_compiler/test_rewrite_rules.py @@ -589,14 +589,14 @@ def test_special_value_handling_numeric_aggregations( ] ] + [ - # NaNs must also be dropped added if needed + # NaNs must also be dropped if needed Case("drop-nan")( col_desc=ColumnDescriptor( ColumnType.DECIMAL, allow_null=False, allow_nan=True, allow_inf=False ), new_child=DropNullAndNan(child=BASE_EXPR, columns=("col",)), ), - # If not handled, infinities must be clamped to the clamping bounds + # Same for infinities (contrary to other aggregations which use clamping) Case("drop-inf")( col_desc=ColumnDescriptor( ColumnType.DECIMAL, allow_null=False, allow_nan=False, allow_inf=True From f727574ec5f70e0002ac7cdffa8043d7f3c29c84 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 29 Oct 2025 10:09:26 +0100 Subject: [PATCH 07/16] add tests with privacy IDs --- test/system/session/test_special_values.py | 59 +++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index c556ab42..4471c133 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -521,8 +521,63 @@ def test_get_bounds( result = sess.evaluate(query, inf_budget) assert_frame_equal_with_sort(result.toPandas(), expected_df) -# TODO: add tests with AddRowsWithID, checking behavior when removing/replacing nulls in -# the ID column + +@parametrize( + [ + Case("normal_case_explicit")( + # Column "float_all_special" has 26 1s, one null (replaced by 100), one nan + # (replaced by 100), and two infinities (dropped). + query=( + QueryBuilder("private") + .enforce(MaxRowsPerID(1)) + .replace_null_and_nan({"float_all_special": 100.0}) + .drop_infinity(["float_all_special"]) + .sum("float_all_special", 0, 100) + ), + expected_df=pd.DataFrame( + [[226]], # 26+100+100 + columns=["float_all_special_sum"], + ), + ), + Case("normal_case_implicit")( + # Column "float_all_special" has 26 1s, one null, one nan, one negative + # infinity (clamped to -50), one positive infinity (clamped to 100). + query=( + QueryBuilder("private") + .enforce(MaxRowsPerID(1)) + .sum("float_all_special", -50, 100) + ), + expected_df=pd.DataFrame( + [[76]], # 26-50+100 + columns=["float_all_special_sum"], + ), + ), + Case("nulls_are_not_dropped_in_id_column")( + # When called with no argument, replace_null_and_nan should drop all rows + # that have null/nan values anywhere, except in the privacy ID column. This + # should leave 25 1s even if we're summing a column without nulls. + query=( + QueryBuilder("private") + .drop_null_and_nan() + .enforce(MaxRowsPerID(1)) + .sum("int_no_null", 0, 1) + ), + expected_df=pd.DataFrame( [[25]], columns=["int_no_null_sum"]), + ] +) +def test_privacy_ids( + sdf_special_values: DataFrame, query: Query, expected_df: pd.DataFrame +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddRowsWithID("string_nulls"), + ) + result = sess.evaluate(query, inf_budget) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + # TODO: add tests for public and private joins From c31415e016b7964168d9bd4b775892ac6c3b7b34 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 29 Oct 2025 16:10:03 +0100 Subject: [PATCH 08/16] changelog, tests tests tests --- CHANGELOG.rst | 2 + .../_query_expr_compiler/_rewrite_rules.py | 2 +- test/system/session/test_special_values.py | 148 +++++++++++++++++- 3 files changed, 147 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e5c59355..0d5ccec5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,8 @@ Changed - Aggregation mechanisms can now be specified as strings instead of enums, e.g. ``"laplace"`` instead of ``CountMechanism.LAPLACE`` or ``SumMechanism.LAPLACE``. - Removed previously deprecated argument ``max_num_rows`` to ``flat_map``. Use ``max_rows`` instead. - Removed previously deprecated argument ``cols`` to ``count_distinct``. Use ``columns`` instead. +- Infinity values are now automatically dropped before a floating-point column is passed to `get_bounds`. (The documentation previously claimed that this was done, but this was not the case.) +- Fixed the documentation of the behavior of some numeric aggregations (`sum`, `average`, `stdev`, `variance`, `quantile`) to match the actual behavior: infinity values are clamped using the specified bounds before being passed to the aggregation function, not dropped. .. _v0.20.2: diff --git a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py index 96243a97..ba3a0860 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py +++ b/src/tmlt/analytics/_query_expr_compiler/_rewrite_rules.py @@ -216,7 +216,7 @@ def handle_special_values(expr: QueryExpr) -> QueryExpr: expr, ( GroupByBoundedAverage, - GroupByBoundedSTDEV, + GroupByBoundedStdev, GroupByBoundedSum, GroupByBoundedVariance, GroupByQuantile, diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index 4471c133..755324e3 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -23,6 +23,7 @@ Query, QueryBuilder, Session, + TruncationStrategy, ) from tmlt.analytics._schema import Schema, analytics_to_spark_schema @@ -30,7 +31,7 @@ @pytest.fixture(name="sdf_special_values", scope="module") -def null_setup(spark): +def special_values_dataframe(spark): """Set up test data for sessions with special values.""" sdf_col_types = { "string_nulls": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), @@ -562,7 +563,8 @@ def test_get_bounds( .enforce(MaxRowsPerID(1)) .sum("int_no_null", 0, 1) ), - expected_df=pd.DataFrame( [[25]], columns=["int_no_null_sum"]), + expected_df=pd.DataFrame([[25]], columns=["int_no_null_sum"]), + ), ] ) def test_privacy_ids( @@ -579,6 +581,144 @@ def test_privacy_ids( assert_frame_equal_with_sort(result.toPandas(), expected_df) -# TODO: add tests for public and private joins +@pytest.fixture(name="sdf_for_joins", scope="module") +def dataframe_for_join(spark): + """Set up test data for sessions with special values.""" + sdf_col_types = { + "string_nulls": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), + "int_nulls": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), + "float_all_special": ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=True, + allow_nan=True, + allow_inf=True, + ), + "date_nulls": ColumnDescriptor(ColumnType.DATE, allow_null=True), + "time_nulls": ColumnDescriptor(ColumnType.TIMESTAMP, allow_null=True), + "new_int": ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + } + date = datetime.date(2000, 1, 1) + time = datetime.datetime(2020, 1, 1) + sdf = spark.createDataFrame( + [ + # Normal row + ("normal_0", 1, 1.0, date, time, 1), + # Rows with nulls: some whose values appear in the data… + (None, 1, 1.0, date, time, 1), + ("u2", None, 1.0, date, time, 1), + ("u3", 1, None, date, time, 1), + # … and two identical rows, where the combination of nulls does not appear + # in the data + ("u4", 1, 1.0, None, None, 1), + ("u5", 1, 1.0, None, None, 1), + # Row with nans + ("a6", 1, float("nan"), date, time, 1), + # Rows with infinities + ("i7", 1, float("inf"), date, time, 1), + ("i8", 1, -float("inf"), date, time, 1), + ], + schema=analytics_to_spark_schema(Schema(sdf_col_types)), + ) + return sdf + -# TODO: add changelog +@parametrize( + [ + Case("public_join_inner_all_match")( + # Joining with the first three columns, all columns of the right table + # should match exactly one row, without duplicates. This checks that tables + # are joined on all three kinds of special values. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public( + "public", + ["string_nulls", "int_nulls", "float_all_special"], + "inner", + ) + .sum("new_int", 0, 1) + ), + expected_df=pd.DataFrame( + [[9]], + columns=["new_int_sum"], + ), + ), + Case("public_join_inner_duplicates")( + # Joining with the date and time columns only creates matches for the rows + # where both are specified: 28 in the left table and 7 in the right table. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public("public", ["date_nulls", "time_nulls"], "inner") + .sum("new_int", 0, 1) + ), + expected_df=pd.DataFrame( + [[28 * 7]], + columns=["new_int_sum"], + ), + ), + Case("public_join_left_duplicates")( + # Same as before, except we do a left join, so 2 rows in the original table + # are preserved in the join. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public("public", ["date_nulls", "time_nulls"], "left") + .count() + ), + expected_df=pd.DataFrame( + [[28 * 7 + 2]], + columns=["count"], + ), + ), + Case("private_join_add_rows")( + # Private joins without duplicates should work the same way as the inner + # public join above, leaving the 9 rows in common between the two tables. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_private( + "private_2", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + truncation_strategy_left=TruncationStrategy.DropNonUnique(), + truncation_strategy_right=TruncationStrategy.DropNonUnique(), + ) + .count() + ), + expected_df=pd.DataFrame([[9]], columns=["count"]), + ), + Case("private_join_ids")( + # Same with the implicit join on the privacy ID column. + protected_change=AddRowsWithID("string_nulls"), + query=( + QueryBuilder("private") + .join_private( + "private_2", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + ) + .enforce(MaxRowsPerID(1)) + .count() + ), + expected_df=pd.DataFrame([[9]], columns=["count"]), + ), + ] +) +def test_joins( + sdf_special_values: DataFrame, + sdf_for_joins: DataFrame, + protected_change: ProtectedChange, + query: Query, + expected_df: pd.DataFrame, +): + inf_budget = PureDPBudget(float("inf")) + sess = ( + Session.Builder() + .with_id_space("default_id_space") + .with_private_dataframe("private", sdf_special_values, protected_change) + .with_private_dataframe("private_2", sdf_for_joins, protected_change) + .with_public_dataframe("public", sdf_for_joins) + .with_privacy_budget(inf_budget) + .build() + ) + result = sess.evaluate(query, inf_budget) + assert_frame_equal_with_sort(result.toPandas(), expected_df) From bee4860f6527e1578e2993143d7e652132c7becb Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Fri, 31 Oct 2025 18:15:32 +0100 Subject: [PATCH 09/16] more tests. send help --- test/system/session/test_special_values.py | 168 ++++++++++++++++++++- 1 file changed, 163 insertions(+), 5 deletions(-) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index 755324e3..e932c83a 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -583,7 +583,11 @@ def test_privacy_ids( @pytest.fixture(name="sdf_for_joins", scope="module") def dataframe_for_join(spark): - """Set up test data for sessions with special values.""" + """Set up test data for sessions with special values. + + This data is then joined with the ``sdf_special_values`` dataframe used previously + in this test suite. + """ sdf_col_types = { "string_nulls": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), "int_nulls": ColumnDescriptor(ColumnType.INTEGER, allow_null=True), @@ -603,12 +607,12 @@ def dataframe_for_join(spark): [ # Normal row ("normal_0", 1, 1.0, date, time, 1), - # Rows with nulls: some whose values appear in the data… + # Rows with nulls: some whose values appear in `sdf_special_values`… (None, 1, 1.0, date, time, 1), ("u2", None, 1.0, date, time, 1), ("u3", 1, None, date, time, 1), # … and two identical rows, where the combination of nulls does not appear - # in the data + # in `sdf_special_values`. ("u4", 1, 1.0, None, None, 1), ("u5", 1, 1.0, None, None, 1), # Row with nans @@ -688,7 +692,7 @@ def dataframe_for_join(spark): expected_df=pd.DataFrame([[9]], columns=["count"]), ), Case("private_join_ids")( - # Same with the implicit join on the privacy ID column. + # Same with a privacy ID column. protected_change=AddRowsWithID("string_nulls"), query=( QueryBuilder("private") @@ -701,6 +705,43 @@ def dataframe_for_join(spark): ), expected_df=pd.DataFrame([[9]], columns=["count"]), ), + Case("private_join_preserves_special_values")( + # After the join, "float_all_special" should have the same data as in the + # table used for the join: 5 1s, one null (replaced by 100), one nan + # (replaced by 100), and two infinities (dropped). + protected_change=AddRowsWithID("string_nulls"), + query=( + QueryBuilder("private") + .join_private( + "private_2", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + ) + .enforce(MaxRowsPerID(1)) + .drop_infinity(["float_all_special"]) + .replace_null_and_nan({"float_all_special": 100.0}) + .sum("float_all_special", 0, 200) + ), + expected_df=pd.DataFrame( + [[5 + 100 + 100]], columns=["float_all_special_sum"] + ), + ), + Case("public_join_preserves_special_values")( + # Same with a public join. + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .join_public( + "public", + join_columns=["string_nulls", "int_nulls", "float_all_special"], + ) + .drop_infinity(["float_all_special"]) + .replace_null_and_nan({"float_all_special": 100.0}) + .sum("float_all_special", 0, 200) + ), + expected_df=pd.DataFrame( + [[5 + 100 + 100]], columns=["float_all_special_sum"] + ), + ), ] ) def test_joins( @@ -710,7 +751,7 @@ def test_joins( query: Query, expected_df: pd.DataFrame, ): - inf_budget = PureDPBudget(float("inf")) + inf_budget = PureDPBudget.inf() sess = ( Session.Builder() .with_id_space("default_id_space") @@ -722,3 +763,120 @@ def test_joins( ) result = sess.evaluate(query, inf_budget) assert_frame_equal_with_sort(result.toPandas(), expected_df) + + +@parametrize( + [ + Case("private_int_remove_nulls")( + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .rename({"int_no_null": "int_joined"}) + .join_private( + QueryBuilder("private_2").rename({"int_nulls": "int_joined"}), + join_columns=["int_joined"], + truncation_strategy_left=TruncationStrategy.DropExcess(30), + truncation_strategy_right=TruncationStrategy.DropExcess(30), + ) + ), + expected_col=( + "int_joined", + ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + ), + ), + Case( + "private_float_remove_both", + marks=pytest.mark.xfail( + True, reason="https://github.com/opendp/tumult-analytics/issues/108" + ), + )( + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .drop_null_and_nan(["float_all_special"]) + .join_private( + QueryBuilder("private").drop_infinity(["float_all_special"]), + join_columns=["float_all_special"], + truncation_strategy_left=TruncationStrategy.DropExcess(30), + truncation_strategy_right=TruncationStrategy.DropExcess(30), + ) + ), + expected_col=( + "float_all_special", + ColumnDescriptor( + ColumnType.DECIMAL, + allow_null=False, + allow_nan=False, + allow_inf=False, + ), + ), + ), + Case("public_int_remove_nulls_from_right")( + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .select(["int_no_null"]) + .rename({"int_no_null": "int_nulls"}) + .join_public( + "public", + join_columns=["int_nulls"], + ) + ), + expected_col=( + "int_nulls", + ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + ), + ), + Case("public_int_remove_nulls_from_left")( + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .rename({"int_nulls": "new_int"}) + .join_public( + "public", + join_columns=["new_int"], + ) + ), + expected_col=( + "new_int", + ColumnDescriptor(ColumnType.INTEGER, allow_null=False), + ), + ), + Case("public_int_keep_null_on_left_join")( + protected_change=AddOneRow(), + query=( + QueryBuilder("private") + .rename({"int_nulls": "new_int"}) + .join_public( + "public", + join_columns=["new_int"], + how="left", + ) + ), + expected_col=( + "int_nulls", + ColumnDescriptor(ColumnType.INTEGER, allow_null=True), + ), + ), + ] +) +def test_join_schema( + sdf_special_values: DataFrame, + sdf_for_joins: DataFrame, + protected_change: ProtectedChange, + query: Query, + expected_col: Dict["str", ColumnDescriptor], +): + inf_budget = PureDPBudget.inf() + sess = ( + Session.Builder() + .with_id_space("default_id_space") + .with_private_dataframe("private", sdf_special_values, protected_change) + .with_private_dataframe("private_2", sdf_for_joins, protected_change) + .with_public_dataframe("public", sdf_for_joins) + .with_privacy_budget(inf_budget) + .build() + ) + sess.create_view(query, "view", cache=False) + schema = sess.get_schema("view") + assert expected_col in schema.items() From 41e971e4b56b8fa4d37760ec416ccfa188055fa9 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Sat, 1 Nov 2025 17:16:11 +0100 Subject: [PATCH 10/16] yay the xfail is a pass now --- test/system/session/test_special_values.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index e932c83a..dacfdc2c 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -784,12 +784,7 @@ def test_joins( ColumnDescriptor(ColumnType.INTEGER, allow_null=False), ), ), - Case( - "private_float_remove_both", - marks=pytest.mark.xfail( - True, reason="https://github.com/opendp/tumult-analytics/issues/108" - ), - )( + Case("private_float_remove_both")( protected_change=AddOneRow(), query=( QueryBuilder("private") From aadba5004287fb488feb28f9fd690bb9c00387cb Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 3 Nov 2025 21:36:51 +0100 Subject: [PATCH 11/16] more tests to please the review gods --- test/system/session/test_special_values.py | 73 ++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index dacfdc2c..9e300e2f 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -17,6 +17,7 @@ AddRowsWithID, ColumnDescriptor, ColumnType, + KeySet, MaxRowsPerID, ProtectedChange, PureDPBudget, @@ -94,6 +95,78 @@ def special_values_dataframe(spark): return sdf +@parametrize( + [ + Case("int_sum")( + # There are 29 1s in the "int_nulls" column and one null, which should be + # dropped by default. + query=QueryBuilder("private").sum("int_nulls", 0, 1), + expected_df=pd.DataFrame( + [[29]], + columns=["int_nulls_sum"], + ), + ), + Case("count_distinct")( + # Nulls, nans, and infinities count as distinct values in a count_distinct + # query. All other values in the "float_all_special" are 1s. + query=QueryBuilder("private").count_distinct(["float_all_special"]), + expected_df=pd.DataFrame( + [[5]], + columns=["count_distinct(float_all_special)"], + ), + ), + Case("count_distinct_deduplicates")( + # The "float_infs" column contains 1s, two positive infinities, and two + # negative infinities. + query=QueryBuilder("private").count_distinct(["float_infs"]), + expected_df=pd.DataFrame( + [[3]], + columns=["count_distinct(float_infs)"], + ), + ), + Case("float_average")( + # In the "float_all_special" column, there are 26 1s, one null (dropped), + # one NaN (dropped), one negative infinity (clamped to 50), and one positive + # infinity (clamed to 100). + query=QueryBuilder("private").sum("float_all_special", -100, 300), + expected_df=pd.DataFrame( + [[226]], # 26-100+300 + columns=["float_all_special_sum"], + ), + ), + Case("group_by_null")( + # Nulls can be used as group-by + query=( + QueryBuilder("private") + .groupby( + KeySet.from_dict( + {"date_nulls": [datetime.date(2000, 1, 1), None]} + ) + ) + .count() + ), + expected_df=pd.DataFrame( + [[datetime.date(2000, 1, 1), 29], [None, 1]], + columns=["date_nulls", "count"], + ), + ), + ] +) +def test_default_behavior( + sdf_special_values: DataFrame, query: Query, expected_df: pd.DataFrame +): + inf_budget = PureDPBudget(float("inf")) + sess = Session.from_dataframe( + inf_budget, + "private", + sdf_special_values, + AddOneRow(), + ) + result = sess.evaluate(query, inf_budget) + print(expected_df) + assert_frame_equal_with_sort(result.toPandas(), expected_df) + + @parametrize( [ Case("int_noop")( From d1236b457a057cd6ad487815f513cb4a7ca67985 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 3 Nov 2025 21:41:28 +0100 Subject: [PATCH 12/16] better test comments --- test/system/session/test_special_values.py | 12 +++++++++--- test/unit/query_expr_compiler/test_rewrite_rules.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index 9e300e2f..36608731 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -542,7 +542,7 @@ def test_drop_infinity( @parametrize( [ Case("works_with_nulls")( - # Check that + # get_bounds doesn't explode when called on a null column query=QueryBuilder("private").get_bounds("int_nulls"), expected_df=pd.DataFrame( [[-1, 1]], @@ -550,7 +550,7 @@ def test_drop_infinity( ), ), Case("works_with_nan")( - # Check that + # Same with nans query=QueryBuilder("private").get_bounds("float_nans"), expected_df=pd.DataFrame( [[-1, 1]], @@ -558,7 +558,7 @@ def test_drop_infinity( ), ), Case("works_with_infinity")( - # Check that + # Same with infinities query=QueryBuilder("private").get_bounds("float_infs"), expected_df=pd.DataFrame( [[-1, 1]], @@ -841,6 +841,7 @@ def test_joins( @parametrize( [ Case("private_int_remove_nulls")( + # Null joined with no nulls = no nulls protected_change=AddOneRow(), query=( QueryBuilder("private") @@ -858,6 +859,7 @@ def test_joins( ), ), Case("private_float_remove_both")( + # All special joined with only nulls & nan = only nulls & nan protected_change=AddOneRow(), query=( QueryBuilder("private") @@ -880,6 +882,7 @@ def test_joins( ), ), Case("public_int_remove_nulls_from_right")( + # No nulls joined with nulls = no nulls (public version) protected_change=AddOneRow(), query=( QueryBuilder("private") @@ -896,6 +899,8 @@ def test_joins( ), ), Case("public_int_remove_nulls_from_left")( + # Nulls joined with no nulls = no nulls (reverse) + protected_change=AddOneRow(), protected_change=AddOneRow(), query=( QueryBuilder("private") @@ -911,6 +916,7 @@ def test_joins( ), ), Case("public_int_keep_null_on_left_join")( + # Nulls *left* joined with no nulls = nulls protected_change=AddOneRow(), query=( QueryBuilder("private") diff --git a/test/unit/query_expr_compiler/test_rewrite_rules.py b/test/unit/query_expr_compiler/test_rewrite_rules.py index a0a96912..c5b4a249 100644 --- a/test/unit/query_expr_compiler/test_rewrite_rules.py +++ b/test/unit/query_expr_compiler/test_rewrite_rules.py @@ -479,7 +479,7 @@ def test_special_value_handling_count_unaffected( ] ] + [ - # NaNs must also be dropped added if needed + # NaNs must also be dropped if needed Case("drop_nan")( col_desc=ColumnDescriptor( ColumnType.DECIMAL, allow_null=False, allow_nan=True, allow_inf=False From 0e34fdb756b94ddf352cfa51fcc025636f207ee4 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 3 Nov 2025 21:43:07 +0100 Subject: [PATCH 13/16] lint --- test/system/session/test_special_values.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index 36608731..c00765e3 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -139,9 +139,7 @@ def special_values_dataframe(spark): query=( QueryBuilder("private") .groupby( - KeySet.from_dict( - {"date_nulls": [datetime.date(2000, 1, 1), None]} - ) + KeySet.from_dict({"date_nulls": [datetime.date(2000, 1, 1), None]}) ) .count() ), From 30d34b934490e563818afdf59116e191e97348da Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 3 Nov 2025 21:54:12 +0100 Subject: [PATCH 14/16] whoops where did that come from --- test/system/session/test_special_values.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/system/session/test_special_values.py b/test/system/session/test_special_values.py index c00765e3..846af57f 100644 --- a/test/system/session/test_special_values.py +++ b/test/system/session/test_special_values.py @@ -899,7 +899,6 @@ def test_joins( Case("public_int_remove_nulls_from_left")( # Nulls joined with no nulls = no nulls (reverse) protected_change=AddOneRow(), - protected_change=AddOneRow(), query=( QueryBuilder("private") .rename({"int_nulls": "new_int"}) From 14345cd18cddd5a7a652855d7960fed815c68510 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 3 Nov 2025 22:10:15 +0100 Subject: [PATCH 15/16] why would this start failing now?! --- src/tmlt/analytics/query_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tmlt/analytics/query_builder.py b/src/tmlt/analytics/query_builder.py index ea1a76a1..478e9f33 100644 --- a/src/tmlt/analytics/query_builder.py +++ b/src/tmlt/analytics/query_builder.py @@ -2995,9 +2995,9 @@ def quantile( .. note:: If the column being measured contains NaN or null values, a - :meth:`~drop_null_and_nan` query will be performed first. If the column - being measured contains infinite values, these values will be clamped - between ``low`` and ``high``. + :meth:`~QueryBuilder.drop_null_and_nan` query will be performed first. If + the column being measured contains infinite values, these values will be + clamped between ``low`` and ``high``. .. >>> from tmlt.analytics import ( ... AddOneRow, From 801d2876575f3122e18e773281fdd0d70c81b34a Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 3 Nov 2025 23:15:05 +0100 Subject: [PATCH 16/16] remove visitor tests checking for special values; they are now handled in the rewriting rules instead --- .../test_measurement_visitor.py | 52 ------------------- 1 file changed, 52 deletions(-) diff --git a/test/unit/query_expr_compiler/test_measurement_visitor.py b/test/unit/query_expr_compiler/test_measurement_visitor.py index e8791f98..7be38fa6 100644 --- a/test/unit/query_expr_compiler/test_measurement_visitor.py +++ b/test/unit/query_expr_compiler/test_measurement_visitor.py @@ -635,24 +635,6 @@ def test_visit_groupby_count_distinct( ] ), ), - ( - QueryBuilder("private").quantile( - low=-100, - high=100, - name="custom_output_column", - column="null_and_nan", - quantile=0.1, - ), - PureDP(), - NoiseInfo( - [ - { - "noise_mechanism": _NoiseMechanism.EXPONENTIAL, - "noise_parameter": 3.3333333333333326, - } - ] - ), - ), ( QueryBuilder("private") .groupby(KeySet.from_dict({"B": [0, 1]})) @@ -673,26 +655,6 @@ def test_visit_groupby_count_distinct( ] ), ), - ( - QueryBuilder("private") - .groupby(KeySet.from_dict({"B": [0, 1]})) - .quantile( - column="null_and_inf", - name="quantile", - low=123.345, - high=987.65, - quantile=0.25, - ), - PureDP(), - NoiseInfo( - [ - { - "noise_mechanism": _NoiseMechanism.EXPONENTIAL, - "noise_parameter": 3.3333333333333326, - } - ] - ), - ), ( QueryBuilder("private") .groupby(KeySet.from_dict({"A": ["zero"]})) @@ -707,20 +669,6 @@ def test_visit_groupby_count_distinct( ] ), ), - ( - QueryBuilder("private") - .groupby(KeySet.from_dict({"A": ["zero"]})) - .quantile(quantile=0.5, low=0, high=1, column="nan_and_inf"), - RhoZCDP(), - NoiseInfo( - [ - { - "noise_mechanism": _NoiseMechanism.EXPONENTIAL, - "noise_parameter": 2.9814239699997196, - } - ] - ), - ), ( QueryBuilder("private") .groupby(KeySet.from_dict({"A": ["zero"]}))