diff --git a/src/tmlt/analytics/_query_expr.py b/src/tmlt/analytics/_query_expr.py index c816c227..2ed6ae5c 100644 --- a/src/tmlt/analytics/_query_expr.py +++ b/src/tmlt/analytics/_query_expr.py @@ -36,6 +36,7 @@ analytics_to_py_types, analytics_to_spark_columns_descriptor, analytics_to_spark_schema, + spark_dataframe_domain_to_analytics_columns, spark_schema_to_analytics_columns, ) from tmlt.analytics.config import config @@ -761,52 +762,35 @@ def __eq__(self, other: object) -> bool: ) -def _schema_for_join( +def _validate_join( left_schema: Schema, right_schema: Schema, join_columns: Optional[Tuple[str, ...]], - join_id_space: Optional[str] = None, - how: str = "inner", -) -> Schema: - """Return the schema resulting from joining two tables. - - It is assumed that if either schema has an ID column, the one from left_schema - should be used, because this is true for both public and private joins. With private - joins, the ID columns must be compatible; this check must happen outside this - function. +): + """Validates that both tables can be joined by comparing their schemas. - Args: - left_schema: Schema for the left table. - right_schema: Schema for the right table. - join_columns: The set of columns to join on. - join_id_space: The ID space of the resulting join. - how: The type of join to perform. Default is "inner". + This is used for both public and private joins; therefore, this does not check + any properties related to ID columns & ID spaces. """ - if left_schema.grouping_column is None: - grouping_column = right_schema.grouping_column - elif right_schema.grouping_column is None: - grouping_column = left_schema.grouping_column - elif left_schema.grouping_column == right_schema.grouping_column: - grouping_column = left_schema.grouping_column - else: + if ( + left_schema.grouping_column is not None + and right_schema.grouping_column is not None + and left_schema.grouping_column != right_schema.grouping_column + ): raise ValueError( "Joining tables which both have grouping columns is only supported " "if they have the same grouping column" ) - common_columns = set(left_schema) & set(right_schema) - if join_columns is None and not common_columns: - raise ValueError("Tables have no common columns to join on") + if join_columns is not None and not join_columns: # This error case should be caught when constructing the query # expression, so it should never get here. raise AnalyticsInternalError("Empty list of join columns provided.") - join_columns = ( - join_columns - if join_columns - else tuple(sorted(common_columns, key=list(left_schema).index)) - ) - + common_columns = set(left_schema) & set(right_schema) + if join_columns is None and not common_columns: + raise ValueError("Tables have no common columns to join on") + join_columns = join_columns or tuple(common_columns) if not set(join_columns) <= common_columns: raise ValueError("Join columns must be common to both tables") @@ -818,23 +802,35 @@ def _schema_for_join( f"{right_schema[column]} are incompatible" ) - join_column_schemas = {column: left_schema[column] for column in join_columns} - output_schema = { - **join_column_schemas, - **{ - column + ("_left" if column in common_columns else ""): left_schema[column] - for column in left_schema - if column not in join_columns - }, - **{ - column - + ("_right" if column in common_columns else ""): right_schema[column] - for column in right_schema - if column not in join_columns - }, - } - # Use Core's join utilities for determining whether a column can be null - # TODO: This could potentially be used more in this function + +def _schema_for_join( + left_schema: Schema, + right_schema: Schema, + join_columns: Optional[Tuple[str, ...]], + join_id_space: Optional[str], + how: str, +) -> Schema: + """Return the schema resulting from joining two tables. + + It is assumed that if either schema has an ID column, the one from left_schema + should be used, because this is true for both public and private joins. With private + joins, the ID columns must be compatible; this check must happen outside this + function. + + Args: + left_schema: Schema for the left table. + right_schema: Schema for the right table. + join_columns: The set of columns to join on. + join_id_space: The ID space of the resulting join. + how: The type of join to perform. + """ + grouping_column = left_schema.grouping_column or right_schema.grouping_column + common_columns = set(left_schema) & set(right_schema) + join_columns = join_columns or tuple( + sorted(common_columns, key=list(left_schema).index) + ) + + # Get the join schema from the Core convenience method output_domain = domain_after_join( left_domain=SparkDataFrameDomain( analytics_to_spark_columns_descriptor(left_schema) @@ -846,16 +842,8 @@ def _schema_for_join( how=how, nulls_are_equal=True, ) - for column in output_schema: - col_schema = output_schema[column] - output_schema[column] = ColumnDescriptor( - column_type=col_schema.column_type, - allow_null=output_domain.schema[column].allow_null, - allow_nan=col_schema.allow_nan, - allow_inf=col_schema.allow_inf, - ) return Schema( - output_schema, + column_descs=spark_dataframe_domain_to_analytics_columns(output_domain), grouping_column=grouping_column, id_column=left_schema.id_column, id_space=join_id_space, @@ -921,6 +909,7 @@ def _validate(self, left_schema: Schema, right_schema: Schema): "Private joins between tables with the AddRowsWithID protected change" " are only possible when both tables are in the same ID space" ) + _validate_join(left_schema, right_schema, self.join_columns) def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr. @@ -941,6 +930,7 @@ def schema(self, catalog: Catalog) -> Schema: right_schema=right_schema, join_columns=self.join_columns, join_id_space=left_schema.id_space, + how="inner", ) def accept(self, visitor: "QueryExprVisitor") -> Any: @@ -982,7 +972,7 @@ def __post_init__(self): f"Invalid join type '{self.how}': must be 'inner' or 'left'" ) - def _validate(self, catalog: Catalog): + def _validate(self, catalog: Catalog, left_schema: Schema, right_schema: Schema): """Validation checks for this QueryExpr.""" if isinstance(self.public_table, str): if not isinstance(catalog.tables[self.public_table], PublicTable): @@ -990,6 +980,7 @@ def _validate(self, catalog: Catalog): f"Attempted public join on table '{self.public_table}', " "which is not a public table" ) + _validate_join(left_schema, right_schema, self.join_columns) def schema(self, catalog: Catalog) -> Schema: """Returns the schema resulting from evaluating this QueryExpr. @@ -998,13 +989,13 @@ def schema(self, catalog: Catalog) -> Schema: table is the left table. """ input_schema = self.child.schema(catalog) - self._validate(catalog) if isinstance(self.public_table, str): right_schema = catalog.tables[self.public_table].schema else: right_schema = Schema( spark_schema_to_analytics_columns(self.public_table.schema) ) + self._validate(catalog, input_schema, right_schema) return _schema_for_join( left_schema=input_schema, right_schema=right_schema,