diff --git a/formdata.go b/formdata.go index f153348c..b170d83d 100644 --- a/formdata.go +++ b/formdata.go @@ -85,30 +85,51 @@ func (v MimeTypeValidator) Validate(fh *multipart.FileHeader, location string) ( } } -func (m *MultipartFormFiles[T]) readFile( - fh *multipart.FileHeader, - location string, - validator MimeTypeValidator, -) (FormFile, *ErrorDetail) { - f, err := fh.Open() - if err != nil { - return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location} - } - contentType, validationErr := validator.Validate(fh, location) - if validationErr != nil { - return FormFile{}, validationErr +func (m *MultipartFormFiles[T]) Data() *T { + return m.data +} + +// Decodes multipart.Form data into *T, returning []*ErrorDetail if any +// Schema is used to check for validation constraints +func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error { + var ( + dataType = reflect.TypeOf(m.data).Elem() + value = reflect.New(dataType) + errors []error + ) + for i := 0; i < dataType.NumField(); i++ { + field := value.Elem().Field(i) + structField := dataType.Field(i) + key := structField.Tag.Get("form") + if key == "" { + key = structField.Name + } + fileHeaders := m.Form.File[key] + switch { + case field.Type() == reflect.TypeOf(FormFile{}): + file, err := readSingleFile(fileHeaders, key, opMediaType) + if err != nil { + errors = append(errors, err) + continue + } + field.Set(reflect.ValueOf(file)) + case field.Type() == reflect.TypeOf([]FormFile{}): + files, errs := readMultipleFiles(fileHeaders, key, opMediaType) + if errs != nil { + errors = append(errors, errs...) + continue + } + field.Set(reflect.ValueOf(files)) + + default: + continue + } } - return FormFile{ - File: f, - ContentType: contentType, - IsSet: true, - Size: fh.Size, - Filename: fh.Filename, - }, nil + m.data = value.Interface().(*T) + return errors } -func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaType) (FormFile, *ErrorDetail) { - fileHeaders := m.Form.File[key] +func readSingleFile(fileHeaders []*multipart.FileHeader, key string, opMediaType *MediaType) (FormFile, *ErrorDetail) { if len(fileHeaders) == 0 { if opMediaType.Schema.requiredMap[key] { return FormFile{}, &ErrorDetail{Message: "File required", Location: key} @@ -117,7 +138,7 @@ func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaTyp } } else if len(fileHeaders) == 1 { validator := NewMimeTypeValidator(opMediaType.Encoding[key]) - return m.readFile(fileHeaders[0], key, validator) + return readFile(fileHeaders[0], key, validator) } return FormFile{}, &ErrorDetail{ Message: "Multiple files received but only one was expected", @@ -125,8 +146,7 @@ func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaTyp } } -func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *MediaType) ([]FormFile, []error) { - fileHeaders := m.Form.File[key] +func readMultipleFiles(fileHeaders []*multipart.FileHeader, key string, opMediaType *MediaType) ([]FormFile, []error) { var ( files = make([]FormFile, len(fileHeaders)) errors []error @@ -136,7 +156,7 @@ func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *Media } validator := NewMimeTypeValidator(opMediaType.Encoding[key]) for i, fh := range fileHeaders { - file, err := m.readFile( + file, err := readFile( fh, fmt.Sprintf("%s[%d]", key, i), validator, @@ -150,47 +170,26 @@ func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *Media return files, errors } -func (m *MultipartFormFiles[T]) Data() *T { - return m.data -} - -// Decodes multipart.Form data into *T, returning []*ErrorDetail if any -// Schema is used to check for validation constraints -func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error { - var ( - dataType = reflect.TypeOf(m.data).Elem() - value = reflect.New(dataType) - errors []error - ) - for i := 0; i < dataType.NumField(); i++ { - field := value.Elem().Field(i) - structField := dataType.Field(i) - key := structField.Tag.Get("form") - if key == "" { - key = structField.Name - } - switch { - case field.Type() == reflect.TypeOf(FormFile{}): - file, err := m.readSingleFile(key, opMediaType) - if err != nil { - errors = append(errors, err) - continue - } - field.Set(reflect.ValueOf(file)) - case field.Type() == reflect.TypeOf([]FormFile{}): - files, errs := m.readMultipleFiles(key, opMediaType) - if errs != nil { - errors = append(errors, errs...) - continue - } - field.Set(reflect.ValueOf(files)) - - default: - continue - } +func readFile( + fh *multipart.FileHeader, + location string, + validator MimeTypeValidator, +) (FormFile, *ErrorDetail) { + f, err := fh.Open() + if err != nil { + return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location} } - m.data = value.Interface().(*T) - return errors + contentType, validationErr := validator.Validate(fh, location) + if validationErr != nil { + return FormFile{}, validationErr + } + return FormFile{ + File: f, + ContentType: contentType, + IsSet: true, + Size: fh.Size, + Filename: fh.Filename, + }, nil } func formDataFieldName(f reflect.StructField) string { @@ -208,7 +207,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema { Properties: make(map[string]*Schema, nFields), requiredMap: make(map[string]bool, nFields), } - requiredFields := make([]string, nFields) + requiredFields := make([]string, 0, nFields) for i := 0; i < nFields; i++ { f := t.Field(i) name := formDataFieldName(f) @@ -227,7 +226,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema { } if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) { - requiredFields[i] = name + requiredFields = append(requiredFields, name) schema.requiredMap[name] = true } } diff --git a/huma.go b/huma.go index eb0e5457..9de2165a 100644 --- a/huma.go +++ b/huma.go @@ -611,293 +611,35 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if op.Method == "" || op.Path == "" { panic("method and path must be specified in operation") } + initResponses(&op) inputType := reflect.TypeOf((*I)(nil)).Elem() if inputType.Kind() != reflect.Struct { panic("input must be a struct") } - inputParams := findParams(registry, &op, inputType) - inputBodyIndex := []int{} - hasInputBody := false - if f, ok := inputType.FieldByName("Body"); ok { - hasInputBody = true - inputBodyIndex = f.Index - if op.RequestBody == nil { - op.RequestBody = &RequestBody{} - } - - required := f.Type.Kind() != reflect.Ptr && f.Type.Kind() != reflect.Interface - if f.Tag.Get("required") == "true" { - required = true - } - - contentType := "application/json" - if c := f.Tag.Get("contentType"); c != "" { - contentType = c - } - hint := getHint(inputType, f.Name, op.OperationID+"Request") - if nameHint := f.Tag.Get("nameHint"); nameHint != "" { - hint = nameHint - } - s := SchemaFromField(registry, f, hint) - - op.RequestBody.Required = required - - if op.RequestBody.Content == nil { - op.RequestBody.Content = map[string]*MediaType{} - } - if op.RequestBody.Content[contentType] == nil { - op.RequestBody.Content[contentType] = &MediaType{} - } - op.RequestBody.Content[contentType].Schema = s - - if op.BodyReadTimeout == 0 { - // 5 second default - op.BodyReadTimeout = 5 * time.Second - } - - if op.MaxBodyBytes == 0 { - // 1 MB default - op.MaxBodyBytes = 1024 * 1024 - } - } - rawBodyIndex := []int{} - rawBodyMultipart := false - rawBodyDecodedMultipart := false - if f, ok := inputType.FieldByName("RawBody"); ok { - rawBodyIndex = f.Index - if op.RequestBody == nil { - op.RequestBody = &RequestBody{ - Required: true, - } - } - - if op.RequestBody.Content == nil { - op.RequestBody.Content = map[string]*MediaType{} - } - - contentType := "application/octet-stream" - - if f.Type.String() == "multipart.Form" { - contentType = "multipart/form-data" - rawBodyMultipart = true - } - if strings.HasPrefix(f.Type.Name(), "MultipartFormFiles") { - contentType = "multipart/form-data" - rawBodyDecodedMultipart = true - } - - if c := f.Tag.Get("contentType"); c != "" { - contentType = c - } - - switch contentType { - case "multipart/form-data": - if op.RequestBody.Content["multipart/form-data"] != nil { - break - } - if rawBodyMultipart { - op.RequestBody.Content["multipart/form-data"] = &MediaType{ - Schema: &Schema{ - Type: "object", - Properties: map[string]*Schema{ - "name": { - Type: "string", - Description: "general purpose name for multipart form value", - }, - "filename": { - Type: "string", - Format: "binary", - Description: "filename of the file being uploaded", - }, - }, - }, - } - } - if rawBodyDecodedMultipart { - dataField, ok := f.Type.FieldByName("data") - if !ok { - panic("Expected type MultipartFormFiles[T] to have a 'data *T' generic pointer field") - } - op.RequestBody.Content["multipart/form-data"] = &MediaType{ - Schema: multiPartFormFileSchema(dataField.Type.Elem()), - Encoding: multiPartContentEncoding(dataField.Type.Elem()), - } - op.RequestBody.Required = false - } - default: - op.RequestBody.Content[contentType] = &MediaType{ - Schema: &Schema{ - Type: "string", - Format: "binary", - }, - } - } - } - - if op.RequestBody != nil { - for _, mediatype := range op.RequestBody.Content { - if mediatype.Schema != nil { - // Ensure all schema validation errors are set up properly as some - // parts of the schema may have been user-supplied. - mediatype.Schema.PrecomputeMessages() - } - } - } - - var inSchema *Schema - if op.RequestBody != nil && op.RequestBody.Content != nil && op.RequestBody.Content["application/json"] != nil && op.RequestBody.Content["application/json"].Schema != nil { - hasInputBody = true - inSchema = op.RequestBody.Content["application/json"].Schema - } - - resolvers := findResolvers(resolverType, inputType) - defaults := findDefaults(registry, inputType) + inputParams, inputBodyIndex, hasInputBody, rawBodyIndex, rbt, inSchema := processInputType(inputType, &op, registry) - if op.Responses == nil { - op.Responses = map[string]*Response{} - } outputType := reflect.TypeOf((*O)(nil)).Elem() if outputType.Kind() != reflect.Struct { panic("output must be a struct") } + outHeaders, outStatusIndex, outBodyIndex, outBodyFunc := processOutputType(outputType, &op, registry) - outStatusIndex := -1 - if f, ok := outputType.FieldByName("Status"); ok { - outStatusIndex = f.Index[0] - if f.Type.Kind() != reflect.Int { - panic("status field must be an int") - } - // TODO: enum tag? - // TODO: register each of the possible responses with the right model - // and headers down below. - } - outHeaders := findHeaders(outputType) - outBodyIndex := -1 - outBodyFunc := false - if f, ok := outputType.FieldByName("Body"); ok { - outBodyIndex = f.Index[0] - if f.Type.Kind() == reflect.Func { - outBodyFunc = true - - if f.Type != bodyCallbackType { - panic("body field must be a function with signature func(huma.Context)") - } - } - status := op.DefaultStatus - if status == 0 { - status = http.StatusOK - } - statusStr := strconv.Itoa(status) - if op.Responses[statusStr] == nil { - op.Responses[statusStr] = &Response{} - } - if op.Responses[statusStr].Description == "" { - op.Responses[statusStr].Description = http.StatusText(status) - } - if op.Responses[statusStr].Headers == nil { - op.Responses[statusStr].Headers = map[string]*Param{} - } - if !outBodyFunc { - hint := getHint(outputType, f.Name, op.OperationID+"Response") - if nameHint := f.Tag.Get("nameHint"); nameHint != "" { - hint = nameHint - } - outSchema := SchemaFromField(registry, f, hint) - if op.Responses[statusStr].Content == nil { - op.Responses[statusStr].Content = map[string]*MediaType{} - } - // Check if the field's type implements ContentTypeFilter - contentType := "application/json" - if reflect.PointerTo(f.Type).Implements(reflect.TypeFor[ContentTypeFilter]()) { - instance := reflect.New(f.Type).Interface().(ContentTypeFilter) - contentType = instance.ContentType(contentType) - } - if len(op.Responses[statusStr].Content) == 0 { - op.Responses[statusStr].Content[contentType] = &MediaType{} - } - if op.Responses[statusStr].Content[contentType] != nil && op.Responses[statusStr].Content[contentType].Schema == nil { - op.Responses[statusStr].Content[contentType].Schema = outSchema - } - } - } - if op.DefaultStatus == 0 { - if outBodyIndex != -1 { - op.DefaultStatus = http.StatusOK - } else { - op.DefaultStatus = http.StatusNoContent - } - } - defaultStatusStr := strconv.Itoa(op.DefaultStatus) - if op.Responses[defaultStatusStr] == nil { - op.Responses[defaultStatusStr] = &Response{ - Description: http.StatusText(op.DefaultStatus), - } - } - for _, entry := range outHeaders.Paths { - // Document the header's name and type. - if op.Responses[defaultStatusStr].Headers == nil { - op.Responses[defaultStatusStr].Headers = map[string]*Param{} - } - v := entry.Value - f := v.Field - if f.Type.Kind() == reflect.Slice { - f.Type = deref(f.Type.Elem()) - } - if reflect.PointerTo(f.Type).Implements(fmtStringerType) { - // Special case: this field will be written as a string by calling - // `.String()` on the value. - f.Type = stringType - } - op.Responses[defaultStatusStr].Headers[v.Name] = &Header{ - // We need to generate the schema from the field to get validation info - // like min/max and enums. Useful to let the client know possible values. - Schema: SchemaFromField(registry, f, getHint(outputType, f.Name, op.OperationID+defaultStatusStr+v.Name)), - } - } - - if len(op.Errors) > 0 && (len(inputParams.Paths) > 0 || hasInputBody) { - op.Errors = append(op.Errors, http.StatusUnprocessableEntity) - } if len(op.Errors) > 0 { - op.Errors = append(op.Errors, http.StatusInternalServerError) - } - - exampleErr := NewError(0, "") - errContentType := "application/json" - if ctf, ok := exampleErr.(ContentTypeFilter); ok { - errContentType = ctf.ContentType(errContentType) - } - errType := deref(reflect.TypeOf(exampleErr)) - errSchema := registry.Schema(errType, true, getHint(errType, "", "Error")) - for _, code := range op.Errors { - op.Responses[strconv.Itoa(code)] = &Response{ - Description: http.StatusText(code), - Content: map[string]*MediaType{ - errContentType: { - Schema: errSchema, - }, - }, - } - } - if len(op.Responses) <= 1 && len(op.Errors) == 0 { - // No errors are defined, so set a default response. - op.Responses["default"] = &Response{ - Description: "Error", - Content: map[string]*MediaType{ - errContentType: { - Schema: errSchema, - }, - }, + if len(inputParams.Paths) > 0 || hasInputBody { + op.Errors = append(op.Errors, http.StatusUnprocessableEntity) } + op.Errors = append(op.Errors, http.StatusInternalServerError) } + defineErrors(&op, registry) if !op.Hidden { oapi.AddOperation(&op) } + resolvers := findResolvers(resolverType, inputType) + defaults := findDefaults(registry, inputType) a := api.Adapter() - a.Handle(&op, api.Middlewares().Handler(op.Middlewares.Handler(func(ctx Context) { var input I @@ -921,15 +663,12 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if f.Kind() == reflect.Invalid { return } - var value string - switch p.Loc { - case "path": - value = ctx.Param(p.Name) - case "query": - value = ctx.Query(p.Name) - case "header": - value = ctx.Header(p.Name) - case "cookie": + + pb.Reset() + pb.Push(p.Loc) + pb.Push(p.Name) + + if p.Loc == "cookie" { if cookies == nil { // Only parse the cookie headers once, on-demand. cookies = map[string]*http.Cookie{} @@ -937,303 +676,30 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) cookies[c.Name] = c } } - if c, ok := cookies[p.Name]; ok { + if c, ok := cookies[p.Name]; ok && f.Type() == cookieType { // Special case: http.Cookie type, meaning we want the entire parsed // cookie struct, not just the value. - if f.Type() == cookieType { - f.Set(reflect.ValueOf(cookies[p.Name]).Elem()) - return - } - - value = c.Value + f.Set(reflect.ValueOf(c).Elem()) + return } } - pb.Reset() - pb.Push(p.Loc) - pb.Push(p.Name) - - if value == "" && p.Default != "" { - value = p.Default - } - - if !op.SkipValidateParams && p.Required && value == "" { - // Path params are always required. - res.Add(pb, "", "required "+p.Loc+" parameter is missing") + value := getParamValue(*p, ctx, cookies) + if value == "" { + if !op.SkipValidateParams && p.Required { + // Path params are always required. + res.Add(pb, "", "required "+p.Loc+" parameter is missing") + } return } - if value != "" { - var pv any - - switch p.Type.Kind() { - case reflect.String: - f.SetString(value) - pv = value - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v, err := strconv.ParseInt(value, 10, 64) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.SetInt(v) - pv = v - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v, err := strconv.ParseUint(value, 10, 64) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.SetUint(v) - pv = v - case reflect.Float32, reflect.Float64: - v, err := strconv.ParseFloat(value, 64) - if err != nil { - res.Add(pb, value, "invalid float") - return - } - f.SetFloat(v) - pv = v - case reflect.Bool: - v, err := strconv.ParseBool(value) - if err != nil { - res.Add(pb, value, "invalid boolean") - return - } - f.SetBool(v) - pv = v - default: - if f.Type().Kind() == reflect.Slice { - var values []string - if p.Explode { - u := ctx.URL() - values = (&u).Query()[p.Name] - } else { - values = strings.Split(value, ",") - } - switch f.Type().Elem().Kind() { - - case reflect.String: - if f.Type() == reflect.TypeOf(values) { - f.Set(reflect.ValueOf(values)) - } else { - // Change element type to support slice of string subtypes (enums) - enumValues := reflect.New(f.Type()).Elem() - for _, val := range values { - enumVal := reflect.New(f.Type().Elem()).Elem() - enumVal.SetString(val) - enumValues.Set(reflect.Append(enumValues, enumVal)) - } - f.Set(enumValues) - } - pv = values - - case reflect.Int: - vs, err := parseArrElement(values, func(s string) (int, error) { - val, err := strconv.ParseInt(s, 10, strconv.IntSize) - if err != nil { - return 0, err - } - return int(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int8: - vs, err := parseArrElement(values, func(s string) (int8, error) { - val, err := strconv.ParseInt(s, 10, 8) - if err != nil { - return 0, err - } - return int8(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int16: - vs, err := parseArrElement(values, func(s string) (int16, error) { - val, err := strconv.ParseInt(s, 10, 16) - if err != nil { - return 0, err - } - return int16(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int32: - vs, err := parseArrElement(values, func(s string) (int32, error) { - val, err := strconv.ParseInt(s, 10, 32) - if err != nil { - return 0, err - } - return int32(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int64: - vs, err := parseArrElement(values, func(s string) (int64, error) { - val, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return 0, err - } - return val, nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint: - vs, err := parseArrElement(values, func(s string) (uint, error) { - val, err := strconv.ParseUint(s, 10, strconv.IntSize) - if err != nil { - return 0, err - } - return uint(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint16: - vs, err := parseArrElement(values, func(s string) (uint16, error) { - val, err := strconv.ParseUint(s, 10, 16) - if err != nil { - return 0, err - } - return uint16(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint32: - vs, err := parseArrElement(values, func(s string) (uint32, error) { - val, err := strconv.ParseUint(s, 10, 32) - if err != nil { - return 0, err - } - return uint32(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint64: - vs, err := parseArrElement(values, func(s string) (uint64, error) { - val, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return 0, err - } - return val, nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Float32: - vs, err := parseArrElement(values, func(s string) (float32, error) { - val, err := strconv.ParseFloat(s, 32) - if err != nil { - return 0, err - } - return float32(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid floating value") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Float64: - vs, err := parseArrElement(values, func(s string) (float64, error) { - val, err := strconv.ParseFloat(s, 64) - if err != nil { - return 0, err - } - return float64(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid floating value") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - } - break - } - - // Special case: time.Time - if f.Type() == timeType { - t, err := time.Parse(p.TimeFormat, value) - if err != nil { - res.Add(pb, value, "invalid date/time for format "+p.TimeFormat) - return - } - f.Set(reflect.ValueOf(t)) - pv = value - break - // Special case: url.URL - } else if f.Type() == urlType { - u, err := url.Parse(value) - if err != nil { - res.Add(pb, value, "invalid url.URL value") - return - } - f.Set(reflect.ValueOf(*u)) - pv = value - break - } - - // Last resort: use the `encoding.TextUnmarshaler` interface. - if fn, ok := f.Addr().Interface().(encoding.TextUnmarshaler); ok { - if err := fn.UnmarshalText([]byte(value)); err != nil { - res.Add(pb, value, "invalid value: "+err.Error()) - return - } - pv = value - break - } - - panic("unsupported param type " + p.Type.String()) - } + pv, err := parseInto(ctx, f, value, *p) + if err != nil { + res.Add(pb, value, err.Error()) + } - if !op.SkipValidateParams { - Validate(oapi.Components.Schemas, p.Schema, pb, ModeWriteToServer, pv, res) - } + if !op.SkipValidateParams { + Validate(oapi.Components.Schemas, p.Schema, pb, ModeWriteToServer, pv, res) } }) @@ -1246,161 +712,57 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) ctx.SetReadDeadline(time.Time{}) } - if rawBodyMultipart || rawBodyDecodedMultipart { - form, err := ctx.GetMultipartForm() - if err != nil { - res.Errors = append(res.Errors, &ErrorDetail{ - Location: "body", - Message: "cannot read multipart form: " + err.Error(), - }) - } else { - f := v - for _, i := range rawBodyIndex { - f = f.Field(i) - } - if rawBodyMultipart { - f.Set(reflect.ValueOf(*form)) - } else { - f.FieldByName("Form").Set(reflect.ValueOf(form)) - r := f.Addr(). - MethodByName("Decode"). - Call([]reflect.Value{ - reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]), - }) - errs := r[0].Interface().([]error) - if errs != nil { - WriteErr(api, ctx, http.StatusUnprocessableEntity, "validation failed", errs...) - return - } - } + if rbt.isMultipart() { + if cErr := processMultipartMsgBody(op, ctx, v, rbt, rawBodyIndex, res); cErr != nil { + writeErr(api, ctx, cErr, *res) + return } } else { + // Read body buf := bufPool.Get().(*bytes.Buffer) - reader := ctx.BodyReader() - if reader == nil { - reader = bytes.NewReader(nil) - } - if closer, ok := reader.(io.Closer); ok { - defer closer.Close() - } - if op.MaxBodyBytes > 0 { - reader = io.LimitReader(reader, op.MaxBodyBytes) - } - count, err := io.Copy(buf, reader) - if op.MaxBodyBytes > 0 { - if count == op.MaxBodyBytes { - buf.Reset() - bufPool.Put(buf) - WriteErr(api, ctx, http.StatusRequestEntityTooLarge, fmt.Sprintf("request body is too large limit=%d bytes", op.MaxBodyBytes), res.Errors...) - return - } - } - if err != nil { + bufCloser := func() { buf.Reset() bufPool.Put(buf) - - if e, ok := err.(net.Error); ok && e.Timeout() { - WriteErr(api, ctx, http.StatusRequestTimeout, "request body read timeout", res.Errors...) - return - } - - WriteErr(api, ctx, http.StatusInternalServerError, "cannot read request body", err) + } + if cErr := readBody(buf, ctx, op.MaxBodyBytes); cErr != nil { + bufCloser() + writeErr(api, ctx, cErr, *res) return } body := buf.Bytes() + // Store raw body if len(rawBodyIndex) > 0 { - f := v - for _, i := range rawBodyIndex { - f = f.Field(i) - } + f := v.FieldByIndex(rawBodyIndex) f.SetBytes(body) } - if len(body) == 0 { - if op.RequestBody != nil && op.RequestBody.Required { - buf.Reset() - bufPool.Put(buf) - WriteErr(api, ctx, http.StatusBadRequest, "request body is required", res.Errors...) - return - } - } else { - parseErrCount := 0 - if hasInputBody && !op.SkipValidateBody { - // Validate the input. First, parse the body into []any or map[string]any - // or equivalent, which can be easily validated. Then, convert to the - // expected struct type to call the handler. - var parsed any - if err := api.Unmarshal(ctx.Header("Content-Type"), body, &parsed); err != nil { - errStatus = http.StatusBadRequest - if errors.Is(err, ErrUnknownContentType) { - errStatus = http.StatusUnsupportedMediaType - } - res.Errors = append(res.Errors, &ErrorDetail{ - Location: "body", - Message: err.Error(), - Value: string(body), - }) - parseErrCount++ - } else { - pb.Reset() - pb.Push("body") - count := len(res.Errors) - Validate(oapi.Components.Schemas, inSchema, pb, ModeWriteToServer, parsed, res) - parseErrCount = len(res.Errors) - count - if parseErrCount > 0 { - errStatus = http.StatusUnprocessableEntity - } - } - } - - if hasInputBody && len(inputBodyIndex) > 0 { - // We need to get the body into the correct type now that it has been - // validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a - // second time is faster than `mapstructure.Decode` or any of the other - // common reflection-based approaches when using real-world medium-sized - // JSON payloads with lots of strings. - f := v - for _, index := range inputBodyIndex { - f = f.Field(index) - } - if err := api.Unmarshal(ctx.Header("Content-Type"), body, f.Addr().Interface()); err != nil { - if parseErrCount == 0 { - // Hmm, this should have worked... validator missed something? - res.Errors = append(res.Errors, &ErrorDetail{ - Location: "body", - Message: err.Error(), - Value: string(body), - }) - } - } else { - // Set defaults for any fields that were not in the input. - defaults.Every(v, func(item reflect.Value, def any) { - if item.IsZero() { - if item.Kind() == reflect.Pointer { - item.Set(reflect.New(item.Type().Elem())) - item = item.Elem() - } - item.Set(reflect.Indirect(reflect.ValueOf(def))) - } - }) - } - } + // Process body + unmarshaler := func(data []byte, v any) error { return api.Unmarshal(ctx.Header("Content-Type"), data, v) } + validator := func(data any, res *ValidateResult) { + pb.Reset() + pb.Push("body") + Validate(oapi.Components.Schemas, inSchema, pb, ModeWriteToServer, data, res) + } + processErrStatus, cErr := processRegularMsgBody(body, op, v, hasInputBody, inputBodyIndex, unmarshaler, validator, defaults, res) + if processErrStatus > 0 { + errStatus = processErrStatus + } + if cErr != nil { + bufCloser() + writeErr(api, ctx, cErr, *res) + return + } - if len(rawBodyIndex) > 0 { - // If the raw body is used, then we must wait until *AFTER* the - // handler has run to return the body byte buffer to the pool, as - // the handler can read and modify this buffer. The safest way is - // to just wait until the end of this handler via defer. - defer bufPool.Put(buf) - defer buf.Reset() - } else { - // No raw body, and the body has already been unmarshalled above, so - // we can return the buffer to the pool now as we don't need the - // bytes any more. - buf.Reset() - bufPool.Put(buf) - } + // Clean up + // If the raw body is used, then we must wait until *AFTER* the + // handler has run to return the body byte buffer to the pool, as + // the handler can read and modify this buffer. The safest way is + // to just wait until the end of this handler via defer. + if len(rawBodyIndex) > 0 { + defer bufCloser() + } else { + bufCloser() } } } @@ -1532,6 +894,775 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) }))) } +// initResponses initializes Responses if it was unset. +func initResponses(op *Operation) { + if op.Responses == nil { + op.Responses = map[string]*Response{} + } +} + +// processInputType validates the input type, extracts expected requests and +// defines them on the operation op. +func processInputType(inputType reflect.Type, op *Operation, registry Registry) (*findResult[*paramFieldInfo], []int, bool, []int, rawBodyType, *Schema) { + inputParams := findParams(registry, op, inputType) + inputBodyIndex := []int{} + hasInputBody := false + if f, ok := inputType.FieldByName("Body"); ok { + hasInputBody = true + inputBodyIndex = f.Index + initRequestBody(op) + setRequestBodyFromBody(op, registry, f, inputType) + ensureBodyReadTimeout(op) + ensureMaxBodyBytes(op) + } + rawBodyIndex := []int{} + var rbt rawBodyType + if f, ok := inputType.FieldByName("RawBody"); ok { + rawBodyIndex = f.Index + initRequestBody(op, setRequestBodyRequired) + rbt = setRequestBodyFromRawBody(op, f) + } + + if op.RequestBody != nil { + for _, mediatype := range op.RequestBody.Content { + if mediatype.Schema != nil { + // Ensure all schema validation errors are set up properly as some + // parts of the schema may have been user-supplied. + mediatype.Schema.PrecomputeMessages() + } + } + } + + var inSchema *Schema + if op.RequestBody != nil && op.RequestBody.Content != nil && op.RequestBody.Content["application/json"] != nil && op.RequestBody.Content["application/json"].Schema != nil { + hasInputBody = true + inSchema = op.RequestBody.Content["application/json"].Schema + } + return inputParams, inputBodyIndex, hasInputBody, rawBodyIndex, rbt, inSchema +} + +// ensureMaxBodyBytes sets the MaxBodyBytes to a default value if it was unset. +func ensureMaxBodyBytes(op *Operation) { + if op.MaxBodyBytes == 0 { + // 1 MB default + op.MaxBodyBytes = 1024 * 1024 + } +} + +// ensureBodyReadTimeout sets the BodyReadTimeout to a default value if it was unset. +func ensureBodyReadTimeout(op *Operation) { + if op.BodyReadTimeout == 0 { + // 5 second default + op.BodyReadTimeout = 5 * time.Second + } +} + +// setRequestBodyFromBody configures op.RequestBody from the Body field. +func setRequestBodyFromBody(op *Operation, registry Registry, fBody reflect.StructField, inputType reflect.Type) { + if fBody.Tag.Get("required") == "true" || (fBody.Type.Kind() != reflect.Ptr && fBody.Type.Kind() != reflect.Interface) { + setRequestBodyRequired(op.RequestBody) + } + contentType := "application/json" + if c := fBody.Tag.Get("contentType"); c != "" { + contentType = c + } + hint := getHint(inputType, fBody.Name, op.OperationID+"Request") + if nameHint := fBody.Tag.Get("nameHint"); nameHint != "" { + hint = nameHint + } + s := SchemaFromField(registry, fBody, hint) + if op.RequestBody.Content[contentType] == nil { + op.RequestBody.Content[contentType] = &MediaType{} + } + op.RequestBody.Content[contentType].Schema = s + +} + +type rawBodyType int + +const ( + rbtMultipart rawBodyType = iota + 1 + rbtMultipartDecoded + rbtOther +) + +func (r rawBodyType) isMultipart() bool { + return r == rbtMultipart || r == rbtMultipartDecoded +} + +// setRequestBodyFromRawBody configures op.RequestBody from the RawBody field. +func setRequestBodyFromRawBody(op *Operation, fRawBody reflect.StructField) rawBodyType { + rbt := rbtOther + contentType := "application/octet-stream" + if fRawBody.Type.String() == "multipart.Form" { + contentType = "multipart/form-data" + rbt = rbtMultipart + } + if strings.HasPrefix(fRawBody.Type.Name(), "MultipartFormFiles") { + contentType = "multipart/form-data" + rbt = rbtMultipartDecoded + } + if c := fRawBody.Tag.Get("contentType"); c != "" { + contentType = c + } + + if contentType != "multipart/form-data" { + op.RequestBody.Content[contentType] = &MediaType{ + Schema: &Schema{ + Type: "string", + Format: "binary", + }, + } + return rbt + } + if op.RequestBody.Content["multipart/form-data"] != nil { + return rbt + } + + switch rbt { + case rbtMultipart: + op.RequestBody.Content["multipart/form-data"] = &MediaType{ + Schema: &Schema{ + Type: "object", + Properties: map[string]*Schema{ + "name": { + Type: "string", + Description: "general purpose name for multipart form value", + }, + "filename": { + Type: "string", + Format: "binary", + Description: "filename of the file being uploaded", + }, + }, + }, + } + case rbtMultipartDecoded: + dataField, ok := fRawBody.Type.FieldByName("data") + if !ok { + panic("Expected type MultipartFormFiles[T] to have a 'data *T' generic pointer field") + } + op.RequestBody.Content["multipart/form-data"] = &MediaType{ + Schema: multiPartFormFileSchema(dataField.Type.Elem()), + Encoding: multiPartContentEncoding(dataField.Type.Elem()), + } + op.RequestBody.Required = false + } + return rbt +} + +// initRequestBody initializes an empty RequestBody and its Content map. +func initRequestBody(op *Operation, rbOpts ...func(*RequestBody)) { + if op.RequestBody == nil { + op.RequestBody = &RequestBody{} + } + if op.RequestBody.Content == nil { + op.RequestBody.Content = map[string]*MediaType{} + } + for _, opt := range rbOpts { + opt(op.RequestBody) + } +} + +func setRequestBodyRequired(rb *RequestBody) { + rb.Required = true +} + +// processOutputType validates the output type, extracts possible responses and +// defines them on the operation op. +func processOutputType(outputType reflect.Type, op *Operation, registry Registry) (*findResult[*headerInfo], int, int, bool) { + outStatusIndex := -1 + if f, ok := outputType.FieldByName("Status"); ok { + outStatusIndex = f.Index[0] + if f.Type.Kind() != reflect.Int { + panic("status field must be an int") + } + // TODO: enum tag? + // TODO: register each of the possible responses with the right model + // and headers down below. + } + outBodyIndex := -1 + outBodyFunc := false + if f, ok := outputType.FieldByName("Body"); ok { + outBodyIndex = f.Index[0] + if f.Type.Kind() == reflect.Func { + outBodyFunc = true + + if f.Type != bodyCallbackType { + panic("body field must be a function with signature func(huma.Context)") + } + } + status := op.DefaultStatus + if status == 0 { + status = http.StatusOK + } + statusStr := strconv.Itoa(status) + if op.Responses[statusStr] == nil { + op.Responses[statusStr] = &Response{} + } + if op.Responses[statusStr].Description == "" { + op.Responses[statusStr].Description = http.StatusText(status) + } + if op.Responses[statusStr].Headers == nil { + op.Responses[statusStr].Headers = map[string]*Param{} + } + if !outBodyFunc { + hint := getHint(outputType, f.Name, op.OperationID+"Response") + if nameHint := f.Tag.Get("nameHint"); nameHint != "" { + hint = nameHint + } + outSchema := SchemaFromField(registry, f, hint) + if op.Responses[statusStr].Content == nil { + op.Responses[statusStr].Content = map[string]*MediaType{} + } + // Check if the field's type implements ContentTypeFilter + contentType := "application/json" + if reflect.PointerTo(f.Type).Implements(reflect.TypeFor[ContentTypeFilter]()) { + instance := reflect.New(f.Type).Interface().(ContentTypeFilter) + contentType = instance.ContentType(contentType) + } + if len(op.Responses[statusStr].Content) == 0 { + op.Responses[statusStr].Content[contentType] = &MediaType{} + } + if op.Responses[statusStr].Content[contentType] != nil && op.Responses[statusStr].Content[contentType].Schema == nil { + op.Responses[statusStr].Content[contentType].Schema = outSchema + } + } + } + if op.DefaultStatus == 0 { + if outBodyIndex != -1 { + op.DefaultStatus = http.StatusOK + } else { + op.DefaultStatus = http.StatusNoContent + } + } + defaultStatusStr := strconv.Itoa(op.DefaultStatus) + if op.Responses[defaultStatusStr] == nil { + op.Responses[defaultStatusStr] = &Response{ + Description: http.StatusText(op.DefaultStatus), + } + } + outHeaders := findHeaders(outputType) + for _, entry := range outHeaders.Paths { + // Document the header's name and type. + if op.Responses[defaultStatusStr].Headers == nil { + op.Responses[defaultStatusStr].Headers = map[string]*Param{} + } + v := entry.Value + f := v.Field + if f.Type.Kind() == reflect.Slice { + f.Type = deref(f.Type.Elem()) + } + if reflect.PointerTo(f.Type).Implements(fmtStringerType) { + // Special case: this field will be written as a string by calling + // `.String()` on the value. + f.Type = stringType + } + op.Responses[defaultStatusStr].Headers[v.Name] = &Header{ + // We need to generate the schema from the field to get validation info + // like min/max and enums. Useful to let the client know possible values. + Schema: SchemaFromField(registry, f, getHint(outputType, f.Name, op.OperationID+defaultStatusStr+v.Name)), + } + } + return outHeaders, outStatusIndex, outBodyIndex, outBodyFunc +} + +// defineErrors extracts possible error responses and defines them on the +// operation op. +func defineErrors(op *Operation, registry Registry) { + exampleErr := NewError(0, "") + errContentType := "application/json" + if ctf, ok := exampleErr.(ContentTypeFilter); ok { + errContentType = ctf.ContentType(errContentType) + } + errType := deref(reflect.TypeOf(exampleErr)) + errSchema := registry.Schema(errType, true, getHint(errType, "", "Error")) + for _, code := range op.Errors { + op.Responses[strconv.Itoa(code)] = &Response{ + Description: http.StatusText(code), + Content: map[string]*MediaType{ + errContentType: { + Schema: errSchema, + }, + }, + } + } + if len(op.Responses) <= 1 && len(op.Errors) == 0 { + // No errors are defined, so set a default response. + op.Responses["default"] = &Response{ + Description: "Error", + Content: map[string]*MediaType{ + errContentType: { + Schema: errSchema, + }, + }, + } + } +} + +// getParamValue extracts the requested parameter from the relevant +// context or cookie source. If unset, the function returns the default value +// for this parameter. +func getParamValue(p paramFieldInfo, ctx Context, cookies map[string]*http.Cookie) string { + var value string + switch p.Loc { + case "path": + value = ctx.Param(p.Name) + case "query": + value = ctx.Query(p.Name) + case "header": + value = ctx.Header(p.Name) + case "cookie": + if c, ok := cookies[p.Name]; ok { + value = c.Value + } + } + if value == "" { + value = p.Default + } + return value +} + +var errUnparsable = errors.New("unparsable value") + +// parseInto converts the string value into the expected type using the +// parameter field information p and sets the result on f. +func parseInto(ctx Context, f reflect.Value, value string, p paramFieldInfo) (any, error) { + // built-in types + switch p.Type.Kind() { + case reflect.String: + f.SetString(value) + return value, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, errors.New("invalid integer") + } + f.SetInt(v) + return v, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return nil, errors.New("invalid integer") + } + f.SetUint(v) + return v, nil + case reflect.Float32, reflect.Float64: + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return nil, errors.New("invalid float") + } + f.SetFloat(v) + return v, nil + case reflect.Bool: + v, err := strconv.ParseBool(value) + if err != nil { + return nil, errors.New("invalid boolean") + } + f.SetBool(v) + return v, nil + case reflect.Slice: + var values []string + if p.Explode { + u := ctx.URL() + values = (&u).Query()[p.Name] + } else { + values = strings.Split(value, ",") + } + pv, err := parseSliceInto(f, values) + if err != nil { + if errors.Is(err, errUnparsable) { + break + } + return nil, err + } + return pv, nil + } + + // special types + switch f.Type() { + case timeType: // Special case: time.Time + // return nil, errors.New(value) + t, err := time.Parse(p.TimeFormat, value) + if err != nil { + return nil, errors.New("invalid date/time for format " + p.TimeFormat) + } + f.Set(reflect.ValueOf(t)) + return value, nil + case urlType: // Special case: url.URL + u, err := url.Parse(value) + if err != nil { + return nil, errors.New("invalid url.URL value") + } + f.Set(reflect.ValueOf(*u)) + return value, nil + } + + // Last resort: use the `encoding.TextUnmarshaler` interface. + if fn, ok := f.Addr().Interface().(encoding.TextUnmarshaler); ok { + if err := fn.UnmarshalText([]byte(value)); err != nil { + return nil, errors.New("invalid value: " + err.Error()) + } + return value, nil + } + + panic("unsupported param type " + p.Type.String()) +} + +// parseSliceInto converts a slice of string values into the expected type of f +// and sets the result on f. +func parseSliceInto(f reflect.Value, values []string) (any, error) { + switch f.Type().Elem().Kind() { + + case reflect.String: + if f.Type() == reflect.TypeOf(values) { + f.Set(reflect.ValueOf(values)) + } else { + // Change element type to support slice of string subtypes (enums) + enumValues := reflect.New(f.Type()).Elem() + for _, val := range values { + enumVal := reflect.New(f.Type().Elem()).Elem() + enumVal.SetString(val) + enumValues.Set(reflect.Append(enumValues, enumVal)) + } + f.Set(enumValues) + } + return values, nil + + case reflect.Int: + vs, err := parseArrElement(values, func(s string) (int, error) { + val, err := strconv.ParseInt(s, 10, strconv.IntSize) + if err != nil { + return 0, err + } + return int(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Int8: + vs, err := parseArrElement(values, func(s string) (int8, error) { + val, err := strconv.ParseInt(s, 10, 8) + if err != nil { + return 0, err + } + return int8(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Int16: + vs, err := parseArrElement(values, func(s string) (int16, error) { + val, err := strconv.ParseInt(s, 10, 16) + if err != nil { + return 0, err + } + return int16(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Int32: + vs, err := parseArrElement(values, func(s string) (int32, error) { + val, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return 0, err + } + return int32(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Int64: + vs, err := parseArrElement(values, func(s string) (int64, error) { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + return val, nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Uint: + vs, err := parseArrElement(values, func(s string) (uint, error) { + val, err := strconv.ParseUint(s, 10, strconv.IntSize) + if err != nil { + return 0, err + } + return uint(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Uint16: + vs, err := parseArrElement(values, func(s string) (uint16, error) { + val, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + return uint16(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Uint32: + vs, err := parseArrElement(values, func(s string) (uint32, error) { + val, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return 0, err + } + return uint32(val), nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Uint64: + vs, err := parseArrElement(values, func(s string) (uint64, error) { + val, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, err + } + return val, nil + }) + if err != nil { + return nil, errors.New("invalid integer") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Float32: + vs, err := parseArrElement(values, func(s string) (float32, error) { + val, err := strconv.ParseFloat(s, 32) + if err != nil { + return 0, err + } + return float32(val), nil + }) + if err != nil { + return nil, errors.New("invalid floating value") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + + case reflect.Float64: + vs, err := parseArrElement(values, func(s string) (float64, error) { + val, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, err + } + return float64(val), nil + }) + if err != nil { + return nil, errors.New("invalid floating value") + } + f.Set(reflect.ValueOf(vs)) + return vs, nil + } + return nil, errUnparsable +} + +type contextError struct { + Code int + Msg string + Errs []error +} + +func (e *contextError) Error() string { + return e.Msg +} + +func writeErr(api API, ctx Context, cErr *contextError, res ValidateResult) { + if cErr.Errs != nil { + WriteErr(api, ctx, cErr.Code, cErr.Msg, cErr.Errs...) + } else { + WriteErr(api, ctx, cErr.Code, cErr.Msg, res.Errors...) + } +} + +func processMultipartMsgBody(op Operation, ctx Context, inputValue reflect.Value, rbt rawBodyType, rawBodyIndex []int, res *ValidateResult) *contextError { + form, err := ctx.GetMultipartForm() + if err != nil { + res.Errors = append(res.Errors, &ErrorDetail{ + Location: "body", + Message: "cannot read multipart form: " + err.Error(), + }) + return nil + } + f := inputValue + for _, i := range rawBodyIndex { + f = f.Field(i) + } + switch rbt { + case rbtMultipart: + f.Set(reflect.ValueOf(*form)) + case rbtMultipartDecoded: + f.FieldByName("Form").Set(reflect.ValueOf(form)) + r := f.Addr(). + MethodByName("Decode"). + Call( + []reflect.Value{ + reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]), + }) + errs := r[0].Interface().([]error) + if errs != nil { + return &contextError{Code: http.StatusUnprocessableEntity, Msg: "validation failed", Errs: errs} + } + } + return nil +} + +type intoUnmarshaler = func(data []byte, v any) error + +// processRegularMsgBody parses the raw body with unmarshaler and validates it +// with validator. Validation errors are documented in res and the +// corresponding error code is returned. If no errors were found, the return +// value is -1. +func processRegularMsgBody(body []byte, op Operation, v reflect.Value, hasInputBody bool, inputBodyIndex []int, unmarshaler intoUnmarshaler, validator func(data any, res *ValidateResult), defaults *findResult[any], res *ValidateResult) (int, *contextError) { + errStatus := -1 + // Check preconditions + if len(body) == 0 { + if op.RequestBody != nil && op.RequestBody.Required { + return errStatus, &contextError{Code: http.StatusBadRequest, Msg: "request body is required"} + } + return errStatus, nil + } + if !hasInputBody { + return errStatus, nil + } + + // Validate + isValid := true + if !op.SkipValidateBody { + validateErrStatus := validateBody(body, unmarshaler, validator, res) + errStatus = validateErrStatus + if errStatus > 0 { + isValid = false + } + } + + // Parse into value + if len(inputBodyIndex) > 0 { + if err := parseBodyInto(v, inputBodyIndex, unmarshaler, body, defaults); err != nil && isValid { + // Hmm, this should have worked... validator missed something? + res.Errors = append(res.Errors, err) + } + } + return errStatus, nil +} + +// validateBody parses the raw body with u and validates it with the validator. +// Any errors are documented in res and the corresponding error code is +// returned. If no errors were found, the return value is -1. +func validateBody(body []byte, u intoUnmarshaler, validator func(data any, res *ValidateResult), res *ValidateResult) int { + errStatus := -1 + // Validate the input. First, parse the body into []any or map[string]any + // or equivalent, which can be easily validated. Then, convert to the + // expected struct type to call the handler. + var parsed any + if err := u(body, &parsed); err != nil { + errStatus = http.StatusBadRequest + if errors.Is(err, ErrUnknownContentType) { + errStatus = http.StatusUnsupportedMediaType + } + + res.Errors = append(res.Errors, &ErrorDetail{ + Location: "body", + Message: err.Error(), + Value: string(body), + }) + } else { + preValidationErrCount := len(res.Errors) + validator(parsed, res) + if len(res.Errors)-preValidationErrCount > 0 { + errStatus = http.StatusUnprocessableEntity + } + } + return errStatus +} + +// parseBodyInto parses the raw body with u and populates the result in v at +// index bodyIndex. Afterwards, it sets default values on v for all fields that +// were not populated with body. +func parseBodyInto(v reflect.Value, bodyIndex []int, u intoUnmarshaler, body []byte, defaults *findResult[any]) *ErrorDetail { + // We need to get the body into the correct type now that it has been + // validated. Benchmarks on Go 1.20 show that using `json.Unmarshal` a + // second time is faster than `mapstructure.Decode` or any of the other + // common reflection-based approaches when using real-world medium-sized + // JSON payloads with lots of strings. + f := v.FieldByIndex(bodyIndex) + if err := u(body, f.Addr().Interface()); err != nil { + return &ErrorDetail{ + Location: "body", + Message: err.Error(), + Value: string(body), + } + } + // Set defaults for any fields that were not in the input. + defaults.Every(v, func(item reflect.Value, def any) { + if item.IsZero() { + if item.Kind() == reflect.Pointer { + item.Set(reflect.New(item.Type().Elem())) + item = item.Elem() + } + item.Set(reflect.Indirect(reflect.ValueOf(def))) + } + }) + return nil +} + +// readBody reads the message body from ctx into buf, respecting the +func readBody(buf io.Writer, ctx Context, maxBytes int64) *contextError { + reader := ctx.BodyReader() + if reader == nil { + reader = bytes.NewReader(nil) + } + if closer, ok := reader.(io.Closer); ok { + defer closer.Close() + } + if maxBytes > 0 { + reader = io.LimitReader(reader, maxBytes) + } + count, err := io.Copy(buf, reader) + if maxBytes > 0 { + if count == maxBytes { + return &contextError{Code: http.StatusRequestEntityTooLarge, Msg: fmt.Sprintf("request body is too large limit=%d bytes", maxBytes)} + } + } + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + return &contextError{Code: http.StatusRequestTimeout, Msg: "request body read timeout"} + } + + return &contextError{Code: http.StatusInternalServerError, Msg: "cannot read request body", Errs: []error{err}} + } + return nil +} + // AutoRegister auto-detects operation registration methods and registers them // with the given API. Any method named `Register...` will be called and // passed the API as the only argument. Since registration happens at diff --git a/huma_test.go b/huma_test.go index 84f6524a..227b3484 100644 --- a/huma_test.go +++ b/huma_test.go @@ -75,6 +75,17 @@ type StructWithDefaultField struct { Field string `json:"field" default:"default"` } +// MyTextUnmarshaler is a custom type that implements the +// `encoding.TextUnmarshaler` interface +type MyTextUnmarshaler struct { + value string +} + +func (m *MyTextUnmarshaler) UnmarshalText(text []byte) error { + m.value = "Hello, World!" + return nil +} + func TestFeatures(t *testing.T) { for _, feature := range []struct { Name string @@ -532,6 +543,22 @@ func TestFeatures(t *testing.T) { assert.Equal(t, http.StatusNoContent, resp.Code) }, }, + { + Name: "parse-with-textunmarshaler", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/{mytext}", + }, func(ctx context.Context, i *struct { + MyText MyTextUnmarshaler `path:"mytext"` + }) (*struct{}, error) { + assert.Equal(t, "Hello, World!", i.MyText.value) + return nil, nil + }) + }, + Method: http.MethodGet, + URL: "/test", + }, { Name: "request-body", Register: func(t *testing.T, api huma.API) {