这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion router.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ var connContextKey contextKey = "huma-request-conn"
// has finished.
var opIDContextKey contextKey = "huma-operation-id"

// routerContextKey is used to get the router associated with the API
var routerContextKey contextKey = "huma-router"

// GetConn gets the underlying `net.Conn` from a context.
func GetConn(ctx context.Context) net.Conn {
conn := ctx.Value(connContextKey)
Expand All @@ -34,6 +37,15 @@ func GetConn(ctx context.Context) net.Conn {
return nil
}

// GetRouter gets the `*Router` handling API requests
func GetRouter(ctx context.Context) *Router {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be used to greatly simplify the code in #80? The context would already have the router injected and there's no need to pass it around everywhere right? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. Note that several tests don't have a router (in which case I send in nil). But since the GetRouter function returns nil when it can't find a Router in the context, then it should actually be pretty equivalent. I agree, it would simplify the code. The issue here was that I wanted to provide you with independent PRs in case you decided to reject any. So that's why I don't leverage this in #80. But if you want me to refactor it once this PR is merged, I can certainly do that.

router := ctx.Value(routerContextKey)
if router != nil {
return router.(*Router)
}
return nil
}

// Router is the entrypoint to your API.
type Router struct {
mux *chi.Mux
Expand Down Expand Up @@ -243,6 +255,45 @@ func (r *Router) Resource(path string) *Resource {
return res
}

// GetOperation returns an `OperationInfo` struct for the operation named by the
// `id` argument. The `OperationInfo` struct provides the URL template and a
// summary of the operation along with any tags associated with the operation.
func (r *Router) GetOperation(id string) *OperationInfo {
// Loop over all router resources looking for the specified operation
for _, res := range r.resources {
result := getOperation(id, res)
if result != nil {
return result
}
}
return nil
}

func getOperation(id string, res *Resource) *OperationInfo {
// First, search in this resource
for _, op := range res.operations {
if op.id == id {
return &OperationInfo{
ID: op.id,
URITemplate: op.resource.path,
Summary: op.summary,
Tags: append([]string{}, op.resource.tags...),
}
}
}
// If we still haven't found anything, look in subresources
if res.subResources != nil {
for _, sub := range res.subResources {
result := getOperation(id, sub)
if result != nil {
return result
}
}
}
// If we get here, nothing in this part of the tree
return nil
}

// Middleware adds a new standard middleware to this router at the root,
// so it will apply to all requests. Middleware can also be applied at the
// resource level.
Expand Down Expand Up @@ -532,7 +583,10 @@ func New(docs, version string) *Router {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Inject the operation info before other middleware so that the later
// middleware will have access to it.
req = req.WithContext(context.WithValue(req.Context(), opIDContextKey, &OperationInfo{}))
reqContext := req.Context()
withOpID := context.WithValue(reqContext, opIDContextKey, &OperationInfo{})
withRouter := context.WithValue(withOpID, routerContextKey, r)
req = req.WithContext(withRouter)

next.ServeHTTP(w, req)

Expand Down
9 changes: 9 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func TestStreamingInput(t *testing.T) {
ctx.WriteHeader(http.StatusNoContent)
})

stream := r.GetOperation("stream")
assert.NotNil(t, stream)
assert.Equal(t, *stream, OperationInfo{
ID: "stream",
URITemplate: "/stream",
Summary: "Stream test",
Tags: []string{},
})

w := httptest.NewRecorder()
body := bytes.NewReader(make([]byte, 1024))
req, _ := http.NewRequest(http.MethodPost, "/stream", body)
Expand Down