diff --git a/adapters/humabunrouter/humabunrouter.go b/adapters/humabunrouter/humabunrouter.go index 834faf98..89a3e217 100644 --- a/adapters/humabunrouter/humabunrouter.go +++ b/adapters/humabunrouter/humabunrouter.go @@ -140,6 +140,15 @@ func (c *bunContext) Version() huma.ProtoVersion { } } +func (c *bunContext) WithContext(ctx context.Context) huma.Context { + return &bunContext{ + op: c.op, + r: c.r.WithContext(ctx), + w: c.w, + status: c.status, + } +} + // NewContext creates a new Huma context from an HTTP request and response. func NewContext(op *huma.Operation, r bunrouter.Request, w http.ResponseWriter) huma.Context { return &bunContext{op: op, r: r, w: w} @@ -243,6 +252,15 @@ func (c *bunCompatContext) Version() huma.ProtoVersion { } } +func (c *bunCompatContext) WithContext(ctx context.Context) huma.Context { + return &bunCompatContext{ + op: c.op, + r: c.r.WithContext(ctx), + w: c.w, + status: c.status, + } +} + // NewCompatContext creates a new Huma context from an HTTP request and response. func NewCompatContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context { return &bunCompatContext{op: op, r: r, w: w} @@ -310,3 +328,15 @@ func NewCompat(r *bunrouter.CompatRouter, config huma.Config) huma.API { func New(r *bunrouter.Router, config huma.Config) huma.API { return huma.NewAPI(config, NewAdapter(r)) } + +func middleware(mw bunrouter.MiddlewareFunc) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + r, w := Unwrap(ctx) + f := mw(func(w http.ResponseWriter, r bunrouter.Request) error { + ctx = NewContext(ctx.Operation(), r, w) + next(ctx) + return nil + }) + f(w, r) + } +} diff --git a/adapters/humabunrouter/humabunrouter_test.go b/adapters/humabunrouter/humabunrouter_test.go index c61758c2..761cbafe 100644 --- a/adapters/humabunrouter/humabunrouter_test.go +++ b/adapters/humabunrouter/humabunrouter_test.go @@ -14,9 +14,12 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/uptrace/bunrouter" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/humatest" ) var lastModified = time.Now() @@ -320,3 +323,47 @@ func BenchmarkRawBunRouterFast(b *testing.B) { r.ServeHTTP(w, req) } } + +// See https://github.com/danielgtaylor/huma/issues/859 +func TestWithValueShouldPropagateContext(t *testing.T) { + r := bunrouter.New() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { + return func(w http.ResponseWriter, r bunrouter.Request) error { + val, _ := r.Context().Value(ctxKey{}).(string) + _, err := io.WriteString(w, val) + return err + } + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humachi/humachi.go b/adapters/humachi/humachi.go index 470d14c0..c1095065 100644 --- a/adapters/humachi/humachi.go +++ b/adapters/humachi/humachi.go @@ -146,6 +146,15 @@ func (c *chiContext) Version() huma.ProtoVersion { } } +func (c *chiContext) WithContext(ctx context.Context) huma.Context { + return &chiContext{ + op: c.op, + r: c.r.WithContext(ctx), + w: c.w, + status: c.status, + } +} + // NewContext creates a new Huma context from an HTTP request and response. func NewContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context { return &chiContext{op: op, r: r, w: w} @@ -174,3 +183,13 @@ func NewAdapter(r chi.Router) huma.Adapter { func New(r chi.Router, config huma.Config) huma.API { return huma.NewAPI(config, &chiAdapter{router: r}) } + +func middleware(mw func(http.Handler) http.Handler) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + r, w := Unwrap(ctx) + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx = NewContext(ctx.Operation(), r, w) + next(ctx) + })).ServeHTTP(w, r) + } +} diff --git a/adapters/humachi/humachi_test.go b/adapters/humachi/humachi_test.go index 0b9314c6..0f9d99ba 100644 --- a/adapters/humachi/humachi_test.go +++ b/adapters/humachi/humachi_test.go @@ -438,3 +438,46 @@ func TestPathParamDecoding(t *testing.T) { // app.ServeHTTP(w, req) // } // } + +// See https://github.com/danielgtaylor/huma/issues/859 +func TestWithValueShouldPropagateContext(t *testing.T) { + r := chi.NewMux() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val, _ := r.Context().Value(ctxKey{}).(string) + io.WriteString(w, val) + }) + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humaecho/humaecho.go b/adapters/humaecho/humaecho.go index d7c77089..9821719d 100644 --- a/adapters/humaecho/humaecho.go +++ b/adapters/humaecho/humaecho.go @@ -138,6 +138,16 @@ func (c *echoCtx) Version() huma.ProtoVersion { } } +func (c *echoCtx) WithContext(ctx context.Context) huma.Context { + new := c.orig + new.SetRequest(new.Request().WithContext(ctx)) + return &echoCtx{ + op: c.op, + orig: new, + status: c.status, + } +} + type router interface { Add(method, path string, handler echo.HandlerFunc, middlewares ...echo.MiddlewareFunc) *echo.Route } @@ -170,3 +180,15 @@ func New(r *echo.Echo, config huma.Config) huma.API { func NewWithGroup(r *echo.Echo, g *echo.Group, config huma.Config) huma.API { return huma.NewAPI(config, &echoAdapter{Handler: r, router: g}) } + +func middleware(mw echo.MiddlewareFunc) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + eCtx := Unwrap(ctx) + f := mw(func(c echo.Context) error { + ctx = &echoCtx{op: ctx.Operation(), orig: eCtx} + next(ctx) + return nil + }) + f(eCtx) + } +} diff --git a/adapters/humaecho/humaecho_test.go b/adapters/humaecho/humaecho_test.go index 46be5954..73b90126 100644 --- a/adapters/humaecho/humaecho_test.go +++ b/adapters/humaecho/humaecho_test.go @@ -13,7 +13,10 @@ import ( "time" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/humatest" "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var lastModified = time.Now() @@ -240,3 +243,47 @@ func BenchmarkRawEchoFast(b *testing.B) { r.ServeHTTP(w, req) } } + +// See https://github.com/danielgtaylor/huma/issues/859 +func TestWithValueShouldPropagateContext(t *testing.T) { + r := echo.New() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + val, _ := c.Request().Context().Value(ctxKey{}).(string) + _, err := io.WriteString(c.Response().Writer, val) + return err + } + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humafiber/humafiber.go b/adapters/humafiber/humafiber.go index ba972d9a..12b8b04d 100644 --- a/adapters/humafiber/humafiber.go +++ b/adapters/humafiber/humafiber.go @@ -153,6 +153,17 @@ func (c *fiberWrapper) Version() huma.ProtoVersion { } } +func (c *fiberWrapper) WithContext(ctx context.Context) huma.Context { + new := c.orig + new.SetUserContext(ctx) + return &fiberWrapper{ + op: c.op, + status: c.status, + orig: new, + ctx: ctx, + } +} + type router interface { Add(method, path string, handlers ...fiber.Handler) fiber.Router } @@ -242,3 +253,14 @@ func New(r *fiber.App, config huma.Config) huma.API { func NewWithGroup(r *fiber.App, g fiber.Router, config huma.Config) huma.API { return huma.NewAPI(config, &fiberAdapter{tester: r, router: g}) } + +func middleware(mw func(next fiber.Handler) fiber.Handler) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + fCtx := Unwrap(ctx) + h := mw(func(c *fiber.Ctx) error { + next(ctx) + return nil + }) + h(fCtx) + } +} diff --git a/adapters/humafiber/humafiber_test.go b/adapters/humafiber/humafiber_test.go index 1538d9d2..4cb5261e 100644 --- a/adapters/humafiber/humafiber_test.go +++ b/adapters/humafiber/humafiber_test.go @@ -2,11 +2,15 @@ package humafiber import ( "context" + "io" "net/http" "testing" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/humatest" "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func BenchmarkHumaFiber(b *testing.B) { @@ -59,3 +63,46 @@ func BenchmarkNotHuma(b *testing.B) { r.Test(req) } } + +func TestWithValueShouldPropagateContext(t *testing.T) { + r := fiber.New() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(next fiber.Handler) fiber.Handler { + return func(c *fiber.Ctx) error { + val, _ := c.UserContext().Value(ctxKey{}).(string) + _, err := c.WriteString(val) + return err + } + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humagin/humagin.go b/adapters/humagin/humagin.go index 5d47f026..68241e71 100644 --- a/adapters/humagin/humagin.go +++ b/adapters/humagin/humagin.go @@ -137,6 +137,16 @@ func (c *ginCtx) Version() huma.ProtoVersion { } } +func (c *ginCtx) WithContext(ctx context.Context) huma.Context { + new := c.orig + new.Request = c.orig.Request.WithContext(ctx) + return &ginCtx{ + op: c.op, + orig: new, + status: c.status, + } +} + // NewContext creates a new Huma context from a Gin context func NewContext(op *huma.Operation, c *gin.Context) huma.Context { return &ginCtx{op: op, orig: c} @@ -174,3 +184,15 @@ func New(r *gin.Engine, config huma.Config) huma.API { func NewWithGroup(r *gin.Engine, g *gin.RouterGroup, config huma.Config) huma.API { return huma.NewAPI(config, &ginAdapter{Handler: r, router: g}) } + +// middleware converts a Gin middleware function to a Huma middleware function. +func middleware(mw func(gin.HandlerFunc) gin.HandlerFunc) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + c := Unwrap(ctx) + f := mw(func(gCtx *gin.Context) { + ctx := NewContext(ctx.Operation(), gCtx) + next(ctx) + }) + f(c) + } +} diff --git a/adapters/humagin/humagin_test.go b/adapters/humagin/humagin_test.go index 28094cbd..a5e633ef 100644 --- a/adapters/humagin/humagin_test.go +++ b/adapters/humagin/humagin_test.go @@ -2,6 +2,7 @@ package humagin import ( "context" + "io" "net/http" "net/http/httptest" "strings" @@ -9,7 +10,10 @@ import ( "time" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/humatest" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var lastModified = time.Now() @@ -71,3 +75,45 @@ func BenchmarkHumaGin(b *testing.B) { } } } + +func TestWithValueShouldPropagateContext(t *testing.T) { + r := gin.New() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(next gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + val, _ := c.Request.Context().Value(ctxKey{}).(string) + c.String(http.StatusOK, val) + } + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humago/humago.go b/adapters/humago/humago.go index 01e02208..41cbd4bd 100644 --- a/adapters/humago/humago.go +++ b/adapters/humago/humago.go @@ -138,6 +138,15 @@ func (c *goContext) Version() huma.ProtoVersion { } } +func (c *goContext) WithContext(ctx context.Context) huma.Context { + return &goContext{ + op: c.op, + r: c.r.WithContext(ctx), + w: c.w, + status: c.status, + } +} + // NewContext creates a new Huma context from an HTTP request and response. func NewContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context { return &goContext{op: op, r: r, w: w} @@ -190,3 +199,13 @@ func NewWithPrefix(m Mux, prefix string, config huma.Config) huma.API { } return huma.NewAPI(config, &goAdapter{m, prefix}) } + +func middleware(mw func(http.Handler) http.Handler) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + r, w := Unwrap(ctx) + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx = NewContext(ctx.Operation(), r, w) + next(ctx) + })).ServeHTTP(w, r) + } +} diff --git a/adapters/humago/humago_test.go b/adapters/humago/humago_test.go index 8ca76b5d..71267032 100644 --- a/adapters/humago/humago_test.go +++ b/adapters/humago/humago_test.go @@ -13,6 +13,9 @@ import ( "time" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/humatest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var lastModified = time.Now() @@ -177,3 +180,46 @@ func BenchmarkRawGo(b *testing.B) { } } } + +// See https://github.com/danielgtaylor/huma/issues/859 +func TestWithValueShouldPropagateContext(t *testing.T) { + r := http.NewServeMux() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val, _ := r.Context().Value(ctxKey{}).(string) + io.WriteString(w, val) + }) + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humahttprouter/humahttprouter.go b/adapters/humahttprouter/humahttprouter.go index 1d63d5c8..e970c8c3 100644 --- a/adapters/humahttprouter/humahttprouter.go +++ b/adapters/humahttprouter/humahttprouter.go @@ -140,6 +140,16 @@ func (c *httprouterContext) Version() huma.ProtoVersion { } } +func (c *httprouterContext) WithContext(ctx context.Context) huma.Context { + return &httprouterContext{ + op: c.op, + r: c.r.WithContext(ctx), + w: c.w, + ps: c.ps, + status: c.status, + } +} + type httprouterAdapter struct { router *httprouter.Router } @@ -161,3 +171,22 @@ func (a *httprouterAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { func New(r *httprouter.Router, config huma.Config) huma.API { return huma.NewAPI(config, &httprouterAdapter{router: r}) } + +// middleware adapts a Httprouter middleware to huma's middleware type for testing +func middleware(mw func(next httprouter.Handle) httprouter.Handle) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + // Unwrap the context to get the httprouter params + r, w, ps := Unwrap(ctx) + h := mw(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + ctx = &httprouterContext{ + op: ctx.Operation(), + r: r, + w: w, + ps: p, + status: ctx.Status(), + } + next(ctx) + }) + h(w, r, ps) + } +} diff --git a/adapters/humahttprouter/humahttprouter_test.go b/adapters/humahttprouter/humahttprouter_test.go index a8685013..1de9cd92 100644 --- a/adapters/humahttprouter/humahttprouter_test.go +++ b/adapters/humahttprouter/humahttprouter_test.go @@ -2,6 +2,7 @@ package humahttprouter import ( "context" + "io" "net/http" "net/http/httptest" "strings" @@ -9,7 +10,10 @@ import ( "time" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/humatest" "github.com/julienschmidt/httprouter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var lastModified = time.Now() @@ -70,3 +74,45 @@ func BenchmarkHumaHttprouter(b *testing.B) { } } } + +func TestWithValueShouldPropagateContext(t *testing.T) { + r := httprouter.New() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(next httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + val, _ := r.Context().Value(ctxKey{}).(string) + w.Write([]byte(val)) + } + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humamux/humagmux_test.go b/adapters/humamux/humagmux_test.go index 4d95beda..361cc771 100644 --- a/adapters/humamux/humagmux_test.go +++ b/adapters/humamux/humagmux_test.go @@ -3,6 +3,7 @@ package humamux import ( "context" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -13,6 +14,7 @@ import ( "github.com/danielgtaylor/huma/v2/humatest" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var lastModified = time.Now() @@ -133,3 +135,45 @@ func BenchmarkHumaGorillaMux(b *testing.B) { } } } + +func TestWithValueShouldPropagateContext(t *testing.T) { + r := mux.NewRouter() + app := New(r, huma.DefaultConfig("Test", "1.0.0")) + + type ( + testInput struct{} + testOutput struct{} + ctxKey struct{} + ) + + ctxValue := "sentinelValue" + + huma.Register(app, huma.Operation{ + OperationID: "test", + Path: "/test", + Method: http.MethodGet, + Middlewares: huma.Middlewares{ + func(ctx huma.Context, next func(huma.Context)) { + ctx = huma.WithValue(ctx, ctxKey{}, ctxValue) + next(ctx) + }, + middleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val, _ := r.Context().Value(ctxKey{}).(string) + w.Write([]byte(val)) + }) + }), + }, + }, func(ctx context.Context, input *testInput) (*testOutput, error) { + out := &testOutput{} + return out, nil + }) + + tapi := humatest.Wrap(t, app) + + resp := tapi.Get("/test") + assert.Equal(t, http.StatusOK, resp.Code) + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, ctxValue, string(out)) +} diff --git a/adapters/humamux/humamux.go b/adapters/humamux/humamux.go index a8cc6312..c94ba9e1 100644 --- a/adapters/humamux/humamux.go +++ b/adapters/humamux/humamux.go @@ -96,6 +96,15 @@ func (c *gmuxContext) Version() huma.ProtoVersion { } } +func (c *gmuxContext) WithContext(ctx context.Context) huma.Context { + return &gmuxContext{ + op: c.op, + r: c.r.WithContext(ctx), + w: c.w, + status: c.status, + } +} + func (c *gmuxContext) EachHeader(cb func(name, value string)) { for name, values := range c.r.Header { for _, value := range values { @@ -163,3 +172,18 @@ func (a *gMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { func New(r *mux.Router, config huma.Config, options ...Option) huma.API { return huma.NewAPI(config, &gMux{router: r, options: parseOptions(options)}) } + +// middleware converts a Gin middleware function to a Huma middleware function. +func middleware(mw func(next http.Handler) http.Handler) func(ctx huma.Context, next func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + r, w := Unwrap(ctx) + mw(http.HandlerFunc(func(gw http.ResponseWriter, gr *http.Request) { + ctx = &gmuxContext{ + op: ctx.Operation(), + r: gr, + w: gw, + } + next(ctx) + })).ServeHTTP(w, r) + } +} diff --git a/api.go b/api.go index e93c88fd..71a490fa 100644 --- a/api.go +++ b/api.go @@ -155,6 +155,9 @@ func (c subContext) Unwrap() Context { // replaced with the given one. This is useful for middleware that needs to // modify the request context. func WithContext(ctx Context, override context.Context) Context { + if sub, ok := ctx.(interface{ WithContext(context.Context) Context }); ok { + return sub.WithContext(override) + } return subContext{humaContext: ctx, override: override} }