From 6c340d5fcfdc42bb1a2f03e04cda54a3437e1f17 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 13 Oct 2025 17:03:11 +0200 Subject: [PATCH 1/8] Removes OutputSchemaVisitor, add schema() as a QueryExpr method --- src/tmlt/analytics/_query_expr.py | 838 +++++++++++++++- .../_base_measurement_visitor.py | 31 +- .../_base_transformation_visitor.py | 5 +- .../_query_expr_compiler/_compiler.py | 11 +- .../_measurement_visitor.py | 8 +- .../_output_schema_visitor.py | 933 ------------------ .../test_measurement_visitor.py | 7 +- .../transformation_visitor/test_add_keys.py | 13 +- .../transformation_visitor/test_add_rows.py | 15 +- ...tor.py => test_query_expression_schema.py} | 123 ++- 10 files changed, 922 insertions(+), 1062 deletions(-) delete mode 100644 src/tmlt/analytics/_query_expr_compiler/_output_schema_visitor.py rename test/unit/{query_expr_compiler/test_output_schema_visitor.py => test_query_expression_schema.py} (95%) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 46085b23..0a64e0e4 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -13,18 +13,31 @@ import datetime from abc import ABC, abstractmethod -from dataclasses import dataclass +from collections.abc import Collection +from dataclasses import dataclass, replace from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession +from tmlt.core.domains.spark_domains import SparkDataFrameDomain +from tmlt.core.utils.join import domain_after_join from typeguard import check_type from tmlt.analytics import AnalyticsInternalError +from tmlt.analytics._catalog import Catalog, PrivateTable, PublicTable from tmlt.analytics._coerce_spark_schema import coerce_spark_schema_or_fail -from tmlt.analytics._schema import FrozenDict, Schema +from tmlt.analytics._schema import ( + ColumnDescriptor, + ColumnType, + FrozenDict, + Schema, + analytics_to_py_types, + analytics_to_spark_columns_descriptor, + analytics_to_spark_schema, + spark_schema_to_analytics_columns, +) from tmlt.analytics.config import config -from tmlt.analytics.constraints import Constraint +from tmlt.analytics.constraints import Constraint, MaxGroupsPerID, MaxRowsPerGroupPerID from tmlt.analytics.keyset import KeySet from tmlt.analytics.truncation_strategy import TruncationStrategy @@ -193,8 +206,20 @@ def __post_init__(self): " (_), and it cannot start with a number, or contain any spaces." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + if self.source_id not in catalog.tables: + raise ValueError(f"Query references nonexistent table '{self.source_id}'") + table = catalog.tables[self.source_id] + if not isinstance(table, PrivateTable): + raise ValueError( + f"Attempted query on table '{self.source_id}', which is " + "not a private table." + ) + return table.schema + def accept(self, visitor: "QueryExprVisitor") -> Any: - """Visit this QueryExpr with visitor.""" + """Visits this QueryExpr with visitor.""" return visitor.visit_private_source(self) @@ -217,6 +242,29 @@ def __post_init__(self): check_type(self.child, QueryExpr) check_type(self.columns, Optional[Tuple[str, ...]]) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + if self.columns: + nonexistent_columns = set(self.columns) - set(input_schema) + if nonexistent_columns: + raise ValueError( + f"Nonexistent columns in get_groups query: {nonexistent_columns}" + ) + input_schema = Schema( + {column: input_schema[column] for column in self.columns} + ) + + else: + input_schema = Schema( + { + column: input_schema[column] + for column in input_schema + if column != input_schema.id_column + } + ) + return input_schema + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_get_groups(self) @@ -247,6 +295,10 @@ def __post_init__(self): check_type(self.lower_bound_column, str) check_type(self.upper_bound_column, str) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_get_bounds(self) @@ -279,6 +331,37 @@ def __post_init__(self): ' "" are not allowed' ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + grouping_column = input_schema.grouping_column + id_column = input_schema.id_column + id_space = input_schema.id_space + nonexistent_columns = set(self.column_mapper) - set(input_schema) + if nonexistent_columns: + raise ValueError( + f"Nonexistent columns in rename query: {nonexistent_columns}" + ) + for old, new in self.column_mapper.items(): + if new in input_schema and new != old: + raise ValueError( + f"Cannot rename '{old}' to '{new}': column '{new}' already exists" + ) + if old == grouping_column: + grouping_column = new + if old == id_column: + id_column = new + + return Schema( + { + self.column_mapper.get(column, column): input_schema[column] + for column in input_schema + }, + grouping_column=grouping_column, + id_column=id_column, + id_space=id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_rename(self) @@ -302,6 +385,19 @@ def __post_init__(self): check_type(self.child, QueryExpr) check_type(self.condition, str) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + spark = SparkSession.builder.getOrCreate() + test_df = spark.createDataFrame( + [], schema=analytics_to_spark_schema(input_schema) + ) + try: + test_df.filter(self.condition) + except Exception as e: + raise ValueError(f"Invalid filter condition '{self.condition}': {e}") from e + return input_schema + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_filter(self) @@ -323,6 +419,34 @@ def __post_init__(self): if len(self.columns) != len(set(self.columns)): raise ValueError(f"Column name appears more than once in {self.columns}") + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + grouping_column = input_schema.grouping_column + id_column = input_schema.id_column + if grouping_column is not None and grouping_column not in self.columns: + raise ValueError( + f"Grouping column '{grouping_column}' may not " + "be dropped by select query" + ) + if id_column is not None and id_column not in self.columns: + raise ValueError( + f"ID column '{id_column}' may not be dropped by select query" + ) + + nonexistent_columns = set(self.columns) - set(input_schema) + if nonexistent_columns: + raise ValueError( + f"Nonexistent columns in select query: {nonexistent_columns}" + ) + + return Schema( + {column: input_schema[column] for column in self.columns}, + grouping_column=grouping_column, + id_column=id_column, + id_space=input_schema.id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_select(self) @@ -354,6 +478,46 @@ def __post_init__(self): if self.schema_new_columns.grouping_column is not None: raise ValueError("Map cannot be be used to create grouping columns") + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + new_columns = self.schema_new_columns.column_descs + # Any column created by Map could contain a null value + for name in list(new_columns.keys()): + new_columns[name] = replace(new_columns[name], allow_null=True) + + if self.augment: + overlapping_columns = set(input_schema.keys()) & set(new_columns.keys()) + if overlapping_columns: + raise ValueError( + "New columns in augmenting map must not overwrite " + "existing columns, but found new columns that " + f"already exist: {', '.join(overlapping_columns)}" + ) + return Schema( + {**input_schema, **new_columns}, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + elif input_schema.grouping_column: + raise ValueError( + "Map must set augment=True to ensure that " + f"grouping column '{input_schema.grouping_column}' is not lost." + ) + elif input_schema.id_column: + raise ValueError( + "Map must set augment=True to ensure that " + f"ID column '{input_schema.id_column}' is not lost." + ) + return Schema( + new_columns, + grouping_column=self.schema_new_columns.grouping_column, + id_column=self.schema_new_columns.id_column, + id_space=self.schema_new_columns.id_space, + ) + + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_map(self) @@ -420,6 +584,61 @@ def __post_init__(self): "columns, grouping flat map can only result in 1 new column" ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + if self.schema_new_columns.grouping_column is not None: + if input_schema.grouping_column: + raise ValueError( + "Multiple grouping transformations are used in this query. " + "Only one grouping transformation is allowed." + ) + if input_schema.id_column: + raise ValueError( + "Grouping flat map cannot be used on tables with " + "the AddRowsWithID protected change." + ) + grouping_column = self.schema_new_columns.grouping_column + else: + grouping_column = input_schema.grouping_column + + new_columns = self.schema_new_columns.column_descs + # Any column created by the FlatMap could contain a null value + for name in list(new_columns.keys()): + new_columns[name] = replace(new_columns[name], allow_null=True) + if self.augment: + overlapping_columns = set(input_schema.keys()) & set(new_columns.keys()) + if overlapping_columns: + raise ValueError( + "New columns in augmenting map must not overwrite " + "existing columns, but found new columns that " + f"already exist: {', '.join(overlapping_columns)}" + ) + return Schema( + {**input_schema, **new_columns}, + grouping_column=grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + elif input_schema.grouping_column: + raise ValueError( + "Flat map must set augment=True to ensure that " + f"grouping column '{input_schema.grouping_column}' is not lost." + ) + elif input_schema.id_column: + raise ValueError( + "Flat map must set augment=True to ensure that " + f"ID column '{input_schema.id_column}' is not lost." + ) + + return Schema( + new_columns, + grouping_column=grouping_column, + id_column=self.schema_new_columns.id_column, + id_space=self.schema_new_columns.id_space, + ) + + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_flat_map(self) @@ -470,6 +689,34 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_flat_map_by_id(self) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this Queryself.""" + input_schema = self.child.schema(catalog) + id_column = input_schema.id_column + new_columns = self.schema_new_columns.column_descs + + if not id_column: + raise ValueError( + "Flat-map-by-ID may only be used on tables with ID columns." + ) + if input_schema.grouping_column: + raise AnalyticsInternalError( + "Encountered table with both an ID column and a grouping column." + ) + if id_column in new_columns: + raise ValueError( + "Flat-map-by-ID mapping function output cannot include ID column." + ) + + for name in list(new_columns.keys()): + new_columns[name] = replace(new_columns[name], allow_null=True) + return Schema( + {id_column: input_schema[id_column], **new_columns}, + grouping_column=None, + id_column=id_column, + id_space=input_schema.id_space, + ) + def __eq__(self, other: object) -> bool: """Returns true iff self == other. @@ -486,6 +733,106 @@ def __eq__(self, other: object) -> bool: ) +def _schema_for_join( + left_schema: Schema, + right_schema: Schema, + join_columns: Optional[Tuple[str, ...]], + join_id_space: Optional[str] = None, + how: str = "inner", +) -> Schema: + """Return the resulting schema from joining two tables. + + It is assumed that if either schema has an ID column, the one from + left_schema should be used. This is because the appropriate behavior here + depends on the type of join being performed, so checks for compatibility of + ID columns must happen outside this function. + + Args: + left_schema: Schema for the left table. + right_schema: Schema for the right table. + join_columns: The set of columns to join on. + join_id_space: The ID space of the resulting join. + how: The type of join to perform. Default is "inner". + """ + if left_schema.grouping_column is None: + grouping_column = right_schema.grouping_column + elif right_schema.grouping_column is None: + grouping_column = left_schema.grouping_column + elif left_schema.grouping_column == right_schema.grouping_column: + grouping_column = left_schema.grouping_column + else: + raise ValueError( + "Joining tables which both have grouping columns is only supported " + "if they have the same grouping column" + ) + common_columns = set(left_schema) & set(right_schema) + if join_columns is None and not common_columns: + raise ValueError("Tables have no common columns to join on") + if join_columns is not None and not join_columns: + # This error case should be caught when constructing the query + # expression, so it should never get here. + raise AnalyticsInternalError("Empty list of join columns provided.") + + join_columns = ( + join_columns + if join_columns + else tuple(sorted(common_columns, key=list(left_schema).index)) + ) + + if not set(join_columns) <= common_columns: + raise ValueError("Join columns must be common to both tables") + + for column in join_columns: + if left_schema[column].column_type != right_schema[column].column_type: + raise ValueError( + "Join columns must have identical types on both tables, " + f"but column '{column}' does not: {left_schema[column]} and " + f"{right_schema[column]} are incompatible" + ) + + join_column_schemas = {column: left_schema[column] for column in join_columns} + output_schema = { + **join_column_schemas, + **{ + column + ("_left" if column in common_columns else ""): left_schema[column] + for column in left_schema + if column not in join_columns + }, + **{ + column + + ("_right" if column in common_columns else ""): right_schema[column] + for column in right_schema + if column not in join_columns + }, + } + # Use Core's join utilities for determining whether a column can be null + # TODO: This could potentially be used more in this function + output_domain = domain_after_join( + left_domain=SparkDataFrameDomain( + analytics_to_spark_columns_descriptor(left_schema) + ), + right_domain=SparkDataFrameDomain( + analytics_to_spark_columns_descriptor(right_schema) + ), + on=list(join_columns), + how=how, + nulls_are_equal=True, + ) + for column in output_schema: + col_schema = output_schema[column] + output_schema[column] = ColumnDescriptor( + column_type=col_schema.column_type, + allow_null=output_domain.schema[column].allow_null, + allow_nan=col_schema.allow_nan, + allow_inf=col_schema.allow_inf, + ) + return Schema( + output_schema, + grouping_column=grouping_column, + id_column=left_schema.id_column, + id_space=join_id_space, + ) + @dataclass(frozen=True) class JoinPrivate(QueryExpr): """Returns the join of two private tables. @@ -527,6 +874,48 @@ def __post_init__(self): if len(self.join_columns) != len(set(self.join_columns)): raise ValueError("Join columns must be distinct") + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr. + + The ordering of output columns are: + + 1. The join columns + 2. Columns that are only in the left table + 3. Columns that are only in the right table + 4. Columns that are in both tables, but not included in the join columns. These + columns are included with _left and _right suffixes.""" + left_schema = self.child.schema(catalog) + right_schema = self.right_operand_expr.schema(catalog) + if left_schema.id_column != right_schema.id_column: + if left_schema.id_column is None or right_schema.id_column is None: + raise ValueError( + "Private joins can only be performed between two tables " + "with the same type of protected change" + ) + raise ValueError( + "Private joins between tables with the AddRowsWithID " + "protected change are only possible when the ID columns of " + "the two tables have the same name" + ) + if ( + left_schema.id_space + and right_schema.id_space + and left_schema.id_space != right_schema.id_space + ): + raise ValueError( + "Private joins between tables with the AddRowsWithID protected change" + " are only possible when both tables are in the same ID space" + ) + join_id_space: Optional[str] = None + if left_schema.id_space and right_schema.id_space: + join_id_space = left_schema.id_space + return _schema_for_join( + left_schema=left_schema, + right_schema=right_schema, + join_columns=self.join_columns, + join_id_space=join_id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_join_private(self) @@ -568,6 +957,32 @@ def __post_init__(self): f"Invalid join type '{self.how}': must be 'inner' or 'left'" ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr. + + Has analogous behavior to :meth:`JoinPrivate.schema`, where the private + table is the left table.""" + input_schema = self.child.schema(catalog) + if isinstance(self.public_table, str): + public_table = catalog.tables[self.public_table] + if not isinstance(public_table, PublicTable): + raise ValueError( + f"Attempted public join on table '{self.public_table}', " + "which is not a public table" + ) + right_schema = public_table.schema + else: + right_schema = Schema( + spark_schema_to_analytics_columns(self.public_table.schema) + ) + return _schema_for_join( + left_schema=input_schema, + right_schema=right_schema, + join_columns=self.join_columns, + join_id_space=input_schema.id_space, + how=self.how, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_join_public(self) @@ -671,6 +1086,69 @@ def __post_init__(self): FrozenDict, ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + if ( + input_schema.grouping_column + and input_schema.grouping_column in self.replace_with + ): + raise ValueError( + "Cannot replace null values in column " + f"'{input_schema.grouping_column}', as it is a grouping column." + ) + if input_schema.id_column and input_schema.id_column in self.replace_with: + raise ValueError( + f"Cannot replace null values in column '{input_schema.id_column}', " + "as it is an ID column." + ) + if input_schema.id_column and (len(self.replace_with) == 0): + raise RuntimeWarning( + f"Replacing null values in the ID column '{input_schema.id_column}' " + "is not allowed, so the ID column may still contain null values." + ) + + if len(self.replace_with) != 0: + pytypes = analytics_to_py_types(input_schema) + for col, val in self.replace_with.items(): + if col not in input_schema.keys(): + raise ValueError( + f"Column '{col}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + if not isinstance(val, pytypes[col]): + # it's okay to use an int as a float + # so don't raise an error in that case + if not (isinstance(val, int) and pytypes[col] == float): + raise ValueError( + f"Column '{col}' cannot have nulls replaced with " + f"{repr(val)}, as that value's type does not match the " + f"column type {input_schema[col].column_type.name}" + ) + + columns_to_change = list(dict(self.replace_with).keys()) + if len(columns_to_change) == 0: + columns_to_change = [ + col + for col in input_schema.column_descs.keys() + if (input_schema[col].allow_null or input_schema[col].allow_nan) + and not (col in [input_schema.grouping_column, input_schema.id_column]) + ] + return Schema( + { + name: ColumnDescriptor( + column_type=cd.column_type, + allow_null=(cd.allow_null and not name in columns_to_change), + allow_nan=(cd.allow_nan and not name in columns_to_change), + allow_inf=cd.allow_inf, + ) + for name, cd in input_schema.column_descs.items() + }, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_replace_null_and_nan(self) @@ -710,6 +1188,61 @@ def __init__( object.__setattr__(self, "replace_with", FrozenDict.from_dict(updated_dict)) object.__setattr__(self, "child", child) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this Queryself.""" + input_schema = self.child.schema(catalog) + + if ( + input_schema.grouping_column + and input_schema.grouping_column in self.replace_with + ): + raise ValueError( + "Cannot replace infinite values in column " + f"'{input_schema.grouping_column}', as it is a grouping column" + ) + # Float-valued columns cannot be ID columns, but include this to be safe. + if input_schema.id_column and input_schema.id_column in self.replace_with: + raise ValueError( + f"Cannot replace infinite values in column '{input_schema.id_column}', " + "as it is an ID column" + ) + + columns_to_change = list(self.replace_with.keys()) + if len(columns_to_change) == 0: + columns_to_change = [ + col + for col in input_schema.column_descs.keys() + if input_schema[col].column_type == ColumnType.DECIMAL + ] + else: + for name in self.replace_with: + if name not in input_schema.keys(): + raise ValueError( + f"Column '{name}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + if input_schema[name].column_type != ColumnType.DECIMAL: + raise ValueError( + f"Column '{name}' has a replacement value provided, but it is " + f"of type {input_schema[name].column_type.name} (not " + f"{ColumnType.DECIMAL.name}) and so cannot " + "contain infinite values" + ) + return Schema( + { + name: ColumnDescriptor( + column_type=cd.column_type, + allow_null=cd.allow_null, + allow_nan=cd.allow_nan, + allow_inf=(cd.allow_inf and not name in columns_to_change), + ) + for name, cd in input_schema.column_descs.items() + }, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_replace_infinity(self) @@ -740,6 +1273,57 @@ def __post_init__(self) -> None: check_type(self.child, QueryExpr) check_type(self.columns, Tuple[str, ...]) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + if ( + input_schema.grouping_column + and input_schema.grouping_column in self.columns + ): + raise ValueError( + f"Cannot drop null values in column '{input_schema.grouping_column}', " + "as it is a grouping column" + ) + if input_schema.id_column and input_schema.id_column in self.columns: + raise ValueError( + f"Cannot drop null values in column '{input_schema.id_column}', " + "as it is an ID column." + ) + if input_schema.id_column and len(self.columns) == 0: + raise RuntimeWarning( + f"Replacing null values in the ID column '{input_schema.id_column}' " + "is not allowed, so the ID column may still contain null values." + ) + columns = self.columns + if len(columns) == 0: + columns = tuple( + name + for name, cd in input_schema.column_descs.items() + if (cd.allow_null or cd.allow_nan) + and not name in [input_schema.grouping_column, input_schema.id_column] + ) + else: + for name in columns: + if name not in input_schema.keys(): + raise ValueError( + f"Column '{name}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + return Schema( + { + name: ColumnDescriptor( + column_type=cd.column_type, + allow_null=(cd.allow_null and not name in columns), + allow_nan=(cd.allow_nan and not name in columns), + allow_inf=(cd.allow_inf), + ) + for name, cd in input_schema.column_descs.items() + }, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_drop_null_and_nan(self) @@ -763,6 +1347,63 @@ def __post_init__(self) -> None: check_type(self.child, QueryExpr) check_type(self.columns, Tuple[str, ...]) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this Queryself.""" + input_schema = self.child.schema(catalog) + + if ( + input_schema.grouping_column + and input_schema.grouping_column in self.columns + ): + raise ValueError( + "Cannot drop infinite values in column " + f"'{input_schema.grouping_column}', as it is a grouping column" + ) + # Float-valued columns cannot be ID columns, but include this to be safe. + if input_schema.id_column and input_schema.id_column in self.columns: + raise ValueError( + f"Cannot drop infinite values in column '{input_schema.id_column}', " + "as it is an ID column" + ) + + columns = self.columns + if len(columns) == 0: + columns = tuple( + name + for name, cd in input_schema.column_descs.items() + if (cd.allow_inf) and not name == input_schema.grouping_column + ) + else: + for name in columns: + if name not in input_schema.keys(): + raise ValueError( + f"Column '{name}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + if input_schema[name].column_type != ColumnType.DECIMAL: + raise ValueError( + f"Column '{name}' was given as a column to drop " + "infinite values from, but it is of type" + f"{input_schema[name].column_type.name} (not " + f"{ColumnType.DECIMAL.name}) and so cannot " + "contain infinite values" + ) + + return Schema( + { + name: ColumnDescriptor( + column_type=cd.column_type, + allow_null=cd.allow_null, + allow_nan=cd.allow_nan, + allow_inf=(cd.allow_inf and not name in columns), + ) + for name, cd in input_schema.column_descs.items() + }, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_drop_infinity(self) @@ -782,10 +1423,163 @@ class EnforceConstraint(QueryExpr): Appropriate values here vary depending on the constraint. These options are to support advanced use cases, and generally should not be used.""" + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + + if not input_schema.id_column: + raise ValueError( + f"Constraint {self.constraint} can only be applied to tables" + " with the AddRowsWithID protected change" + ) + if isinstance(self.constraint, (MaxGroupsPerID, MaxRowsPerGroupPerID)): + grouping_column = self.constraint.grouping_column + if grouping_column not in input_schema: + raise ValueError( + f"The grouping column of constraint {self.constraint}" + " does not exist in this table; available columns" + f" are: {', '.join(input_schema.keys())}" + ) + if grouping_column == input_schema.id_column: + raise ValueError( + f"The grouping column of constraint {self.constraint} cannot be" + " the ID column of the table it is applied to" + ) + return input_schema + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_enforce_constraint(self) +def _schema_for_groupby( + query: Union[ + "GroupByBoundedAverage", + "GroupByBoundedSTDEV", + "GroupByBoundedSum", + "GroupByBoundedVariance", + "GroupByCount", + "GroupByCountDistinct", + "GroupByQuantile", + ], + catalog: Catalog, +) -> Schema: + """Validates and returns the schema of a group-by QueryExpr. + + Args: + query: Query expression to be validated. + catalog: The catalog. + + Returns: + Output schema of current QueryExpr + """ + input_schema = query.child.schema(catalog) + + # Validating group-by columns + if isinstance(query.groupby_keys, KeySet): + # Checks that the KeySet is valid + schema = query.groupby_keys.schema() + groupby_columns: Collection[str] = schema.keys() + + for column_name, column_desc in schema.items(): + try: + input_column_desc = input_schema[column_name] + except KeyError as e: + raise KeyError( + f"Groupby column '{column_name}' is not in the input schema." + ) from e + if column_desc.column_type != input_column_desc.column_type: + raise ValueError( + f"Groupby column '{column_name}' has type" + f" '{column_desc.column_type.name}', but the column with the same " + f"name in the input data has type " + f"'{input_column_desc.column_type.name}' instead." + ) + elif isinstance(query.groupby_keys, tuple): + # Checks that the listed groupby columns exist in the schema + for col in query.groupby_keys: + if col not in input_schema: + raise ValueError(f"Groupby column '{col}' is not in the input schema.") + groupby_columns = query.groupby_keys + else: + raise AnalyticsInternalError( + f"Unexpected groupby_keys type: {type(query.groupby_keys)}." + ) + + # Validating compatibility between grouping columns and group-by columns + grouping_column = input_schema.grouping_column + if grouping_column is not None and grouping_column not in groupby_columns: + raise ValueError( + f"Column '{grouping_column}' produced by grouping transformation " + f"is not in groupby columns {list(groupby_columns)}." + ) + if ( + not isinstance(query, (GroupByCount, GroupByCountDistinct)) + and query.measure_column in groupby_columns + ): + raise ValueError( + "Column to aggregate must be a non-grouped column, not " + f"'{query.measure_column}'." + ) + + # Validating the measure column + if isinstance(query, (GetBounds, GroupByQuantile, GroupByBoundedSum, + GroupByBoundedSTDEV, GroupByBoundedAverage, + GroupByBoundedVariance)): + if query.measure_column not in input_schema: + raise ValueError( + f"{type(query).__name__} query's measure column " + f"'{query.measure_column}' does not exist in the table." + ) + if input_schema[query.measure_column].column_type not in [ + ColumnType.INTEGER, + ColumnType.DECIMAL, + ]: + raise ValueError( + f"{type(query).__name__} query's measure column " + f"'{query.measure_column}' has invalid type " + f"'{input_schema[query.measure_column].column_type.name}'. " + "Expected types: 'INTEGER' or 'DECIMAL'." + ) + if input_schema.id_column and (input_schema.id_column == query.measure_column): + raise ValueError( + f"{type(query).__name__} query's measure column is the same as the " + f"privacy ID column({input_schema.id_column}) on a table with the " + "AddRowsWithID protected change." + ) + + # Determining the output column types & names + if isinstance(query, (GroupByCount, GroupByCountDistinct)): + output_column_type = ColumnType.INTEGER + elif isinstance(query, (GetBounds, GroupByBoundedSum)): + output_column_type = input_schema[query.measure_column].column_type + elif isinstance(query, (GroupByQuantile, GroupByBoundedSum, + GroupByBoundedSTDEV, GroupByBoundedAverage, + GroupByBoundedVariance)): + output_column_type = ColumnType.DECIMAL + else: + raise AnalyticsInternalError("Unexpected QueryExpr type: {type(query)}.") + if isinstance(query, GetBounds): + output_columns = { + query.lower_bound_column: ColumnDescriptor(output_column_type, allow_null=False), + query.upper_bound_column: ColumnDescriptor(output_column_type, allow_null=False), + } + else: + output_columns = { + query.output_column: ColumnDescriptor(output_column_type, allow_null=False), + } + + return Schema( + { + **{column: input_schema[column] for column in groupby_columns}, + **output_columns + }, + grouping_column=None, + id_column=None, + ) @dataclass(frozen=True) class GroupByCount(QueryExpr): @@ -813,6 +1607,10 @@ def __post_init__(self): check_type(self.output_column, str) check_type(self.mechanism, CountMechanism) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_count(self) @@ -849,6 +1647,10 @@ def __post_init__(self): check_type(self.output_column, str) check_type(self.mechanism, CountDistinctMechanism) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_count_distinct(self) @@ -910,6 +1712,10 @@ def __post_init__(self): f"the upper bound '{self.high}'." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_quantile(self) @@ -971,6 +1777,10 @@ def __post_init__(self): f"the upper bound '{self.high}'." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_bounded_sum(self) @@ -1032,6 +1842,10 @@ def __post_init__(self): f"the upper bound '{self.high}'." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_bounded_average(self) @@ -1093,6 +1907,10 @@ def __post_init__(self): f"the upper bound '{self.high}'." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_bounded_variance(self) @@ -1155,6 +1973,10 @@ def __post_init__(self): f"the upper bound '{self.high}'." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return _schema_for_groupby(self, catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_groupby_bounded_stdev(self) @@ -1187,11 +2009,16 @@ def __post_init__(self) -> None: check_type(self.column, str) check_type(self.threshold, float) + def schema(self, catalog: Catalog) -> Schema: + """Returns the resulting schema from evaluating this QueryExpr.""" + return self.child.schema(catalog) + def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_suppress_aggregates(self) + class QueryExprVisitor(ABC): """A base class for implementing visitors for :class:`QueryExpr`.""" @@ -1314,3 +2141,4 @@ def visit_groupby_bounded_stdev(self, expr: GroupByBoundedSTDEV) -> Any: def visit_suppress_aggregates(self, expr: SuppressAggregates) -> Any: """Visit a :class:`SuppressAggregates`.""" raise NotImplementedError + diff --git a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py index 19dee7d3..4849e4bc 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py +++ b/src/tmlt/analytics/_query_expr_compiler/_base_measurement_visitor.py @@ -101,9 +101,6 @@ SuppressAggregates, VarianceMechanism, ) -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._schema import ColumnType, FrozenDict, Schema from tmlt.analytics._table_identifier import Identifier from tmlt.analytics._table_reference import TableReference @@ -708,7 +705,7 @@ def _pick_noise_for_non_count( GroupByQuantile and GetBounds only supports one noise mechanism, so it is not included here. """ - measure_column_type = query.child.accept(OutputSchemaVisitor(self.catalog))[ + measure_column_type = query.child.schema(self.catalog)[ query.measure_column ].column_type requested_mechanism: NoiseMechanism @@ -802,7 +799,7 @@ def _add_special_value_handling_to_query( These changes are added immediately before the groupby aggregation in the query. """ - expected_schema = query.child.accept(OutputSchemaVisitor(self.catalog)) + expected_schema = query.child.schema(self.catalog) # You can't perform these queries on nulls, NaNs, or infinite values # so check for those @@ -1046,7 +1043,7 @@ def visit_groupby_count(self, expr: GroupByCount) -> Tuple[Measurement, NoiseInf self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1130,7 +1127,7 @@ def visit_groupby_count_distinct( self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) if isinstance(expr.groupby_keys, KeySet): groupby_cols = tuple(expr.groupby_keys.dataframe().columns) @@ -1150,7 +1147,7 @@ def visit_groupby_count_distinct( ) = self._visit_child_transformation(expr.child, mechanism) constrained_query = _generate_constrained_count_distinct( expr, - expr.child.accept(OutputSchemaVisitor(self.catalog)), + expr.child.schema(self.catalog), child_constraints, ) if constrained_query is not None: @@ -1251,7 +1248,7 @@ def visit_groupby_quantile( self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): @@ -1264,7 +1261,7 @@ def visit_groupby_quantile( self.adjusted_budget ) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) child_transformation, child_ref = self._truncate_table( *self._visit_child_transformation(expr.child, self.default_mechanism), @@ -1346,7 +1343,7 @@ def visit_groupby_bounded_sum( self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): @@ -1442,7 +1439,7 @@ def visit_groupby_bounded_average( self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): @@ -1538,7 +1535,7 @@ def visit_groupby_bounded_variance( self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): @@ -1634,7 +1631,7 @@ def visit_groupby_bounded_stdev( self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): @@ -1726,7 +1723,7 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]: self._validate_approxDP_and_adjust_budget(expr) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) expr = self._add_special_value_handling_to_query(expr) if isinstance(expr.groupby_keys, KeySet): @@ -1740,7 +1737,7 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]: ) # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) child_transformation, child_ref = self._truncate_table( *self._visit_child_transformation(expr.child, NoiseMechanism.GEOMETRIC), @@ -1823,7 +1820,7 @@ def visit_suppress_aggregates( self, expr: SuppressAggregates ) -> Tuple[Measurement, NoiseInfo]: """Create a measurement from a SuppressAggregates query expression.""" - expr.accept(OutputSchemaVisitor(self.catalog)) + expr.schema(self.catalog) child_measurement, noise_info = expr.child.accept(self) if not isinstance(child_measurement, Measurement): diff --git a/src/tmlt/analytics/_query_expr_compiler/_base_transformation_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_base_transformation_visitor.py index 599530af..1180d04c 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_base_transformation_visitor.py +++ b/src/tmlt/analytics/_query_expr_compiler/_base_transformation_visitor.py @@ -157,9 +157,6 @@ propagate_select, propagate_unmodified, ) -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._schema import ( ColumnDescriptor, ColumnType, @@ -245,7 +242,7 @@ def validate_transformation( catalog: Catalog, ): """Ensure that a query's transformation is valid on a given catalog.""" - expected_schema = query.accept(OutputSchemaVisitor(catalog)) + expected_schema = query.schema(catalog) expected_output_domain = SparkDataFrameDomain( analytics_to_spark_columns_descriptor(expected_schema) ) diff --git a/src/tmlt/analytics/_query_expr_compiler/_compiler.py b/src/tmlt/analytics/_query_expr_compiler/_compiler.py index 01d0fffe..c1f8077c 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_compiler.py +++ b/src/tmlt/analytics/_query_expr_compiler/_compiler.py @@ -22,9 +22,6 @@ from tmlt.analytics._noise_info import NoiseInfo from tmlt.analytics._query_expr import QueryExpr from tmlt.analytics._query_expr_compiler._measurement_visitor import MeasurementVisitor -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._query_expr_compiler._transformation_visitor import ( TransformationVisitor, ) @@ -108,13 +105,13 @@ def output_measure(self) -> Union[PureDP, ApproxDP, RhoZCDP]: @staticmethod def query_schema(query: QueryExpr, catalog: Catalog) -> Schema: """Return the schema created by a given query.""" - result = query.accept(OutputSchemaVisitor(catalog=catalog)) - if not isinstance(result, Schema): + schema = query.schema(catalog) + if not isinstance(schema, Schema): raise AnalyticsInternalError( "Schema for this query is not a Schema but is instead a(n) " f"{type(result)}." ) - return result + return schema def __call__( self, @@ -207,7 +204,7 @@ def build_transformation( catalog: The catalog, used only for query validation. table_constraints: A mapping of tables to the existing constraints on them. """ - query.accept(OutputSchemaVisitor(catalog)) + query.schema(catalog) transformation_visitor = TransformationVisitor( input_domain=input_domain, diff --git a/src/tmlt/analytics/_query_expr_compiler/_measurement_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_measurement_visitor.py index 5118a1e0..43e6c60e 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_measurement_visitor.py +++ b/src/tmlt/analytics/_query_expr_compiler/_measurement_visitor.py @@ -26,9 +26,6 @@ from tmlt.analytics._query_expr_compiler._base_measurement_visitor import ( BaseMeasurementVisitor, ) -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._query_expr_compiler._transformation_visitor import ( TransformationVisitor, ) @@ -75,10 +72,7 @@ def visit_get_groups(self, expr: GetGroups) -> Tuple[Measurement, NoiseInfo]: if not isinstance(self.budget, ApproxDPBudget): raise ValueError("GetGroups is only supported with ApproxDPBudgets.") - # Peek at the schema, to see if there are errors there - expr.accept(OutputSchemaVisitor(self.catalog)) - - schema = expr.child.accept(OutputSchemaVisitor(self.catalog)) + schema = expr.child.schema(self.catalog) # Set the columns if no columns were provided. if expr.columns: diff --git a/src/tmlt/analytics/_query_expr_compiler/_output_schema_visitor.py b/src/tmlt/analytics/_query_expr_compiler/_output_schema_visitor.py deleted file mode 100644 index 51f1ad9d..00000000 --- a/src/tmlt/analytics/_query_expr_compiler/_output_schema_visitor.py +++ /dev/null @@ -1,933 +0,0 @@ -"""Defines a visitor for determining the output schemas of query expressions.""" - -# SPDX-License-Identifier: Apache-2.0 -# Copyright Tumult Labs 2025 - -from collections.abc import Collection -from dataclasses import replace -from typing import Optional, Tuple, Union - -from pyspark.sql import SparkSession -from tmlt.core.domains.spark_domains import SparkDataFrameDomain -from tmlt.core.utils.join import domain_after_join - -from tmlt.analytics import AnalyticsInternalError -from tmlt.analytics._catalog import Catalog, PrivateTable, PublicTable -from tmlt.analytics._query_expr import ( - DropInfinity, - DropNullAndNan, - EnforceConstraint, - Filter, - FlatMap, - FlatMapByID, - GetBounds, - GetGroups, - GroupByBoundedAverage, - GroupByBoundedSTDEV, - GroupByBoundedSum, - GroupByBoundedVariance, - GroupByCount, - GroupByCountDistinct, - GroupByQuantile, - JoinPrivate, - JoinPublic, - Map, - PrivateSource, - QueryExprVisitor, - Rename, - ReplaceInfinity, - ReplaceNullAndNan, - Select, - SuppressAggregates, -) -from tmlt.analytics._schema import ( - ColumnDescriptor, - ColumnType, - Schema, - analytics_to_py_types, - analytics_to_spark_columns_descriptor, - analytics_to_spark_schema, - spark_schema_to_analytics_columns, -) -from tmlt.analytics.constraints import MaxGroupsPerID, MaxRowsPerGroupPerID -from tmlt.analytics.keyset import KeySet - - -def _output_schema_for_join( - left_schema: Schema, - right_schema: Schema, - join_columns: Optional[Tuple[str, ...]], - join_id_space: Optional[str] = None, - how: str = "inner", -) -> Schema: - """Return the resulting schema from joining two tables. - - It is assumed that if either schema has an ID column, the one from - left_schema should be used. This is because the appropriate behavior here - depends on the type of join being performed, so checks for compatibility of - ID columns must happen outside this function. - - Args: - left_schema: Schema for the left table. - right_schema: Schema for the right table. - join_columns: The set of columns to join on. - join_id_space: The ID space of the resulting join. - how: The type of join to perform. Default is "inner". - """ - if left_schema.grouping_column is None: - grouping_column = right_schema.grouping_column - elif right_schema.grouping_column is None: - grouping_column = left_schema.grouping_column - elif left_schema.grouping_column == right_schema.grouping_column: - grouping_column = left_schema.grouping_column - else: - raise ValueError( - "Joining tables which both have grouping columns is only supported " - "if they have the same grouping column" - ) - common_columns = set(left_schema) & set(right_schema) - if join_columns is None and not common_columns: - raise ValueError("Tables have no common columns to join on") - if join_columns is not None and not join_columns: - # This error case should be caught when constructing the query - # expression, so it should never get here. - raise AnalyticsInternalError("Empty list of join columns provided.") - - join_columns = ( - join_columns - if join_columns - else tuple(sorted(common_columns, key=list(left_schema).index)) - ) - - if not set(join_columns) <= common_columns: - raise ValueError("Join columns must be common to both tables") - - for column in join_columns: - if left_schema[column].column_type != right_schema[column].column_type: - raise ValueError( - "Join columns must have identical types on both tables, " - f"but column '{column}' does not: {left_schema[column]} and " - f"{right_schema[column]} are incompatible" - ) - - join_column_schemas = {column: left_schema[column] for column in join_columns} - output_schema = { - **join_column_schemas, - **{ - column + ("_left" if column in common_columns else ""): left_schema[column] - for column in left_schema - if column not in join_columns - }, - **{ - column - + ("_right" if column in common_columns else ""): right_schema[column] - for column in right_schema - if column not in join_columns - }, - } - # Use Core's join utilities for determining whether a column can be null - # TODO: This could potentially be used more in this function - output_domain = domain_after_join( - left_domain=SparkDataFrameDomain( - analytics_to_spark_columns_descriptor(left_schema) - ), - right_domain=SparkDataFrameDomain( - analytics_to_spark_columns_descriptor(right_schema) - ), - on=list(join_columns), - how=how, - nulls_are_equal=True, - ) - for column in output_schema: - col_schema = output_schema[column] - output_schema[column] = ColumnDescriptor( - column_type=col_schema.column_type, - allow_null=output_domain.schema[column].allow_null, - allow_nan=col_schema.allow_nan, - allow_inf=col_schema.allow_inf, - ) - return Schema( - output_schema, - grouping_column=grouping_column, - id_column=left_schema.id_column, - id_space=join_id_space, - ) - - -def _validate_groupby( - query: Union[ - GroupByBoundedAverage, - GroupByBoundedSTDEV, - GroupByBoundedSum, - GroupByBoundedVariance, - GroupByCount, - GroupByCountDistinct, - GroupByQuantile, - GetBounds, - ], - output_schema_visitor: "OutputSchemaVisitor", -) -> Schema: - """Validate groupby aggregate query. - - Args: - query: Query expression to be validated. - output_schema_visitor: A visitor to get the output schema of an expression. - - Returns: - Output schema of current QueryExpr - """ - input_schema = query.child.accept(output_schema_visitor) - - if isinstance(query.groupby_keys, KeySet): - # Checks that the KeySet is valid - schema = query.groupby_keys.schema() - groupby_columns: Collection[str] = schema.keys() - - for column_name, column_desc in schema.items(): - try: - input_column_desc = input_schema[column_name] - except KeyError as e: - raise KeyError( - f"Groupby column '{column_name}' is not in the input schema." - ) from e - if column_desc.column_type != input_column_desc.column_type: - raise ValueError( - f"Groupby column '{column_name}' has type" - f" '{column_desc.column_type.name}', but the column with the same " - f"name in the input data has type " - f"'{input_column_desc.column_type.name}' instead." - ) - elif isinstance(query.groupby_keys, tuple): - # Checks that the listed groupby columns exist in the schema - for col in query.groupby_keys: - if col not in input_schema: - raise ValueError(f"Groupby column '{col}' is not in the input schema.") - groupby_columns = query.groupby_keys - else: - raise AnalyticsInternalError( - f"Unexpected groupby_keys type: {type(query.groupby_keys)}." - ) - - grouping_column = input_schema.grouping_column - if grouping_column is not None and grouping_column not in groupby_columns: - raise ValueError( - f"Column '{grouping_column}' produced by grouping transformation " - f"is not in groupby columns {list(groupby_columns)}." - ) - if ( - not isinstance(query, (GroupByCount, GroupByCountDistinct)) - and query.measure_column in groupby_columns - ): - raise ValueError( - "Column to aggregate must be a non-grouped column, not " - f"'{query.measure_column}'." - ) - - if isinstance(query, (GroupByCount, GroupByCountDistinct)): - output_column_type = ColumnType.INTEGER - elif isinstance(query, GetBounds): - # Measure column type check not needed, since we check it early in - # OutputSchemaVisitor.visit_get_bounds - output_column_type = input_schema[query.measure_column].column_type - elif isinstance(query, GroupByQuantile): - if input_schema[query.measure_column].column_type not in [ - ColumnType.INTEGER, - ColumnType.DECIMAL, - ]: - raise ValueError( - f"Quantile query's measure column '{query.measure_column}' has invalid" - f" type '{input_schema[query.measure_column].column_type.name}'." - " Expected types: 'INTEGER' or 'DECIMAL'." - ) - output_column_type = ColumnType.DECIMAL - elif isinstance( - query, - ( - GroupByBoundedSum, - GroupByBoundedSTDEV, - GroupByBoundedAverage, - GroupByBoundedVariance, - ), - ): - if input_schema[query.measure_column].column_type not in [ - ColumnType.INTEGER, - ColumnType.DECIMAL, - ]: - raise ValueError( - f"{type(query).__name__} query's measure column " - f"'{query.measure_column}' has invalid type " - f"'{input_schema[query.measure_column].column_type.name}'. " - "Expected types: 'INTEGER' or 'DECIMAL'." - ) - output_column_type = ( - input_schema[query.measure_column].column_type - if isinstance(query, GroupByBoundedSum) - else ColumnType.DECIMAL - ) - else: - raise AssertionError( - "Unexpected QueryExpr type. This should not happen and is" - "probably a bug; please let us know so we can fix it!" - ) - if isinstance(query, GetBounds): - output_schema = Schema( - { - **{column: input_schema[column] for column in groupby_columns}, - **{ - query.lower_bound_column: ColumnDescriptor( - output_column_type, allow_null=False - ) - }, - **{ - query.upper_bound_column: ColumnDescriptor( - output_column_type, allow_null=False - ) - }, - }, - grouping_column=None, - id_column=None, - ) - else: - output_schema = Schema( - { - **{column: input_schema[column] for column in groupby_columns}, - **{ - query.output_column: ColumnDescriptor( - output_column_type, allow_null=False - ) - }, - }, - grouping_column=None, - id_column=None, - ) - return output_schema - - -class OutputSchemaVisitor(QueryExprVisitor): - """A visitor to get the output schema of a query expression.""" - - def __init__(self, catalog: Catalog): - """Visitor constructor. - - Args: - catalog: The catalog defining schemas and relations between tables. - """ - self._catalog = catalog - - def visit_private_source(self, expr: PrivateSource) -> Schema: - """Return the resulting schema from evaluating a PrivateSource.""" - if expr.source_id not in self._catalog.tables: - raise ValueError(f"Query references nonexistent table '{expr.source_id}'") - table = self._catalog.tables[expr.source_id] - if not isinstance(table, PrivateTable): - raise ValueError( - f"Attempted query on table '{expr.source_id}', which is " - "not a private table" - ) - return table.schema - - def visit_rename(self, expr: Rename) -> Schema: - """Returns the resulting schema from evaluating a Rename.""" - input_schema = expr.child.accept(self) - grouping_column = input_schema.grouping_column - id_column = input_schema.id_column - id_space = input_schema.id_space - nonexistent_columns = set(expr.column_mapper) - set(input_schema) - if nonexistent_columns: - raise ValueError( - f"Nonexistent columns in rename query: {nonexistent_columns}" - ) - for old, new in expr.column_mapper.items(): - if new in input_schema and new != old: - raise ValueError( - f"Cannot rename '{old}' to '{new}': column '{new}' already exists" - ) - if old == grouping_column: - grouping_column = new - if old == id_column: - id_column = new - - return Schema( - { - expr.column_mapper.get(column, column): input_schema[column] - for column in input_schema - }, - grouping_column=grouping_column, - id_column=id_column, - id_space=id_space, - ) - - def visit_filter(self, expr: Filter) -> Schema: - """Returns the resulting schema from evaluating a Filter.""" - input_schema = expr.child.accept(self) - spark = SparkSession.builder.getOrCreate() - test_df = spark.createDataFrame( - [], schema=analytics_to_spark_schema(input_schema) - ) - try: - test_df.filter(expr.condition) - except Exception as e: - raise ValueError(f"Invalid filter condition '{expr.condition}': {e}") from e - return input_schema - - def visit_select(self, expr: Select) -> Schema: - """Returns the resulting schema from evaluating a Select.""" - input_schema = expr.child.accept(self) - - grouping_column = input_schema.grouping_column - id_column = input_schema.id_column - if grouping_column is not None and grouping_column not in expr.columns: - raise ValueError( - f"Grouping column '{grouping_column}' may not " - "be dropped by select query" - ) - if id_column is not None and id_column not in expr.columns: - raise ValueError( - f"ID column '{id_column}' may not be dropped by select query" - ) - - nonexistent_columns = set(expr.columns) - set(input_schema) - if nonexistent_columns: - raise ValueError( - f"Nonexistent columns in select query: {nonexistent_columns}" - ) - - return Schema( - {column: input_schema[column] for column in expr.columns}, - grouping_column=grouping_column, - id_column=id_column, - id_space=input_schema.id_space, - ) - - def visit_map(self, expr: Map) -> Schema: - """Returns the resulting schema from evaluating a Map.""" - input_schema = expr.child.accept(self) - new_columns = expr.schema_new_columns.column_descs - # Any column created by Map could contain a null value - for name in list(new_columns.keys()): - new_columns[name] = replace(new_columns[name], allow_null=True) - - if expr.augment: - overlapping_columns = set(input_schema.keys()) & set(new_columns.keys()) - if overlapping_columns: - raise ValueError( - "New columns in augmenting map must not overwrite " - "existing columns, but found new columns that " - f"already exist: {', '.join(overlapping_columns)}" - ) - return Schema( - {**input_schema, **new_columns}, - grouping_column=input_schema.grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - elif input_schema.grouping_column: - raise ValueError( - "Map must set augment=True to ensure that " - f"grouping column '{input_schema.grouping_column}' is not lost." - ) - elif input_schema.id_column: - raise ValueError( - "Map must set augment=True to ensure that " - f"ID column '{input_schema.id_column}' is not lost." - ) - return Schema( - new_columns, - grouping_column=expr.schema_new_columns.grouping_column, - id_column=expr.schema_new_columns.id_column, - id_space=expr.schema_new_columns.id_space, - ) - - def visit_flat_map(self, expr: FlatMap) -> Schema: - """Returns the resulting schema from evaluating a FlatMap.""" - input_schema = expr.child.accept(self) - if expr.schema_new_columns.grouping_column is not None: - if input_schema.grouping_column: - raise ValueError( - "Multiple grouping transformations are used in this query. " - "Only one grouping transformation is allowed." - ) - if input_schema.id_column: - raise ValueError( - "Grouping flat map cannot be used on tables with " - "the AddRowsWithID protected change." - ) - grouping_column = expr.schema_new_columns.grouping_column - else: - grouping_column = input_schema.grouping_column - - new_columns = expr.schema_new_columns.column_descs - # Any column created by the FlatMap could contain a null value - for name in list(new_columns.keys()): - new_columns[name] = replace(new_columns[name], allow_null=True) - if expr.augment: - overlapping_columns = set(input_schema.keys()) & set(new_columns.keys()) - if overlapping_columns: - raise ValueError( - "New columns in augmenting map must not overwrite " - "existing columns, but found new columns that " - f"already exist: {', '.join(overlapping_columns)}" - ) - return Schema( - {**input_schema, **new_columns}, - grouping_column=grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - elif input_schema.grouping_column: - raise ValueError( - "Flat map must set augment=True to ensure that " - f"grouping column '{input_schema.grouping_column}' is not lost." - ) - elif input_schema.id_column: - raise ValueError( - "Flat map must set augment=True to ensure that " - f"ID column '{input_schema.id_column}' is not lost." - ) - - return Schema( - new_columns, - grouping_column=grouping_column, - id_column=expr.schema_new_columns.id_column, - id_space=expr.schema_new_columns.id_space, - ) - - def visit_flat_map_by_id(self, expr: FlatMapByID) -> Schema: - """Returns the resulting schema from evaluating a FlatMapByID.""" - input_schema = expr.child.accept(self) - id_column = input_schema.id_column - new_columns = expr.schema_new_columns.column_descs - - if not id_column: - raise ValueError( - "Flat-map-by-ID may only be used on tables with ID columns." - ) - if input_schema.grouping_column: - raise AnalyticsInternalError( - "Encountered table with both an ID column and a grouping column." - ) - if id_column in new_columns: - raise ValueError( - "Flat-map-by-ID mapping function output cannot include ID column." - ) - - for name in list(new_columns.keys()): - new_columns[name] = replace(new_columns[name], allow_null=True) - return Schema( - {id_column: input_schema[id_column], **new_columns}, - grouping_column=None, - id_column=id_column, - id_space=input_schema.id_space, - ) - - def visit_join_private(self, expr: JoinPrivate) -> Schema: - """Returns the resulting schema from evaluating a JoinPrivate. - - The ordering of output columns are: - - 1. The join columns - 2. Columns that are only in the left table - 3. Columns that are only in the right table - 4. Columns that are in both tables, but not included in the join columns. These - columns are included with _left and _right suffixes. - """ - left_schema = expr.child.accept(self) - right_schema = expr.right_operand_expr.accept(self) - if left_schema.id_column != right_schema.id_column: - if left_schema.id_column is None or right_schema.id_column is None: - raise ValueError( - "Private joins can only be performed between two tables " - "with the same type of protected change" - ) - raise ValueError( - "Private joins between tables with the AddRowsWithID " - "protected change are only possible when the ID columns of " - "the two tables have the same name" - ) - if ( - left_schema.id_space - and right_schema.id_space - and left_schema.id_space != right_schema.id_space - ): - raise ValueError( - "Private joins between tables with the AddRowsWithID protected change" - " are only possible when both tables are in the same ID space" - ) - join_id_space: Optional[str] = None - if left_schema.id_space and right_schema.id_space: - join_id_space = left_schema.id_space - return _output_schema_for_join( - left_schema=left_schema, - right_schema=right_schema, - join_columns=expr.join_columns, - join_id_space=join_id_space, - ) - - def visit_join_public(self, expr: JoinPublic) -> Schema: - """Returns the resulting schema from evaluating a JoinPublic. - - Has analogous behavior to :meth:`OutputSchemaVisitor.visit_join_private`, - where the private table is the left table. - """ - input_schema = expr.child.accept(self) - if isinstance(expr.public_table, str): - public_table = self._catalog.tables[expr.public_table] - if not isinstance(public_table, PublicTable): - raise ValueError( - f"Attempted public join on table '{expr.public_table}', " - "which is not a public table" - ) - right_schema = public_table.schema - else: - right_schema = Schema( - spark_schema_to_analytics_columns(expr.public_table.schema) - ) - return _output_schema_for_join( - left_schema=input_schema, - right_schema=right_schema, - join_columns=expr.join_columns, - join_id_space=input_schema.id_space, - how=expr.how, - ) - - def visit_replace_null_and_nan(self, expr: ReplaceNullAndNan) -> Schema: - """Returns the resulting schema from evaluating a ReplaceNullAndNan.""" - input_schema = expr.child.accept(self) - if ( - input_schema.grouping_column - and input_schema.grouping_column in expr.replace_with - ): - raise ValueError( - "Cannot replace null values in column " - f"'{input_schema.grouping_column}', as it is a grouping column." - ) - if input_schema.id_column and input_schema.id_column in expr.replace_with: - raise ValueError( - f"Cannot replace null values in column '{input_schema.id_column}', " - "as it is an ID column." - ) - if input_schema.id_column and (len(expr.replace_with) == 0): - raise RuntimeWarning( - f"Replacing null values in the ID column '{input_schema.id_column}' " - "is not allowed, so the ID column may still contain null values." - ) - - if len(expr.replace_with) != 0: - pytypes = analytics_to_py_types(input_schema) - for col, val in expr.replace_with.items(): - if col not in input_schema.keys(): - raise ValueError( - f"Column '{col}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) - if not isinstance(val, pytypes[col]): - # it's okay to use an int as a float - # so don't raise an error in that case - if not (isinstance(val, int) and pytypes[col] == float): - raise ValueError( - f"Column '{col}' cannot have nulls replaced with " - f"{repr(val)}, as that value's type does not match the " - f"column type {input_schema[col].column_type.name}" - ) - - columns_to_change = list(dict(expr.replace_with).keys()) - if len(columns_to_change) == 0: - columns_to_change = [ - col - for col in input_schema.column_descs.keys() - if (input_schema[col].allow_null or input_schema[col].allow_nan) - and not (col in [input_schema.grouping_column, input_schema.id_column]) - ] - return Schema( - { - name: ColumnDescriptor( - column_type=cd.column_type, - allow_null=(cd.allow_null and not name in columns_to_change), - allow_nan=(cd.allow_nan and not name in columns_to_change), - allow_inf=cd.allow_inf, - ) - for name, cd in input_schema.column_descs.items() - }, - grouping_column=input_schema.grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - - def visit_replace_infinity(self, expr: ReplaceInfinity) -> Schema: - """Returns the resulting schema from evaluating a ReplaceInfinity.""" - input_schema = expr.child.accept(self) - - if ( - input_schema.grouping_column - and input_schema.grouping_column in expr.replace_with - ): - raise ValueError( - "Cannot replace infinite values in column " - f"'{input_schema.grouping_column}', as it is a grouping column" - ) - # Float-valued columns cannot be ID columns, but include this to be safe. - if input_schema.id_column and input_schema.id_column in expr.replace_with: - raise ValueError( - f"Cannot replace infinite values in column '{input_schema.id_column}', " - "as it is an ID column" - ) - - columns_to_change = list(expr.replace_with.keys()) - if len(columns_to_change) == 0: - columns_to_change = [ - col - for col in input_schema.column_descs.keys() - if input_schema[col].column_type == ColumnType.DECIMAL - ] - else: - for name in expr.replace_with: - if name not in input_schema.keys(): - raise ValueError( - f"Column '{name}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) - if input_schema[name].column_type != ColumnType.DECIMAL: - raise ValueError( - f"Column '{name}' has a replacement value provided, but it is " - f"of type {input_schema[name].column_type.name} (not " - f"{ColumnType.DECIMAL.name}) and so cannot " - "contain infinite values" - ) - return Schema( - { - name: ColumnDescriptor( - column_type=cd.column_type, - allow_null=cd.allow_null, - allow_nan=cd.allow_nan, - allow_inf=(cd.allow_inf and not name in columns_to_change), - ) - for name, cd in input_schema.column_descs.items() - }, - grouping_column=input_schema.grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - - def visit_drop_null_and_nan(self, expr: DropNullAndNan) -> Schema: - """Returns the resulting schema from evaluating a DropNullAndNan.""" - input_schema = expr.child.accept(self) - if ( - input_schema.grouping_column - and input_schema.grouping_column in expr.columns - ): - raise ValueError( - f"Cannot drop null values in column '{input_schema.grouping_column}', " - "as it is a grouping column" - ) - if input_schema.id_column and input_schema.id_column in expr.columns: - raise ValueError( - f"Cannot drop null values in column '{input_schema.id_column}', " - "as it is an ID column." - ) - if input_schema.id_column and len(expr.columns) == 0: - raise RuntimeWarning( - f"Replacing null values in the ID column '{input_schema.id_column}' " - "is not allowed, so the ID column may still contain null values." - ) - columns = expr.columns - if len(columns) == 0: - columns = tuple( - name - for name, cd in input_schema.column_descs.items() - if (cd.allow_null or cd.allow_nan) - and not name in [input_schema.grouping_column, input_schema.id_column] - ) - else: - for name in columns: - if name not in input_schema.keys(): - raise ValueError( - f"Column '{name}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) - return Schema( - { - name: ColumnDescriptor( - column_type=cd.column_type, - allow_null=(cd.allow_null and not name in columns), - allow_nan=(cd.allow_nan and not name in columns), - allow_inf=(cd.allow_inf), - ) - for name, cd in input_schema.column_descs.items() - }, - grouping_column=input_schema.grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - - def visit_drop_infinity(self, expr: DropInfinity) -> Schema: - """Returns the resulting schema from evaluating a DropInfinity.""" - input_schema = expr.child.accept(self) - - if ( - input_schema.grouping_column - and input_schema.grouping_column in expr.columns - ): - raise ValueError( - "Cannot drop infinite values in column " - f"'{input_schema.grouping_column}', as it is a grouping column" - ) - # Float-valued columns cannot be ID columns, but include this to be safe. - if input_schema.id_column and input_schema.id_column in expr.columns: - raise ValueError( - f"Cannot drop infinite values in column '{input_schema.id_column}', " - "as it is an ID column" - ) - - columns = expr.columns - if len(columns) == 0: - columns = tuple( - name - for name, cd in input_schema.column_descs.items() - if (cd.allow_inf) and not name == input_schema.grouping_column - ) - else: - for name in columns: - if name not in input_schema.keys(): - raise ValueError( - f"Column '{name}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) - if input_schema[name].column_type != ColumnType.DECIMAL: - raise ValueError( - f"Column '{name}' was given as a column to drop " - "infinite values from, but it is of type" - f"{input_schema[name].column_type.name} (not " - f"{ColumnType.DECIMAL.name}) and so cannot " - "contain infinite values" - ) - - return Schema( - { - name: ColumnDescriptor( - column_type=cd.column_type, - allow_null=cd.allow_null, - allow_nan=cd.allow_nan, - allow_inf=(cd.allow_inf and not name in columns), - ) - for name, cd in input_schema.column_descs.items() - }, - grouping_column=input_schema.grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - - def visit_enforce_constraint(self, expr: EnforceConstraint) -> Schema: - """Returns the resulting schema from evaluating an EnforceConstraint.""" - input_schema = expr.child.accept(self) - constraint = expr.constraint - - if not input_schema.id_column: - raise ValueError( - f"Constraint {expr.constraint} can only be applied to tables" - " with the AddRowsWithID protected change" - ) - if isinstance(constraint, (MaxGroupsPerID, MaxRowsPerGroupPerID)): - grouping_column = constraint.grouping_column - if grouping_column not in input_schema: - raise ValueError( - f"The grouping column of constraint {constraint}" - " does not exist in this table; available columns" - f" are: {', '.join(input_schema.keys())}" - ) - if grouping_column == input_schema.id_column: - raise ValueError( - f"The grouping column of constraint {constraint} cannot be" - " the ID column of the table it is applied to" - ) - - # No current constraints modify the schema. If that changes in the - # future, the logic for it may have to be pushed into the Constraint - # type (like how constraint._enforce() works), but for now this works. - return input_schema - - def visit_get_groups(self, expr: GetGroups) -> Schema: - """Returns the resulting schema from GetGroups.""" - input_schema = expr.child.accept(self) - - if expr.columns: - nonexistent_columns = set(expr.columns) - set(input_schema) - if nonexistent_columns: - raise ValueError( - f"Nonexistent columns in get_groups query: {nonexistent_columns}" - ) - input_schema = Schema( - {column: input_schema[column] for column in expr.columns} - ) - - else: - input_schema = Schema( - { - column: input_schema[column] - for column in input_schema - if column != input_schema.id_column - } - ) - - return input_schema - - def visit_get_bounds(self, expr: GetBounds) -> Schema: - """Returns the resulting schema from GetBounds.""" - input_schema = expr.child.accept(self) - - if expr.measure_column not in set(input_schema): - raise ValueError( - f"Cannot get bounds for column '{expr.measure_column}', which " - "does not exist" - ) - - column = input_schema[expr.measure_column] - if column.column_type not in [ - ColumnType.INTEGER, - ColumnType.DECIMAL, - ]: - raise ValueError( - f"Cannot get bounds for column '{expr.measure_column}'," - f" which is of type {column.column_type.name}; only columns of" - f" numerical type are supported." - ) - - # Check if we're trying to get the bounds of the ID column. - if input_schema.id_column and (input_schema.id_column == expr.measure_column): - raise ValueError( - "get_bounds cannot be used on the privacy ID column" - f" ({input_schema.id_column}) of a table with the AddRowsWithID" - " protected change." - ) - return _validate_groupby(expr, self) - - def visit_groupby_count(self, expr: GroupByCount) -> Schema: - """Returns the resulting schema from evaluating a GroupByCount.""" - return _validate_groupby(expr, self) - - def visit_groupby_count_distinct(self, expr: GroupByCountDistinct) -> Schema: - """Returns the resulting schema from evaluating a GroupByCountDistinct.""" - return _validate_groupby(expr, self) - - def visit_groupby_quantile(self, expr: GroupByQuantile) -> Schema: - """Returns the resulting schema from evaluating a GroupByQuantile.""" - return _validate_groupby(expr, self) - - def visit_groupby_bounded_sum(self, expr: GroupByBoundedSum) -> Schema: - """Returns the resulting schema from evaluating a GroupByBoundedSum.""" - return _validate_groupby(expr, self) - - def visit_groupby_bounded_average(self, expr: GroupByBoundedAverage) -> Schema: - """Returns the resulting schema from evaluating a GroupByBoundedAverage.""" - return _validate_groupby(expr, self) - - def visit_groupby_bounded_variance(self, expr: GroupByBoundedVariance) -> Schema: - """Returns the resulting schema from evaluating a GroupByBoundedVariance.""" - return _validate_groupby(expr, self) - - def visit_groupby_bounded_stdev(self, expr: GroupByBoundedSTDEV) -> Schema: - """Returns the resulting schema from evaluating a GroupByBoundedSTDEV.""" - return _validate_groupby(expr, self) - - def visit_suppress_aggregates(self, expr: SuppressAggregates) -> Schema: - """Returns the resulting schema from evaluating a SuppressAggregates.""" - return expr.child.accept(self) diff --git a/test/unit/query_expr_compiler/test_measurement_visitor.py b/test/unit/query_expr_compiler/test_measurement_visitor.py index 321271fc..8b51779a 100644 --- a/test/unit/query_expr_compiler/test_measurement_visitor.py +++ b/test/unit/query_expr_compiler/test_measurement_visitor.py @@ -77,9 +77,6 @@ _get_query_bounds, ) from tmlt.analytics._query_expr_compiler._measurement_visitor import MeasurementVisitor -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._schema import ( ColumnDescriptor, ColumnType, @@ -329,9 +326,7 @@ def run_with_empty_data_and_check_schema( self, query: QueryExpr, output_measure: Union[PureDP, RhoZCDP] ): """Run a query and check the schema of the result.""" - expected_column_types = query.accept( - OutputSchemaVisitor(self.catalog) - ).column_types + expected_column_types = query.schema(self.catalog).column_types self.visitor.output_measure = output_measure measurement, _ = query.accept(self.visitor) empty_data = create_empty_input(measurement.input_domain) diff --git a/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py b/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py index 796043cc..09de96a4 100644 --- a/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py +++ b/test/unit/query_expr_compiler/transformation_visitor/test_add_keys.py @@ -35,9 +35,6 @@ ReplaceNullAndNan, Select, ) -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._query_expr_compiler._transformation_visitor import ( TransformationVisitor, ) @@ -80,7 +77,7 @@ def _validate_transform_basics( first_transform = chain_to_list(transformation)[0] assert isinstance(first_transform, IdentityTransformation) - expected_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_schema = query.schema(self.catalog) assert expected_schema.grouping_column == grouping_column expected_output_domain = SparkDataFrameDomain( @@ -485,11 +482,11 @@ def test_visit_replace_null_and_nan( self._validate_result(transformation, reference, expected_df) assert constraints == [] - expected_output_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_output_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( schema=analytics_to_spark_columns_descriptor(expected_output_schema) ) - expected_output_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_output_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( schema=analytics_to_spark_columns_descriptor(expected_output_schema) ) @@ -528,7 +525,7 @@ def test_visit_replace_infinity( self._validate_result(transformation, reference, expected_df) assert constraints == [] - expected_output_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_output_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( schema=analytics_to_spark_columns_descriptor(expected_output_schema) ) @@ -556,7 +553,7 @@ def _validate_transform_basics( first_transform = chain_to_list(transformation)[0] assert isinstance(first_transform, IdentityTransformation) - expected_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_schema = query.schema(self.catalog) assert expected_schema.grouping_column == "id" expected_output_domain = SparkDataFrameDomain( diff --git a/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py b/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py index 31efba60..3d4fcc92 100644 --- a/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py +++ b/test/unit/query_expr_compiler/transformation_visitor/test_add_rows.py @@ -44,9 +44,6 @@ ReplaceNullAndNan, Select, ) -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._query_expr_compiler._transformation_visitor import ( TransformationVisitor, ) @@ -86,7 +83,7 @@ def _validate_transform_basics( first_transform = chain_to_list(t)[0] assert isinstance(first_transform, IdentityTransformation) - expected_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( analytics_to_spark_columns_descriptor(expected_schema) ) @@ -396,7 +393,7 @@ def test_visit_join_private( assert transformation.input_domain == self.visitor.input_domain assert transformation.input_metric == self.visitor.input_metric - expected_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( analytics_to_spark_columns_descriptor(expected_schema) ) @@ -585,11 +582,11 @@ def test_visit_replace_null_and_nan( self._validate_transform_basics(transformation, reference, query) assert isinstance(transformation, ChainTT) assert isinstance(transformation.transformation2, AugmentDictTransformation) - expected_output_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_output_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( schema=analytics_to_spark_columns_descriptor(expected_output_schema) ) - expected_output_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_output_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( schema=analytics_to_spark_columns_descriptor(expected_output_schema) ) @@ -672,7 +669,7 @@ def test_visit_replace_infinity( transformation, reference, constraints = query.accept(self.visitor) self._validate_transform_basics(transformation, reference, query) - expected_output_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_output_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( schema=analytics_to_spark_columns_descriptor(expected_output_schema) ) @@ -847,7 +844,7 @@ def _validate_transform_basics( assert t.input_domain == self.visitor.input_domain assert t.input_metric == self.visitor.input_metric - expected_schema = query.accept(OutputSchemaVisitor(self.catalog)) + expected_schema = query.schema(self.catalog) expected_output_domain = SparkDataFrameDomain( analytics_to_spark_columns_descriptor(expected_schema) ) diff --git a/test/unit/query_expr_compiler/test_output_schema_visitor.py b/test/unit/test_query_expression_schema.py similarity index 95% rename from test/unit/query_expr_compiler/test_output_schema_visitor.py rename to test/unit/test_query_expression_schema.py index 80a47200..33064e9a 100644 --- a/test/unit/query_expr_compiler/test_output_schema_visitor.py +++ b/test/unit/test_query_expression_schema.py @@ -1,4 +1,4 @@ -"""Tests for OutputSchemaVisitor.""" +"""Tests for QueryExpression schema determination.""" # SPDX-License-Identifier: Apache-2.0 # Copyright Tumult Labs 2025 @@ -37,9 +37,6 @@ Select, SuppressAggregates, ) -from tmlt.analytics._query_expr_compiler._output_schema_visitor import ( - OutputSchemaVisitor, -) from tmlt.analytics._schema import ( ColumnDescriptor, ColumnType, @@ -325,7 +322,7 @@ ###TESTS FOR QUERY VALIDATION### -@pytest.fixture(name="validation_visitor", scope="class") +@pytest.fixture(name="validation_catalog", scope="class") def setup_validation(request): """Set up test data.""" catalog = Catalog() @@ -361,15 +358,14 @@ def setup_validation(request): catalog.add_private_table( "groupby_one_column_private", {"A": ColumnDescriptor(ColumnType.VARCHAR)} ) - visitor = OutputSchemaVisitor(catalog) - request.cls.visitor = visitor + request.cls.catalog = catalog -@pytest.mark.usefixtures("validation_visitor") +@pytest.mark.usefixtures("validation_catalog") class TestValidation: - """Test Validation with Visitor.""" + """Test Validation with Catalog.""" - visitor: OutputSchemaVisitor + catalog: Catalog @pytest.mark.parametrize( "query_expr,expected_error_msg", OUTPUT_SCHEMA_INVALID_QUERY_TESTS @@ -379,7 +375,7 @@ def test_invalid_query_expr( ) -> None: """Check that appropriate exceptions are raised on invalid queries.""" with pytest.raises(ValueError, match=expected_error_msg): - query_expr.accept(self.visitor) + query_expr.schema(self.catalog) @pytest.mark.parametrize( "groupby_keys,exception_type,expected_error_msg", @@ -428,7 +424,7 @@ def test_invalid_group_by_count( ) -> None: """Test invalid measurement QueryExpr.""" with pytest.raises(exception_type, match=expected_error_msg): - GroupByCount(PrivateSource("private"), groupby_keys).accept(self.visitor) + GroupByCount(PrivateSource("private"), groupby_keys).schema(self.catalog) @pytest.mark.parametrize( "groupby_keys,exception_type,expected_error_msg", @@ -485,13 +481,13 @@ def test_invalid_group_by_aggregations( GroupByBoundedVariance, ]: with pytest.raises(exception_type, match=expected_error_msg): - DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).accept( - self.visitor + DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).schema( + self.catalog ) with pytest.raises(exception_type, match=expected_error_msg): GroupByQuantile( PrivateSource("private"), groupby_keys, "B", 0.5, 1.0, 5.0 - ).accept(self.visitor) + ).schema(self.catalog) with pytest.raises(exception_type, match=expected_error_msg): GetBounds( PrivateSource("private"), @@ -499,12 +495,12 @@ def test_invalid_group_by_aggregations( "B", "lower_bound", "upper_bound", - ).accept(self.visitor) + ).schema(self.catalog) ###QUERY VALIDATION WITH NULLS### @pytest.fixture(name="test_data_nulls", scope="class") -def setup_visitor_with_nulls(request) -> None: +def setup_catalog_with_nulls(request) -> None: """Set up test data.""" catalog = Catalog() catalog.add_private_table( @@ -543,15 +539,14 @@ def setup_visitor_with_nulls(request) -> None: "groupby_one_column_private", {"A": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True)}, ) - visitor = OutputSchemaVisitor(catalog) - request.cls.visitor = visitor + request.cls.catalog = catalog @pytest.mark.usefixtures("test_data_nulls") class TestValidationWithNulls: """Test Validation with Nulls.""" - visitor: OutputSchemaVisitor + catalog: Catalog @pytest.mark.parametrize( "query_expr,expected_error_msg", OUTPUT_SCHEMA_INVALID_QUERY_TESTS @@ -561,7 +556,7 @@ def test_invalid_query_expr_null( ) -> None: """Check that appropriate exceptions are raised on invalid queries.""" with pytest.raises(ValueError, match=expected_error_msg): - query_expr.accept(self.visitor) + query_expr.schema(self.catalog) @pytest.mark.parametrize( "groupby_keys,exception_type,expected_error_msg", @@ -602,7 +597,7 @@ def test_invalid_group_by_count_null( ) -> None: """Test invalid measurement QueryExpr.""" with pytest.raises(exception_type, match=expected_error_msg): - GroupByCount(PrivateSource("private"), groupby_keys).accept(self.visitor) + GroupByCount(PrivateSource("private"), groupby_keys).schema(self.catalog) @pytest.mark.parametrize( "groupby_keys,exception_type,expected_error_msg", @@ -659,13 +654,11 @@ def test_invalid_group_by_aggregations_null( GroupByBoundedVariance, ]: with pytest.raises(exception_type, match=expected_error_msg): - DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).accept( - self.visitor - ) + DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).schema( self.catalog) with pytest.raises(exception_type, match=expected_error_msg): GroupByQuantile( PrivateSource("private"), groupby_keys, "B", 0.5, 1.0, 5.0 - ).accept(self.visitor) + ).schema(self.catalog) with pytest.raises(exception_type, match=expected_error_msg): GetBounds( PrivateSource("private"), @@ -673,15 +666,15 @@ def test_invalid_group_by_aggregations_null( "B", "lower_bound", "upper_bound", - ).accept(self.visitor) + ).schema(self.catalog) - def test_visit_private_source(self) -> None: + def test_schema_private_source(self) -> None: """Test visit_private_source.""" query = PrivateSource("private") - schema = self.visitor.visit_private_source(query) + schema = query.schema(self.catalog) assert ( schema - == self.visitor._catalog.tables[ # pylint: disable=protected-access + == self.catalog.tables[ # pylint: disable=protected-access "private" ].schema ) @@ -753,7 +746,7 @@ def test_visit_private_source(self) -> None: ), ], ) - def test_visit_rename( + def test_schema_rename( self, column_mapper: Dict[str, str], expected_schema: Schema ) -> None: """Test visit_rename.""" @@ -761,17 +754,17 @@ def test_visit_rename( child=PrivateSource("private"), column_mapper=FrozenDict.from_dict(column_mapper), ) - schema = self.visitor.visit_rename(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize("condition", ["B > X", "X < 500", "NOTNULL < 30"]) - def test_visit_filter(self, condition: str) -> None: + def test_schema_filter(self, condition: str) -> None: """Test visit_filter.""" query = Filter(child=PrivateSource("private"), condition=condition) - schema = self.visitor.visit_filter(query) + schema = query.schema(self.catalog) assert ( schema - == self.visitor._catalog.tables[ # pylint: disable=protected-access + == self.catalog.tables[ # pylint: disable=protected-access "private" ].schema ) @@ -796,10 +789,10 @@ def test_visit_filter(self, condition: str) -> None: ), ], ) - def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None: + def test_schema_select(self, columns: List[str], expected_schema: Schema) -> None: """Test visit_select.""" query = Select(child=PrivateSource("private"), columns=tuple(columns)) - schema = self.visitor.visit_select(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -936,9 +929,9 @@ def test_visit_select(self, columns: List[str], expected_schema: Schema) -> None ), ], ) - def test_visit_map(self, query: Map, expected_schema: Schema) -> None: + def test_schema_map(self, query: Map, expected_schema: Schema) -> None: """Test visit_map.""" - schema = self.visitor.visit_map(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -987,9 +980,9 @@ def test_visit_map(self, query: Map, expected_schema: Schema) -> None: ), ], ) - def test_visit_flat_map(self, query: FlatMap, expected_schema: Schema) -> None: + def test_schema_flat_map(self, query: FlatMap, expected_schema: Schema) -> None: """Test visit_flat_map.""" - schema = self.visitor.visit_flat_map(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -1022,11 +1015,11 @@ def test_visit_flat_map(self, query: FlatMap, expected_schema: Schema) -> None: ) ], ) - def test_visit_join_private( + def test_schema_join_private( self, query: JoinPrivate, expected_schema: Schema ) -> None: """Test visit_join_private.""" - schema = self.visitor.visit_join_private(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -1081,11 +1074,11 @@ def test_visit_join_private( ), ], ) - def test_visit_join_public( + def test_schema_join_public( self, query: JoinPublic, expected_schema: Schema ) -> None: """Test visit_join_public.""" - schema = self.visitor.visit_join_public(query) + schema = query.schema(self.catalog) assert schema == expected_schema @parametrize( @@ -1134,19 +1127,18 @@ def test_visit_join_public( ), ), ) - def test_visit_join_private_nulls(self, left_schema, right_schema, expected_schema): - """Test that OutputSchemaVisitor correctly propagates nulls through a join.""" + def test_schema_join_private_nulls(self, left_schema, right_schema, expected_schema): + """Test that schema correctly propagates nulls through a join.""" catalog = Catalog() catalog.add_private_table("left", left_schema) catalog.add_private_table("right", right_schema) - visitor = OutputSchemaVisitor(catalog) query = JoinPrivate( child=PrivateSource("left"), right_operand_expr=PrivateSource("right"), truncation_strategy_left=TruncationStrategy.DropExcess(1), truncation_strategy_right=TruncationStrategy.DropExcess(1), ) - result_schema = visitor.visit_join_private(query) + result_schema = query.schema(catalog) assert result_schema == expected_schema @parametrize( @@ -1195,16 +1187,15 @@ def test_visit_join_private_nulls(self, left_schema, right_schema, expected_sche ), ), ) - def test_visit_join_public_nulls( + def test_schema_join_public_nulls( self, private_schema, public_schema, expected_schema ): - """Test that OutputSchemaVisitor correctly propagates nulls through a join.""" + """Test that schema correctly propagates nulls through a join.""" catalog = Catalog() catalog.add_private_table("private", private_schema) catalog.add_public_table("public", public_schema) - visitor = OutputSchemaVisitor(catalog) query = JoinPublic(child=PrivateSource("private"), public_table="public") - result_schema = visitor.visit_join_public(query) + result_schema = query.schema(catalog) assert result_schema == expected_schema @pytest.mark.parametrize( @@ -1289,11 +1280,11 @@ def test_visit_join_public_nulls( ), ], ) - def test_visit_replace_null_and_nan( + def test_schema_replace_null_and_nan( self, query: ReplaceNullAndNan, expected_schema: Schema ) -> None: """Test visit_replace_null_and_nan.""" - schema = self.visitor.visit_replace_null_and_nan(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -1376,11 +1367,11 @@ def test_visit_replace_null_and_nan( ), ], ) - def test_visit_drop_null_and_nan( + def test_schema_drop_null_and_nan( self, query: DropNullAndNan, expected_schema: Schema ) -> None: """Test visit_drop_null_and_nan.""" - schema = self.visitor.visit_drop_null_and_nan(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -1461,11 +1452,11 @@ def test_visit_drop_null_and_nan( ), ], ) - def test_visit_drop_infinity( + def test_schema_drop_infinity( self, query: DropInfinity, expected_schema: Schema ) -> None: """Test visit_drop_infinity.""" - schema = self.visitor.visit_drop_infinity(query) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -1623,14 +1614,14 @@ def test_visit_drop_infinity( ), ], ) - def test_visit_groupby_queries( + def test_schema_groupby_queries( self, query: QueryExpr, expected_schema: Schema ) -> None: """Test visit_groupby_*.""" - schema = query.accept(self.visitor) + schema = query.schema(self.catalog) assert schema == expected_schema - def test_visit_groupby_get_bounds_partition_selection(self) -> None: + def test_schema_groupby_get_bounds_partition_selection(self) -> None: """Test visit_get_bounds with auto partition selection enabled.""" expected_schema = Schema( { @@ -1647,7 +1638,7 @@ def test_visit_groupby_get_bounds_partition_selection(self) -> None: lower_bound_column="lower_bound", upper_bound_column="upper_bound", ) - schema = query.accept(self.visitor) + schema = query.schema(self.catalog) assert schema == expected_schema @pytest.mark.parametrize( @@ -1678,8 +1669,8 @@ def test_visit_groupby_get_bounds_partition_selection(self) -> None: ), ], ) - def test_visit_suppress_aggregates(self, query: SuppressAggregates) -> None: + def test_schema_suppress_aggregates(self, query: SuppressAggregates) -> None: """Test visit_suppress_aggregates.""" - expected_schema = query.child.accept(self.visitor) - got_schema = query.accept(self.visitor) + expected_schema = query.child.schema(self.catalog) + got_schema = query.schema(self.catalog) assert expected_schema == got_schema From 2aee15ed18297f8c39aae86d1c510c1c9a7a0c5c Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Mon, 13 Oct 2025 17:12:39 +0200 Subject: [PATCH 2/8] make linters happy --- src/tmlt/analytics/_query_expr.py | 110 +++++++++++------- .../_query_expr_compiler/_compiler.py | 2 +- test/unit/test_query_expression_schema.py | 22 ++-- 3 files changed, 75 insertions(+), 59 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 0a64e0e4..eea49e00 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -182,6 +182,11 @@ class QueryExpr(ABC): returns a relation. """ + @abstractmethod + def schema(self, catalog: Catalog) -> Any: + """Returns the schema resulting from evaluating this QueryExpr.""" + raise NotImplementedError() + @abstractmethod def accept(self, visitor: "QueryExprVisitor") -> Any: """Dispatch methods on a visitor based on the QueryExpr type.""" @@ -207,7 +212,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" if self.source_id not in catalog.tables: raise ValueError(f"Query references nonexistent table '{self.source_id}'") table = catalog.tables[self.source_id] @@ -243,7 +248,7 @@ def __post_init__(self): check_type(self.columns, Optional[Tuple[str, ...]]) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) if self.columns: nonexistent_columns = set(self.columns) - set(input_schema) @@ -296,7 +301,7 @@ def __post_init__(self): check_type(self.upper_bound_column, str) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -332,7 +337,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) grouping_column = input_schema.grouping_column id_column = input_schema.id_column @@ -386,7 +391,7 @@ def __post_init__(self): check_type(self.condition, str) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) spark = SparkSession.builder.getOrCreate() test_df = spark.createDataFrame( @@ -420,7 +425,7 @@ def __post_init__(self): raise ValueError(f"Column name appears more than once in {self.columns}") def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) grouping_column = input_schema.grouping_column id_column = input_schema.id_column @@ -479,7 +484,7 @@ def __post_init__(self): raise ValueError("Map cannot be be used to create grouping columns") def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) new_columns = self.schema_new_columns.column_descs # Any column created by Map could contain a null value @@ -517,7 +522,6 @@ def schema(self, catalog: Catalog) -> Schema: id_space=self.schema_new_columns.id_space, ) - def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_map(self) @@ -585,7 +589,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) if self.schema_new_columns.grouping_column is not None: if input_schema.grouping_column: @@ -638,7 +642,6 @@ def schema(self, catalog: Catalog) -> Schema: id_space=self.schema_new_columns.id_space, ) - def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_flat_map(self) @@ -690,7 +693,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_flat_map_by_id(self) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this Queryself.""" + """Returns the schema resulting from evaluating this Queryself.""" input_schema = self.child.schema(catalog) id_column = input_schema.id_column new_columns = self.schema_new_columns.column_descs @@ -740,7 +743,7 @@ def _schema_for_join( join_id_space: Optional[str] = None, how: str = "inner", ) -> Schema: - """Return the resulting schema from joining two tables. + """Return the schema resulting from joining two tables. It is assumed that if either schema has an ID column, the one from left_schema should be used. This is because the appropriate behavior here @@ -833,6 +836,7 @@ def _schema_for_join( id_space=join_id_space, ) + @dataclass(frozen=True) class JoinPrivate(QueryExpr): """Returns the join of two private tables. @@ -875,7 +879,7 @@ def __post_init__(self): raise ValueError("Join columns must be distinct") def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr. + """Returns the schema resulting from evaluating this QueryExpr. The ordering of output columns are: @@ -883,7 +887,8 @@ def schema(self, catalog: Catalog) -> Schema: 2. Columns that are only in the left table 3. Columns that are only in the right table 4. Columns that are in both tables, but not included in the join columns. These - columns are included with _left and _right suffixes.""" + columns are included with _left and _right suffixes. + """ left_schema = self.child.schema(catalog) right_schema = self.right_operand_expr.schema(catalog) if left_schema.id_column != right_schema.id_column: @@ -958,10 +963,11 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr. + """Returns the schema resulting from evaluating this QueryExpr. Has analogous behavior to :meth:`JoinPrivate.schema`, where the private - table is the left table.""" + table is the left table. + """ input_schema = self.child.schema(catalog) if isinstance(self.public_table, str): public_table = catalog.tables[self.public_table] @@ -1087,7 +1093,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) if ( input_schema.grouping_column @@ -1189,7 +1195,7 @@ def __init__( object.__setattr__(self, "child", child) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this Queryself.""" + """Returns the schema resulting from evaluating this Queryself.""" input_schema = self.child.schema(catalog) if ( @@ -1274,7 +1280,7 @@ def __post_init__(self) -> None: check_type(self.columns, Tuple[str, ...]) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) if ( input_schema.grouping_column @@ -1348,7 +1354,7 @@ def __post_init__(self) -> None: check_type(self.columns, Tuple[str, ...]) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this Queryself.""" + """Returns the schema resulting from evaluating this Queryself.""" input_schema = self.child.schema(catalog) if ( @@ -1424,11 +1430,7 @@ class EnforceConstraint(QueryExpr): to support advanced use cases, and generally should not be used.""" def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) - - def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) if not input_schema.id_column: @@ -1455,8 +1457,10 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_enforce_constraint(self) + def _schema_for_groupby( query: Union[ + "GetBounds", "GroupByBoundedAverage", "GroupByBoundedSTDEV", "GroupByBoundedSum", @@ -1526,9 +1530,17 @@ def _schema_for_groupby( ) # Validating the measure column - if isinstance(query, (GetBounds, GroupByQuantile, GroupByBoundedSum, - GroupByBoundedSTDEV, GroupByBoundedAverage, - GroupByBoundedVariance)): + if isinstance( + query, + ( + GetBounds, + GroupByQuantile, + GroupByBoundedSum, + GroupByBoundedSTDEV, + GroupByBoundedAverage, + GroupByBoundedVariance, + ), + ): if query.measure_column not in input_schema: raise ValueError( f"{type(query).__name__} query's measure column " @@ -1556,16 +1568,27 @@ def _schema_for_groupby( output_column_type = ColumnType.INTEGER elif isinstance(query, (GetBounds, GroupByBoundedSum)): output_column_type = input_schema[query.measure_column].column_type - elif isinstance(query, (GroupByQuantile, GroupByBoundedSum, - GroupByBoundedSTDEV, GroupByBoundedAverage, - GroupByBoundedVariance)): + elif isinstance( + query, + ( + GroupByQuantile, + GroupByBoundedSum, + GroupByBoundedSTDEV, + GroupByBoundedAverage, + GroupByBoundedVariance, + ), + ): output_column_type = ColumnType.DECIMAL else: raise AnalyticsInternalError("Unexpected QueryExpr type: {type(query)}.") if isinstance(query, GetBounds): output_columns = { - query.lower_bound_column: ColumnDescriptor(output_column_type, allow_null=False), - query.upper_bound_column: ColumnDescriptor(output_column_type, allow_null=False), + query.lower_bound_column: ColumnDescriptor( + output_column_type, allow_null=False + ), + query.upper_bound_column: ColumnDescriptor( + output_column_type, allow_null=False + ), } else: output_columns = { @@ -1575,12 +1598,13 @@ def _schema_for_groupby( return Schema( { **{column: input_schema[column] for column in groupby_columns}, - **output_columns + **output_columns, }, grouping_column=None, id_column=None, ) + @dataclass(frozen=True) class GroupByCount(QueryExpr): """Returns the count of each combination of the groupby domains.""" @@ -1608,7 +1632,7 @@ def __post_init__(self): check_type(self.mechanism, CountMechanism) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1648,7 +1672,7 @@ def __post_init__(self): check_type(self.mechanism, CountDistinctMechanism) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1713,7 +1737,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1778,7 +1802,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1843,7 +1867,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1908,7 +1932,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1974,7 +1998,7 @@ def __post_init__(self): ) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return _schema_for_groupby(self, catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -2010,7 +2034,7 @@ def __post_init__(self) -> None: check_type(self.threshold, float) def schema(self, catalog: Catalog) -> Schema: - """Returns the resulting schema from evaluating this QueryExpr.""" + """Returns the schema resulting from evaluating this QueryExpr.""" return self.child.schema(catalog) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -2018,7 +2042,6 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_suppress_aggregates(self) - class QueryExprVisitor(ABC): """A base class for implementing visitors for :class:`QueryExpr`.""" @@ -2141,4 +2164,3 @@ def visit_groupby_bounded_stdev(self, expr: GroupByBoundedSTDEV) -> Any: def visit_suppress_aggregates(self, expr: SuppressAggregates) -> Any: """Visit a :class:`SuppressAggregates`.""" raise NotImplementedError - diff --git a/src/tmlt/analytics/_query_expr_compiler/_compiler.py b/src/tmlt/analytics/_query_expr_compiler/_compiler.py index c1f8077c..b76281df 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_compiler.py +++ b/src/tmlt/analytics/_query_expr_compiler/_compiler.py @@ -109,7 +109,7 @@ def query_schema(query: QueryExpr, catalog: Catalog) -> Schema: if not isinstance(schema, Schema): raise AnalyticsInternalError( "Schema for this query is not a Schema but is instead a(n) " - f"{type(result)}." + f"{type(schema)}." ) return schema diff --git a/test/unit/test_query_expression_schema.py b/test/unit/test_query_expression_schema.py index 33064e9a..c1071bd0 100644 --- a/test/unit/test_query_expression_schema.py +++ b/test/unit/test_query_expression_schema.py @@ -654,7 +654,9 @@ def test_invalid_group_by_aggregations_null( GroupByBoundedVariance, ]: with pytest.raises(exception_type, match=expected_error_msg): - DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).schema( self.catalog) + DataClass(PrivateSource("private"), groupby_keys, "B", 1.0, 5.0).schema( + self.catalog + ) with pytest.raises(exception_type, match=expected_error_msg): GroupByQuantile( PrivateSource("private"), groupby_keys, "B", 0.5, 1.0, 5.0 @@ -672,12 +674,7 @@ def test_schema_private_source(self) -> None: """Test visit_private_source.""" query = PrivateSource("private") schema = query.schema(self.catalog) - assert ( - schema - == self.catalog.tables[ # pylint: disable=protected-access - "private" - ].schema - ) + assert schema == self.catalog.tables["private"].schema @pytest.mark.parametrize( "column_mapper,expected_schema", @@ -762,12 +759,7 @@ def test_schema_filter(self, condition: str) -> None: """Test visit_filter.""" query = Filter(child=PrivateSource("private"), condition=condition) schema = query.schema(self.catalog) - assert ( - schema - == self.catalog.tables[ # pylint: disable=protected-access - "private" - ].schema - ) + assert schema == self.catalog.tables["private"].schema @pytest.mark.parametrize( "columns,expected_schema", @@ -1127,7 +1119,9 @@ def test_schema_join_public( ), ), ) - def test_schema_join_private_nulls(self, left_schema, right_schema, expected_schema): + def test_schema_join_private_nulls( + self, left_schema, right_schema, expected_schema + ): """Test that schema correctly propagates nulls through a join.""" catalog = Catalog() catalog.add_private_table("left", left_schema) From 2f21dc66777dc6bce4a0894ddd34b68f3572a76f Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 15 Oct 2025 12:00:39 +0200 Subject: [PATCH 3/8] first batch of comments --- src/tmlt/analytics/_query_expr.py | 27 ++++++++++++---------- test/unit/test_query_expression_schema.py | 28 +++++++++++------------ 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index eea49e00..2a5678dc 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -12,6 +12,7 @@ # Copyright Tumult Labs 2025 import datetime +import warnings from abc import ABC, abstractmethod from collections.abc import Collection from dataclasses import dataclass, replace @@ -693,7 +694,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_flat_map_by_id(self) def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this Queryself.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) id_column = input_schema.id_column new_columns = self.schema_new_columns.column_descs @@ -745,10 +746,10 @@ def _schema_for_join( ) -> Schema: """Return the schema resulting from joining two tables. - It is assumed that if either schema has an ID column, the one from - left_schema should be used. This is because the appropriate behavior here - depends on the type of join being performed, so checks for compatibility of - ID columns must happen outside this function. + It is assumed that if either schema has an ID column, the one from left_schema + should be used, because this is true for both public and private joins. With private + joins, the ID columns must be compatible; this check must happen outside this + function. Args: left_schema: Schema for the left table. @@ -1109,9 +1110,10 @@ def schema(self, catalog: Catalog) -> Schema: "as it is an ID column." ) if input_schema.id_column and (len(self.replace_with) == 0): - raise RuntimeWarning( + warnings.warn( f"Replacing null values in the ID column '{input_schema.id_column}' " - "is not allowed, so the ID column may still contain null values." + "is not allowed, so the ID column may still contain null values.", + RuntimeWarning, ) if len(self.replace_with) != 0: @@ -1195,7 +1197,7 @@ def __init__( object.__setattr__(self, "child", child) def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this Queryself.""" + """Returns the schema resulting from evaluating this QueryExpr.""" input_schema = self.child.schema(catalog) if ( @@ -1296,9 +1298,10 @@ def schema(self, catalog: Catalog) -> Schema: "as it is an ID column." ) if input_schema.id_column and len(self.columns) == 0: - raise RuntimeWarning( + warning.warn( f"Replacing null values in the ID column '{input_schema.id_column}' " - "is not allowed, so the ID column may still contain null values." + "is not allowed, so the ID column may still contain null values.", + RuntimeWarning, ) columns = self.columns if len(columns) == 0: @@ -1354,7 +1357,7 @@ def __post_init__(self) -> None: check_type(self.columns, Tuple[str, ...]) def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this Queryself.""" + """Returns the schema resulting from evaluating this QueryExp.""" input_schema = self.child.schema(catalog) if ( @@ -1580,7 +1583,7 @@ def _schema_for_groupby( ): output_column_type = ColumnType.DECIMAL else: - raise AnalyticsInternalError("Unexpected QueryExpr type: {type(query)}.") + raise AnalyticsInternalError(f"Unexpected QueryExpr type: {type(query)}.") if isinstance(query, GetBounds): output_columns = { query.lower_bound_column: ColumnDescriptor( diff --git a/test/unit/test_query_expression_schema.py b/test/unit/test_query_expression_schema.py index c1071bd0..2d13f3f1 100644 --- a/test/unit/test_query_expression_schema.py +++ b/test/unit/test_query_expression_schema.py @@ -671,7 +671,7 @@ def test_invalid_group_by_aggregations_null( ).schema(self.catalog) def test_schema_private_source(self) -> None: - """Test visit_private_source.""" + """Test schema for private_source.""" query = PrivateSource("private") schema = query.schema(self.catalog) assert schema == self.catalog.tables["private"].schema @@ -746,7 +746,7 @@ def test_schema_private_source(self) -> None: def test_schema_rename( self, column_mapper: Dict[str, str], expected_schema: Schema ) -> None: - """Test visit_rename.""" + """Test schema for rename.""" query = Rename( child=PrivateSource("private"), column_mapper=FrozenDict.from_dict(column_mapper), @@ -756,7 +756,7 @@ def test_schema_rename( @pytest.mark.parametrize("condition", ["B > X", "X < 500", "NOTNULL < 30"]) def test_schema_filter(self, condition: str) -> None: - """Test visit_filter.""" + """Test schema for filter.""" query = Filter(child=PrivateSource("private"), condition=condition) schema = query.schema(self.catalog) assert schema == self.catalog.tables["private"].schema @@ -782,7 +782,7 @@ def test_schema_filter(self, condition: str) -> None: ], ) def test_schema_select(self, columns: List[str], expected_schema: Schema) -> None: - """Test visit_select.""" + """Test schema for select.""" query = Select(child=PrivateSource("private"), columns=tuple(columns)) schema = query.schema(self.catalog) assert schema == expected_schema @@ -922,7 +922,7 @@ def test_schema_select(self, columns: List[str], expected_schema: Schema) -> Non ], ) def test_schema_map(self, query: Map, expected_schema: Schema) -> None: - """Test visit_map.""" + """Test schema for map.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -973,7 +973,7 @@ def test_schema_map(self, query: Map, expected_schema: Schema) -> None: ], ) def test_schema_flat_map(self, query: FlatMap, expected_schema: Schema) -> None: - """Test visit_flat_map.""" + """Test schema for flat_map.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -1010,7 +1010,7 @@ def test_schema_flat_map(self, query: FlatMap, expected_schema: Schema) -> None: def test_schema_join_private( self, query: JoinPrivate, expected_schema: Schema ) -> None: - """Test visit_join_private.""" + """Test schema for join_private.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -1069,7 +1069,7 @@ def test_schema_join_private( def test_schema_join_public( self, query: JoinPublic, expected_schema: Schema ) -> None: - """Test visit_join_public.""" + """Test schema for join_public.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -1277,7 +1277,7 @@ def test_schema_join_public_nulls( def test_schema_replace_null_and_nan( self, query: ReplaceNullAndNan, expected_schema: Schema ) -> None: - """Test visit_replace_null_and_nan.""" + """Test schema for replace_null_and_nan.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -1364,7 +1364,7 @@ def test_schema_replace_null_and_nan( def test_schema_drop_null_and_nan( self, query: DropNullAndNan, expected_schema: Schema ) -> None: - """Test visit_drop_null_and_nan.""" + """Test schema for drop_null_and_nan.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -1449,7 +1449,7 @@ def test_schema_drop_null_and_nan( def test_schema_drop_infinity( self, query: DropInfinity, expected_schema: Schema ) -> None: - """Test visit_drop_infinity.""" + """Test schema for drop_infinity.""" schema = query.schema(self.catalog) assert schema == expected_schema @@ -1611,12 +1611,12 @@ def test_schema_drop_infinity( def test_schema_groupby_queries( self, query: QueryExpr, expected_schema: Schema ) -> None: - """Test visit_groupby_*.""" + """Test schema for groupby_*.""" schema = query.schema(self.catalog) assert schema == expected_schema def test_schema_groupby_get_bounds_partition_selection(self) -> None: - """Test visit_get_bounds with auto partition selection enabled.""" + """Test schema for get_bounds with auto partition selection enabled.""" expected_schema = Schema( { "A": ColumnDescriptor(ColumnType.VARCHAR, allow_null=True), @@ -1664,7 +1664,7 @@ def test_schema_groupby_get_bounds_partition_selection(self) -> None: ], ) def test_schema_suppress_aggregates(self, query: SuppressAggregates) -> None: - """Test visit_suppress_aggregates.""" + """Test schema for suppress_aggregates.""" expected_schema = query.child.schema(self.catalog) got_schema = query.schema(self.catalog) assert expected_schema == got_schema From 568fc6da64b6719d7d3622c73a5d17d5e59c459c Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 15 Oct 2025 19:32:15 +0200 Subject: [PATCH 4/8] review comments, mostly splitting validation --- src/tmlt/analytics/_query_expr.py | 497 ++++++++++-------- .../session/ids/test_id_col_operations.py | 4 +- test/system/session/rows/test_add_max_rows.py | 11 +- 3 files changed, 291 insertions(+), 221 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 2a5678dc..3c20180e 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -212,17 +212,20 @@ def __post_init__(self): " (_), and it cannot start with a number, or contain any spaces." ) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" + def _validate(self, catalog: Catalog): + """Validation checks for this QueryExpr.""" if self.source_id not in catalog.tables: raise ValueError(f"Query references nonexistent table '{self.source_id}'") - table = catalog.tables[self.source_id] - if not isinstance(table, PrivateTable): + if not isinstance(catalog.tables[self.source_id], PrivateTable): raise ValueError( f"Attempted query on table '{self.source_id}', which is " "not a private table." ) - return table.schema + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + self._validate(catalog) + return catalog.tables[self.source_id].schema def accept(self, visitor: "QueryExprVisitor") -> Any: """Visits this QueryExpr with visitor.""" @@ -248,28 +251,29 @@ def __post_init__(self): check_type(self.child, QueryExpr) check_type(self.columns, Optional[Tuple[str, ...]]) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" if self.columns: nonexistent_columns = set(self.columns) - set(input_schema) if nonexistent_columns: raise ValueError( f"Nonexistent columns in get_groups query: {nonexistent_columns}" ) - input_schema = Schema( - {column: input_schema[column] for column in self.columns} - ) - else: - input_schema = Schema( - { - column: input_schema[column] - for column in input_schema - if column != input_schema.id_column - } - ) - return input_schema + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + + if self.columns: + return Schema({column: input_schema[column] for column in self.columns}) + return Schema( + { + column: input_schema[column] + for column in input_schema + if column != input_schema.id_column + } + ) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -303,7 +307,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -337,12 +343,8 @@ def __post_init__(self): ' "" are not allowed' ) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) - grouping_column = input_schema.grouping_column - id_column = input_schema.id_column - id_space = input_schema.id_space + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" nonexistent_columns = set(self.column_mapper) - set(input_schema) if nonexistent_columns: raise ValueError( @@ -353,10 +355,19 @@ def schema(self, catalog: Catalog) -> Schema: raise ValueError( f"Cannot rename '{old}' to '{new}': column '{new}' already exists" ) - if old == grouping_column: - grouping_column = new - if old == id_column: - id_column = new + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + + grouping_column = input_schema.grouping_column + if grouping_column in self.column_mapper: + grouping_column = self.column_mapper[grouping_column] + + id_column = input_schema.id_column + if id_column in self.column_mapper: + id_column = self.column_mapper[id_column] return Schema( { @@ -365,7 +376,7 @@ def schema(self, catalog: Catalog) -> Schema: }, grouping_column=grouping_column, id_column=id_column, - id_space=id_space, + id_space=input_schema.id_space, ) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -391,9 +402,8 @@ def __post_init__(self): check_type(self.child, QueryExpr) check_type(self.condition, str) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" spark = SparkSession.builder.getOrCreate() test_df = spark.createDataFrame( [], schema=analytics_to_spark_schema(input_schema) @@ -402,6 +412,11 @@ def schema(self, catalog: Catalog) -> Schema: test_df.filter(self.condition) except Exception as e: raise ValueError(f"Invalid filter condition '{self.condition}': {e}") from e + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) return input_schema def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -425,9 +440,8 @@ def __post_init__(self): if len(self.columns) != len(set(self.columns)): raise ValueError(f"Column name appears more than once in {self.columns}") - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" grouping_column = input_schema.grouping_column id_column = input_schema.id_column if grouping_column is not None and grouping_column not in self.columns: @@ -439,17 +453,20 @@ def schema(self, catalog: Catalog) -> Schema: raise ValueError( f"ID column '{id_column}' may not be dropped by select query" ) - nonexistent_columns = set(self.columns) - set(input_schema) if nonexistent_columns: raise ValueError( f"Nonexistent columns in select query: {nonexistent_columns}" ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) return Schema( {column: input_schema[column] for column in self.columns}, - grouping_column=grouping_column, - id_column=id_column, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, id_space=input_schema.id_space, ) @@ -484,14 +501,9 @@ def __post_init__(self): if self.schema_new_columns.grouping_column is not None: raise ValueError("Map cannot be be used to create grouping columns") - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" new_columns = self.schema_new_columns.column_descs - # Any column created by Map could contain a null value - for name in list(new_columns.keys()): - new_columns[name] = replace(new_columns[name], allow_null=True) - if self.augment: overlapping_columns = set(input_schema.keys()) & set(new_columns.keys()) if overlapping_columns: @@ -500,28 +512,36 @@ def schema(self, catalog: Catalog) -> Schema: "existing columns, but found new columns that " f"already exist: {', '.join(overlapping_columns)}" ) - return Schema( - {**input_schema, **new_columns}, - grouping_column=input_schema.grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - elif input_schema.grouping_column: + return + if input_schema.grouping_column: raise ValueError( "Map must set augment=True to ensure that " f"grouping column '{input_schema.grouping_column}' is not lost." ) - elif input_schema.id_column: + if input_schema.id_column: raise ValueError( "Map must set augment=True to ensure that " f"ID column '{input_schema.id_column}' is not lost." ) - return Schema( - new_columns, - grouping_column=self.schema_new_columns.grouping_column, - id_column=self.schema_new_columns.id_column, - id_space=self.schema_new_columns.id_space, - ) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + new_columns = self.schema_new_columns.column_descs + # Any column created by Map could contain a null value + for name in list(new_columns.keys()): + new_columns[name] = replace(new_columns[name], allow_null=True) + + if self.augment: + return Schema( + {**input_schema, **new_columns}, + grouping_column=input_schema.grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + # If augment=False, there is no grouping column nor ID column + return Schema(new_columns) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -589,9 +609,8 @@ def __post_init__(self): "columns, grouping flat map can only result in 1 new column" ) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema): + """Validation checks for this QueryExpr.""" if self.schema_new_columns.grouping_column is not None: if input_schema.grouping_column: raise ValueError( @@ -603,14 +622,8 @@ def schema(self, catalog: Catalog) -> Schema: "Grouping flat map cannot be used on tables with " "the AddRowsWithID protected change." ) - grouping_column = self.schema_new_columns.grouping_column - else: - grouping_column = input_schema.grouping_column new_columns = self.schema_new_columns.column_descs - # Any column created by the FlatMap could contain a null value - for name in list(new_columns.keys()): - new_columns[name] = replace(new_columns[name], allow_null=True) if self.augment: overlapping_columns = set(input_schema.keys()) & set(new_columns.keys()) if overlapping_columns: @@ -619,29 +632,42 @@ def schema(self, catalog: Catalog) -> Schema: "existing columns, but found new columns that " f"already exist: {', '.join(overlapping_columns)}" ) - return Schema( - {**input_schema, **new_columns}, - grouping_column=grouping_column, - id_column=input_schema.id_column, - id_space=input_schema.id_space, - ) - elif input_schema.grouping_column: + return + if input_schema.grouping_column: raise ValueError( "Flat map must set augment=True to ensure that " f"grouping column '{input_schema.grouping_column}' is not lost." ) - elif input_schema.id_column: + if input_schema.id_column: raise ValueError( "Flat map must set augment=True to ensure that " f"ID column '{input_schema.id_column}' is not lost." ) - return Schema( - new_columns, - grouping_column=grouping_column, - id_column=self.schema_new_columns.id_column, - id_space=self.schema_new_columns.id_space, + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + + grouping_column = ( + self.schema_new_columns.grouping_column + if self.schema_new_columns.grouping_column is not None + else input_schema.grouping_column ) + new_columns = self.schema_new_columns.column_descs + # Any column created by the FlatMap could contain a null value + for name in list(new_columns.keys()): + new_columns[name] = replace(new_columns[name], allow_null=True) + + if self.augment: + return Schema( + {**input_schema, **new_columns}, + grouping_column=grouping_column, + id_column=input_schema.id_column, + id_space=input_schema.id_space, + ) + # If augment=False, there is no grouping column nor ID column + return Schema(new_columns) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -693,13 +719,9 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" return visitor.visit_flat_map_by_id(self) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) - id_column = input_schema.id_column - new_columns = self.schema_new_columns.column_descs - - if not id_column: + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" + if not input_schema.id_column: raise ValueError( "Flat-map-by-ID may only be used on tables with ID columns." ) @@ -707,16 +729,23 @@ def schema(self, catalog: Catalog) -> Schema: raise AnalyticsInternalError( "Encountered table with both an ID column and a grouping column." ) - if id_column in new_columns: + if input_schema.id_column in self.schema_new_columns.column_descs: raise ValueError( "Flat-map-by-ID mapping function output cannot include ID column." ) + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + + id_column = input_schema.id_column + new_columns = self.schema_new_columns.column_descs + for name in list(new_columns.keys()): new_columns[name] = replace(new_columns[name], allow_null=True) return Schema( {id_column: input_schema[id_column], **new_columns}, - grouping_column=None, id_column=id_column, id_space=input_schema.id_space, ) @@ -879,19 +908,8 @@ def __post_init__(self): if len(self.join_columns) != len(set(self.join_columns)): raise ValueError("Join columns must be distinct") - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr. - - The ordering of output columns are: - - 1. The join columns - 2. Columns that are only in the left table - 3. Columns that are only in the right table - 4. Columns that are in both tables, but not included in the join columns. These - columns are included with _left and _right suffixes. - """ - left_schema = self.child.schema(catalog) - right_schema = self.right_operand_expr.schema(catalog) + def _validate(self, left_schema: Schema, right_schema: Schema): + """Validation checks for this QueryExpr.""" if left_schema.id_column != right_schema.id_column: if left_schema.id_column is None or right_schema.id_column is None: raise ValueError( @@ -903,23 +921,31 @@ def schema(self, catalog: Catalog) -> Schema: "protected change are only possible when the ID columns of " "the two tables have the same name" ) - if ( - left_schema.id_space - and right_schema.id_space - and left_schema.id_space != right_schema.id_space - ): + if left_schema.id_space != right_schema.id_space: raise ValueError( "Private joins between tables with the AddRowsWithID protected change" " are only possible when both tables are in the same ID space" ) - join_id_space: Optional[str] = None - if left_schema.id_space and right_schema.id_space: - join_id_space = left_schema.id_space + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr. + + The ordering of output columns are: + + 1. The join columns + 2. Columns that are only in the left table + 3. Columns that are only in the right table + 4. Columns that are in both tables, but not included in the join columns. These + columns are included with _left and _right suffixes. + """ + left_schema = self.child.schema(catalog) + right_schema = self.right_operand_expr.schema(catalog) + self._validate(left_schema, right_schema) return _schema_for_join( left_schema=left_schema, right_schema=right_schema, join_columns=self.join_columns, - join_id_space=join_id_space, + join_id_space=left_schema.id_space, ) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -963,6 +989,15 @@ def __post_init__(self): f"Invalid join type '{self.how}': must be 'inner' or 'left'" ) + def _validate(self, catalog: Catalog): + """Validation checks for this QueryExpr.""" + if isinstance(self.public_table, str): + if not isinstance(catalog.tables[self.public_table], PublicTable): + raise ValueError( + f"Attempted public join on table '{self.public_table}', " + "which is not a public table" + ) + def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr. @@ -970,14 +1005,9 @@ def schema(self, catalog: Catalog) -> Schema: table is the left table. """ input_schema = self.child.schema(catalog) + self._validate(catalog) if isinstance(self.public_table, str): - public_table = catalog.tables[self.public_table] - if not isinstance(public_table, PublicTable): - raise ValueError( - f"Attempted public join on table '{self.public_table}', " - "which is not a public table" - ) - right_schema = public_table.schema + right_schema = catalog.tables[self.public_table].schema else: right_schema = Schema( spark_schema_to_analytics_columns(self.public_table.schema) @@ -1093,9 +1123,8 @@ def __post_init__(self): FrozenDict, ) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" if ( input_schema.grouping_column and input_schema.grouping_column in self.replace_with @@ -1116,23 +1145,26 @@ def schema(self, catalog: Catalog) -> Schema: RuntimeWarning, ) - if len(self.replace_with) != 0: - pytypes = analytics_to_py_types(input_schema) - for col, val in self.replace_with.items(): - if col not in input_schema.keys(): + pytypes = analytics_to_py_types(input_schema) + for col, val in self.replace_with.items(): + if col not in input_schema.keys(): + raise ValueError( + f"Column '{col}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + if not isinstance(val, pytypes[col]): + # Using an int as a float is OK + if not (isinstance(val, int) and pytypes[col] == float): raise ValueError( - f"Column '{col}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" + f"Column '{col}' cannot have nulls replaced with " + f"{repr(val)}, as that value's type does not match the " + f"column type {input_schema[col].column_type.name}" ) - if not isinstance(val, pytypes[col]): - # it's okay to use an int as a float - # so don't raise an error in that case - if not (isinstance(val, int) and pytypes[col] == float): - raise ValueError( - f"Column '{col}' cannot have nulls replaced with " - f"{repr(val)}, as that value's type does not match the " - f"column type {input_schema[col].column_type.name}" - ) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) columns_to_change = list(dict(self.replace_with).keys()) if len(columns_to_change) == 0: @@ -1196,10 +1228,8 @@ def __init__( object.__setattr__(self, "replace_with", FrozenDict.from_dict(updated_dict)) object.__setattr__(self, "child", child) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) - + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" if ( input_schema.grouping_column and input_schema.grouping_column in self.replace_with @@ -1208,13 +1238,31 @@ def schema(self, catalog: Catalog) -> Schema: "Cannot replace infinite values in column " f"'{input_schema.grouping_column}', as it is a grouping column" ) - # Float-valued columns cannot be ID columns, but include this to be safe. if input_schema.id_column and input_schema.id_column in self.replace_with: raise ValueError( f"Cannot replace infinite values in column '{input_schema.id_column}', " "as it is an ID column" ) + for name in self.replace_with: + if name not in input_schema.keys(): + raise ValueError( + f"Column '{name}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + if input_schema[name].column_type != ColumnType.DECIMAL: + raise ValueError( + f"Column '{name}' has a replacement value provided, but it is " + f"of type {input_schema[name].column_type.name} (not " + f"{ColumnType.DECIMAL.name}) and so cannot " + "contain infinite values" + ) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + columns_to_change = list(self.replace_with.keys()) if len(columns_to_change) == 0: columns_to_change = [ @@ -1222,20 +1270,6 @@ def schema(self, catalog: Catalog) -> Schema: for col in input_schema.column_descs.keys() if input_schema[col].column_type == ColumnType.DECIMAL ] - else: - for name in self.replace_with: - if name not in input_schema.keys(): - raise ValueError( - f"Column '{name}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) - if input_schema[name].column_type != ColumnType.DECIMAL: - raise ValueError( - f"Column '{name}' has a replacement value provided, but it is " - f"of type {input_schema[name].column_type.name} (not " - f"{ColumnType.DECIMAL.name}) and so cannot " - "contain infinite values" - ) return Schema( { name: ColumnDescriptor( @@ -1281,9 +1315,8 @@ def __post_init__(self) -> None: check_type(self.child, QueryExpr) check_type(self.columns, Tuple[str, ...]) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" if ( input_schema.grouping_column and input_schema.grouping_column in self.columns @@ -1298,11 +1331,23 @@ def schema(self, catalog: Catalog) -> Schema: "as it is an ID column." ) if input_schema.id_column and len(self.columns) == 0: - warning.warn( + warnings.warn( f"Replacing null values in the ID column '{input_schema.id_column}' " "is not allowed, so the ID column may still contain null values.", RuntimeWarning, ) + for name in self.columns: + if name not in input_schema.keys(): + raise ValueError( + f"Column '{name}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) + columns = self.columns if len(columns) == 0: columns = tuple( @@ -1311,13 +1356,7 @@ def schema(self, catalog: Catalog) -> Schema: if (cd.allow_null or cd.allow_nan) and not name in [input_schema.grouping_column, input_schema.id_column] ) - else: - for name in columns: - if name not in input_schema.keys(): - raise ValueError( - f"Column '{name}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) + return Schema( { name: ColumnDescriptor( @@ -1356,10 +1395,8 @@ def __post_init__(self) -> None: check_type(self.child, QueryExpr) check_type(self.columns, Tuple[str, ...]) - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExp.""" - input_schema = self.child.schema(catalog) - + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" if ( input_schema.grouping_column and input_schema.grouping_column in self.columns @@ -1374,6 +1411,25 @@ def schema(self, catalog: Catalog) -> Schema: f"Cannot drop infinite values in column '{input_schema.id_column}', " "as it is an ID column" ) + for name in self.columns: + if name not in input_schema.keys(): + raise ValueError( + f"Column '{name}' does not exist in this table, " + f"available columns are {list(input_schema.keys())}" + ) + if input_schema[name].column_type != ColumnType.DECIMAL: + raise ValueError( + f"Column '{name}' was given as a column to drop " + "infinite values from, but it is of type" + f"{input_schema[name].column_type.name} (not " + f"{ColumnType.DECIMAL.name}) and so cannot " + "contain infinite values" + ) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExp.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) columns = self.columns if len(columns) == 0: @@ -1382,21 +1438,6 @@ def schema(self, catalog: Catalog) -> Schema: for name, cd in input_schema.column_descs.items() if (cd.allow_inf) and not name == input_schema.grouping_column ) - else: - for name in columns: - if name not in input_schema.keys(): - raise ValueError( - f"Column '{name}' does not exist in this table, " - f"available columns are {list(input_schema.keys())}" - ) - if input_schema[name].column_type != ColumnType.DECIMAL: - raise ValueError( - f"Column '{name}' was given as a column to drop " - "infinite values from, but it is of type" - f"{input_schema[name].column_type.name} (not " - f"{ColumnType.DECIMAL.name}) and so cannot " - "contain infinite values" - ) return Schema( { @@ -1432,10 +1473,8 @@ class EnforceConstraint(QueryExpr): Appropriate values here vary depending on the constraint. These options are to support advanced use cases, and generally should not be used.""" - def schema(self, catalog: Catalog) -> Schema: - """Returns the schema resulting from evaluating this QueryExpr.""" - input_schema = self.child.schema(catalog) - + def _validate(self, input_schema: Schema): + """Validation checks for this QueryExpr.""" if not input_schema.id_column: raise ValueError( f"Constraint {self.constraint} can only be applied to tables" @@ -1454,6 +1493,11 @@ def schema(self, catalog: Catalog) -> Schema: f"The grouping column of constraint {self.constraint} cannot be" " the ID column of the table it is applied to" ) + + def schema(self, catalog: Catalog) -> Schema: + """Returns the schema resulting from evaluating this QueryExpr.""" + input_schema = self.child.schema(catalog) + self._validate(input_schema) return input_schema def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -1461,7 +1505,7 @@ def accept(self, visitor: "QueryExprVisitor") -> Any: return visitor.visit_enforce_constraint(self) -def _schema_for_groupby( +def _validate_groupby( query: Union[ "GetBounds", "GroupByBoundedAverage", @@ -1472,19 +1516,9 @@ def _schema_for_groupby( "GroupByCountDistinct", "GroupByQuantile", ], - catalog: Catalog, -) -> Schema: - """Validates and returns the schema of a group-by QueryExpr. - - Args: - query: Query expression to be validated. - catalog: The catalog. - - Returns: - Output schema of current QueryExpr - """ - input_schema = query.child.schema(catalog) - + input_schema: Schema, +): + """Validates the arguments of a group-by QueryExpr.""" # Validating group-by columns if isinstance(query.groupby_keys, KeySet): # Checks that the KeySet is valid @@ -1566,6 +1600,27 @@ def _schema_for_groupby( "AddRowsWithID protected change." ) + +def _schema_for_groupby( + query: Union[ + "GetBounds", + "GroupByBoundedAverage", + "GroupByBoundedSTDEV", + "GroupByBoundedSum", + "GroupByBoundedVariance", + "GroupByCount", + "GroupByCountDistinct", + "GroupByQuantile", + ], + input_schema: Schema, +) -> Schema: + """Returns the schema of a group-by QueryExpr.""" + groupby_columns = ( + query.groupby_keys.schema().keys() + if isinstance(query.groupby_keys, KeySet) + else query.groupby_keys + ) + # Determining the output column types & names if isinstance(query, (GroupByCount, GroupByCountDistinct)): output_column_type = ColumnType.INTEGER @@ -1636,7 +1691,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -1676,7 +1733,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -1741,7 +1800,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -1806,7 +1867,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -1871,7 +1934,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -1936,7 +2001,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" @@ -2002,7 +2069,9 @@ def __post_init__(self): def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr.""" - return _schema_for_groupby(self, catalog) + input_schema = self.child.schema(catalog) + _validate_groupby(self, input_schema) + return _schema_for_groupby(self, input_schema) def accept(self, visitor: "QueryExprVisitor") -> Any: """Visit this QueryExpr with visitor.""" diff --git a/test/system/session/ids/test_id_col_operations.py b/test/system/session/ids/test_id_col_operations.py index f2963144..a0069deb 100644 --- a/test/system/session/ids/test_id_col_operations.py +++ b/test/system/session/ids/test_id_col_operations.py @@ -104,7 +104,7 @@ def test_replace_null_and_nan_raises_error( ) def test_replace_null_and_nan_raises_warning(session, query: QueryBuilder): """Tests that replace nulls/nans raises warning on IDs table with empty mapping.""" - with pytest.raises( + with pytest.warns( RuntimeWarning, match="the ID column may still contain null values." ): session.evaluate( @@ -121,7 +121,7 @@ def test_replace_null_and_nan_raises_warning(session, query: QueryBuilder): ) def test_drop_null_and_nan_raises_warning(session, query: QueryBuilder): """Tests that replace nulls/nans raises warning on IDs table with empty list.""" - with pytest.raises( + with pytest.warns( RuntimeWarning, match="the ID column may still contain null values." ): session.evaluate( diff --git a/test/system/session/rows/test_add_max_rows.py b/test/system/session/rows/test_add_max_rows.py index c606bbd8..78655126 100644 --- a/test/system/session/rows/test_add_max_rows.py +++ b/test/system/session/rows/test_add_max_rows.py @@ -529,8 +529,8 @@ def test_get_bounds_inf_budget_sum(self, spark, data): column="str_column", protected_change=AddOneRow(), error_type=ValueError, - message="Cannot get bounds for column 'str_column'," - " which is of type VARCHAR", + message="GetBounds query's measure column 'str_column' has invalid type" + " 'VARCHAR'. Expected types: 'INTEGER' or 'DECIMAL'", ), Case("missing_column")( data=pd.DataFrame( @@ -540,8 +540,8 @@ def test_get_bounds_inf_budget_sum(self, spark, data): column="column_does_not_exist", protected_change=AddOneRow(), error_type=ValueError, - message="Cannot get bounds for column 'column_does_not_exist'," - " which does not exist", + message="GetBounds query's measure column 'column_does_not_exist'" + " does not exist", ), Case("id_column")( data=pd.DataFrame( @@ -551,7 +551,8 @@ def test_get_bounds_inf_budget_sum(self, spark, data): column="id_column", protected_change=AddRowsWithID("id_column"), error_type=ValueError, - message="get_bounds cannot be used on the privacy ID column", + message="GetBounds query's measure column is the same as the privacy ID " + "column\(id_column\)", ), ) def test_get_bounds_invalid_columns( From e63c6d2268069da2255032de3105d279c08f5439 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 15 Oct 2025 19:43:41 +0200 Subject: [PATCH 5/8] actually perform validation. otherwise validation is not performed. --- src/tmlt/analytics/_query_expr_compiler/_compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tmlt/analytics/_query_expr_compiler/_compiler.py b/src/tmlt/analytics/_query_expr_compiler/_compiler.py index b76281df..1e7f5f06 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_compiler.py +++ b/src/tmlt/analytics/_query_expr_compiler/_compiler.py @@ -136,6 +136,8 @@ def __call__( catalog: The catalog, used only for query validation. table_constraints: A mapping of tables to the existing constraints on them. """ + query.schema(catalog) + visitor = MeasurementVisitor( privacy_budget=privacy_budget, stability=stability, From e76ceec24c2926796c0025af2dccd64a74b67d28 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Wed, 15 Oct 2025 19:50:28 +0200 Subject: [PATCH 6/8] ah it was just because get_bounds has no transformation! doing it this way makes a lot more sense --- src/tmlt/analytics/_query_expr_compiler/_compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tmlt/analytics/_query_expr_compiler/_compiler.py b/src/tmlt/analytics/_query_expr_compiler/_compiler.py index 1e7f5f06..c7e7de9a 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_compiler.py +++ b/src/tmlt/analytics/_query_expr_compiler/_compiler.py @@ -136,6 +136,7 @@ def __call__( catalog: The catalog, used only for query validation. table_constraints: A mapping of tables to the existing constraints on them. """ + # Computing the schema validates that the query is well-formed. query.schema(catalog) visitor = MeasurementVisitor( @@ -206,8 +207,6 @@ def build_transformation( catalog: The catalog, used only for query validation. table_constraints: A mapping of tables to the existing constraints on them. """ - query.schema(catalog) - transformation_visitor = TransformationVisitor( input_domain=input_domain, input_metric=input_metric, From 5d3e4b42236c3d3f789559f5b7975d791ebb4c46 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Thu, 16 Oct 2025 10:19:06 +0200 Subject: [PATCH 7/8] check the schema for *both* transformations and measurements --- src/tmlt/analytics/_query_expr.py | 16 +++++++++------- .../analytics/_query_expr_compiler/_compiler.py | 5 +++++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 3c20180e..8efe0958 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -1169,10 +1169,10 @@ def schema(self, catalog: Catalog) -> Schema: columns_to_change = list(dict(self.replace_with).keys()) if len(columns_to_change) == 0: columns_to_change = [ - col - for col in input_schema.column_descs.keys() - if (input_schema[col].allow_null or input_schema[col].allow_nan) - and not (col in [input_schema.grouping_column, input_schema.id_column]) + name + for name, cd in input_schema.column_descs.items() + if (cd.allow_null or cd.allow_nan) + and not (name in [input_schema.grouping_column, input_schema.id_column]) ] return Schema( { @@ -1267,8 +1267,9 @@ def schema(self, catalog: Catalog) -> Schema: if len(columns_to_change) == 0: columns_to_change = [ col - for col in input_schema.column_descs.keys() - if input_schema[col].column_type == ColumnType.DECIMAL + for name, cd in input_schema.column_descs.items() + if cd.column_type == ColumnType.DECIMAL and cd.allow_inf + and not (name in [input_schema.grouping_column, input_schema.id_column]) ] return Schema( { @@ -1436,7 +1437,8 @@ def schema(self, catalog: Catalog) -> Schema: columns = tuple( name for name, cd in input_schema.column_descs.items() - if (cd.allow_inf) and not name == input_schema.grouping_column + if cd.column_type == ColumnType.DECIMAL and cd.allow_inf + and not name in (input_schema.grouping_column, input_schema.id_column) ) return Schema( diff --git a/src/tmlt/analytics/_query_expr_compiler/_compiler.py b/src/tmlt/analytics/_query_expr_compiler/_compiler.py index c7e7de9a..37475d60 100644 --- a/src/tmlt/analytics/_query_expr_compiler/_compiler.py +++ b/src/tmlt/analytics/_query_expr_compiler/_compiler.py @@ -207,6 +207,11 @@ def build_transformation( catalog: The catalog, used only for query validation. table_constraints: A mapping of tables to the existing constraints on them. """ + # Computing the schema validates that the query is well-formed. It's useful to + # perform this check here in addition to __call__ so validation errors can be + # raised at view creation, not just query evaluation. + query.schema(catalog) + transformation_visitor = TransformationVisitor( input_domain=input_domain, input_metric=input_metric, From 3ea3afdca6951e36bf68af2387de0c37f80f48a5 Mon Sep 17 00:00:00 2001 From: Damien Desfontaines Date: Thu, 16 Oct 2025 11:01:17 +0200 Subject: [PATCH 8/8] lint --- src/tmlt/analytics/_query_expr.py | 8 +++++--- test/system/session/rows/test_add_max_rows.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index 8efe0958..dbcee359 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -1266,9 +1266,10 @@ def schema(self, catalog: Catalog) -> Schema: columns_to_change = list(self.replace_with.keys()) if len(columns_to_change) == 0: columns_to_change = [ - col + name for name, cd in input_schema.column_descs.items() - if cd.column_type == ColumnType.DECIMAL and cd.allow_inf + if cd.column_type == ColumnType.DECIMAL + and cd.allow_inf and not (name in [input_schema.grouping_column, input_schema.id_column]) ] return Schema( @@ -1437,7 +1438,8 @@ def schema(self, catalog: Catalog) -> Schema: columns = tuple( name for name, cd in input_schema.column_descs.items() - if cd.column_type == ColumnType.DECIMAL and cd.allow_inf + if cd.column_type == ColumnType.DECIMAL + and cd.allow_inf and not name in (input_schema.grouping_column, input_schema.id_column) ) diff --git a/test/system/session/rows/test_add_max_rows.py b/test/system/session/rows/test_add_max_rows.py index 78655126..252cc6ba 100644 --- a/test/system/session/rows/test_add_max_rows.py +++ b/test/system/session/rows/test_add_max_rows.py @@ -552,7 +552,7 @@ def test_get_bounds_inf_budget_sum(self, spark, data): protected_change=AddRowsWithID("id_column"), error_type=ValueError, message="GetBounds query's measure column is the same as the privacy ID " - "column\(id_column\)", + "column\\(id_column\\)", ), ) def test_get_bounds_invalid_columns(