diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 67dbd7f4..0e0f0053 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,7 +20,7 @@ jobs: uses: golangci/golangci-lint-action@v6 with: version: v1.60.1 - - run: go test -coverprofile=coverage.txt -covermode=atomic ./... + - run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... - uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/adapters/humafiber/humafiber.go b/adapters/humafiber/humafiber.go index 51b05b84..0e189a95 100644 --- a/adapters/humafiber/humafiber.go +++ b/adapters/humafiber/humafiber.go @@ -15,120 +15,70 @@ import ( "github.com/gofiber/fiber/v2" ) -type fiberCtx struct { +type fiberAdapter struct { + tester requestTester + router router +} + +type fiberWrapper struct { op *huma.Operation status int - - /* - * Web framework "fiber" https://gofiber.io/ uses high-performance zero-allocation "fasthttp" server https://github.com/valyala/fasthttp - * - * The underlying fasthttp server prohibits to use or refer to `*fasthttp.RequestCtx` outside handler - * The quote from documentation to fasthttp https://github.com/valyala/fasthttp/blob/master/README.md - * - * > VERY IMPORTANT! Fasthttp disallows holding references to RequestCtx or to its' members after returning from RequestHandler. Otherwise data races are inevitable. Carefully inspect all the net/http request handlers converted to fasthttp whether they retain references to RequestCtx or to its' members after returning - * - * As the result "fiber" prohibits to use or refer to `*fiber.Ctx` outside handler - * The quote from documentation to fiber https://docs.gofiber.io/#zero-allocation - * - * > Because fiber is optimized for high-performance, values returned from fiber.Ctx are not immutable by default and will be re-used across requests. As a rule of thumb, you must only use context values within the handler, and you must not keep any references. As soon as you return from the handler, any values you have obtained from the context will be re-used in future requests and will change below your feet - * - * To deal with these limitations, the contributor of to this adapter @excavador (Oleg Tsarev, email: oleg@tsarev.id, telegram: @oleg_tsarev) is clear variable explicitly in the end of huma.Adapter methods Handle and ServeHTTP - * - * You must NOT use member `unsafeFiberCtx` directly in adapter, but instead use `orig()` private method - */ - unsafeFiberCtx *fiber.Ctx - unsafeGolangCtx context.Context + orig *fiber.Ctx + ctx context.Context } // check that fiberCtx implements huma.Context -var _ huma.Context = &fiberCtx{} -var _ context.Context = &fiberCtx{} - -func (c *fiberCtx) orig() *fiber.Ctx { - var result = c.unsafeFiberCtx - select { - case <-c.unsafeGolangCtx.Done(): - panic("handler was done already") - default: - return result - } -} - -func (c *fiberCtx) Deadline() (deadline time.Time, ok bool) { - return c.unsafeGolangCtx.Deadline() -} - -func (c *fiberCtx) Done() <-chan struct{} { - return c.unsafeGolangCtx.Done() -} - -func (c *fiberCtx) Err() error { - return c.unsafeGolangCtx.Err() -} - -func (c *fiberCtx) Value(key any) any { - var orig = c.unsafeFiberCtx - select { - case <-c.unsafeGolangCtx.Done(): - return nil - default: - var value = orig.UserContext().Value(key) - if value != nil { - return value - } - return orig.Context().Value(key) - } -} +var _ huma.Context = &fiberWrapper{} -func (c *fiberCtx) Operation() *huma.Operation { +func (c *fiberWrapper) Operation() *huma.Operation { return c.op } -func (c *fiberCtx) Matched() string { - return c.orig().Route().Path +func (c *fiberWrapper) Matched() string { + return c.orig.Route().Path } -func (c *fiberCtx) Context() context.Context { - return c +func (c *fiberWrapper) Context() context.Context { + return c.ctx } -func (c *fiberCtx) Method() string { - return c.orig().Method() +func (c *fiberWrapper) Method() string { + return c.orig.Method() } -func (c *fiberCtx) Host() string { - return c.orig().Hostname() +func (c *fiberWrapper) Host() string { + return c.orig.Hostname() } -func (c *fiberCtx) RemoteAddr() string { - return c.orig().Context().RemoteAddr().String() +func (c *fiberWrapper) RemoteAddr() string { + return c.orig.Context().RemoteAddr().String() } -func (c *fiberCtx) URL() url.URL { - u, _ := url.Parse(string(c.orig().Request().RequestURI())) +func (c *fiberWrapper) URL() url.URL { + u, _ := url.Parse(string(c.orig.Request().RequestURI())) return *u } -func (c *fiberCtx) Param(name string) string { - return c.orig().Params(name) +func (c *fiberWrapper) Param(name string) string { + return c.orig.Params(name) } -func (c *fiberCtx) Query(name string) string { - return c.orig().Query(name) +func (c *fiberWrapper) Query(name string) string { + return c.orig.Query(name) } -func (c *fiberCtx) Header(name string) string { - return c.orig().Get(name) +func (c *fiberWrapper) Header(name string) string { + return c.orig.Get(name) } -func (c *fiberCtx) EachHeader(cb func(name, value string)) { - c.orig().Request().Header.VisitAll(func(k, v []byte) { +func (c *fiberWrapper) EachHeader(cb func(name, value string)) { + c.orig.Request().Header.VisitAll(func(k, v []byte) { cb(string(k), string(v)) }) } -func (c *fiberCtx) BodyReader() io.Reader { - var orig = c.orig() +func (c *fiberWrapper) BodyReader() io.Reader { + var orig = c.orig if orig.App().Server().StreamRequestBody { // Streaming is enabled, so send the reader. return orig.Request().BodyStream() @@ -136,47 +86,47 @@ func (c *fiberCtx) BodyReader() io.Reader { return bytes.NewReader(orig.BodyRaw()) } -func (c *fiberCtx) GetMultipartForm() (*multipart.Form, error) { - return c.orig().MultipartForm() +func (c *fiberWrapper) GetMultipartForm() (*multipart.Form, error) { + return c.orig.MultipartForm() } -func (c *fiberCtx) SetReadDeadline(deadline time.Time) error { +func (c *fiberWrapper) SetReadDeadline(deadline time.Time) error { // Note: for this to work properly you need to do two things: // 1. Set the Fiber app's `StreamRequestBody` to `true` // 2. Set the Fiber app's `BodyLimit` to some small value like `1` // Fiber will only call the request handler for streaming once the limit is // reached. This is annoying but currently how things work. - return c.orig().Context().Conn().SetReadDeadline(deadline) + return c.orig.Context().Conn().SetReadDeadline(deadline) } -func (c *fiberCtx) SetStatus(code int) { - var orig = c.orig() +func (c *fiberWrapper) SetStatus(code int) { + var orig = c.orig c.status = code orig.Status(code) } -func (c *fiberCtx) Status() int { +func (c *fiberWrapper) Status() int { return c.status } -func (c *fiberCtx) AppendHeader(name string, value string) { - c.orig().Append(name, value) +func (c *fiberWrapper) AppendHeader(name string, value string) { + c.orig.Append(name, value) } -func (c *fiberCtx) SetHeader(name string, value string) { - c.orig().Set(name, value) +func (c *fiberWrapper) SetHeader(name string, value string) { + c.orig.Set(name, value) } -func (c *fiberCtx) BodyWriter() io.Writer { - return c.orig().Context() +func (c *fiberWrapper) BodyWriter() io.Writer { + return c.orig.Context() } -func (c *fiberCtx) TLS() *tls.ConnectionState { - return c.orig().Context().TLSConnectionState() +func (c *fiberWrapper) TLS() *tls.ConnectionState { + return c.orig.Context().TLSConnectionState() } -func (c *fiberCtx) Version() huma.ProtoVersion { +func (c *fiberWrapper) Version() huma.ProtoVersion { return huma.ProtoVersion{ - Proto: c.orig().Protocol(), + Proto: c.orig.Protocol(), } } @@ -188,9 +138,31 @@ type requestTester interface { Test(*http.Request, ...int) (*http.Response, error) } -type fiberAdapter struct { - tester requestTester - router router +type contextWrapperValue struct { + Key any + Value any +} + +type contextWrapper struct { + values []*contextWrapperValue + context.Context +} + +var ( + _ context.Context = &contextWrapper{} +) + +func (c *contextWrapper) Value(key any) any { + var raw = c.Context.Value(key) + if raw != nil { + return raw + } + for _, pair := range c.values { + if pair.Key == key { + return pair.Value + } + } + return nil } func (a *fiberAdapter) Handle(op *huma.Operation, handler func(huma.Context)) { @@ -199,17 +171,21 @@ func (a *fiberAdapter) Handle(op *huma.Operation, handler func(huma.Context)) { path = strings.ReplaceAll(path, "{", ":") path = strings.ReplaceAll(path, "}", "") a.router.Add(op.Method, path, func(c *fiber.Ctx) error { - var ctx, cancel = context.WithCancel(context.Background()) - var fc = &fiberCtx{ - op: op, - unsafeFiberCtx: c, - unsafeGolangCtx: ctx, - } - defer func() { - cancel() - fc.unsafeFiberCtx = nil - }() - handler(fc) + var values []*contextWrapperValue + c.Context().VisitUserValuesAll(func(key, value any) { + values = append(values, &contextWrapperValue{ + Key: key, + Value: value, + }) + }) + handler(&fiberWrapper{ + op: op, + orig: c, + ctx: &contextWrapper{ + values: values, + Context: c.UserContext(), + }, + }) return nil }) } @@ -218,6 +194,11 @@ func (a *fiberAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { // b, _ := httputil.DumpRequest(r, true) // fmt.Println(string(b)) resp, err := a.tester.Test(r) + if resp != nil && resp.Body != nil { + defer func() { + _ = resp.Body.Close() + }() + } if err != nil { panic(err) } @@ -228,7 +209,7 @@ func (a *fiberAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) + _, _ = io.Copy(w, resp.Body) } func New(r *fiber.App, config huma.Config) huma.API { diff --git a/adapters/humafiber/humafiber_context_test.go b/adapters/humafiber/humafiber_context_test.go new file mode 100644 index 00000000..41a4b636 --- /dev/null +++ b/adapters/humafiber/humafiber_context_test.go @@ -0,0 +1,377 @@ +package humafiber_test + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humafiber" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ( + HelloRequestBody struct { + Name string `json:"name"` + } + + HelloResponseBody struct { + Message string `json:"message"` + FiberUserValue string `json:"fiber-user-value"` + FiberUserContext string `json:"fiber-user-context"` + Huma string `json:"huma"` + } + + HelloRequest struct { + Delay string `query:"huma-fiber-delay"` + Body HelloRequestBody + } + + HelloResponse struct { + Body HelloResponseBody + } + + contextKeyFiberUserValue string + contextKeyFiberUserContext string + contextKeyHuma string +) + +const ( + contextValueFiberUserValue = contextKeyFiberUserValue("context-fiber-user-value") + contextValueFiberUserContext = contextKeyFiberUserContext("context-fiber-user-context") + contextValueHuma = contextKeyHuma("context-huma") +) + +var ( + HeaderNameFiberUserValue = http.CanonicalHeaderKey("fiber-user-value") + HeaderNameFiberUserContext = http.CanonicalHeaderKey("fiber-user-context") + HeaderNameHuma = http.CanonicalHeaderKey("huma") +) + +const ( + PingPath = "/ping" + HelloPath = "/hello" +) + +func PingHandler(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) +} + +func RegisterPing(app *fiber.App) { + _ = app.Get(PingPath, PingHandler) +} + +func CallPing(ctx context.Context, server string, timeout time.Duration) error { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + request, err := http.NewRequestWithContext(ctx, http.MethodGet, server+PingPath, nil) + if err != nil { + return err + } + response, err := http.DefaultClient.Do(request) + if response != nil { + _ = response.Body.Close() + } + if err != nil { + return err + } + if response == nil { + return errors.New("response is empty") + } + if response.StatusCode != fiber.StatusOK { + return fmt.Errorf("unexpected status code %d", response.StatusCode) + } + return nil +} + +func WaitPing(ctx context.Context, server string, timeout time.Duration) error { + for { + after := time.After(timeout) + err := CallPing(ctx, server, timeout) + if err == nil { + return nil + } + select { + case <-ctx.Done(): + return err + case <-after: + } + } +} + +func SimulateAccessToContextOutsideHandler( + global context.Context, + wait *sync.WaitGroup, + timeout time.Duration, + retries int, +) func(ctx context.Context) { + return func(ctx context.Context) { + wait.Add(1) + go func() { + defer wait.Done() + global, cancel := context.WithTimeout(global, timeout*time.Duration(retries)) + defer cancel() + for { + _, _ = ctx.Deadline() + _ = ctx.Done() + _ = ctx.Err() + _ = ctx.Value(contextValueFiberUserValue) + _ = ctx.Value(contextValueFiberUserContext) + _ = ctx.Value(contextValueHuma) + select { + case <-global.Done(): + return + case <-time.After(timeout / 10): + } + } + }() + } +} + +func HelloHandler(simulator func(context.Context)) func(ctx context.Context, request *HelloRequest) (response *HelloResponse, err error) { + return func(ctx context.Context, request *HelloRequest) (response *HelloResponse, err error) { + simulator(ctx) + var delay time.Duration + if request.Delay != "" { + var err error + if delay, err = time.ParseDuration(request.Delay); err != nil { + return nil, err + } + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + var responseBody = HelloResponseBody{ + Message: fmt.Sprintf("Hello, %s!", request.Body.Name), + } + if raw := ctx.Value(contextValueFiberUserValue); raw != nil { + responseBody.FiberUserValue = raw.(string) + } + if raw := ctx.Value(contextValueFiberUserContext); raw != nil { + responseBody.FiberUserContext = raw.(string) + } + if raw := ctx.Value(contextValueHuma); raw != nil { + responseBody.Huma = raw.(string) + } + return &HelloResponse{ + Body: responseBody, + }, nil + } +} + +func HelloOperation() huma.Operation { + return huma.Operation{ + OperationID: "Hello", + Method: fiber.MethodPost, + Path: HelloPath, + Description: "Hello description", + Tags: []string{"hello"}, + DefaultStatus: fiber.StatusOK, + } +} + +func HelloResponseValidate(t *testing.T, expected HelloResponseBody, response *http.Response) { + assert.NotNil(t, response) + assert.Equal(t, fiber.StatusOK, response.StatusCode) + var actual HelloResponseBody + err := json.NewDecoder(response.Body).Decode(&actual) + if assert.NoError(t, err) { + assert.Equal(t, expected, actual) + } +} + +func FiberMiddlewareUserValue(c *fiber.Ctx) error { + headers := c.GetReqHeaders() + if values, found := headers[HeaderNameFiberUserValue]; found && len(values) > 0 { + c.Context().SetUserValue(contextValueFiberUserValue, values[0]) + } + return c.Next() +} + +func FiberMiddlewareUserContext(c *fiber.Ctx) error { + headers := c.GetReqHeaders() + if values, found := headers[HeaderNameFiberUserContext]; found && len(values) > 0 { + var original = c.UserContext() + var result = context.WithValue(original, contextValueFiberUserContext, values[0]) + c.SetUserContext(result) + defer c.SetUserContext(original) + } + return c.Next() +} + +func HumaMiddleware(ctx huma.Context, next func(huma.Context)) { + value := ctx.Header(HeaderNameHuma) + if value != "" { + ctx = huma.WithValue(ctx, contextValueHuma, value) + } + next(ctx) +} + +func TestHumaFiber(t *testing.T) { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + var wait sync.WaitGroup + defer wait.Wait() + + timeout := time.Millisecond * 10 + retries := 10 + simulator := SimulateAccessToContextOutsideHandler(ctx, &wait, timeout, retries) + + ln, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + port := ln.Addr().(*net.TCPAddr).Port + require.NotZero(t, port) + server := fmt.Sprintf("http://localhost:%d", port) + + app := fiber.New(fiber.Config{ + DisableStartupMessage: true, + }) + app.Use(FiberMiddlewareUserValue) + app.Use(FiberMiddlewareUserContext) + RegisterPing(app) + + config := huma.DefaultConfig("hello", "1.0.0") + api := humafiber.New(app, config) + api.UseMiddleware(HumaMiddleware) + huma.Register(api, HelloOperation(), HelloHandler(simulator)) + + wait.Add(1) + go func() { + defer wait.Done() + err := app.Listener(ln) + assert.NoError(t, err) + }() + defer wait.Wait() + + err = WaitPing(ctx, server, timeout) + require.NoError(t, err) + + name := "Bob" + message := fmt.Sprintf("Hello, %s!", name) + requestBody, err := json.Marshal(HelloRequestBody{ + Name: name, + }) + require.NoError(t, err) + assert.NotEmpty(t, requestBody) + requestBodyReader := bytes.NewReader(requestBody) + expected := HelloResponseBody{ + Message: message, + FiberUserValue: "one", + FiberUserContext: "two", + Huma: "three", + } + + request, err := http.NewRequestWithContext(ctx, fiber.MethodPost, server+HelloPath, requestBodyReader) + require.NoError(t, err) + request.Header.Add(HeaderNameFiberUserValue, "one") + request.Header.Add(HeaderNameFiberUserContext, "two") + request.Header.Add(HeaderNameHuma, "three") + query := request.URL.Query() + query.Add("huma-fiber-delay", timeout.String()) + request.URL.RawQuery = query.Encode() + + // simple check + response, err := http.DefaultClient.Do(request) + if response != nil && response.Body != nil { + defer func() { + _ = response.Body.Close() + }() + } + require.NoError(t, err) + HelloResponseValidate(t, expected, response) + + // check that delay works + doneFirst := make(chan bool) + wait.Add(1) + go func() { + defer wait.Done() + defer close(doneFirst) + response, err := http.DefaultClient.Do(request) + if response != nil && response.Body != nil { + defer func() { + _ = response.Body.Close() + }() + } + assert.NoError(t, err) + HelloResponseValidate(t, expected, response) + }() + select { + case <-ctx.Done(): + return + case <-doneFirst: + assert.Fail(t, "expected other branch") + default: + // ok + } + select { + case <-ctx.Done(): + return + case <-doneFirst: + // ok + case <-time.After(timeout * 2): + assert.Fail(t, "expected other branch") + } + + // check graceful shutdown + doneSecond := make(chan bool) + wait.Add(1) + go func() { + defer wait.Done() + defer close(doneSecond) + response, err := http.DefaultClient.Do(request) + if response != nil && response.Body != nil { + defer func() { + _ = response.Body.Close() + }() + } + assert.NoError(t, err) + HelloResponseValidate(t, expected, response) + }() + + // perform shutdown + doneShutdown := make(chan bool) + wait.Add(1) + go func() { + defer wait.Done() + defer close(doneShutdown) + time.Sleep(timeout) // delay before shutdown to start request processing + err := app.ShutdownWithContext(ctx) + assert.NoError(t, err) + time.Sleep(timeout) // delay after shutdown to catch request processing + }() + + // request should be handled + select { + case <-ctx.Done(): + return + case <-doneSecond: + // ok + case <-doneShutdown: + assert.Fail(t, "expected other branch") + } + + // shutdown should be handled + select { + case <-ctx.Done(): + return + case <-doneShutdown: + // ok + } + + wait.Wait() +}