diff --git a/router.go b/router.go index 3295bab9..33953ea2 100644 --- a/router.go +++ b/router.go @@ -25,6 +25,9 @@ var connContextKey contextKey = "huma-request-conn" // has finished. var opIDContextKey contextKey = "huma-operation-id" +// routerContextKey is used to get the router associated with the API +var routerContextKey contextKey = "huma-router" + // GetConn gets the underlying `net.Conn` from a context. func GetConn(ctx context.Context) net.Conn { conn := ctx.Value(connContextKey) @@ -34,6 +37,15 @@ func GetConn(ctx context.Context) net.Conn { return nil } +// GetRouter gets the `*Router` handling API requests +func GetRouter(ctx context.Context) *Router { + router := ctx.Value(routerContextKey) + if router != nil { + return router.(*Router) + } + return nil +} + // Router is the entrypoint to your API. type Router struct { mux *chi.Mux @@ -243,6 +255,45 @@ func (r *Router) Resource(path string) *Resource { return res } +// GetOperation returns an `OperationInfo` struct for the operation named by the +// `id` argument. The `OperationInfo` struct provides the URL template and a +// summary of the operation along with any tags associated with the operation. +func (r *Router) GetOperation(id string) *OperationInfo { + // Loop over all router resources looking for the specified operation + for _, res := range r.resources { + result := getOperation(id, res) + if result != nil { + return result + } + } + return nil +} + +func getOperation(id string, res *Resource) *OperationInfo { + // First, search in this resource + for _, op := range res.operations { + if op.id == id { + return &OperationInfo{ + ID: op.id, + URITemplate: op.resource.path, + Summary: op.summary, + Tags: append([]string{}, op.resource.tags...), + } + } + } + // If we still haven't found anything, look in subresources + if res.subResources != nil { + for _, sub := range res.subResources { + result := getOperation(id, sub) + if result != nil { + return result + } + } + } + // If we get here, nothing in this part of the tree + return nil +} + // Middleware adds a new standard middleware to this router at the root, // so it will apply to all requests. Middleware can also be applied at the // resource level. @@ -532,7 +583,10 @@ func New(docs, version string) *Router { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Inject the operation info before other middleware so that the later // middleware will have access to it. - req = req.WithContext(context.WithValue(req.Context(), opIDContextKey, &OperationInfo{})) + reqContext := req.Context() + withOpID := context.WithValue(reqContext, opIDContextKey, &OperationInfo{}) + withRouter := context.WithValue(withOpID, routerContextKey, r) + req = req.WithContext(withRouter) next.ServeHTTP(w, req) diff --git a/router_test.go b/router_test.go index e23f90a9..836d4803 100644 --- a/router_test.go +++ b/router_test.go @@ -64,6 +64,15 @@ func TestStreamingInput(t *testing.T) { ctx.WriteHeader(http.StatusNoContent) }) + stream := r.GetOperation("stream") + assert.NotNil(t, stream) + assert.Equal(t, *stream, OperationInfo{ + ID: "stream", + URITemplate: "/stream", + Summary: "Stream test", + Tags: []string{}, + }) + w := httptest.NewRecorder() body := bytes.NewReader(make([]byte, 1024)) req, _ := http.NewRequest(http.MethodPost, "/stream", body)