diff --git a/huma.go b/huma.go index 7c20f304..d4952c8b 100644 --- a/huma.go +++ b/huma.go @@ -365,14 +365,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } inputParams := findParams(registry, &op, inputType) inputBodyIndex := -1 - var inSchema *Schema if f, ok := inputType.FieldByName("Body"); ok { inputBodyIndex = f.Index[0] - inSchema = registry.Schema(f.Type, true, getHint(inputType, f.Name, op.OperationID+"Request")) op.RequestBody = &RequestBody{ Content: map[string]*MediaType{ "application/json": { - Schema: inSchema, + Schema: registry.Schema(f.Type, true, getHint(inputType, f.Name, op.OperationID+"Request")), }, }, } @@ -391,6 +389,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if f, ok := inputType.FieldByName("RawBody"); ok { rawBodyIndex = f.Index[0] } + + var inSchema *Schema + if op.RequestBody != nil && op.RequestBody.Content != nil && op.RequestBody.Content["application/json"] != nil && op.RequestBody.Content["application/json"].Schema != nil { + inSchema = op.RequestBody.Content["application/json"].Schema + } + resolvers := findResolvers(resolverType, inputType) defaults := findDefaults(inputType) @@ -628,7 +632,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) }) // Read input body if defined. - if inputBodyIndex != -1 { + if inputBodyIndex != -1 || rawBodyIndex != -1 { if op.BodyReadTimeout > 0 { ctx.SetReadDeadline(time.Now().Add(op.BodyReadTimeout)) } else if op.BodyReadTimeout < 0 { @@ -676,7 +680,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } if len(body) == 0 { - kind := v.Field(inputBodyIndex).Kind() + kind := reflect.Slice // []byte by default for raw body + if inputBodyIndex != -1 { + kind = v.Field(inputBodyIndex).Kind() + } if kind != reflect.Ptr && kind != reflect.Interface { buf.Reset() bufPool.Put(buf) @@ -711,28 +718,30 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } } - // We need to get the body into the correct type now that it has been - // validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a - // second time is faster than `mapstructure.Decode` or any of the other - // common reflection-based approaches when using real-world medium-sized - // JSON payloads with lots of strings. - f := v.Field(inputBodyIndex) - if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil { - if parseErrCount == 0 { - // Hmm, this should have worked... validator missed something? - res.Errors = append(res.Errors, &ErrorDetail{ - Location: "body", - Message: err.Error(), - Value: string(body), + if inputBodyIndex != -1 { + // We need to get the body into the correct type now that it has been + // validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a + // second time is faster than `mapstructure.Decode` or any of the other + // common reflection-based approaches when using real-world medium-sized + // JSON payloads with lots of strings. + f := v.Field(inputBodyIndex) + if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil { + if parseErrCount == 0 { + // Hmm, this should have worked... validator missed something? + res.Errors = append(res.Errors, &ErrorDetail{ + Location: "body", + Message: err.Error(), + Value: string(body), + }) + } + } else { + // Set defaults for any fields that were not in the input. + defaults.Every(v, func(item reflect.Value, def any) { + if item.IsZero() { + item.Set(reflect.Indirect(reflect.ValueOf(def))) + } }) } - } else { - // Set defaults for any fields that were not in the input. - defaults.Every(v, func(item reflect.Value, def any) { - if item.IsZero() { - item.Set(reflect.Indirect(reflect.ValueOf(def))) - } - }) } buf.Reset() diff --git a/huma_test.go b/huma_test.go index 07c47b6b..efef9582 100644 --- a/huma_test.go +++ b/huma_test.go @@ -434,6 +434,64 @@ func TestFeatures(t *testing.T) { assert.Equal(t, 256, resp.Code) }, }, + { + Name: "one-of input", + Register: func(t *testing.T, api API) { + // Step 1: create a custom schema + customSchema := &Schema{ + OneOf: []*Schema{ + { + Type: TypeObject, + Properties: map[string]*Schema{ + "foo": {Type: TypeString}, + }, + }, + { + Type: TypeArray, + Items: &Schema{ + Type: TypeObject, + Properties: map[string]*Schema{ + "foo": {Type: TypeString}, + }, + }, + }, + }, + } + customSchema.PrecomputeMessages() + + Register(api, Operation{ + Method: http.MethodPut, + Path: "/one-of", + // Step 2: register an operation with a custom schema + RequestBody: &RequestBody{ + Required: true, + Content: map[string]*MediaType{ + "application/json": { + Schema: customSchema, + }, + }, + }, + }, func(ctx context.Context, input *struct { + // Step 3: only take in raw bytes + RawBody []byte + }) (*struct{}, error) { + // Step 4: determine which it is and parse it into the right type. + // We will check the first byte but there are other ways to do this. + assert.EqualValues(t, '[', input.RawBody[0]) + var parsed []struct { + Foo string `json:"foo"` + } + assert.NoError(t, json.Unmarshal(input.RawBody, &parsed)) + assert.Len(t, parsed, 2) + assert.Equal(t, "first", parsed[0].Foo) + assert.Equal(t, "second", parsed[1].Foo) + return nil, nil + }) + }, + Method: http.MethodPut, + URL: "/one-of", + Body: `[{"foo": "first"}, {"foo": "second"}]`, + }, } { t.Run(feature.Name, func(t *testing.T) { r := chi.NewRouter() diff --git a/schema.go b/schema.go index 7d886ca1..31823090 100644 --- a/schema.go +++ b/schema.go @@ -88,6 +88,11 @@ type Schema struct { Deprecated bool `yaml:"deprecated,omitempty"` Extensions map[string]any `yaml:",inline"` + OneOf []*Schema `yaml:"oneOf,omitempty"` + AnyOf []*Schema `yaml:"anyOf,omitempty"` + AllOf []*Schema `yaml:"allOf,omitempty"` + Not *Schema `yaml:"not,omitempty"` + patternRe *regexp.Regexp `yaml:"-"` requiredMap map[string]bool `yaml:"-"` propertyNames []string `yaml:"-"` @@ -162,6 +167,22 @@ func (s *Schema) PrecomputeMessages() { s.msgRequired[name] = "expected required property " + name + " to be present" } } + + for _, sub := range s.OneOf { + sub.PrecomputeMessages() + } + + for _, sub := range s.AnyOf { + sub.PrecomputeMessages() + } + + for _, sub := range s.AllOf { + sub.PrecomputeMessages() + } + + if sub := s.Not; sub != nil { + sub.PrecomputeMessages() + } } // MarshalJSON marshals the schema into JSON, respecting the `Extensions` map diff --git a/validate.go b/validate.go index 78f9c5c1..fd21c6b7 100644 --- a/validate.go +++ b/validate.go @@ -262,6 +262,40 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult } } +func validateOneOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) { + found := false + subRes := &ValidateResult{} + for _, sub := range s.OneOf { + Validate(r, sub, path, mode, v, subRes) + if len(subRes.Errors) == 0 { + if found { + res.Add(path, v, "expected value to match exactly one schema but matched multiple") + } + found = true + } + subRes.Reset() + } + if !found { + res.Add(path, v, "expected value to match exactly one schema but matched none") + } +} + +func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) { + matches := 0 + subRes := &ValidateResult{} + for _, sub := range s.AnyOf { + Validate(r, sub, path, mode, v, subRes) + if len(subRes.Errors) == 0 { + matches++ + } + subRes.Reset() + } + + if matches == 0 { + res.Add(path, v, "expected value to match at least one schema but matched none") + } +} + // Validate an input value against a schema, collecting errors in the validation // result object. If successful, `res.Errors` will be empty. It is suggested // to use a `sync.Pool` to reuse the PathBuffer and ValidateResult objects, @@ -284,6 +318,28 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, s = r.SchemaFromRef(s.Ref) } + if s.OneOf != nil { + validateOneOf(r, s, path, mode, v, res) + } + + if s.AnyOf != nil { + validateAnyOf(r, s, path, mode, v, res) + } + + if s.AllOf != nil { + for _, sub := range s.AllOf { + Validate(r, sub, path, mode, v, res) + } + } + + if s.Not != nil { + subRes := &ValidateResult{} + Validate(r, s.Not, path, mode, v, subRes) + if len(subRes.Errors) == 0 { + res.Add(path, v, "expected value to not match schema") + } + } + switch s.Type { case TypeBoolean: if _, ok := v.(bool); !ok { diff --git a/validate_test.go b/validate_test.go index c05988c4..15e575a8 100644 --- a/validate_test.go +++ b/validate_test.go @@ -11,9 +11,14 @@ import ( "github.com/stretchr/testify/assert" ) +func Ptr[T any](v T) *T { + return &v +} + var validateTests = []struct { name string typ reflect.Type + s *Schema input any mode ValidateMode errs []string @@ -727,6 +732,105 @@ var validateTests = []struct { input: map[string]any{"value": ""}, errs: []string{"expected length >= 1"}, }, + { + name: "oneOf success bool", + s: &Schema{ + OneOf: []*Schema{ + {Type: TypeBoolean}, + {Type: TypeString}, + }, + }, + input: true, + }, + { + name: "oneOf success string", + s: &Schema{ + OneOf: []*Schema{ + {Type: TypeBoolean}, + {Type: TypeString}, + }, + }, + input: "hello", + }, + { + name: "oneOf fail zero", + s: &Schema{ + OneOf: []*Schema{ + {Type: TypeBoolean}, + {Type: TypeString}, + }, + }, + input: 123, + errs: []string{"expected value to match exactly one schema but matched none"}, + }, + { + name: "oneOf fail multi", + s: &Schema{ + OneOf: []*Schema{ + {Type: TypeNumber, Minimum: Ptr(float64(5))}, + {Type: TypeNumber, Maximum: Ptr(float64(10))}, + }, + }, + input: 8, + errs: []string{"expected value to match exactly one schema but matched multiple"}, + }, + { + name: "anyOf success", + s: &Schema{ + AnyOf: []*Schema{ + {Type: TypeNumber, Minimum: Ptr(float64(5))}, + {Type: TypeNumber, Maximum: Ptr(float64(10))}, + }, + }, + input: 8, + }, + { + name: "anyOf fail", + s: &Schema{ + AnyOf: []*Schema{ + {Type: TypeNumber, Minimum: Ptr(float64(5))}, + {Type: TypeNumber, Minimum: Ptr(float64(10))}, + }, + }, + input: 1, + errs: []string{"expected value to match at least one schema but matched none"}, + }, + { + name: "allOf success", + s: &Schema{ + AllOf: []*Schema{ + {Type: TypeNumber, Minimum: Ptr(float64(5))}, + {Type: TypeNumber, Maximum: Ptr(float64(10))}, + }, + }, + input: 8, + }, + { + name: "allOf fail", + s: &Schema{ + AllOf: []*Schema{ + {Type: TypeNumber, Minimum: Ptr(float64(5))}, + {Type: TypeNumber, Maximum: Ptr(float64(10))}, + }, + }, + input: 12, + errs: []string{"expected number <= 10"}, + }, + { + name: "not success", + s: &Schema{ + Not: &Schema{Type: TypeNumber}, + }, + input: "hello", + }, + { + name: "not fail", + s: &Schema{ + Not: &Schema{Type: TypeNumber}, + }, + input: 5, + errs: []string{"expected value to not match schema"}, + }, } func TestValidate(t *testing.T) { @@ -744,7 +848,12 @@ func TestValidate(t *testing.T) { }) return } else { - s = registry.Schema(test.typ, false, "TestInput") + if test.s != nil { + s = test.s + s.PrecomputeMessages() + } else { + s = registry.Schema(test.typ, false, "TestInput") + } } pb.Reset()