这是indexloc提供的服务,不要输入任何密码
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions adapters/humabunrouter/humabunrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ func (c *bunContext) Version() huma.ProtoVersion {
}
}

func (c *bunContext) WithContext(ctx context.Context) huma.Context {
return &bunContext{
op: c.op,
r: c.r.WithContext(ctx),
w: c.w,
status: c.status,
}
}

// NewContext creates a new Huma context from an HTTP request and response.
func NewContext(op *huma.Operation, r bunrouter.Request, w http.ResponseWriter) huma.Context {
return &bunContext{op: op, r: r, w: w}
Expand Down Expand Up @@ -243,6 +252,15 @@ func (c *bunCompatContext) Version() huma.ProtoVersion {
}
}

func (c *bunCompatContext) WithContext(ctx context.Context) huma.Context {
return &bunCompatContext{
op: c.op,
r: c.r.WithContext(ctx),
w: c.w,
status: c.status,
}
}

// NewCompatContext creates a new Huma context from an HTTP request and response.
func NewCompatContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context {
return &bunCompatContext{op: op, r: r, w: w}
Expand Down Expand Up @@ -310,3 +328,15 @@ func NewCompat(r *bunrouter.CompatRouter, config huma.Config) huma.API {
func New(r *bunrouter.Router, config huma.Config) huma.API {
return huma.NewAPI(config, NewAdapter(r))
}

func middleware(mw bunrouter.MiddlewareFunc) func(ctx huma.Context, next func(huma.Context)) {
return func(ctx huma.Context, next func(huma.Context)) {
r, w := Unwrap(ctx)
f := mw(func(w http.ResponseWriter, r bunrouter.Request) error {
ctx = NewContext(ctx.Operation(), r, w)
next(ctx)
return nil
})
f(w, r)
}
}
47 changes: 47 additions & 0 deletions adapters/humabunrouter/humabunrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uptrace/bunrouter"

"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/humatest"
)

var lastModified = time.Now()
Expand Down Expand Up @@ -320,3 +323,47 @@ func BenchmarkRawBunRouterFast(b *testing.B) {
r.ServeHTTP(w, req)
}
}

// See https://github.com/danielgtaylor/huma/issues/859
func TestWithValueShouldPropagateContext(t *testing.T) {
r := bunrouter.New()
app := New(r, huma.DefaultConfig("Test", "1.0.0"))

type (
testInput struct{}
testOutput struct{}
ctxKey struct{}
)

ctxValue := "sentinelValue"

huma.Register(app, huma.Operation{
OperationID: "test",
Path: "/test",
Method: http.MethodGet,
Middlewares: huma.Middlewares{
func(ctx huma.Context, next func(huma.Context)) {
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
next(ctx)
},
middleware(func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, r bunrouter.Request) error {
val, _ := r.Context().Value(ctxKey{}).(string)
_, err := io.WriteString(w, val)
return err
}
}),
},
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
out := &testOutput{}
return out, nil
})

tapi := humatest.Wrap(t, app)

resp := tapi.Get("/test")
assert.Equal(t, http.StatusOK, resp.Code)
out, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, ctxValue, string(out))
}
19 changes: 19 additions & 0 deletions adapters/humachi/humachi.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ func (c *chiContext) Version() huma.ProtoVersion {
}
}

func (c *chiContext) WithContext(ctx context.Context) huma.Context {
return &chiContext{
op: c.op,
r: c.r.WithContext(ctx),
w: c.w,
status: c.status,
}
}

// NewContext creates a new Huma context from an HTTP request and response.
func NewContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context {
return &chiContext{op: op, r: r, w: w}
Expand Down Expand Up @@ -174,3 +183,13 @@ func NewAdapter(r chi.Router) huma.Adapter {
func New(r chi.Router, config huma.Config) huma.API {
return huma.NewAPI(config, &chiAdapter{router: r})
}

func middleware(mw func(http.Handler) http.Handler) func(ctx huma.Context, next func(huma.Context)) {
return func(ctx huma.Context, next func(huma.Context)) {
r, w := Unwrap(ctx)
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx = NewContext(ctx.Operation(), r, w)
next(ctx)
})).ServeHTTP(w, r)
}
}
43 changes: 43 additions & 0 deletions adapters/humachi/humachi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,46 @@ func TestPathParamDecoding(t *testing.T) {
// app.ServeHTTP(w, req)
// }
// }

// See https://github.com/danielgtaylor/huma/issues/859
func TestWithValueShouldPropagateContext(t *testing.T) {
r := chi.NewMux()
app := New(r, huma.DefaultConfig("Test", "1.0.0"))

type (
testInput struct{}
testOutput struct{}
ctxKey struct{}
)

ctxValue := "sentinelValue"

huma.Register(app, huma.Operation{
OperationID: "test",
Path: "/test",
Method: http.MethodGet,
Middlewares: huma.Middlewares{
func(ctx huma.Context, next func(huma.Context)) {
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
next(ctx)
},
middleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val, _ := r.Context().Value(ctxKey{}).(string)
io.WriteString(w, val)
})
}),
},
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
out := &testOutput{}
return out, nil
})

tapi := humatest.Wrap(t, app)

