这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
inputParams := findParams(registry, &op, inputType)
inputBodyIndex := -1
var inSchema *Schema
if f, ok := inputType.FieldByName("Body"); ok {
inputBodyIndex = f.Index[0]
inSchema = registry.Schema(f.Type, true, getHint(inputType, f.Name, op.OperationID+"Request"))
op.RequestBody = &RequestBody{
Content: map[string]*MediaType{
"application/json": {
Schema: inSchema,
Schema: registry.Schema(f.Type, true, getHint(inputType, f.Name, op.OperationID+"Request")),
},
},
}
Expand All @@ -391,6 +389,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if f, ok := inputType.FieldByName("RawBody"); ok {
rawBodyIndex = f.Index[0]
}

var inSchema *Schema
if op.RequestBody != nil && op.RequestBody.Content != nil && op.RequestBody.Content["application/json"] != nil && op.RequestBody.Content["application/json"].Schema != nil {
inSchema = op.RequestBody.Content["application/json"].Schema
}

resolvers := findResolvers(resolverType, inputType)
defaults := findDefaults(inputType)

Expand Down Expand Up @@ -628,7 +632,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
})

// Read input body if defined.
if inputBodyIndex != -1 {
if inputBodyIndex != -1 || rawBodyIndex != -1 {
if op.BodyReadTimeout > 0 {
ctx.SetReadDeadline(time.Now().Add(op.BodyReadTimeout))
} else if op.BodyReadTimeout < 0 {
Expand Down Expand Up @@ -676,7 +680,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

if len(body) == 0 {
kind := v.Field(inputBodyIndex).Kind()
kind := reflect.Slice // []byte by default for raw body
if inputBodyIndex != -1 {
kind = v.Field(inputBodyIndex).Kind()
}
if kind != reflect.Ptr && kind != reflect.Interface {
buf.Reset()
bufPool.Put(buf)
Expand Down Expand Up @@ -711,28 +718,30 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
}

// We need to get the body into the correct type now that it has been
// validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a
// second time is faster than `mapstructure.Decode` or any of the other
// common reflection-based approaches when using real-world medium-sized
// JSON payloads with lots of strings.
f := v.Field(inputBodyIndex)
if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil {
if parseErrCount == 0 {
// Hmm, this should have worked... validator missed something?
res.Errors = append(res.Errors, &ErrorDetail{
Location: "body",
Message: err.Error(),
Value: string(body),
if inputBodyIndex != -1 {
// We need to get the body into the correct type now that it has been
// validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a
// second time is faster than `mapstructure.Decode` or any of the other
// common reflection-based approaches when using real-world medium-sized
// JSON payloads with lots of strings.
f := v.Field(inputBodyIndex)
if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil {
if parseErrCount == 0 {
// Hmm, this should have worked... validator missed something?
res.Errors = append(res.Errors, &ErrorDetail{
Location: "body",
Message: err.Error(),
Value: string(body),
})
}
} else {
// Set defaults for any fields that were not in the input.
defaults.Every(v, func(item reflect.Value, def any) {
if item.IsZero() {
item.Set(reflect.Indirect(reflect.ValueOf(def)))
}
})
}
} else {
// Set defaults for any fields that were not in the input.
defaults.Every(v, func(item reflect.Value, def any) {
if item.IsZero() {
item.Set(reflect.Indirect(reflect.ValueOf(def)))
}
})
}

buf.Reset()
Expand Down
58 changes: 58 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,64 @@ func TestFeatures(t *testing.T) {
assert.Equal(t, 256, resp.Code)
},
},
{
Name: "one-of input",
Register: func(t *testing.T, api API) {
// Step 1: create a custom schema
customSchema := &Schema{
OneOf: []*Schema{
{
Type: TypeObject,
Properties: map[string]*Schema{
"foo": {Type: TypeString},
},
},
{
Type: TypeArray,
Items: &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"foo": {Type: TypeString},
},
},
},
},
}
customSchema.PrecomputeMessages()

Register(api, Operation{
Method: http.MethodPut,
Path: "/one-of",
// Step 2: register an operation with a custom schema
RequestBody: &RequestBody{
Required: true,
Content: map[string]*MediaType{
"application/json": {
Schema: customSchema,
},
},
},
}, func(ctx context.Context, input *struct {
// Step 3: only take in raw bytes
RawBody []byte
}) (*struct{}, error) {
// Step 4: determine which it is and parse it into the right type.
// We will check the first byte but there are other ways to do this.
assert.EqualValues(t, '[', input.RawBody[0])
var parsed []struct {
Foo string `json:"foo"`
}
assert.NoError(t, json.Unmarshal(input.RawBody, &parsed))
assert.Len(t, parsed, 2)
assert.Equal(t, "first", parsed[0].Foo)
assert.Equal(t, "second", parsed[1].Foo)
return nil, nil
})
},
Method: http.MethodPut,
URL: "/one-of",
Body: `[{"foo": "first"}, {"foo": "second"}]`,
},
} {
t.Run(feature.Name, func(t *testing.T) {
r := chi.NewRouter()
Expand Down
21 changes: 21 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ type Schema struct {
Deprecated bool `yaml:"deprecated,omitempty"`
Extensions map[string]any `yaml:",inline"`

OneOf []*Schema `yaml:"oneOf,omitempty"`
AnyOf []*Schema `yaml:"anyOf,omitempty"`
AllOf []*Schema `yaml:"allOf,omitempty"`
Not *Schema `yaml:"not,omitempty"`

patternRe *regexp.Regexp `yaml:"-"`
requiredMap map[string]bool `yaml:"-"`
propertyNames []string `yaml:"-"`
Expand Down Expand Up @@ -162,6 +167,22 @@ func (s *Schema) PrecomputeMessages() {
s.msgRequired[name] = "expected required property " + name + " to be present"
}
}

for _, sub := range s.OneOf {
sub.PrecomputeMessages()
}

for _, sub := range s.AnyOf {
sub.PrecomputeMessages()
}

for _, sub := range s.AllOf {
sub.PrecomputeMessages()
}

if sub := s.Not; sub != nil {
sub.PrecomputeMessages()
}
}

// MarshalJSON marshals the schema into JSON, respecting the `Extensions` map
Expand Down
56 changes: 56 additions & 0 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,40 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
}

func validateOneOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) {
found := false
subRes := &ValidateResult{}
for _, sub := range s.OneOf {
Validate(r, sub, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
if found {
res.Add(path, v, "expected value to match exactly one schema but matched multiple")
}
found = true
}
subRes.Reset()
}
if !found {
res.Add(path, v, "expected value to match exactly one schema but matched none")
}
}

func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) {
matches := 0
subRes := &ValidateResult{}
for _, sub := range s.AnyOf {
Validate(r, sub, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
matches++
}
subRes.Reset()
}

if matches == 0 {
res.Add(path, v, "expected value to match at least one schema but matched none")
}
}

// Validate an input value against a schema, collecting errors in the validation
// result object. If successful, `res.Errors` will be empty. It is suggested
// to use a `sync.Pool` to reuse the PathBuffer and ValidateResult objects,
Expand All @@ -284,6 +318,28 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
s = r.SchemaFromRef(s.Ref)
}

if s.OneOf != nil {
validateOneOf(r, s, path, mode, v, res)
}

if s.AnyOf != nil {
validateAnyOf(r, s, path, mode, v, res)
}

if s.AllOf != nil {
for _, sub := range s.AllOf {
Validate(r, sub, path, mode, v, res)
}
}

if s.Not != nil {
subRes := &ValidateResult{}
Validate(r, s.Not, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
res.Add(path, v, "expected value to not match schema")
}
}

switch s.Type {
case TypeBoolean:
if _, ok := v.(bool); !ok {
Expand Down
Loading