diff --git a/context.go b/context.go index a6ece894..c84102c5 100644 --- a/context.go +++ b/context.go @@ -342,7 +342,7 @@ func (c *hcontext) writeModel(ct string, status int, model interface{}) { } encoded, err = mode.Marshal(model) if err != nil { - panic(fmt.Errorf("Unable to marshal JSON: %w", err)) + panic(fmt.Errorf("Unable to marshal CBOR: %w", err)) } } diff --git a/humatest/humatest.go b/humatest/humatest.go index 793fba1c..9b6a68bb 100644 --- a/humatest/humatest.go +++ b/humatest/humatest.go @@ -26,13 +26,13 @@ func NewRouter(t testing.TB) *huma.Router { func NewRouterObserver(t testing.TB) (*huma.Router, *observer.ObservedLogs) { core, logs := observer.New(zapcore.DebugLevel) - router := huma.New("Test API", "1.0.0") - router.Middleware(middleware.DefaultChain) - middleware.NewLogger = func() (*zap.Logger, error) { l := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore(func(zapcore.Core) zapcore.Core { return core }))) return l, nil } + router := huma.New("Test API", "1.0.0") + router.Middleware(middleware.DefaultChain) + return router, logs } diff --git a/humatest/humatest_test.go b/humatest/humatest_test.go index cf817ec5..00e5ee39 100644 --- a/humatest/humatest_test.go +++ b/humatest/humatest_test.go @@ -7,6 +7,7 @@ import ( "github.com/danielgtaylor/huma" "github.com/danielgtaylor/huma/humatest" + "github.com/danielgtaylor/huma/middleware" "github.com/danielgtaylor/huma/responses" "github.com/stretchr/testify/assert" ) @@ -51,3 +52,25 @@ func TestPackage(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "Hello, test!", w.Body.String()) } + +func TestCapturedLog(t *testing.T) { + // Create the test router. Logs will be hidden unless the test fails. + r, logs := humatest.NewRouterObserver(t) + + // Set up routes & handlers. + r.Resource("/test").Get("test", "Test get", + responses.OK().ContentType("text/plain"), + ).Run(func(ctx huma.Context) { + logger := middleware.GetLogger(ctx) + logger.With("foo", "bar").Info("Just a test") + ctx.Write([]byte("")) + }) + + // Make a test request. + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Just a test", logs.All()[0].Message) + assert.Equal(t, "bar", logs.All()[0].ContextMap()["foo"]) +} diff --git a/middleware/recovery.go b/middleware/recovery.go index 44175099..e643cb15 100644 --- a/middleware/recovery.go +++ b/middleware/recovery.go @@ -114,12 +114,6 @@ func Recovery(onPanic PanicFunc) func(http.Handler) http.Handler { r = r.WithContext(context.WithValue(r.Context(), bufContextKey, buf)) } - for _, v := range RemovedHeaders { - if r.Header.Get(v) != "" { - r.Header.Set(v, redacted) - } - } - // Recovering comes *after* the above so the buffer is not returned to // the pool until after we print out its contents. This deferred func // is used to recover from panics and deliberately left in-line. @@ -134,6 +128,12 @@ func Recovery(onPanic PanicFunc) func(http.Handler) http.Handler { r.Body = ioutil.NopCloser(io.LimitReader(r.Body, MaxLogBodyBytes)) } + for _, v := range RemovedHeaders { + if r.Header.Get(v) != "" { + r.Header.Set(v, redacted) + } + } + request, _ := httputil.DumpRequest(r, true) if _, ok := err.(error); !ok { diff --git a/middleware/recovery_test.go b/middleware/recovery_test.go index 107fb152..024b5f53 100644 --- a/middleware/recovery_test.go +++ b/middleware/recovery_test.go @@ -83,6 +83,34 @@ func TestRecoveryMiddlewareLogBody(t *testing.T) { assert.Contains(t, log.All()[0].ContextMap()["http.request"], `{"foo": "bar"}`) } +func TestRecoveryMiddlewareLogBodySensitive(t *testing.T) { + app, log := newTestRouter(t) + + app.Resource("/panic").Put("panic", "Panic recovery test", + responses.NoContent(), + ).Run(func(ctx huma.Context, input struct { + Authorization string `header:"authorization"` + Body struct { + Foo string `json:"foo"` + } + }) { + assert.Equal(t, "secrets!", input.Authorization) + panic(fmt.Errorf("Some error")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPut, "/panic", strings.NewReader(`{"foo": "bar"}`)) + req.Header.Set("Authorization", "secrets!") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type")) + + logged := log.All()[0].ContextMap()["http.request"] + assert.Contains(t, logged, `{"foo": "bar"}`) + assert.Contains(t, logged, redacted) + assert.NotContains(t, logged, "secrets!") +} + func TestLogBodyWithoutPanic(t *testing.T) { app, _ := newTestRouter(t)