这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 51 additions & 60 deletions src/tmlt/analytics/_query_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
analytics_to_py_types,
analytics_to_spark_columns_descriptor,
analytics_to_spark_schema,
spark_dataframe_domain_to_analytics_columns,
spark_schema_to_analytics_columns,
)
from tmlt.analytics.config import config
Expand Down Expand Up @@ -761,52 +762,35 @@ def __eq__(self, other: object) -> bool:
)


def _schema_for_join(
def _validate_join(
left_schema: Schema,
right_schema: Schema,
join_columns: Optional[Tuple[str, ...]],
join_id_space: Optional[str] = None,
how: str = "inner",
) -> Schema:
"""Return the schema resulting from joining two tables.

It is assumed that if either schema has an ID column, the one from left_schema
should be used, because this is true for both public and private joins. With private
joins, the ID columns must be compatible; this check must happen outside this
function.
):
"""Validates that both tables can be joined by comparing their schemas.

Args:
left_schema: Schema for the left table.
right_schema: Schema for the right table.
join_columns: The set of columns to join on.
join_id_space: The ID space of the resulting join.
how: The type of join to perform. Default is "inner".
This is used for both public and private joins; therefore, this does not check
any properties related to ID columns & ID spaces.
"""
if left_schema.grouping_column is None:
grouping_column = right_schema.grouping_column
elif right_schema.grouping_column is None:
grouping_column = left_schema.grouping_column
elif left_schema.grouping_column == right_schema.grouping_column:
grouping_column = left_schema.grouping_column
else:
if (
left_schema.grouping_column is not None
and right_schema.grouping_column is not None
and left_schema.grouping_column != right_schema.grouping_column
):
raise ValueError(
"Joining tables which both have grouping columns is only supported "
"if they have the same grouping column"
)
common_columns = set(left_schema) & set(right_schema)
if join_columns is None and not common_columns:
raise ValueError("Tables have no common columns to join on")

if join_columns is not None and not join_columns:
# This error case should be caught when constructing the query
# expression, so it should never get here.
raise AnalyticsInternalError("Empty list of join columns provided.")

join_columns = (
join_columns
if join_columns
else tuple(sorted(common_columns, key=list(left_schema).index))
)

common_columns = set(left_schema) & set(right_schema)
if join_columns is None and not common_columns:
raise ValueError("Tables have no common columns to join on")
join_columns = join_columns or tuple(common_columns)
if not set(join_columns) <= common_columns:
raise ValueError("Join columns must be common to both tables")

Expand All @@ -818,23 +802,35 @@ def _schema_for_join(
f"{right_schema[column]} are incompatible"
)

join_column_schemas = {column: left_schema[column] for column in join_columns}
output_schema = {
**join_column_schemas,
**{
column + ("_left" if column in common_columns else ""): left_schema[column]
for column in left_schema
if column not in join_columns
},
**{
column
+ ("_right" if column in common_columns else ""): right_schema[column]
for column in right_schema
if column not in join_columns
},
}
# Use Core's join utilities for determining whether a column can be null
# TODO: This could potentially be used more in this function

def _schema_for_join(
left_schema: Schema,
right_schema: Schema,
join_columns: Optional[Tuple[str, ...]],
join_id_space: Optional[str],
how: str,
) -> Schema:
"""Return the schema resulting from joining two tables.

It is assumed that if either schema has an ID column, the one from left_schema
should be used, because this is true for both public and private joins. With private
joins, the ID columns must be compatible; this check must happen outside this
function.

Args:
left_schema: Schema for the left table.
right_schema: Schema for the right table.
join_columns: The set of columns to join on.
join_id_space: The ID space of the resulting join.
how: The type of join to perform.
"""
grouping_column = left_schema.grouping_column or right_schema.grouping_column
common_columns = set(left_schema) & set(right_schema)
join_columns = join_columns or tuple(
sorted(common_columns, key=list(left_schema).index)
)

# Get the join schema from the Core convenience method
output_domain = domain_after_join(
left_domain=SparkDataFrameDomain(
analytics_to_spark_columns_descriptor(left_schema)
Expand All @@ -846,16 +842,8 @@ def _schema_for_join(
how=how,
nulls_are_equal=True,
)
for column in output_schema:
col_schema = output_schema[column]
output_schema[column] = ColumnDescriptor(
column_type=col_schema.column_type,
allow_null=output_domain.schema[column].allow_null,
allow_nan=col_schema.allow_nan,
allow_inf=col_schema.allow_inf,
)
return Schema(
output_schema,
column_descs=spark_dataframe_domain_to_analytics_columns(output_domain),
grouping_column=grouping_column,
id_column=left_schema.id_column,
id_space=join_id_space,
Expand Down Expand Up @@ -921,6 +909,7 @@ def _validate(self, left_schema: Schema, right_schema: Schema):
"Private joins between tables with the AddRowsWithID protected change"
" are only possible when both tables are in the same ID space"
)
_validate_join(left_schema, right_schema, self.join_columns)

def schema(self, catalog: Catalog) -> Schema:
"""Returns the schema resulting from evaluating this QueryExpr.
Expand All @@ -941,6 +930,7 @@ def schema(self, catalog: Catalog) -> Schema:
right_schema=right_schema,
join_columns=self.join_columns,
join_id_space=left_schema.id_space,
how="inner",
)

def accept(self, visitor: "QueryExprVisitor") -> Any:
Expand Down Expand Up @@ -982,14 +972,15 @@ def __post_init__(self):
f"Invalid join type '{self.how}': must be 'inner' or 'left'"
)

def _validate(self, catalog: Catalog):
def _validate(self, catalog: Catalog, left_schema: Schema, right_schema: Schema):
"""Validation checks for this QueryExpr."""
if isinstance(self.public_table, str):
if not isinstance(catalog.tables[self.public_table], PublicTable):
raise ValueError(
f"Attempted public join on table '{self.public_table}', "
"which is not a public table"
)
_validate_join(left_schema, right_schema, self.join_columns)

def schema(self, catalog: Catalog) -> Schema:
"""Returns the schema resulting from evaluating this QueryExpr.
Expand All @@ -998,13 +989,13 @@ def schema(self, catalog: Catalog) -> Schema:
table is the left table.
"""
input_schema = self.child.schema(catalog)
self._validate(catalog)
if isinstance(self.public_table, str):
right_schema = catalog.tables[self.public_table].schema
else:
right_schema = Schema(
spark_schema_to_analytics_columns(self.public_table.schema)
)
self._validate(catalog, input_schema, right_schema)
return _schema_for_join(
left_schema=input_schema,
right_schema=right_schema,
Expand Down