diff --git a/crates/cli/src/native_query/pipeline/mod.rs b/crates/cli/src/native_query/pipeline/mod.rs index fad8853b..664670ed 100644 --- a/crates/cli/src/native_query/pipeline/mod.rs +++ b/crates/cli/src/native_query/pipeline/mod.rs @@ -1,4 +1,5 @@ mod match_stage; +mod project_stage; use std::{collections::BTreeMap, iter::once}; @@ -54,7 +55,7 @@ pub fn infer_pipeline_types( if let TypeConstraint::Object(stage_type_name) = last_stage_type { if let Some(object_type) = context.get_object_type(&stage_type_name) { context.insert_object_type(object_type_name.clone(), object_type.into_owned()); - context.set_stage_doc_type(TypeConstraint::Object(object_type_name)) + context.set_stage_doc_type(TypeConstraint::Object(object_type_name)); } } @@ -93,9 +94,25 @@ fn infer_stage_output_type( None } Stage::Sort(_) => None, - Stage::Limit(_) => None, + Stage::Skip(expression) => { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&TypeConstraint::Scalar(BsonScalarType::Int)), + expression.clone(), + )?; + None + } + Stage::Limit(expression) => { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&TypeConstraint::Scalar(BsonScalarType::Int)), + expression.clone(), + )?; + None + } Stage::Lookup { .. } => todo!("lookup stage"), - Stage::Skip(_) => None, Stage::Group { key_expression, accumulators, @@ -110,7 +127,18 @@ fn infer_stage_output_type( } Stage::Facet(_) => todo!("facet stage"), Stage::Count(_) => todo!("count stage"), - Stage::ReplaceWith(selection) => { + Stage::Project(doc) => { + let augmented_type = project_stage::infer_type_from_project_stage( + context, + &format!("{desired_object_type_name}_project"), + doc, + )?; + Some(augmented_type) + } + Stage::ReplaceRoot { + new_root: selection, + } + | Stage::ReplaceWith(selection) => { let selection: &Document = selection.into(); Some( aggregation_expression::infer_type_from_aggregation_expression( @@ -291,7 +319,11 @@ fn infer_type_from_unwind_stage( Ok(TypeConstraint::WithFieldOverrides { augmented_object_type_name: format!("{desired_object_type_name}_unwind").into(), target_type: Box::new(context.get_input_document_type()?.clone()), - fields: unwind_stage_object_type.fields, + fields: unwind_stage_object_type + .fields + .into_iter() + .map(|(k, t)| (k, Some(t))) + .collect(), }) } @@ -360,7 +392,7 @@ mod tests { }))]); let config = mflix_config(); let pipeline_types = - infer_pipeline_types(&config, "movies", Some(&("movies".into())), &pipeline).unwrap(); + infer_pipeline_types(&config, "movies", Some(&("movies".into())), &pipeline)?; let expected = [( "movies_replaceWith".into(), ObjectType { @@ -415,13 +447,18 @@ mod tests { augmented_object_type_name: "unwind_stage_unwind".into(), target_type: Box::new(TypeConstraint::Variable(input_doc_variable)), fields: [ - ("idx".into(), TypeConstraint::Scalar(BsonScalarType::Long)), + ( + "idx".into(), + Some(TypeConstraint::Scalar(BsonScalarType::Long)) + ), ( "words".into(), - TypeConstraint::ElementOf(Box::new(TypeConstraint::FieldOf { - target_type: Box::new(TypeConstraint::Variable(input_doc_variable)), - path: nonempty!["words".into()], - })) + Some(TypeConstraint::ElementOf(Box::new( + TypeConstraint::FieldOf { + target_type: Box::new(TypeConstraint::Variable(input_doc_variable)), + path: nonempty!["words".into()], + } + ))) ) ] .into(), diff --git a/crates/cli/src/native_query/pipeline/project_stage.rs b/crates/cli/src/native_query/pipeline/project_stage.rs new file mode 100644 index 00000000..05bdea41 --- /dev/null +++ b/crates/cli/src/native_query/pipeline/project_stage.rs @@ -0,0 +1,444 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + str::FromStr as _, +}; + +use itertools::Itertools as _; +use mongodb::bson::{Bson, Decimal128, Document}; +use mongodb_support::BsonScalarType; +use ndc_models::{FieldName, ObjectTypeName}; +use nonempty::{nonempty, NonEmpty}; + +use crate::native_query::{ + aggregation_expression::infer_type_from_aggregation_expression, + error::{Error, Result}, + pipeline_type_context::PipelineTypeContext, + type_constraint::{ObjectTypeConstraint, TypeConstraint}, +}; + +enum Mode { + Exclusion, + Inclusion, +} + +// $project has two distinct behaviors: +// +// Exclusion mode: if every value in the projection document is `false` or `0` then the output +// preserves fields from the input except for fields that are specifically excluded. The special +// value `$$REMOVE` **cannot** be used in this mode. +// +// Inclusion (replace) mode: if any value in the projection document specifies a field for +// inclusion, replaces the value of an input field with a new value, adds a new field with a new +// value, or removes a field with the special value `$$REMOVE` then output excludes input fields +// that are not specified. The output is composed solely of fields specified in the projection +// document, plus `_id` unless `_id` is specifically excluded. Values of `false` or `0` are not +// allowed in this mode except to suppress `_id`. +// +// TODO: This implementation does not fully account for uses of $$REMOVE. It does correctly select +// inclusion mode if $$REMOVE is used. A complete implementation would infer a nullable type for +// a projection that conditionally resolves to $$REMOVE. +pub fn infer_type_from_project_stage( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + projection: &Document, +) -> Result { + let mode = if projection.values().all(is_false_or_zero) { + Mode::Exclusion + } else { + Mode::Inclusion + }; + match mode { + Mode::Exclusion => exclusion_projection_type(context, desired_object_type_name, projection), + Mode::Inclusion => inclusion_projection_type(context, desired_object_type_name, projection), + } +} + +fn exclusion_projection_type( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + projection: &Document, +) -> Result { + // Projection keys can be dot-separated paths to nested fields. In this case a single + // object-type output field might be specified by multiple project keys. We collect sets of + // each top-level key (the first component of a dot-separated path), and then merge + // constraints. + let mut specifications: HashMap> = Default::default(); + + for (field_name, _) in projection { + let path = field_name.split(".").map(|s| s.into()).collect_vec(); + ProjectionTree::insert_specification(&mut specifications, &path, ())?; + } + + let input_type = context.get_input_document_type()?; + Ok(projection_tree_into_field_overrides( + input_type, + desired_object_type_name, + specifications, + )) +} + +fn projection_tree_into_field_overrides( + input_type: TypeConstraint, + desired_object_type_name: &str, + specifications: HashMap>, +) -> TypeConstraint { + let overrides = specifications + .into_iter() + .map(|(name, spec)| { + let field_override = match spec { + ProjectionTree::Object(sub_specs) => { + let original_field_type = TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty![name.clone()], + }; + Some(projection_tree_into_field_overrides( + original_field_type, + &format!("{desired_object_type_name}_{name}"), + sub_specs, + )) + } + ProjectionTree::Field(_) => None, + }; + (name, field_override) + }) + .collect(); + + TypeConstraint::WithFieldOverrides { + augmented_object_type_name: desired_object_type_name.into(), + target_type: Box::new(input_type), + fields: overrides, + } +} + +fn inclusion_projection_type( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + projection: &Document, +) -> Result { + let input_type = context.get_input_document_type()?; + + // Projection keys can be dot-separated paths to nested fields. In this case a single + // object-type output field might be specified by multiple project keys. We collect sets of + // each top-level key (the first component of a dot-separated path), and then merge + // constraints. + let mut specifications: HashMap> = Default::default(); + + let added_fields = projection + .iter() + .filter(|(_, spec)| !is_false_or_zero(spec)); + + for (field_name, spec) in added_fields { + let path = field_name.split(".").map(|s| s.into()).collect_vec(); + let projected_type = if is_true_or_one(spec) { + TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: NonEmpty::from_slice(&path).ok_or_else(|| { + Error::Other("key in $project stage is an empty string".to_string()) + })?, + } + } else { + let desired_object_type_name = format!("{desired_object_type_name}_{field_name}"); + infer_type_from_aggregation_expression( + context, + &desired_object_type_name, + None, + spec.clone(), + )? + }; + ProjectionTree::insert_specification(&mut specifications, &path, projected_type)?; + } + + let specifies_id = projection.keys().any(|k| k == "_id"); + if !specifies_id { + ProjectionTree::insert_specification( + &mut specifications, + &["_id".into()], + TypeConstraint::Scalar(BsonScalarType::ObjectId), + )?; + } + + let object_type_name = + projection_tree_into_object_type(context, desired_object_type_name, specifications); + + Ok(TypeConstraint::Object(object_type_name)) +} + +fn projection_tree_into_object_type( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + specifications: HashMap>, +) -> ObjectTypeName { + let fields = specifications + .into_iter() + .map(|(field_name, spec)| { + let field_type = match spec { + ProjectionTree::Field(field_type) => field_type, + ProjectionTree::Object(sub_specs) => { + let desired_object_type_name = + format!("{desired_object_type_name}_{field_name}"); + let nested_object_name = projection_tree_into_object_type( + context, + &desired_object_type_name, + sub_specs, + ); + TypeConstraint::Object(nested_object_name) + } + }; + (field_name, field_type) + }) + .collect(); + let object_type = ObjectTypeConstraint { fields }; + let object_type_name = context.unique_type_name(desired_object_type_name); + context.insert_object_type(object_type_name.clone(), object_type); + object_type_name +} + +enum ProjectionTree { + Object(HashMap>), + Field(T), +} + +impl ProjectionTree { + fn insert_specification( + specifications: &mut HashMap>, + path: &[FieldName], + field_type: T, + ) -> Result<()> { + match path { + [] => Err(Error::Other( + "invalid $project: a projection key is an empty string".into(), + ))?, + [field_name] => { + let maybe_old_value = + specifications.insert(field_name.clone(), ProjectionTree::Field(field_type)); + if maybe_old_value.is_some() { + Err(path_collision_error(path))?; + }; + } + [first_field_name, rest @ ..] => { + let entry = specifications.entry(first_field_name.clone()); + match entry { + Entry::Occupied(mut e) => match e.get_mut() { + ProjectionTree::Object(sub_specs) => { + Self::insert_specification(sub_specs, rest, field_type)?; + } + ProjectionTree::Field(_) => Err(path_collision_error(path))?, + }, + Entry::Vacant(entry) => { + let mut sub_specs = Default::default(); + Self::insert_specification(&mut sub_specs, rest, field_type)?; + entry.insert(ProjectionTree::Object(sub_specs)); + } + }; + } + } + Ok(()) + } +} + +// Experimentation confirms that a zero value of any numeric type is interpreted as suppression of +// a field. +fn is_false_or_zero(x: &Bson) -> bool { + let decimal_zero = Decimal128::from_str("0").expect("parse 0 as decimal"); + matches!( + x, + Bson::Boolean(false) | Bson::Int32(0) | Bson::Int64(0) | Bson::Double(0.0) + ) || x == &Bson::Decimal128(decimal_zero) +} + +fn is_true_or_one(x: &Bson) -> bool { + let decimal_one = Decimal128::from_str("1").expect("parse 1 as decimal"); + matches!( + x, + Bson::Boolean(true) | Bson::Int32(1) | Bson::Int64(1) | Bson::Double(1.0) + ) || x == &Bson::Decimal128(decimal_one) +} + +fn path_collision_error(path: impl IntoIterator) -> Error { + Error::Other(format!( + "invalid $project: path collision at {}", + path.into_iter().join(".") + )) +} + +#[cfg(test)] +mod tests { + use mongodb::bson::doc; + use mongodb_support::BsonScalarType; + use nonempty::nonempty; + use pretty_assertions::assert_eq; + use test_helpers::configuration::mflix_config; + + use crate::native_query::{ + pipeline_type_context::PipelineTypeContext, + type_constraint::{ObjectTypeConstraint, TypeConstraint}, + }; + + #[test] + fn infers_type_of_projection_in_inclusion_mode() -> anyhow::Result<()> { + let config = mflix_config(); + let mut context = PipelineTypeContext::new(&config, None); + let input_type = context.set_stage_doc_type(TypeConstraint::Object("movies".into())); + + let input = doc! { + "title": 1, + "tomatoes.critic.rating": true, + "tomatoes.critic.meter": true, + "tomatoes.lastUpdated": true, + "releaseDate": "$released", + }; + + let inferred_type = + super::infer_type_from_project_stage(&mut context, "Movie_project", &input)?; + + assert_eq!( + inferred_type, + TypeConstraint::Object("Movie_project".into()) + ); + + let object_types = context.object_types(); + let expected_object_types = [ + ( + "Movie_project".into(), + ObjectTypeConstraint { + fields: [ + ( + "_id".into(), + TypeConstraint::Scalar(BsonScalarType::ObjectId), + ), + ( + "title".into(), + TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty!["title".into()], + }, + ), + ( + "tomatoes".into(), + TypeConstraint::Object("Movie_project_tomatoes".into()), + ), + ( + "releaseDate".into(), + TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty!["released".into()], + }, + ), + ] + .into(), + }, + ), + ( + "Movie_project_tomatoes".into(), + ObjectTypeConstraint { + fields: [ + ( + "critic".into(), + TypeConstraint::Object("Movie_project_tomatoes_critic".into()), + ), + ( + "lastUpdated".into(), + TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty!["tomatoes".into(), "lastUpdated".into()], + }, + ), + ] + .into(), + }, + ), + ( + "Movie_project_tomatoes_critic".into(), + ObjectTypeConstraint { + fields: [ + ( + "rating".into(), + TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty![ + "tomatoes".into(), + "critic".into(), + "rating".into() + ], + }, + ), + ( + "meter".into(), + TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty!["tomatoes".into(), "critic".into(), "meter".into()], + }, + ), + ] + .into(), + }, + ), + ] + .into(); + + assert_eq!(object_types, &expected_object_types); + + Ok(()) + } + + #[test] + fn infers_type_of_projection_in_exclusion_mode() -> anyhow::Result<()> { + let config = mflix_config(); + let mut context = PipelineTypeContext::new(&config, None); + let input_type = context.set_stage_doc_type(TypeConstraint::Object("movies".into())); + + let input = doc! { + "title": 0, + "tomatoes.critic.rating": false, + "tomatoes.critic.meter": false, + "tomatoes.lastUpdated": false, + }; + + let inferred_type = + super::infer_type_from_project_stage(&mut context, "Movie_project", &input)?; + + assert_eq!( + inferred_type, + TypeConstraint::WithFieldOverrides { + augmented_object_type_name: "Movie_project".into(), + target_type: Box::new(input_type.clone()), + fields: [ + ("title".into(), None), + ( + "tomatoes".into(), + Some(TypeConstraint::WithFieldOverrides { + augmented_object_type_name: "Movie_project_tomatoes".into(), + target_type: Box::new(TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty!["tomatoes".into()], + }), + fields: [ + ("lastUpdated".into(), None), + ( + "critic".into(), + Some(TypeConstraint::WithFieldOverrides { + augmented_object_type_name: "Movie_project_tomatoes_critic" + .into(), + target_type: Box::new(TypeConstraint::FieldOf { + target_type: Box::new(TypeConstraint::FieldOf { + target_type: Box::new(input_type.clone()), + path: nonempty!["tomatoes".into()], + }), + path: nonempty!["critic".into()], + }), + fields: [("rating".into(), None), ("meter".into(), None),] + .into(), + }) + ) + ] + .into(), + }) + ), + ] + .into(), + } + ); + + Ok(()) + } +} diff --git a/crates/cli/src/native_query/pipeline_type_context.rs b/crates/cli/src/native_query/pipeline_type_context.rs index 56fe56a3..f5460117 100644 --- a/crates/cli/src/native_query/pipeline_type_context.rs +++ b/crates/cli/src/native_query/pipeline_type_context.rs @@ -65,12 +65,17 @@ impl PipelineTypeContext<'_> { }; if let Some(type_name) = input_collection_document_type { - context.set_stage_doc_type(TypeConstraint::Object(type_name)) + context.set_stage_doc_type(TypeConstraint::Object(type_name)); } context } + #[cfg(test)] + pub fn object_types(&self) -> &BTreeMap { + &self.object_types + } + #[cfg(test)] pub fn type_variables(&self) -> &HashMap> { &self.type_variables @@ -240,7 +245,8 @@ impl PipelineTypeContext<'_> { self.constraint_references_variable(target_type, variable) || fields .iter() - .any(|(_, t)| self.constraint_references_variable(t, variable)) + .flat_map(|(_, t)| t) + .any(|t| self.constraint_references_variable(t, variable)) } } } @@ -278,9 +284,10 @@ impl PipelineTypeContext<'_> { ) } - pub fn set_stage_doc_type(&mut self, doc_type: TypeConstraint) { + pub fn set_stage_doc_type(&mut self, doc_type: TypeConstraint) -> TypeConstraint { let variable = self.new_type_variable(Variance::Covariant, [doc_type]); self.input_doc_type = Some(variable); + TypeConstraint::Variable(variable) } pub fn add_warning(&mut self, warning: Error) { diff --git a/crates/cli/src/native_query/tests.rs b/crates/cli/src/native_query/tests.rs index b30d36b0..504ee1e1 100644 --- a/crates/cli/src/native_query/tests.rs +++ b/crates/cli/src/native_query/tests.rs @@ -9,12 +9,13 @@ use configuration::{ Configuration, }; use googletest::prelude::*; +use itertools::Itertools as _; use mongodb::bson::doc; use mongodb_support::{ aggregate::{Accumulator, Pipeline, Selection, Stage}, BsonScalarType, }; -use ndc_models::ObjectTypeName; +use ndc_models::{FieldName, ObjectTypeName}; use pretty_assertions::assert_eq; use test_helpers::configuration::mflix_config; @@ -323,6 +324,120 @@ fn supports_various_aggregation_operators() -> googletest::Result<()> { Ok(()) } +#[googletest::test] +fn supports_project_stage_in_exclusion_mode() -> Result<()> { + let config = mflix_config(); + + let pipeline = Pipeline::new(vec![Stage::Project(doc! { + "title": 0, + "tomatoes.critic.rating": false, + "tomatoes.critic.meter": false, + "tomatoes.lastUpdated": false, + })]); + + let native_query = + native_query_from_pipeline(&config, "project_test", Some("movies".into()), pipeline)?; + + let result_type_name = native_query.result_document_type; + let result_type = &native_query.object_types[&result_type_name]; + + expect_false!(result_type.fields.contains_key("title")); + + let tomatoes_type_name = match result_type.fields.get("tomatoes") { + Some(ObjectField { + r#type: Type::Object(name), + .. + }) => ObjectTypeName::from(name.clone()), + _ => panic!("tomatoes field does not have an object type"), + }; + let tomatoes_type = &native_query.object_types[&tomatoes_type_name]; + expect_that!( + tomatoes_type.fields.keys().collect_vec(), + unordered_elements_are![&&FieldName::from("viewer"), &&FieldName::from("critic")] + ); + expect_eq!( + tomatoes_type.fields["viewer"].r#type, + Type::Object("TomatoesCriticViewer".into()), + ); + + let critic_type_name = match tomatoes_type.fields.get("critic") { + Some(ObjectField { + r#type: Type::Object(name), + .. + }) => ObjectTypeName::from(name.clone()), + _ => panic!("tomatoes.critic field does not have an object type"), + }; + let critic_type = &native_query.object_types[&critic_type_name]; + expect_eq!( + critic_type.fields, + object_fields([("numReviews", Type::Scalar(BsonScalarType::Int))]), + ); + + Ok(()) +} + +#[googletest::test] +fn supports_project_stage_in_inclusion_mode() -> Result<()> { + let config = mflix_config(); + + let pipeline = Pipeline::new(vec![Stage::Project(doc! { + "title": 1, + "tomatoes.critic.rating": true, + "tomatoes.critic.meter": true, + "tomatoes.lastUpdated": true, + "releaseDate": "$released", + })]); + + let native_query = + native_query_from_pipeline(&config, "inclusion", Some("movies".into()), pipeline)?; + + expect_eq!(native_query.result_document_type, "inclusion_project".into()); + + expect_eq!( + native_query.object_types, + [ + ( + "inclusion_project".into(), + ObjectType { + fields: object_fields([ + ("_id", Type::Scalar(BsonScalarType::ObjectId)), + ("title", Type::Scalar(BsonScalarType::String)), + ("tomatoes", Type::Object("inclusion_project_tomatoes".into())), + ("releaseDate", Type::Scalar(BsonScalarType::Date)), + ]), + description: None + } + ), + ( + "inclusion_project_tomatoes".into(), + ObjectType { + fields: object_fields([ + ( + "critic", + Type::Object("inclusion_project_tomatoes_critic".into()) + ), + ("lastUpdated", Type::Scalar(BsonScalarType::Date)), + ]), + description: None + } + ), + ( + "inclusion_project_tomatoes_critic".into(), + ObjectType { + fields: object_fields([ + ("rating", Type::Scalar(BsonScalarType::Double)), + ("meter", Type::Scalar(BsonScalarType::Int)), + ]), + description: None + } + ) + ] + .into(), + ); + + Ok(()) +} + fn object_fields(types: impl IntoIterator) -> BTreeMap where S: Into, diff --git a/crates/cli/src/native_query/type_constraint.rs b/crates/cli/src/native_query/type_constraint.rs index 67c04156..3b046dfc 100644 --- a/crates/cli/src/native_query/type_constraint.rs +++ b/crates/cli/src/native_query/type_constraint.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, BTreeSet}; use configuration::MongoScalarType; +use itertools::Itertools as _; use mongodb_support::BsonScalarType; use ndc_models::{FieldName, ObjectTypeName}; use nonempty::NonEmpty; @@ -81,14 +82,50 @@ pub enum TypeConstraint { path: NonEmpty, }, - /// A type that modifies another type by adding or replacing object fields. + /// A type that modifies another type by adding, replacing, or subtracting object fields. WithFieldOverrides { augmented_object_type_name: ObjectTypeName, target_type: Box, - fields: BTreeMap, + fields: BTreeMap>, }, } +impl std::fmt::Display for TypeConstraint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeConstraint::ExtendedJSON => write!(f, "ExtendedJSON"), + TypeConstraint::Scalar(s) => s.fmt(f), + TypeConstraint::Object(name) => write!(f, "Object({name})"), + TypeConstraint::ArrayOf(t) => write!(f, "[{t}]"), + TypeConstraint::Predicate { object_type_name } => { + write!(f, "Predicate({object_type_name})") + } + TypeConstraint::Union(ts) => write!(f, "{}", ts.iter().join(" | ")), + TypeConstraint::OneOf(ts) => write!(f, "{}", ts.iter().join(" / ")), + TypeConstraint::Variable(v) => v.fmt(f), + TypeConstraint::ElementOf(t) => write!(f, "{t}[@]"), + TypeConstraint::FieldOf { target_type, path } => { + write!(f, "{target_type}.{}", path.iter().join(".")) + } + TypeConstraint::WithFieldOverrides { + augmented_object_type_name, + target_type, + fields, + } => { + writeln!(f, "{target_type} // {augmented_object_type_name} {{")?; + for (name, spec) in fields { + write!(f, " {name}: ")?; + match spec { + Some(t) => write!(f, "{t}"), + None => write!(f, "-"), + }?; + } + write!(f, "}}") + } + } + } +} + impl TypeConstraint { /// Order constraints by complexity to help with type unification pub fn complexity(&self) -> usize { @@ -122,6 +159,7 @@ impl TypeConstraint { } => { let overridden_field_complexity: usize = fields .values() + .flatten() .map(|constraint| constraint.complexity()) .sum(); 2 + target_type.complexity() + overridden_field_complexity diff --git a/crates/cli/src/native_query/type_solver/constraint_to_type.rs b/crates/cli/src/native_query/type_solver/constraint_to_type.rs index bc0d4557..76d3b4dd 100644 --- a/crates/cli/src/native_query/type_solver/constraint_to_type.rs +++ b/crates/cli/src/native_query/type_solver/constraint_to_type.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, VecDeque}; use configuration::{ schema::{ObjectField, ObjectType, Type}, @@ -16,6 +16,8 @@ use TypeConstraint as C; /// In cases where there is enough information present in one constraint itself to infer a concrete /// type, do that. Returns None if there is not enough information present. +/// +/// TODO: Most of this logic should be moved to `simplify_one` pub fn constraint_to_type( configuration: &Configuration, solutions: &HashMap, @@ -124,8 +126,9 @@ pub fn constraint_to_type( object_type_constraints, target_type, )?; - let resolved_field_types: Option> = fields + let added_or_replaced_fields: Option> = fields .iter() + .flat_map(|(field_name, option_t)| option_t.as_ref().map(|t| (field_name, t))) .map(|(field_name, t)| { Ok(constraint_to_type( configuration, @@ -137,15 +140,23 @@ pub fn constraint_to_type( .map(|t| (field_name.clone(), t))) }) .collect::>()?; - match (resolved_object_type, resolved_field_types) { - (Some(object_type), Some(fields)) => with_field_overrides( + let subtracted_fields = fields + .iter() + .filter_map(|(n, option_t)| match option_t { + Some(_) => None, + None => Some(n), + }) + .collect_vec(); + match (resolved_object_type, added_or_replaced_fields) { + (Some(object_type), Some(added_fields)) => with_field_overrides( configuration, solutions, added_object_types, object_type_constraints, object_type, augmented_object_type_name.clone(), - fields, + added_fields, + subtracted_fields, )?, _ => None, } @@ -242,8 +253,8 @@ fn field_of<'a>( return Ok(None); }; - let mut path_iter = path.into_iter(); - let Some(field_name) = path_iter.next() else { + let mut path: VecDeque<_> = path.into_iter().collect(); + let Some(field_name) = path.pop_front() else { return Ok(Some(Type::Object(type_name))); }; @@ -256,7 +267,18 @@ fn field_of<'a>( field_name: field_name.clone(), })?; - Ok(Some(field_type.r#type.clone())) + if path.is_empty() { + Ok(Some(field_type.r#type.clone())) + } else { + field_of( + configuration, + solutions, + added_object_types, + object_type_constraints, + field_type.r#type.clone(), + path, + ) + } } Type::Nullable(t) => { let underlying_type = field_of( @@ -274,14 +296,16 @@ fn field_of<'a>( Ok(field_type.map(Type::normalize_type)) } -fn with_field_overrides( +#[allow(clippy::too_many_arguments)] +fn with_field_overrides<'a>( configuration: &Configuration, solutions: &HashMap, added_object_types: &mut BTreeMap, object_type_constraints: &mut BTreeMap, object_type: Type, augmented_object_type_name: ObjectTypeName, - fields: impl IntoIterator, + added_or_replaced_fields: impl IntoIterator, + subtracted_fields: impl IntoIterator, ) -> Result> { let augmented_object_type = match object_type { Type::ExtendedJSON => Some(Type::ExtendedJSON), @@ -297,7 +321,7 @@ fn with_field_overrides( return Ok(None); }; let mut new_object_type = object_type.clone(); - for (field_name, field_type) in fields.into_iter() { + for (field_name, field_type) in added_or_replaced_fields.into_iter() { new_object_type.fields.insert( field_name, ObjectField { @@ -306,6 +330,9 @@ fn with_field_overrides( }, ); } + for field_name in subtracted_fields { + new_object_type.fields.remove(field_name); + } // We might end up back-tracking in which case this will register an object type that // isn't referenced. BUT once solving is complete we should get here again with the // same augmented_object_type_name, overwrite the old definition with an identical one, @@ -321,7 +348,8 @@ fn with_field_overrides( object_type_constraints, *t, augmented_object_type_name, - fields, + added_or_replaced_fields, + subtracted_fields, )?; underlying_type.map(|t| Type::Nullable(Box::new(t))) } diff --git a/crates/cli/src/native_query/type_solver/mod.rs b/crates/cli/src/native_query/type_solver/mod.rs index 74897ff0..bc7a8f38 100644 --- a/crates/cli/src/native_query/type_solver/mod.rs +++ b/crates/cli/src/native_query/type_solver/mod.rs @@ -35,7 +35,24 @@ pub fn unify( } #[cfg(test)] - println!("begin unify:\n type_variables: {type_variables:?}\n object_type_constraints: {object_type_constraints:?}\n"); + { + println!("begin unify:"); + println!(" type_variables:"); + for (var, constraints) in type_variables.iter() { + println!( + " - {var}: {}", + constraints.iter().map(|c| format!("{c}")).join("; ") + ); + } + println!(" object_type_constraints:"); + for (name, ot) in object_type_constraints.iter() { + println!(" {name} ::",); + for (field_name, field_type) in ot.fields.iter() { + println!(" - {field_name}: {field_type}") + } + } + println!(); + } loop { let prev_type_variables = type_variables.clone(); diff --git a/crates/cli/src/native_query/type_solver/simplify.rs b/crates/cli/src/native_query/type_solver/simplify.rs index d41d8e0d..436c0972 100644 --- a/crates/cli/src/native_query/type_solver/simplify.rs +++ b/crates/cli/src/native_query/type_solver/simplify.rs @@ -130,6 +130,8 @@ fn simplify_constraint_pair( (C::ExtendedJSON, b) if variance == Variance::Contravariant => Ok(b), (a, C::ExtendedJSON) if variance == Variance::Contravariant => Ok(a), + // TODO: If we don't get a solution from solve_scalar, if the variable is covariant we want + // to make a union type (C::Scalar(a), C::Scalar(b)) => solve_scalar(variance, a, b), (C::Union(mut a), C::Union(mut b)) if variance == Variance::Covariant => { @@ -498,10 +500,14 @@ fn get_object_constraint_field_type( #[cfg(test)] mod tests { + use std::collections::BTreeSet; + use googletest::prelude::*; use mongodb_support::BsonScalarType; + use nonempty::nonempty; + use test_helpers::configuration::mflix_config; - use crate::native_query::type_constraint::{TypeConstraint, Variance}; + use crate::native_query::type_constraint::{TypeConstraint, TypeVariable, Variance}; #[googletest::test] fn multiple_identical_scalar_constraints_resolve_one_constraint() { @@ -546,4 +552,26 @@ mod tests { Ok(TypeConstraint::Scalar(BsonScalarType::Int)) ); } + + #[googletest::test] + fn simplifies_field_of() -> Result<()> { + let config = mflix_config(); + let result = super::simplify_constraints( + &config, + &Default::default(), + &mut Default::default(), + Some(TypeVariable::new(1, Variance::Covariant)), + [TypeConstraint::FieldOf { + target_type: Box::new(TypeConstraint::Object("movies".into())), + path: nonempty!["title".into()], + }], + ); + expect_that!( + result, + matches_pattern!(Ok(&BTreeSet::from_iter([TypeConstraint::Scalar( + BsonScalarType::String + )]))) + ); + Ok(()) + } } diff --git a/crates/mongodb-agent-common/src/query/pipeline.rs b/crates/mongodb-agent-common/src/query/pipeline.rs index a831d923..9a515f37 100644 --- a/crates/mongodb-agent-common/src/query/pipeline.rs +++ b/crates/mongodb-agent-common/src/query/pipeline.rs @@ -78,7 +78,7 @@ pub fn pipeline_for_non_foreach( .map(make_sort_stages) .flatten_ok() .collect::, _>>()?; - let skip_stage = offset.map(Stage::Skip); + let skip_stage = offset.map(Into::into).map(Stage::Skip); match_stage .into_iter() @@ -132,7 +132,7 @@ pub fn pipeline_for_fields_facet( } } - let limit_stage = limit.map(Stage::Limit); + let limit_stage = limit.map(Into::into).map(Stage::Limit); let replace_with_stage: Stage = Stage::ReplaceWith(selection); Ok(Pipeline::from_iter( @@ -245,7 +245,7 @@ fn pipeline_for_aggregate( Some(Stage::Match( bson::doc! { column.as_str(): { "$exists": true, "$ne": null } }, )), - limit.map(Stage::Limit), + limit.map(Into::into).map(Stage::Limit), Some(Stage::Group { key_expression: field_ref(column.as_str()), accumulators: [].into(), @@ -261,7 +261,7 @@ fn pipeline_for_aggregate( Some(Stage::Match( bson::doc! { column.as_str(): { "$exists": true, "$ne": null } }, )), - limit.map(Stage::Limit), + limit.map(Into::into).map(Stage::Limit), Some(Stage::Count(RESULT_FIELD.to_string())), ] .into_iter() @@ -285,7 +285,7 @@ fn pipeline_for_aggregate( Some(Stage::Match( bson::doc! { column: { "$exists": true, "$ne": null } }, )), - limit.map(Stage::Limit), + limit.map(Into::into).map(Stage::Limit), Some(Stage::Group { key_expression: Bson::Null, accumulators: [(RESULT_FIELD.to_string(), accumulator)].into(), @@ -298,7 +298,7 @@ fn pipeline_for_aggregate( Aggregate::StarCount {} => Pipeline::from_iter( [ - limit.map(Stage::Limit), + limit.map(Into::into).map(Stage::Limit), Some(Stage::Count(RESULT_FIELD.to_string())), ] .into_iter() diff --git a/crates/mongodb-agent-common/src/query/relations.rs b/crates/mongodb-agent-common/src/query/relations.rs index 4018f4c8..44efcc6f 100644 --- a/crates/mongodb-agent-common/src/query/relations.rs +++ b/crates/mongodb-agent-common/src/query/relations.rs @@ -855,7 +855,7 @@ mod tests { } }, { - "$limit": Bson::Int64(50), + "$limit": Bson::Int32(50), }, { "$replaceWith": { @@ -975,7 +975,7 @@ mod tests { } }, { - "$limit": Bson::Int64(50), + "$limit": Bson::Int32(50), }, { "$replaceWith": { diff --git a/crates/mongodb-support/src/aggregate/stage.rs b/crates/mongodb-support/src/aggregate/stage.rs index 3b45630b..76ee4e93 100644 --- a/crates/mongodb-support/src/aggregate/stage.rs +++ b/crates/mongodb-support/src/aggregate/stage.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use mongodb::bson; +use mongodb::bson::{self, Bson}; use serde::{Deserialize, Serialize}; use super::{Accumulator, Pipeline, Selection, SortDocument}; @@ -50,7 +50,7 @@ pub enum Stage { /// /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/limit/#mongodb-pipeline-pipe.-limit #[serde(rename = "$limit")] - Limit(u32), + Limit(Bson), /// Performs a left outer join to another collection in the same database to filter in /// documents from the "joined" collection for processing. @@ -114,7 +114,7 @@ pub enum Stage { /// /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/skip/#mongodb-pipeline-pipe.-skip #[serde(rename = "$skip")] - Skip(u32), + Skip(Bson), /// Groups input documents by a specified identifier expression and applies the accumulator /// expression(s), if specified, to each group. Consumes all input documents and outputs one @@ -152,6 +152,25 @@ pub enum Stage { #[serde(rename = "$count")] Count(String), + /// Reshapes each document in the stream, such as by adding new fields or removing existing + /// fields. For each input document, outputs one document. + /// + /// See also $unset for removing existing fields. + /// + /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/project/#mongodb-pipeline-pipe.-project + #[serde(rename = "$project")] + Project(bson::Document), + + /// Replaces a document with the specified embedded document. The operation replaces all + /// existing fields in the input document, including the _id field. Specify a document embedded + /// in the input document to promote the embedded document to the top level. + /// + /// $replaceWith is an alias for $replaceRoot stage. + /// + /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/replaceRoot/#mongodb-pipeline-pipe.-replaceRoot + #[serde(rename = "$replaceWith", rename_all = "camelCase")] + ReplaceRoot { new_root: Selection }, + /// Replaces a document with the specified embedded document. The operation replaces all /// existing fields in the input document, including the _id field. Specify a document embedded /// in the input document to promote the embedded document to the top level.