这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions examples/sse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ import (

"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/adapters/humachi"
"github.com/danielgtaylor/huma/v2/adapters/humagin"
"github.com/danielgtaylor/huma/v2/humacli"
"github.com/danielgtaylor/huma/v2/sse"
"github.com/gin-gonic/gin"
"github.com/go-chi/chi/v5"

_ "github.com/danielgtaylor/huma/v2/formats/cbor"
)

// Options for the CLI.
type Options struct {
Port int `help:"Port to listen on" default:"8888"`
Port int `help:"Port to listen on" default:"8888"`
Router string `help:"Router to use" enum:"chi,gin" default:"chi"`
}

// First, define your SSE message types. These can be any struct you want and
Expand Down Expand Up @@ -126,8 +129,20 @@ func main() {
// Create a CLI app which takes a port option.
cli := humacli.New(func(hooks humacli.Hooks, options *Options) {
// Create a new router & API
router := chi.NewMux()
api := humachi.New(router, huma.DefaultConfig("My API", "1.0.0"))
var router http.Handler
var api huma.API

if options.Router == "chi" {
r := chi.NewMux()
api = humachi.New(r, huma.DefaultConfig("My API", "1.0.0"))
router = r
} else if options.Router == "gin" {
r := gin.New()
api = humagin.New(r, huma.DefaultConfig("My API", "1.0.0"))
router = r
} else {
panic("Unknown router " + options.Router)
}

// Create a producer to generate messages for clients.
p := Producer{Cancel: make(chan bool, 1)}
Expand Down
48 changes: 44 additions & 4 deletions sse/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ func deref(t reflect.Type) reflect.Type {
return t
}

type unwrapper interface {
Unwrap() http.ResponseWriter
}

type writeDeadliner interface {
SetWriteDeadline(time.Time) error
}

// Message is a single SSE message. There is no `event` field as this is
// handled by the `eventTypeMap` when registering the operation.
type Message struct {
Expand Down Expand Up @@ -119,9 +127,41 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an
ctx.SetHeader("Content-Type", "text/event-stream")
bw := ctx.BodyWriter()
encoder := json.NewEncoder(bw)

// Get the flusher/deadliner from the response writer if possible.
var flusher http.Flusher
flushCheck := bw
for {
if f, ok := flushCheck.(http.Flusher); ok {
flusher = f
break
}
if u, ok := flushCheck.(unwrapper); ok {
flushCheck = u.Unwrap()
} else {
break
}
}

var deadliner writeDeadliner
deadlineCheck := bw
for {
if d, ok := deadlineCheck.(writeDeadliner); ok {
deadliner = d
break
}
if u, ok := deadlineCheck.(unwrapper); ok {
deadlineCheck = u.Unwrap()
} else {
break
}
}
Comment on lines +132 to +158
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor duplicated unwrapping logic into a helper function

The unwrapping logic to obtain http.Flusher and writeDeadliner interfaces from the response writer is duplicated in two places. This duplication can make the code harder to maintain and more error-prone. Consider refactoring this into a reusable helper function to improve code maintainability and reduce duplication.

You could create a generic helper function like this:

// Helper function to unwrap the response writer until the desired interface is found.
func getInterface[T any](bw http.ResponseWriter) T {
	check := bw
	for {
		if res, ok := check.(T); ok {
			return res
		}
		if u, ok := check.(unwrapper); ok {
			check = u.Unwrap()
		} else {
			var zero T
			return zero
		}
	}
}

Then replace the unwrapping loops with calls to this helper function:

- var flusher http.Flusher
- flushCheck := bw
- for {
-     if f, ok := flushCheck.(http.Flusher); ok {
-         flusher = f
-         break
-     }
-     if u, ok := flushCheck.(unwrapper); ok {
-         flushCheck = u.Unwrap()
-     } else {
-         break
-     }
- }
+ flusher := getInterface[http.Flusher](bw)

And similarly for deadliner:

- var deadliner writeDeadliner
- deadlineCheck := bw
- for {
-     if d, ok := deadlineCheck.(writeDeadliner); ok {
-         deadliner = d
-         break
-     }
-     if u, ok := deadlineCheck.(unwrapper); ok {
-         deadlineCheck = u.Unwrap()
-     } else {
-         break
-     }
- }
+ deadliner := getInterface[writeDeadliner](bw)


send := func(msg Message) error {
if d, ok := bw.(interface{ SetWriteDeadline(time.Time) error }); ok {
d.SetWriteDeadline(time.Now().Add(WriteTimeout))
if deadliner != nil {
if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil {
fmt.Println("warning: unable to set write deadline: " + err.Error())
}
} else {
fmt.Println("warning: unable to set write deadline")
}
Expand Down Expand Up @@ -155,8 +195,8 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an
return err
}
bw.Write([]byte("\n"))
if f, ok := bw.(http.Flusher); ok {
f.Flush()
if flusher != nil {
flusher.Flush()
} else {
fmt.Println("error: unable to flush")
return fmt.Errorf("unable to flush: %w", http.ErrNotSupported)
Expand Down
21 changes: 18 additions & 3 deletions sse/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type UserCreatedEvent UserEvent
type UserDeletedEvent UserEvent

type DummyWriter struct {
writeErr error
writeErr error
deadlineErr error
}

func (w *DummyWriter) Header() http.Header {
Expand All @@ -41,8 +42,17 @@ func (w *DummyWriter) Write(p []byte) (n int, err error) {

func (w *DummyWriter) WriteHeader(statusCode int) {}

func (w *DummyWriter) SetWriteDeadline(t time.Time) error {
return nil
func (w *DummyWriter) Unwrap() http.ResponseWriter {
return &WrappedDeadliner{deadlineErr: w.deadlineErr}
}

type WrappedDeadliner struct {
http.ResponseWriter
deadlineErr error
}

func (w *WrappedDeadliner) SetWriteDeadline(t time.Time) error {
return w.deadlineErr
}

func TestSSE(t *testing.T) {
Expand Down Expand Up @@ -105,4 +115,9 @@ data: {"error": "encode error: json: unsupported type: chan int"}
w = &DummyWriter{}
req, _ = http.NewRequest(http.MethodGet, "/sse", nil)
api.Adapter().ServeHTTP(w, req)

// Test inability to set write deadline due to error doesn't panic
w = &DummyWriter{deadlineErr: errors.New("whoops")}
req, _ = http.NewRequest(http.MethodGet, "/sse", nil)
api.Adapter().ServeHTTP(w, req)
}
Loading