diff --git a/adapters/adapters_test.go b/adapters/adapters_test.go index c8caea3d..3f4e6bab 100644 --- a/adapters/adapters_test.go +++ b/adapters/adapters_test.go @@ -14,6 +14,7 @@ import ( "github.com/danielgtaylor/huma/v2/adapters/humaecho" "github.com/danielgtaylor/huma/v2/adapters/humafiber" "github.com/danielgtaylor/huma/v2/adapters/humagin" + "github.com/danielgtaylor/huma/v2/adapters/humago" "github.com/danielgtaylor/huma/v2/adapters/humahttprouter" "github.com/danielgtaylor/huma/v2/adapters/humamux" "github.com/danielgtaylor/huma/v2/humatest" @@ -27,6 +28,8 @@ import ( "github.com/uptrace/bunrouter" ) +type key struct{} + // Test the various input types (path, query, header, body). type TestInput struct { Group string `path:"group"` @@ -90,10 +93,11 @@ func TestAdapters(t *testing.T) { return huma.DefaultConfig("Test", "1.0.0") } - wrap := func(h huma.API, isFiber bool) huma.API { + wrap := func(h huma.API, isFiber bool, unwrapper func(ctx huma.Context)) huma.API { h.UseMiddleware(func(ctx huma.Context, next func(huma.Context)) { assert.Nil(t, ctx.TLS()) v := ctx.Version() + if !isFiber { assert.Equal(t, "HTTP/1.1", v.Proto) assert.Equal(t, 1, v.ProtoMajor) @@ -101,6 +105,14 @@ func TestAdapters(t *testing.T) { } else { assert.Equal(t, "http", v.Proto) } + + // Make sure huma.WithValue works correctly + ctx = huma.WithContext(ctx, context.WithValue(ctx.Context(), key{}, "value")) + + next(ctx) + }, func(ctx huma.Context, next func(huma.Context)) { + // Make sure the Unwrap func does not panic even when the context is wrapped by WithContext + assert.NotPanics(t, func() { unwrapper(ctx) }) next(ctx) }) return h @@ -110,14 +122,35 @@ func TestAdapters(t *testing.T) { name string new func() huma.API }{ - {"chi", func() huma.API { return wrap(humachi.New(chi.NewMux(), config()), false) }}, - {"echo", func() huma.API { return wrap(humaecho.New(echo.New(), config()), false) }}, - {"fiber", func() huma.API { return wrap(humafiber.New(fiber.New(), config()), true) }}, - {"gin", func() huma.API { return wrap(humagin.New(gin.New(), config()), false) }}, - {"httprouter", func() huma.API { return wrap(humahttprouter.New(httprouter.New(), config()), false) }}, - {"mux", func() huma.API { return wrap(humamux.New(mux.NewRouter(), config()), false) }}, - {"bunrouter", func() huma.API { return wrap(humabunrouter.New(bunrouter.New(), config()), false) }}, - {"bunroutercompat", func() huma.API { return wrap(humabunrouter.NewCompat(bunrouter.New().Compat(), config()), false) }}, + {"chi", func() huma.API { + return wrap(humachi.New(chi.NewMux(), config()), false, func(ctx huma.Context) { humachi.Unwrap(ctx) }) + }}, + {"echo", func() huma.API { + return wrap(humaecho.New(echo.New(), config()), false, func(ctx huma.Context) { humaecho.Unwrap(ctx) }) + }}, + {"fiber", func() huma.API { + return wrap(humafiber.New(fiber.New(), config()), true, func(ctx huma.Context) { humafiber.Unwrap(ctx) }) + }}, + {"go", func() huma.API { + return wrap(humago.New(http.NewServeMux(), config()), false, func(ctx huma.Context) { humago.Unwrap(ctx) }) + }}, + {"gin", func() huma.API { + return wrap(humagin.New(gin.New(), config()), false, func(ctx huma.Context) { humagin.Unwrap(ctx) }) + }}, + {"httprouter", func() huma.API { + return wrap(humahttprouter.New(httprouter.New(), config()), false, func(ctx huma.Context) { humahttprouter.Unwrap(ctx) }) + }}, + {"mux", func() huma.API { + return wrap(humamux.New(mux.NewRouter(), config()), false, func(ctx huma.Context) { humamux.Unwrap(ctx) }) + }}, + {"bunrouter", func() huma.API { + return wrap(humabunrouter.New(bunrouter.New(), config()), false, func(ctx huma.Context) { humabunrouter.Unwrap(ctx) }) + }}, + {"bunroutercompat", func() huma.API { + return wrap(humabunrouter.NewCompat(bunrouter.New().Compat(), config()), false, func(ctx huma.Context) { + // FIXME: humabunrouter.Unwrap(ctx) doesn't work with compat mode + }) + }}, } { t.Run(adapter.name, func(t *testing.T) { testAdapter(t, adapter.new()) diff --git a/adapters/humabunrouter/humabunrouter.go b/adapters/humabunrouter/humabunrouter.go index 21bb5df0..834faf98 100644 --- a/adapters/humabunrouter/humabunrouter.go +++ b/adapters/humabunrouter/humabunrouter.go @@ -23,6 +23,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying HTTP request and response writer from a Huma // context. If passed a context from a different adapter it will panic. func Unwrap(ctx huma.Context) (bunrouter.Request, http.ResponseWriter) { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*bunContext); ok { return c.Unwrap() } diff --git a/adapters/humachi/humachi.go b/adapters/humachi/humachi.go index f6140365..8e2747e7 100644 --- a/adapters/humachi/humachi.go +++ b/adapters/humachi/humachi.go @@ -21,6 +21,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying HTTP request and response writer from a Huma // context. If passed a context from a different adapter it will panic. func Unwrap(ctx huma.Context) (*http.Request, http.ResponseWriter) { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*chiContext); ok { return c.Unwrap() } diff --git a/adapters/humaecho/humaecho.go b/adapters/humaecho/humaecho.go index 57ef440c..d7c77089 100644 --- a/adapters/humaecho/humaecho.go +++ b/adapters/humaecho/humaecho.go @@ -21,6 +21,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying Echo context from a Huma context. If passed a // context from a different adapter it will panic. func Unwrap(ctx huma.Context) echo.Context { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*echoCtx); ok { return c.Unwrap() } diff --git a/adapters/humafiber/humafiber.go b/adapters/humafiber/humafiber.go index 651f4b0e..ba972d9a 100644 --- a/adapters/humafiber/humafiber.go +++ b/adapters/humafiber/humafiber.go @@ -21,6 +21,13 @@ import ( // memory-safety: https://docs.gofiber.io/#zero-allocation. Do not keep // references to the underlying context or its values! func Unwrap(ctx huma.Context) *fiber.Ctx { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*fiberWrapper); ok { return c.Unwrap() } diff --git a/adapters/humaflow/humaflow.go b/adapters/humaflow/humaflow.go index 5f75bba6..2c371e5f 100644 --- a/adapters/humaflow/humaflow.go +++ b/adapters/humaflow/humaflow.go @@ -22,6 +22,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying HTTP request and response writer from a Huma // context. If passed a context from a different adapter it will panic. func Unwrap(ctx huma.Context) (*http.Request, http.ResponseWriter) { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*goContext); ok { return c.Unwrap() } diff --git a/adapters/humagin/humagin.go b/adapters/humagin/humagin.go index 1092ffdd..5d47f026 100644 --- a/adapters/humagin/humagin.go +++ b/adapters/humagin/humagin.go @@ -21,6 +21,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying Gin context from a Huma context. If passed a // context from a different adapter it will panic. func Unwrap(ctx huma.Context) *gin.Context { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*ginCtx); ok { return c.Unwrap() } diff --git a/adapters/humago/humago.go b/adapters/humago/humago.go index 8332a8ca..01e02208 100644 --- a/adapters/humago/humago.go +++ b/adapters/humago/humago.go @@ -21,6 +21,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying HTTP request and response writer from a Huma // context. If passed a context from a different adapter it will panic. func Unwrap(ctx huma.Context) (*http.Request, http.ResponseWriter) { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*goContext); ok { return c.Unwrap() } diff --git a/adapters/humahttprouter/humahttprouter.go b/adapters/humahttprouter/humahttprouter.go index 01a181aa..1d63d5c8 100644 --- a/adapters/humahttprouter/humahttprouter.go +++ b/adapters/humahttprouter/humahttprouter.go @@ -22,6 +22,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying HTTP request and response writer from a Huma // context. If passed a context from a different adapter it will panic. func Unwrap(ctx huma.Context) (*http.Request, http.ResponseWriter, httprouter.Params) { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*httprouterContext); ok { return c.Unwrap() } diff --git a/adapters/humamux/humamux.go b/adapters/humamux/humamux.go index 75b57792..a8cc6312 100644 --- a/adapters/humamux/humamux.go +++ b/adapters/humamux/humamux.go @@ -21,6 +21,13 @@ var MultipartMaxMemory int64 = 8 * 1024 // Unwrap extracts the underlying HTTP request and response writer from a Huma // context. If passed a context from a different adapter it will panic. func Unwrap(ctx huma.Context) (*http.Request, http.ResponseWriter) { + for { + if c, ok := ctx.(interface{ Unwrap() huma.Context }); ok { + ctx = c.Unwrap() + continue + } + break + } if c, ok := ctx.(*gmuxContext); ok { return c.Unwrap() } diff --git a/api.go b/api.go index a3e43e35..39acec04 100644 --- a/api.go +++ b/api.go @@ -143,6 +143,12 @@ func (c subContext) Context() context.Context { return c.override } +// Unwrap returns the underlying Huma context, which enables you to access +// the original adapter-specific request context. +func (c subContext) Unwrap() Context { + return c.humaContext +} + // WithContext returns a new `huma.Context` with the underlying `context.Context` // replaced with the given one. This is useful for middleware that needs to // modify the request context.