From e59ddc15d599414926e3898154f9df943ece4f4b Mon Sep 17 00:00:00 2001 From: Zev Goldstein Date: Tue, 30 Mar 2021 21:05:45 -0400 Subject: [PATCH] expose internal.handleHTTP as a standard http middleware --- appengine.go | 3 + internal/api.go | 154 +++++++++++++++++++++------------------- internal/api_classic.go | 4 +- internal/api_test.go | 8 +-- internal/main_vm.go | 2 +- 5 files changed, 92 insertions(+), 79 deletions(-) diff --git a/appengine.go b/appengine.go index 8c969767..f65f94a2 100644 --- a/appengine.go +++ b/appengine.go @@ -54,6 +54,9 @@ func Main() { internal.Main() } +// Middleware wraps an http handler so that it can make GAE API calls +var Middleware func(http.Handler) http.Handler = internal.Middleware + // IsDevAppServer reports whether the App Engine app is running in the // development App Server. func IsDevAppServer() bool { diff --git a/internal/api.go b/internal/api.go index 2748318a..7be64e23 100644 --- a/internal/api.go +++ b/internal/api.go @@ -87,88 +87,98 @@ func apiURL() *url.URL { } } -func handleHTTP(w http.ResponseWriter, r *http.Request) { - c := &context{ - req: r, - outHeader: w.Header(), - apiURL: apiURL(), - } - r = r.WithContext(withContext(r.Context(), c)) - c.req = r - - stopFlushing := make(chan int) +// Middleware wraps an http handler so that it can make GAE API calls +func Middleware(next http.Handler) http.Handler { + return handleHTTPMiddleware(executeRequestSafelyMiddleware(next)) +} - // Patch up RemoteAddr so it looks reasonable. - if addr := r.Header.Get(userIPHeader); addr != "" { - r.RemoteAddr = addr - } else if addr = r.Header.Get(remoteAddrHeader); addr != "" { - r.RemoteAddr = addr - } else { - // Should not normally reach here, but pick a sensible default anyway. - r.RemoteAddr = "127.0.0.1" - } - // The address in the headers will most likely be of these forms: - // 123.123.123.123 - // 2001:db8::1 - // net/http.Request.RemoteAddr is specified to be in "IP:port" form. - if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { - // Assume the remote address is only a host; add a default port. - r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") - } +func handleHTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := &context{ + req: r, + outHeader: w.Header(), + apiURL: apiURL(), + } + r = r.WithContext(withContext(r.Context(), c)) + c.req = r + + stopFlushing := make(chan int) + + // Patch up RemoteAddr so it looks reasonable. + if addr := r.Header.Get(userIPHeader); addr != "" { + r.RemoteAddr = addr + } else if addr = r.Header.Get(remoteAddrHeader); addr != "" { + r.RemoteAddr = addr + } else { + // Should not normally reach here, but pick a sensible default anyway. + r.RemoteAddr = "127.0.0.1" + } + // The address in the headers will most likely be of these forms: + // 123.123.123.123 + // 2001:db8::1 + // net/http.Request.RemoteAddr is specified to be in "IP:port" form. + if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + // Assume the remote address is only a host; add a default port. + r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") + } - if logToLogservice() { - // Start goroutine responsible for flushing app logs. - // This is done after adding c to ctx.m (and stopped before removing it) - // because flushing logs requires making an API call. - go c.logFlusher(stopFlushing) - } + if logToLogservice() { + // Start goroutine responsible for flushing app logs. + // This is done after adding c to ctx.m (and stopped before removing it) + // because flushing logs requires making an API call. + go c.logFlusher(stopFlushing) + } - executeRequestSafely(c, r) - c.outHeader = nil // make sure header changes aren't respected any more + next.ServeHTTP(c, r) + c.outHeader = nil // make sure header changes aren't respected any more - flushed := make(chan struct{}) - if logToLogservice() { - stopFlushing <- 1 // any logging beyond this point will be dropped + flushed := make(chan struct{}) + if logToLogservice() { + stopFlushing <- 1 // any logging beyond this point will be dropped - // Flush any pending logs asynchronously. - c.pendingLogs.Lock() - flushes := c.pendingLogs.flushes - if len(c.pendingLogs.lines) > 0 { - flushes++ + // Flush any pending logs asynchronously. + c.pendingLogs.Lock() + flushes := c.pendingLogs.flushes + if len(c.pendingLogs.lines) > 0 { + flushes++ + } + c.pendingLogs.Unlock() + go func() { + defer close(flushed) + // Force a log flush, because with very short requests we + // may not ever flush logs. + c.flushLog(true) + }() + w.Header().Set(logFlushHeader, strconv.Itoa(flushes)) } - c.pendingLogs.Unlock() - go func() { - defer close(flushed) - // Force a log flush, because with very short requests we - // may not ever flush logs. - c.flushLog(true) - }() - w.Header().Set(logFlushHeader, strconv.Itoa(flushes)) - } - // Avoid nil Write call if c.Write is never called. - if c.outCode != 0 { - w.WriteHeader(c.outCode) - } - if c.outBody != nil { - w.Write(c.outBody) - } - if logToLogservice() { - // Wait for the last flush to complete before returning, - // otherwise the security ticket will not be valid. - <-flushed - } + // Avoid nil Write call if c.Write is never called. + if c.outCode != 0 { + w.WriteHeader(c.outCode) + } + if c.outBody != nil { + w.Write(c.outBody) + } + if logToLogservice() { + // Wait for the last flush to complete before returning, + // otherwise the security ticket will not be valid. + <-flushed + } + }) } -func executeRequestSafely(c *context, r *http.Request) { - defer func() { - if x := recover(); x != nil { - logf(c, 4, "%s", renderPanic(x)) // 4 == critical - c.outCode = 500 - } - }() +func executeRequestSafelyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if x := recover(); x != nil { + c := w.(*context) + logf(c, 4, "%s", renderPanic(x)) // 4 == critical + c.outCode = 500 + } + }() - http.DefaultServeMux.ServeHTTP(c, r) + next.ServeHTTP(w, r) + }) } func renderPanic(x interface{}) string { diff --git a/internal/api_classic.go b/internal/api_classic.go index f0f40b2e..a9beece7 100644 --- a/internal/api_classic.go +++ b/internal/api_classic.go @@ -144,8 +144,8 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } -func handleHTTP(w http.ResponseWriter, r *http.Request) { - panic("handleHTTP called; this should be impossible") +func Middleware(next http.Handler) http.Handler { + panic("Middleware called; this should be impossible") } func logf(c appengine.Context, level int64, format string, args ...interface{}) { diff --git a/internal/api_test.go b/internal/api_test.go index 5b025cf9..d4af31bf 100644 --- a/internal/api_test.go +++ b/internal/api_test.go @@ -302,7 +302,7 @@ func TestDelayedLogFlushing(t *testing.T) { handled := make(chan struct{}) go func() { defer close(handled) - handleHTTP(w, r) + Middleware(http.DefaultServeMux).ServeHTTP(w, r) }() // Check that the log flush eventually comes in. time.Sleep(1200 * time.Millisecond) @@ -360,7 +360,7 @@ func TestLogFlushing(t *testing.T) { } w := httptest.NewRecorder() - handleHTTP(w, r) + Middleware(http.DefaultServeMux).ServeHTTP(w, r) const hdr = "X-AppEngine-Log-Flush-Count" if got := w.HeaderMap.Get(hdr); got != tc.wantHeader { t.Errorf("%s header = %q, want %q", hdr, got, tc.wantHeader) @@ -403,7 +403,7 @@ func TestRemoteAddr(t *testing.T) { Header: tc.headers, Body: ioutil.NopCloser(bytes.NewReader(nil)), } - handleHTTP(httptest.NewRecorder(), r) + Middleware(http.DefaultServeMux).ServeHTTP(httptest.NewRecorder(), r) if addr != tc.addr { t.Errorf("Header %v, got %q, want %q", tc.headers, addr, tc.addr) } @@ -420,7 +420,7 @@ func TestPanickingHandler(t *testing.T) { Body: ioutil.NopCloser(bytes.NewReader(nil)), } rec := httptest.NewRecorder() - handleHTTP(rec, r) + Middleware(http.DefaultServeMux).ServeHTTP(rec, r) if rec.Code != 500 { t.Errorf("Panicking handler returned HTTP %d, want HTTP %d", rec.Code, 500) } diff --git a/internal/main_vm.go b/internal/main_vm.go index ddb79a33..2c53dafe 100644 --- a/internal/main_vm.go +++ b/internal/main_vm.go @@ -29,7 +29,7 @@ func Main() { if IsDevAppServer() { host = "127.0.0.1" } - if err := http.ListenAndServe(host+":"+port, http.HandlerFunc(handleHTTP)); err != nil { + if err := http.ListenAndServe(host+":"+port, Middleware(http.DefaultServeMux)); err != nil { log.Fatalf("http.ListenAndServe: %v", err) } }