diff --git a/adapters/adapters_test.go b/adapters/adapters_test.go index c8caea3d..129949a2 100644 --- a/adapters/adapters_test.go +++ b/adapters/adapters_test.go @@ -91,17 +91,19 @@ func TestAdapters(t *testing.T) { } wrap := func(h huma.API, isFiber bool) huma.API { - h.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - assert.Nil(t, ctx.TLS()) - v := ctx.Version() - if !isFiber { - assert.Equal(t, "HTTP/1.1", v.Proto) - assert.Equal(t, 1, v.ProtoMajor) - assert.Equal(t, 1, v.ProtoMinor) - } else { - assert.Equal(t, "http", v.Proto) + h.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + assert.Nil(t, ctx.TLS()) + v := ctx.Version() + if !isFiber { + assert.Equal(t, "HTTP/1.1", v.Proto) + assert.Equal(t, 1, v.ProtoMajor) + assert.Equal(t, 1, v.ProtoMinor) + } else { + assert.Equal(t, "http", v.Proto) + } + next(ctx) } - next(ctx) }) return h } diff --git a/adapters/humafiber/humafiber_context_test.go b/adapters/humafiber/humafiber_context_test.go index 41a4b636..91da4f87 100644 --- a/adapters/humafiber/humafiber_context_test.go +++ b/adapters/humafiber/humafiber_context_test.go @@ -213,12 +213,14 @@ func FiberMiddlewareUserContext(c *fiber.Ctx) error { return c.Next() } -func HumaMiddleware(ctx huma.Context, next func(huma.Context)) { - value := ctx.Header(HeaderNameHuma) - if value != "" { - ctx = huma.WithValue(ctx, contextValueHuma, value) +func HumaMiddleware(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + value := ctx.Header(HeaderNameHuma) + if value != "" { + ctx = huma.WithValue(ctx, contextValueHuma, value) + } + next(ctx) } - next(ctx) } func TestHumaFiber(t *testing.T) { diff --git a/api.go b/api.go index a3e43e35..02abf29b 100644 --- a/api.go +++ b/api.go @@ -239,7 +239,7 @@ type API interface { // route to a specific handler, which provides opportunity to respond early, // change the course of the request execution, or set request-scoped values for // the next Middleware. - UseMiddleware(middlewares ...func(ctx Context, next func(Context))) + UseMiddleware(middlewares ...Middleware) // Middlewares returns a slice of middleware handler functions that will be // run for all operations. Middleware are run in the order they are added. @@ -328,7 +328,7 @@ func (a *api) Marshal(w io.Writer, ct string, v any) error { return f.Marshal(w, v) } -func (a *api) UseMiddleware(middlewares ...func(ctx Context, next func(Context))) { +func (a *api) UseMiddleware(middlewares ...Middleware) { a.middlewares = append(a.middlewares, middlewares...) } diff --git a/api_test.go b/api_test.go index 4ffbe650..b9c322dc 100644 --- a/api_test.go +++ b/api_test.go @@ -71,11 +71,13 @@ func ExampleAdapter_handle() { func TestContextValue(t *testing.T) { _, api := humatest.New(t) - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - // Make an updated context available to the handler. - ctx = huma.WithValue(ctx, "foo", "bar") - next(ctx) - assert.Equal(t, http.StatusNoContent, ctx.Status()) + api.UseMiddleware(func(next func(huma.Context)) func(huma.Context) { + return func(ctx huma.Context) { + // Make an updated context available to the handler. + ctx = huma.WithValue(ctx, "foo", "bar") + next(ctx) + assert.Equal(t, http.StatusNoContent, ctx.Status()) + } }) // Register a simple hello world operation in the API. diff --git a/autopatch/autopatch.go b/autopatch/autopatch.go index 7c533b73..3935382a 100644 --- a/autopatch/autopatch.go +++ b/autopatch/autopatch.go @@ -11,7 +11,7 @@ package autopatch import ( "bytes" "encoding/json" - "fmt" + "errors" "io" "net/http" "net/http/httptest" @@ -329,7 +329,7 @@ func PatchResource(api huma.API, path *huma.PathItem) { var nullabilitySettings MergePatchNullabilitySettings if extension, ok := oapi.Extensions[MergePatchNullabilityExtension]; ok { if nullabilitySettings, ok = extension.(MergePatchNullabilitySettings); !ok { - huma.WriteErr(api, ctx, http.StatusInternalServerError, "Unable to parse nullability settings", fmt.Errorf("invalid nullability settings type")) + huma.WriteErr(api, ctx, http.StatusInternalServerError, "Unable to parse nullability settings", errors.New("invalid nullability settings type")) return } else if nullabilitySettings.Enabled { preserveNullsInMergePatch = true diff --git a/autopatch/autopatch_test.go b/autopatch/autopatch_test.go index 4c7bd794..2a58507c 100644 --- a/autopatch/autopatch_test.go +++ b/autopatch/autopatch_test.go @@ -10,6 +10,7 @@ import ( "testing/iotest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/danielgtaylor/huma/v2" "github.com/danielgtaylor/huma/v2/humatest" @@ -520,7 +521,7 @@ func TestReplaceNulls(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result, err := replaceNulls([]byte(tc.input), settings) - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq(t, tc.expected, string(result)) }) } @@ -566,7 +567,7 @@ func TestRestoreNulls(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result, err := restoreNulls([]byte(tc.input), settings) - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq(t, tc.expected, string(result)) }) } diff --git a/chain.go b/chain.go index 752efecf..eea4efe9 100644 --- a/chain.go +++ b/chain.go @@ -1,8 +1,10 @@ package huma +type Middleware func(next func(Context)) func(Context) + // Middlewares is a list of middleware functions that can be attached to an // API and will be called for all incoming requests. -type Middlewares []func(ctx Context, next func(Context)) +type Middlewares []Middleware // Handler builds and returns a handler func from the chain of middlewares, // with `endpoint func` as the final handler. @@ -10,13 +12,6 @@ func (m Middlewares) Handler(endpoint func(Context)) func(Context) { return m.chain(endpoint) } -// wrap user middleware func with the next func to one func -func wrap(fn func(Context, func(Context)), next func(Context)) func(Context) { - return func(ctx Context) { - fn(ctx, next) - } -} - // chain builds a Middleware composed of an inline middleware stack and endpoint // handler in the order they are passed. func (m Middlewares) chain(endpoint func(Context)) func(Context) { @@ -26,9 +21,9 @@ func (m Middlewares) chain(endpoint func(Context)) func(Context) { } // Wrap the end handler with the middleware chain - w := wrap(m[len(m)-1], endpoint) + w := m[len(m)-1](endpoint) for i := len(m) - 2; i >= 0; i-- { - w = wrap(m[i], w) + w = m[i](w) } return w } diff --git a/group.go b/group.go index d21f4acc..1e431e2f 100644 --- a/group.go +++ b/group.go @@ -167,7 +167,7 @@ func (g *Group) ModifyOperation(op *Operation, next func(*Operation)) { // UseMiddleware adds one or more middleware functions to the group that will be // run on all operations in the group. Use this to add common functionality to // all operations in the group, e.g. authentication/authorization. -func (g *Group) UseMiddleware(middlewares ...func(ctx Context, next func(Context))) { +func (g *Group) UseMiddleware(middlewares ...Middleware) { g.middlewares = append(g.middlewares, middlewares...) } diff --git a/group_test.go b/group_test.go index a6bc594c..16f7feaf 100644 --- a/group_test.go +++ b/group_test.go @@ -147,13 +147,17 @@ func TestGroupCustomizations(t *testing.T) { opModifier1Called = true }) - grp.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - middleware1Called = true - next(ctx) - }) - grp.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - middleware2Called = true - next(ctx) + grp.UseMiddleware(func(next func(huma.Context)) func(huma.Context) { + return func(ctx huma.Context) { + middleware1Called = true + next(ctx) + } + }) + grp.UseMiddleware(func(next func(huma.Context)) func(huma.Context) { + return func(ctx huma.Context) { + middleware2Called = true + next(ctx) + } }) grp.UseTransformer(func(ctx huma.Context, status string, v any) (any, error) { diff --git a/huma_test.go b/huma_test.go index 127c71fe..b75c3c2f 100644 --- a/huma_test.go +++ b/huma_test.go @@ -118,13 +118,17 @@ func TestFeatures(t *testing.T) { { Name: "middleware", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - // Just a do-nothing passthrough. Shows that chaining works. - next(ctx) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + // Just a do-nothing passthrough. Shows that chaining works. + next(ctx) + } }) - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - // Return an error response, never calling the next handler. - ctx.SetStatus(299) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + // Return an error response, never calling the next handler. + ctx.SetStatus(299) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -144,12 +148,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "foo") - require.NoError(t, err) - assert.Equal(t, "bar", cookie.Value) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "foo") + require.NoError(t, err) + assert.Equal(t, "bar", cookie.Value) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -167,12 +173,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-empty-cookie", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "foo") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "foo") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -190,12 +198,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie-only-semicolon", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "foo") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "foo") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -213,12 +223,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-read-no-cookie-in-header", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "foo") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "foo") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -233,12 +245,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie-invalid-name", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "foo") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "foo") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -256,12 +270,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie-filter-skip", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "foo") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "foo") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -279,12 +295,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie-parse-double-quote", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "bar") - require.NoError(t, err) - assert.NotNil(t, cookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "bar") + require.NoError(t, err) + assert.NotNil(t, cookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -302,12 +320,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie-invalid-value-byte-with-semicolon", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "bar") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "bar") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -325,12 +345,14 @@ func TestFeatures(t *testing.T) { { Name: "middleware-cookie-invalid-value-byte-with-double-backslash", Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - cookie, err := huma.ReadCookie(ctx, "bar") - assert.Nil(t, cookie) - require.ErrorIs(t, err, http.ErrNoCookie) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + cookie, err := huma.ReadCookie(ctx, "bar") + assert.Nil(t, cookie) + require.ErrorIs(t, err, http.ErrNoCookie) - next(ctx) + next(ctx) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet, @@ -352,13 +374,17 @@ func TestFeatures(t *testing.T) { Method: http.MethodGet, Path: "/middleware", Middlewares: huma.Middlewares{ - func(ctx huma.Context, next func(huma.Context)) { - // Just a do-nothing passthrough. Shows that chaining works. - next(ctx) + func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + // Just a do-nothing passthrough. Shows that chaining works. + next(ctx) + } }, - func(ctx huma.Context, next func(huma.Context)) { - // Return an error response, never calling the next handler. - ctx.SetStatus(299) + func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + // Return an error response, never calling the next handler. + ctx.SetStatus(299) + } }, }, }, func(ctx context.Context, input *struct{}) (*struct{}, error) { @@ -2181,18 +2207,20 @@ Content-Type: text/plain }, }, Register: func(t *testing.T, api huma.API) { - api.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { - called := false - defer func() { - if err := recover(); err != nil { - // Ensure the error is the one we expect, possibly wrapped with - // additional info. - assert.ErrorIs(t, err.(error), http.ErrNotSupported) - } - called = true - }() - next(ctx) - assert.True(t, called) + api.UseMiddleware(func(next func(huma.Context)) func(ctx huma.Context) { + return func(ctx huma.Context) { + called := false + defer func() { + if err := recover(); err != nil { + // Ensure the error is the one we expect, possibly wrapped with + // additional info. + assert.ErrorIs(t, err.(error), http.ErrNotSupported) + } + called = true + }() + next(ctx) + assert.True(t, called) + } }) huma.Register(api, huma.Operation{ Method: http.MethodGet,