diff --git a/api.go b/api.go index 39acec04..e93c88fd 100644 --- a/api.go +++ b/api.go @@ -23,6 +23,8 @@ var rxSchema = regexp.MustCompile(`#/components/schemas/([^"]+)`) var ErrUnknownContentType = errors.New("unknown content type") +var ErrUnknownAcceptContentType = errors.New("unknown accept content type") + // Resolver runs a `Resolve` function after a request has been parsed, enabling // you to run custom validation or other code that can modify the request and / // or return errors. @@ -201,6 +203,11 @@ type Config struct { // chosen from the keys of `Formats`. DefaultFormat string + // NoFormatFallback disables the fallback to application/json (if available) + // when the client requests an unknown content type. If set and no format is + // negotiated, then a 406 Not Acceptable response will be returned. + NoFormatFallback bool + // Transformers are a way to modify a response body before it is serialized. Transformers []Transformer @@ -302,8 +309,13 @@ func (a *api) Unmarshal(contentType string, data []byte, v any) error { func (a *api) Negotiate(accept string) (string, error) { ct := negotiation.SelectQValueFast(accept, a.formatKeys) - if ct == "" && a.formatKeys != nil { - ct = a.formatKeys[0] + if ct == "" { + if a.config.NoFormatFallback { + return "", ErrUnknownAcceptContentType + } + if a.formatKeys != nil { + ct = a.formatKeys[0] + } } if _, ok := a.formats[ct]; !ok { return ct, fmt.Errorf("%w: %s", ErrUnknownContentType, ct) @@ -395,8 +407,10 @@ func NewAPI(config Config, a Adapter) API { config.Components.Schemas = NewMapRegistry("#/components/schemas/", DefaultSchemaNamer) } - if config.DefaultFormat == "" && config.Formats["application/json"].Marshal != nil { - config.DefaultFormat = "application/json" + if config.DefaultFormat == "" && !config.NoFormatFallback { + if config.Formats["application/json"].Marshal != nil { + config.DefaultFormat = "application/json" + } } if config.DefaultFormat != "" { newAPI.formatKeys = append(newAPI.formatKeys, config.DefaultFormat) diff --git a/huma.go b/huma.go index 4ea4b031..1887d135 100644 --- a/huma.go +++ b/huma.go @@ -499,8 +499,6 @@ func writeResponse(api API, ctx Context, status int, ct string, body any) error if ctf, ok := body.(ContentTypeFilter); ok { ct = ctf.ContentType(ct) } - - ctx.SetHeader("Content-Type", ct) } if err := transformAndWrite(api, ctx, status, ct, body); err != nil { @@ -544,6 +542,9 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) er return fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr) } } + + ctx.SetHeader("Content-Type", ct) + return nil } diff --git a/huma_test.go b/huma_test.go index 127c71fe..a0ea9221 100644 --- a/huma_test.go +++ b/huma_test.go @@ -108,6 +108,7 @@ func TestFeatures(t *testing.T) { for _, feature := range []struct { Name string Transformers []huma.Transformer + Config huma.Config Register func(t *testing.T, api huma.API) Method string URL string @@ -1769,7 +1770,8 @@ Content-Type: text/plain assert.Equal(t, "form.myInt", errors.Errors[1].Location) } }, - }, { + }, + { Name: "request-body-multipart-file-decoded-with-formvalue-invalid", Register: func(t *testing.T, api huma.API) { huma.Register(api, huma.Operation{ @@ -1999,6 +2001,76 @@ Content-Type: text/plain assert.Equal(t, "application/custom-type", resp.Header().Get("Content-Type")) }, }, + { + Name: "unknown accept header with JSON format fallback", + Config: huma.Config{ + OpenAPI: &huma.OpenAPI{ + OpenAPI: "3.1.0", + }, + Formats: huma.DefaultFormats, + NoFormatFallback: false, // Explicitly setting false for clarity. + }, + Register: func(t *testing.T, api huma.API) { + type Resp struct { + Body struct { + Value string `json:"value"` + } + } + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/response-accept", + }, func(ctx context.Context, input *struct{}) (*Resp, error) { + out := new(Resp) + out.Body.Value = "hello" + return out, nil + }) + }, + Method: http.MethodGet, + Headers: map[string]string{ + "Accept": "custom/dne", + }, + URL: "/response-accept", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + assert.JSONEq(t, `{"value":"hello"}`, resp.Body.String()) + }, + }, + { + Name: "unknown accept header with disabled format fallback returns 406", + Config: huma.Config{ + OpenAPI: &huma.OpenAPI{ + OpenAPI: "3.1.0", + }, + Formats: huma.DefaultFormats, + NoFormatFallback: true, + }, + Register: func(t *testing.T, api huma.API) { + type Resp struct { + Body struct { + Value string `json:"value"` + } + } + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/response-accept", + }, func(ctx context.Context, input *struct{}) (*Resp, error) { + out := new(Resp) + out.Body.Value = "hello" + return out, nil + }) + }, + Method: http.MethodGet, + Headers: map[string]string{ + "Accept": "custom/dne", + }, + URL: "/response-accept", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusNotAcceptable, resp.Code) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + assert.JSONEq(t, `{"title":"Not Acceptable","status":406,"detail":"unable to marshal response","errors":[{"message":"unknown accept content type"}]}`, resp.Body.String()) + }, + }, { Name: "response-body-nameHint", Register: func(t *testing.T, api huma.API) { @@ -2418,11 +2490,13 @@ Content-Type: text/plain } { t.Run(feature.Name, func(t *testing.T) { r := http.NewServeMux() - config := huma.DefaultConfig("Features Test API", "1.0.0") + if feature.Config.OpenAPI == nil { + feature.Config = huma.DefaultConfig("Features Test API", "1.0.0") + } if feature.Transformers != nil { - config.Transformers = append(config.Transformers, feature.Transformers...) + feature.Config.Transformers = append(feature.Config.Transformers, feature.Transformers...) } - api := humatest.Wrap(t, humago.New(r, config)) + api := humatest.Wrap(t, humago.New(r, feature.Config)) feature.Register(t, api) var body io.Reader = nil