这是indexloc提供的服务,不要输入任何密码
Skip to content

implement count aggregates for group by #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This changelog documents the changes between release versions.

### Added

- You can now group documents for aggregation according to multiple grouping criteria ([#144](https://github.com/hasura/ndc-mongodb/pull/144))
- You can now group documents for aggregation according to multiple grouping criteria ([#144](https://github.com/hasura/ndc-mongodb/pull/144), [#145](https://github.com/hasura/ndc-mongodb/pull/145))

### Changed

Expand All @@ -22,7 +22,7 @@ a number of improvements to the spec, and enables features that were previously
not possible. Highlights of those new features include:

- relationships can use a nested object field on the target side as a join key
- grouping result documents, and aggregating on groups of documents (pending implementation in the mongo connector)
- grouping result documents, and aggregating on groups of documents
- queries on fields of nested collections (document fields that are arrays of objects)
- filtering on scalar values inside array document fields - previously it was possible to filter on fields of objects inside arrays, but not on scalars

Expand Down
2 changes: 1 addition & 1 deletion crates/cli/src/native_query/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ fn infer_type_from_group_stage(
None,
expr.clone(),
)?,
Accumulator::Push(expr) => {
Accumulator::AddToSet(expr) | Accumulator::Push(expr) => {
let t = infer_type_from_aggregation_expression(
context,
&format!("{desired_object_type_name}_push"),
Expand Down
32 changes: 30 additions & 2 deletions crates/integration-tests/src/tests/grouping.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use insta::assert_yaml_snapshot;
use ndc_test_helpers::{
asc, binop, column_aggregate, dimension_column, field, grouping, or, ordered_dimensions, query,
query_request, target, value,
and, asc, binop, column_aggregate, column_count_aggregate, dimension_column, field, grouping, or, ordered_dimensions, query, query_request, star_count_aggregate, target, value
};

use crate::{connector::Connector, run_connector_query};
Expand Down Expand Up @@ -40,6 +39,35 @@ async fn runs_single_column_aggregate_on_groups() -> anyhow::Result<()> {
Ok(())
}

#[tokio::test]
async fn counts_column_values_in_groups() -> anyhow::Result<()> {
assert_yaml_snapshot!(
run_connector_query(
Connector::SampleMflix,
query_request().collection("movies").query(
query()
.predicate(and([
binop("_gt", target!("year"), value!(1920)),
binop("_lte", target!("year"), value!(1923)),
]))
.groups(
grouping()
.dimensions([dimension_column("rated")])
.aggregates([
// The distinct count should be 3 or less because we filtered to only 3 years
column_count_aggregate!("year_distinct_count" => "year", distinct: true),
column_count_aggregate!("year_count" => "year", distinct: false),
star_count_aggregate!("count"),
])
.order_by(ordered_dimensions()),
),
),
)
.await?
);
Ok(())
}

#[tokio::test]
async fn groups_by_multiple_dimensions() -> anyhow::Result<()> {
assert_yaml_snapshot!(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
source: crates/integration-tests/src/tests/grouping.rs
expression: "run_connector_query(Connector::SampleMflix,\nquery_request().collection(\"movies\").query(query().predicate(and([binop(\"_gt\",\ntarget!(\"year\"), value!(1920)),\nbinop(\"_lte\", target!(\"year\"),\nvalue!(1923)),])).groups(grouping().dimensions([dimension_column(\"rated\")]).aggregates([column_count_aggregate!(\"year_distinct_count\"\n=> \"year\", distinct: true),\ncolumn_count_aggregate!(\"year_count\" => \"year\", distinct: false),\nstar_count_aggregate!(\"count\"),]).order_by(ordered_dimensions()),),),).await?"
---
- groups:
- dimensions:
- ~
aggregates:
count: 6
year_count: 6
year_distinct_count: 3
- dimensions:
- NOT RATED
aggregates:
count: 4
year_count: 4
year_distinct_count: 3
- dimensions:
- PASSED
aggregates:
count: 3
year_count: 3
year_distinct_count: 1
- dimensions:
- TV-PG
aggregates:
count: 1
year_count: 1
year_distinct_count: 1
- dimensions:
- UNRATED
aggregates:
count: 5
year_count: 5
year_distinct_count: 2
62 changes: 47 additions & 15 deletions crates/mongodb-agent-common/src/query/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Result<T> = std::result::Result<T, MongoAgentError>;
pub fn pipeline_for_groups(grouping: &Grouping) -> Result<Pipeline> {
let group_stage = Stage::Group {
key_expression: dimensions_to_expression(&grouping.dimensions).into(),
accumulators: accumulators_for_aggregates(&grouping.aggregates)?,
accumulators: accumulators_for_aggregates(&grouping.aggregates),
};

// TODO: ENG-1562 This implementation does not fully implement the
Expand Down Expand Up @@ -74,23 +74,39 @@ fn dimensions_to_expression(dimensions: &[Dimension]) -> bson::Array {
.collect()
}

// TODO: This function can be infallible once counts are implemented
fn accumulators_for_aggregates(
aggregates: &IndexMap<FieldName, Aggregate>,
) -> Result<BTreeMap<String, Accumulator>> {
) -> BTreeMap<String, Accumulator> {
aggregates
.into_iter()
.map(|(name, aggregate)| Ok((name.to_string(), aggregate_to_accumulator(aggregate)?)))
.map(|(name, aggregate)| (name.to_string(), aggregate_to_accumulator(aggregate)))
.collect()
}

// TODO: This function can be infallible once counts are implemented
fn aggregate_to_accumulator(aggregate: &Aggregate) -> Result<Accumulator> {
fn aggregate_to_accumulator(aggregate: &Aggregate) -> Accumulator {
use Aggregate as A;
match aggregate {
A::ColumnCount { .. } => Err(MongoAgentError::NotImplemented(Cow::Borrowed(
"count aggregates in groups",
))),
A::ColumnCount {
column,
field_path,
distinct,
..
} => {
let field_ref = ColumnRef::from_column_and_field_path(column, field_path.as_ref())
.into_aggregate_expression()
.into_bson();
if *distinct {
Accumulator::AddToSet(field_ref)
} else {
Accumulator::Sum(bson!({
"$cond": {
"if": { "$eq": [field_ref, null] }, // count non-null, non-missing values
"then": 0,
"else": 1,
}
}))
}
}
A::SingleColumn {
column,
field_path,
Expand All @@ -103,16 +119,14 @@ fn aggregate_to_accumulator(aggregate: &Aggregate) -> Result<Accumulator> {
.into_aggregate_expression()
.into_bson();

Ok(match function {
match function {
A::Avg => Accumulator::Avg(field_ref),
A::Min => Accumulator::Min(field_ref),
A::Max => Accumulator::Max(field_ref),
A::Sum => Accumulator::Sum(field_ref),
})
}
}
A::StarCount => Err(MongoAgentError::NotImplemented(Cow::Borrowed(
"count aggregates in groups",
))),
A::StarCount => Accumulator::Sum(bson!(1)),
}
}

Expand All @@ -130,7 +144,25 @@ fn selection_for_grouping_internal(grouping: &Grouping, dimensions_field_name: &
);
let selected_aggregates = grouping.aggregates.iter().map(|(key, aggregate)| {
let column_ref = ColumnRef::from_field(key).into_aggregate_expression();
let selection = convert_aggregate_result_type(column_ref, aggregate);
// Selecting distinct counts requires some post-processing since the $group stage produces
// an array of unique values. We need to count the non-null values in that array.
let value_expression = match aggregate {
Aggregate::ColumnCount { distinct, .. } if *distinct => bson!({
"$reduce": {
"input": column_ref,
"initialValue": 0,
"in": {
"$cond": {
"if": { "$eq": ["$$this", null] },
"then": "$$value",
"else": { "$sum": ["$$value", 1] },
}
},
}
}),
_ => column_ref.into_bson(),
};
let selection = convert_aggregate_result_type(value_expression, aggregate);
(key.to_string(), selection.into())
});
let selection_doc = std::iter::once(dimensions)
Expand Down
6 changes: 6 additions & 0 deletions crates/mongodb-support/src/aggregate/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ use serde::{Deserialize, Serialize};
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/group/#std-label-accumulators-group
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub enum Accumulator {
/// Returns an array of unique expression values for each group. Order of the array elements is undefined.
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/addToSet/#mongodb-group-grp.-addToSet
#[serde(rename = "$addToSet")]
AddToSet(bson::Bson),

/// Returns an average of numerical values. Ignores non-numeric values.
///
/// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/avg/#mongodb-group-grp.-avg
Expand Down