resp := tapi.Get("/test")
assert.Equal(t, http.StatusOK, resp.Code)
out, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, ctxValue, string(out))
}
22 changes: 22 additions & 0 deletions adapters/humaecho/humaecho.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ func (c *echoCtx) Version() huma.ProtoVersion {
}
}

func (c *echoCtx) WithContext(ctx context.Context) huma.Context {
new := c.orig
new.SetRequest(new.Request().WithContext(ctx))
return &echoCtx{
op: c.op,
orig: new,
status: c.status,
}
}

type router interface {
Add(method, path string, handler echo.HandlerFunc, middlewares ...echo.MiddlewareFunc) *echo.Route
}
Expand Down Expand Up @@ -170,3 +180,15 @@ func New(r *echo.Echo, config huma.Config) huma.API {
func NewWithGroup(r *echo.Echo, g *echo.Group, config huma.Config) huma.API {
return huma.NewAPI(config, &echoAdapter{Handler: r, router: g})
}

func middleware(mw echo.MiddlewareFunc) func(ctx huma.Context, next func(huma.Context)) {
return func(ctx huma.Context, next func(huma.Context)) {
eCtx := Unwrap(ctx)
f := mw(func(c echo.Context) error {
ctx = &echoCtx{op: ctx.Operation(), orig: eCtx}
next(ctx)
return nil
})
f(eCtx)
}
}
47 changes: 47 additions & 0 deletions adapters/humaecho/humaecho_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ import (
"time"

"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/humatest"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var lastModified = time.Now()
Expand Down Expand Up @@ -240,3 +243,47 @@ func BenchmarkRawEchoFast(b *testing.B) {
r.ServeHTTP(w, req)
}
}

// See https://github.com/danielgtaylor/huma/issues/859
func TestWithValueShouldPropagateContext(t *testing.T) {
r := echo.New()
app := New(r, huma.DefaultConfig("Test", "1.0.0"))

type (
testInput struct{}
testOutput struct{}
ctxKey struct{}
)

ctxValue := "sentinelValue"

huma.Register(app, huma.Operation{
OperationID: "test",
Path: "/test",
Method: http.MethodGet,
Middlewares: huma.Middlewares{
func(ctx huma.Context, next func(huma.Context)) {
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
next(ctx)
},
middleware(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
val, _ := c.Request().Context().Value(ctxKey{}).(string)
_, err := io.WriteString(c.Response().Writer, val)
return err
}
}),
},
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
out := &testOutput{}
return out, nil
})

tapi := humatest.Wrap(t, app)

resp := tapi.Get("/test")
assert.Equal(t, http.StatusOK, resp.Code)
out, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, ctxValue, string(out))
}
22 changes: 22 additions & 0 deletions adapters/humafiber/humafiber.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ func (c *fiberWrapper) Version() huma.ProtoVersion {
}
}

func (c *fiberWrapper) WithContext(ctx context.Context) huma.Context {
new := c.orig
new.SetUserContext(ctx)
return &fiberWrapper{
op: c.op,
status: c.status,
orig: new,
ctx: ctx,
}
}

type router interface {
Add(method, path string, handlers ...fiber.Handler) fiber.Router
}
Expand Down Expand Up @@ -242,3 +253,14 @@ func New(r *fiber.App, config huma.Config) huma.API {
func NewWithGroup(r *fiber.App, g fiber.Router, config huma.Config) huma.API {
return huma.NewAPI(config, &fiberAdapter{tester: r, router: g})
}

func middleware(mw func(next fiber.Handler) fiber.Handler) func(ctx huma.Context, next func(huma.Context)) {
return func(ctx huma.Context, next func(huma.Context)) {
fCtx := Unwrap(ctx)
h := mw(func(c *fiber.Ctx) error {
next(ctx)
return nil
})
h(fCtx)
}
}
47 changes: 47 additions & 0 deletions adapters/humafiber/humafiber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package humafiber

import (
"context"
"io"
"net/http"
"testing"

"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/humatest"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func BenchmarkHumaFiber(b *testing.B) {
Expand Down Expand Up @@ -59,3 +63,46 @@ func BenchmarkNotHuma(b *testing.B) {
r.Test(req)
}
}

func TestWithValueShouldPropagateContext(t *testing.T) {
r := fiber.New()
app := New(r, huma.DefaultConfig("Test", "1.0.0"))

type (
testInput struct{}
testOutput struct{}
ctxKey struct{}
)

ctxValue := "sentinelValue"

huma.Register(app, huma.Operation{
OperationID: "test",
Path: "/test",
Method: http.MethodGet,
Middlewares: huma.Middlewares{
func(ctx huma.Context, next func(huma.Context)) {
ctx = huma.WithValue(ctx, ctxKey{}, ctxValue)
next(ctx)
},
middleware(func(next fiber.Handler) fiber.Handler {
return func(c *fiber.Ctx) error {
val, _ := c.UserContext().Value(ctxKey{}).(string)
_, err := c.WriteString(val)
return err
}
}),
},
}, func(ctx context.Context, input *testInput) (*testOutput, error) {
out := &testOutput{}
return out, nil
})

tapi := humatest.Wrap(t, app)

resp := tapi.Get("/test")
assert.Equal(t, http.StatusOK, resp.Code)
out, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, ctxValue, string(out))
}
Loading
Loading