这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
936 changes: 931 additions & 5 deletions src/tmlt/analytics/_query_expr.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@
SuppressAggregates,
VarianceMechanism,
)
from tmlt.analytics._query_expr_compiler._output_schema_visitor import (
OutputSchemaVisitor,
)
from tmlt.analytics._schema import ColumnType, FrozenDict, Schema
from tmlt.analytics._table_identifier import Identifier
from tmlt.analytics._table_reference import TableReference
Expand Down Expand Up @@ -708,7 +705,7 @@ def _pick_noise_for_non_count(
GroupByQuantile and GetBounds only supports one noise mechanism, so it is not
included here.
"""
measure_column_type = query.child.accept(OutputSchemaVisitor(self.catalog))[
measure_column_type = query.child.schema(self.catalog)[
query.measure_column
].column_type
requested_mechanism: NoiseMechanism
Expand Down Expand Up @@ -802,7 +799,7 @@ def _add_special_value_handling_to_query(

These changes are added immediately before the groupby aggregation in the query.
"""
expected_schema = query.child.accept(OutputSchemaVisitor(self.catalog))
expected_schema = query.child.schema(self.catalog)

# You can't perform these queries on nulls, NaNs, or infinite values
# so check for those
Expand Down Expand Up @@ -1046,7 +1043,7 @@ def visit_groupby_count(self, expr: GroupByCount) -> Tuple[Measurement, NoiseInf
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand Down Expand Up @@ -1130,7 +1127,7 @@ def visit_groupby_count_distinct(
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)

if isinstance(expr.groupby_keys, KeySet):
groupby_cols = tuple(expr.groupby_keys.dataframe().columns)
Expand All @@ -1150,7 +1147,7 @@ def visit_groupby_count_distinct(
) = self._visit_child_transformation(expr.child, mechanism)
constrained_query = _generate_constrained_count_distinct(
expr,
expr.child.accept(OutputSchemaVisitor(self.catalog)),
expr.child.schema(self.catalog),
child_constraints,
)
if constrained_query is not None:
Expand Down Expand Up @@ -1251,7 +1248,7 @@ def visit_groupby_quantile(
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
Expand All @@ -1264,7 +1261,7 @@ def visit_groupby_quantile(
self.adjusted_budget
)
# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)

child_transformation, child_ref = self._truncate_table(
*self._visit_child_transformation(expr.child, self.default_mechanism),
Expand Down Expand Up @@ -1346,7 +1343,7 @@ def visit_groupby_bounded_sum(
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
Expand Down Expand Up @@ -1442,7 +1439,7 @@ def visit_groupby_bounded_average(
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
Expand Down Expand Up @@ -1538,7 +1535,7 @@ def visit_groupby_bounded_variance(
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
Expand Down Expand Up @@ -1634,7 +1631,7 @@ def visit_groupby_bounded_stdev(
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)
expr = self._add_special_value_handling_to_query(expr)

if isinstance(expr.groupby_keys, KeySet):
Expand Down Expand Up @@ -1726,7 +1723,7 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]:
self._validate_approxDP_and_adjust_budget(expr)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)

expr = self._add_special_value_handling_to_query(expr)
if isinstance(expr.groupby_keys, KeySet):
Expand All @@ -1740,7 +1737,7 @@ def visit_get_bounds(self, expr: GetBounds) -> Tuple[Measurement, NoiseInfo]:
)

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)

child_transformation, child_ref = self._truncate_table(
*self._visit_child_transformation(expr.child, NoiseMechanism.GEOMETRIC),
Expand Down Expand Up @@ -1823,7 +1820,7 @@ def visit_suppress_aggregates(
self, expr: SuppressAggregates
) -> Tuple[Measurement, NoiseInfo]:
"""Create a measurement from a SuppressAggregates query expression."""
expr.accept(OutputSchemaVisitor(self.catalog))
expr.schema(self.catalog)

child_measurement, noise_info = expr.child.accept(self)
if not isinstance(child_measurement, Measurement):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@
propagate_select,
propagate_unmodified,
)
from tmlt.analytics._query_expr_compiler._output_schema_visitor import (
OutputSchemaVisitor,
)
from tmlt.analytics._schema import (
ColumnDescriptor,
ColumnType,
Expand Down Expand Up @@ -245,7 +242,7 @@ def validate_transformation(
catalog: Catalog,
):
"""Ensure that a query's transformation is valid on a given catalog."""
expected_schema = query.accept(OutputSchemaVisitor(catalog))
expected_schema = query.schema(catalog)
expected_output_domain = SparkDataFrameDomain(
analytics_to_spark_columns_descriptor(expected_schema)
)
Expand Down
19 changes: 11 additions & 8 deletions src/tmlt/analytics/_query_expr_compiler/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from tmlt.analytics._noise_info import NoiseInfo
from tmlt.analytics._query_expr import QueryExpr
from tmlt.analytics._query_expr_compiler._measurement_visitor import MeasurementVisitor
from tmlt.analytics._query_expr_compiler._output_schema_visitor import (
OutputSchemaVisitor,
)
from tmlt.analytics._query_expr_compiler._transformation_visitor import (
TransformationVisitor,
)
Expand Down Expand Up @@ -108,13 +105,13 @@ def output_measure(self) -> Union[PureDP, ApproxDP, RhoZCDP]:
@staticmethod
def query_schema(query: QueryExpr, catalog: Catalog) -> Schema:
"""Return the schema created by a given query."""
result = query.accept(OutputSchemaVisitor(catalog=catalog))
if not isinstance(result, Schema):
schema = query.schema(catalog)
if not isinstance(schema, Schema):
raise AnalyticsInternalError(
"Schema for this query is not a Schema but is instead a(n) "
f"{type(result)}."
f"{type(schema)}."
)
return result
return schema

def __call__(
self,
Expand All @@ -139,6 +136,9 @@ def __call__(
catalog: The catalog, used only for query validation.
table_constraints: A mapping of tables to the existing constraints on them.
"""
# Computing the schema validates that the query is well-formed.
query.schema(catalog)

visitor = MeasurementVisitor(
privacy_budget=privacy_budget,
stability=stability,
Expand Down Expand Up @@ -207,7 +207,10 @@ def build_transformation(
catalog: The catalog, used only for query validation.
table_constraints: A mapping of tables to the existing constraints on them.
"""
query.accept(OutputSchemaVisitor(catalog))
# Computing the schema validates that the query is well-formed. It's useful to
# perform this check here in addition to __call__ so validation errors can be
# raised at view creation, not just query evaluation.
query.schema(catalog)

transformation_visitor = TransformationVisitor(
input_domain=input_domain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
from tmlt.analytics._query_expr_compiler._base_measurement_visitor import (
BaseMeasurementVisitor,
)
from tmlt.analytics._query_expr_compiler._output_schema_visitor import (
OutputSchemaVisitor,
)
from tmlt.analytics._query_expr_compiler._transformation_visitor import (
TransformationVisitor,
)
Expand Down Expand Up @@ -75,10 +72,7 @@ def visit_get_groups(self, expr: GetGroups) -> Tuple[Measurement, NoiseInfo]:
if not isinstance(self.budget, ApproxDPBudget):
raise ValueError("GetGroups is only supported with ApproxDPBudgets.")

# Peek at the schema, to see if there are errors there
expr.accept(OutputSchemaVisitor(self.catalog))

schema = expr.child.accept(OutputSchemaVisitor(self.catalog))
schema = expr.child.schema(self.catalog)

# Set the columns if no columns were provided.
if expr.columns:
Expand Down
Loading