From 2d1984f84fc0e535404e7dc79435bf066c9f64b6 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Sat, 10 Sep 2022 14:02:43 -0700 Subject: [PATCH] feat: automatic PATCH operation generation --- README.md | 24 +++++ context.go | 7 +- go.mod | 1 + go.sum | 3 + operation.go | 122 +++++++++++++++-------- patch.go | 266 +++++++++++++++++++++++++++++++++++++++++++++++++ patch_test.go | 172 ++++++++++++++++++++++++++++++++ resolver.go | 10 +- router.go | 14 +++ router_test.go | 59 +++++++++++ 10 files changed, 631 insertions(+), 47 deletions(-) create mode 100644 patch.go create mode 100644 patch_test.go diff --git a/README.md b/README.md index 898befc8..a47a303c 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,9 @@ Features include: - Support for gzip ([RFC 1952](https://tools.ietf.org/html/rfc1952)) & Brotli ([RFC 7932](https://tools.ietf.org/html/rfc7932)) content encoding via the `Accept-Encoding` header. - Support for JSON ([RFC 8259](https://tools.ietf.org/html/rfc8259)), YAML, and CBOR ([RFC 7049](https://tools.ietf.org/html/rfc7049)) content types via the `Accept` header. - Conditional requests support, e.g. `If-Match` or `If-Unmodified-Since` header utilities. +- Optional automatic generation of `PATCH` operations that support: + - [RFC 7386](https://www.rfc-editor.org/rfc/rfc7386) JSON Merge Patch + - [RFC 6902](https://www.rfc-editor.org/rfc/rfc6902) JSON Patch - Annotated Go types for input and output models - Generates JSON Schema from Go types - Automatic input model validation & error handling @@ -435,6 +438,19 @@ Try a request against the service like: $ restish :8888/things/abc123?q=3 -H "Foo: bar" name: Kari ``` +### Multiple Request Bodies + +Request input structs can support multiple body types based on the content type of the request, with an unknown content type defaulting to the first-defined body. This can be used for things like versioned inputs or to support wildly different input types (e.g. JSON Merge Patch vs. JSON Patch). Example: + +```go +type MyInput struct { + BodyV2 *MyInputBodyV1 `body:"application/my-type-v2+json"` + BodyV1 *MyInputBodyV1 `body:"application/my-type-v1+json"` +} +``` + +It's your responsibility to check which one is non-`nil` in the operation handler. If not using pointers, you'll need to check a known field to determine which was actually sent by the client. + ### Parameter & Body Validation All supported JSON Schema tags work for parameters and body fields. Validation happens before the request handler is called, and if needed an error response is returned. For example: @@ -624,6 +640,14 @@ app.Resource("/resource").Put("put-resource", "Put a resource", }) ``` +### Automatic PATCH Support + +If a `GET` and a `PUT` exist for the same resource, but no `PATCH` exists at server start up, then by default a `PATCH` operation will be generated for you to make editing more convenient for clients. This behavior can be disabled via `app.DisableAutoPatch()`. + +If the `GET` returns an `ETag` or `Last-Modified` header, then these will be used to make conditional requests on the `PUT` operation to prevent distributed write conflicts that might otherwise overwrite someone else's changes. + +If the `PATCH` request has no `Content-Type` header, or uses `application/json` or a variant thereof, then JSON Merge Patch is assumed. + ## Validation Go struct tags are used to annotate inputs/output structs with information that gets turned into [JSON Schema](https://json-schema.org/) for documentation and validation. diff --git a/context.go b/context.go index 9d4541d1..508d1e87 100644 --- a/context.go +++ b/context.go @@ -37,7 +37,6 @@ func AddAllowedHeaders(name ...string) { } } - // ContextFromRequest returns a Huma context for a request, useful for // accessing high-level convenience functions from e.g. middleware. func ContextFromRequest(w http.ResponseWriter, r *http.Request) Context { @@ -310,6 +309,12 @@ func (c *hcontext) writeModel(ct string, status int, model interface{}) { if !found { panic(fmt.Errorf("Invalid model %s, expecting %s for %s %s", modelType, strings.Join(names, ", "), c.r.Method, c.r.URL.Path)) } + } else { + // Some automatic responses won't be registered but will have an error model + // returned. We should support these as well. + if modelType == reflect.TypeOf(&ErrorModel{}) { + modelRef = "/" + modelType.Elem().Name() + } } // If possible, insert a link relation header to the JSON Schema describing diff --git a/go.mod b/go.mod index 003af2b2..fe922982 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/andybalholm/brotli v1.0.4 github.com/benbjohnson/clock v1.3.0 // indirect github.com/danielgtaylor/casing v0.0.0-20210126043903-4e55e6373ac3 + github.com/evanphx/json-patch/v5 v5.6.0 github.com/fatih/structs v1.1.0 github.com/fxamacker/cbor/v2 v2.4.0 github.com/go-chi/chi v4.1.2+incompatible diff --git a/go.sum b/go.sum index 069c1468..b3824023 100644 --- a/go.sum +++ b/go.sum @@ -113,6 +113,8 @@ github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go. github.com/envoyproxy/go-control-plane v0.10.1/go.mod h1:AY7fTTXNdv/aJ2O5jwpxAPOWUZ7hQAEvzN5Pf27BkQQ= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.6.2/go.mod h1:2t7qjJNvHPx8IjnBOzl9E9/baC+qXE/TeeyBRzgJDws= +github.com/evanphx/json-patch/v5 v5.6.0 h1:b91NhWfaz02IuVxO9faSllyAtNXHMPkC5J8sJCLunww= +github.com/evanphx/json-patch/v5 v5.6.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= @@ -260,6 +262,7 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= diff --git a/operation.go b/operation.go index f273400f..fc5091f4 100644 --- a/operation.go +++ b/operation.go @@ -36,22 +36,26 @@ func GetOperationInfo(ctx context.Context) *OperationInfo { } } +type request struct { + override bool + model reflect.Type + schema *schema.Schema +} + // Operation represents an operation (an HTTP verb, e.g. GET / PUT) against // a resource attached to a router. type Operation struct { - resource *Resource - method string - id string - summary string - description string - params map[string]oaParam - requestContentType string - requestSchema *schema.Schema - requestSchemaOverride bool - requestModel reflect.Type - responses []Response - maxBodyBytes int64 - bodyReadTimeout time.Duration + resource *Resource + method string + id string + summary string + description string + params map[string]oaParam + defaultContentType string + requests map[string]*request + responses []Response + maxBodyBytes int64 + bodyReadTimeout time.Duration } func newOperation(resource *Resource, method, id, docs string, responses []Response) *Operation { @@ -62,6 +66,7 @@ func newOperation(resource *Resource, method, id, docs string, responses []Respo id: id, summary: summary, description: desc, + requests: map[string]*request{}, responses: responses, // 1 MiB body limit by default maxBodyBytes: 1024 * 1024, @@ -92,18 +97,14 @@ func (o *Operation) toOpenAPI(components *oaComponents) *gabs.Container { } // Request body - if o.requestSchema != nil { - ct := o.requestContentType - if ct == "" { - ct = "application/json" - } + for ct, request := range o.requests { ref := "" - if o.requestSchemaOverride { - ref = components.AddExistingSchema(o.requestSchema, o.id+"-request", !o.resource.router.disableSchemaProperty) + if request.override { + ref = components.AddExistingSchema(request.schema, o.id+"-request", !o.resource.router.disableSchemaProperty) } else { // Regenerate with ModeAll so the same model can be used for both the // input and output when possible. - ref = components.AddSchema(o.requestModel, schema.ModeAll, o.id+"-request", !o.resource.router.disableSchemaProperty) + ref = components.AddSchema(request.model, schema.ModeAll, o.id+"-request", !o.resource.router.disableSchemaProperty) } doc.Set(ref, "requestBody", "content", ct, "schema", "$ref") } @@ -145,6 +146,15 @@ func (o *Operation) toOpenAPI(components *oaComponents) *gabs.Container { return doc } +func (o *Operation) requestForContentType(ct string) (string, *request) { + req := o.requests[ct] + if req == nil { + ct = o.defaultContentType + req = o.requests[ct] + } + return ct, req +} + // MaxBodyBytes sets the max number of bytes that the request body size may be // before the request is cancelled. The default is 1MiB. func (o *Operation) MaxBodyBytes(size int64) { @@ -175,8 +185,15 @@ func (o *Operation) NoBodyReadTimeout() { // RequestSchema allows overriding the generated input body schema, giving you // more control over documentation and validation. func (o *Operation) RequestSchema(s *schema.Schema) { - o.requestSchema = s - o.requestSchemaOverride = true + o.RequestSchemaForContentType("application/json", s) +} + +func (o *Operation) RequestSchemaForContentType(ct string, s *schema.Schema) { + if o.requests[ct] == nil { + o.requests[ct] = &request{} + } + o.requests[ct].override = true + o.requests[ct].schema = s } // Run registers the handler function for this operation. It should be of the @@ -208,7 +225,6 @@ func (o *Operation) Run(handler interface{}) { t := reflect.TypeOf(handler) if t.Kind() == reflect.Func && t.NumIn() > 1 { - var err error input := t.In(1) // Get parameters @@ -224,29 +240,51 @@ func (o *Operation) Run(handler interface{}) { } possible := []int{http.StatusBadRequest} + foundBody := false + + for i := 0; i < input.NumField(); i++ { + f := input.Field(i) + if ct, ok := f.Tag.Lookup(locationBody); ok || f.Name == strings.Title(locationBody) { + foundBody = true + + if ct == "" || ct == "true" { + // Default to JSON + ct = "application/json" + } + + if o.defaultContentType == "" { + o.defaultContentType = ct + } - if _, ok := input.FieldByName("Body"); ok || len(o.params) > 0 { + if o.requests[ct] == nil { + o.requests[ct] = &request{} + } + + o.requests[ct].model = f.Type + + if !o.requests[ct].override { + s, err := schema.GenerateWithMode(f.Type, schema.ModeWrite, nil) + if o.resource != nil && o.resource.router != nil && !o.resource.router.disableSchemaProperty { + s.AddSchemaField() + } + if err != nil { + panic(fmt.Errorf("unable to generate JSON schema: %w", err)) + } + o.requests[ct].schema = s + } + } + } + + if foundBody || len(o.params) > 0 { // Invalid parameter values or body values can cause a 422. possible = append(possible, http.StatusUnprocessableEntity) } - // Get body if present. - if body, ok := input.FieldByName("Body"); ok { - o.requestModel = body.Type + if foundBody { possible = append(possible, http.StatusRequestEntityTooLarge, http.StatusRequestTimeout, ) - - if o.requestSchema == nil { - o.requestSchema, err = schema.GenerateWithMode(body.Type, schema.ModeWrite, nil) - if o.resource != nil && o.resource.router != nil && !o.resource.router.disableSchemaProperty { - o.requestSchema.AddSchemaField() - } - if err != nil { - panic(fmt.Errorf("unable to generate JSON schema: %w", err)) - } - } } // It's possible for the inputs to generate a few different errors, so @@ -305,16 +343,18 @@ func (o *Operation) Run(handler interface{}) { } } + ct, reqDef := o.requestForContentType(r.Header.Get("Content-Type")) + // Set a read deadline for reading/parsing the input request body, but // only for operations that have a request body model. var conn net.Conn - if o.requestModel != nil && o.bodyReadTimeout > 0 { + if reqDef != nil && reqDef.model != nil && o.bodyReadTimeout > 0 { if conn = GetConn(r.Context()); conn != nil { conn.SetReadDeadline(time.Now().Add(o.bodyReadTimeout)) } } - setFields(ctx, ctx.r, input, inputType) + setFields(ctx, ctx.r, input, inputType, ct, reqDef) if !ctx.HasError() { // No errors yet, so any errors that come after should be treated as a // semantic rather than structural error. @@ -329,7 +369,7 @@ func (o *Operation) Run(handler interface{}) { // Clear any body read deadline if one was set as the body has now been // read in. The one exception is when the body is streamed in via an // `io.Reader` so we don't reset the deadline for that. - if conn != nil && o.requestModel != readerType { + if conn != nil && reqDef != nil && reqDef.model != readerType { conn.SetReadDeadline(time.Time{}) } diff --git a/patch.go b/patch.go new file mode 100644 index 00000000..a169a010 --- /dev/null +++ b/patch.go @@ -0,0 +1,266 @@ +package huma + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "strings" + + "github.com/danielgtaylor/huma/schema" + jsonpatch "github.com/evanphx/json-patch/v5" +) + +// jsonPatchOp describes an RFC 6902 JSON Patch operation. See also: +// https://www.rfc-editor.org/rfc/rfc6902 +type jsonPatchOp struct { + Op string `json:"op" enum:"add,remove,replace,move,copy,test" doc:"Operation name"` + From string `json:"from,omitempty" doc:"JSON Pointer for the source of a move or copy"` + Path string `json:"path" doc:"JSON Pointer to the field being operated on, or the destination of a move/copy operation"` + Value interface{} `json:"value,omitempty" doc:"The value to set"` +} + +var jsonPatchType = reflect.TypeOf([]jsonPatchOp{}) +var jsonPatchSchema, _ = schema.Generate(jsonPatchType) + +// allResources recursively collects and returns all resources/sub-resources +// attached to a router. +func (r *Router) allResources() []*Resource { + resources := []*Resource{} + resources = append(resources, r.resources...) + + for i := 0; i < len(resources); i++ { + if len(resources[i].subResources) > 0 { + resources = append(resources, resources[i].subResources...) + } + } + + return resources +} + +// makeAllOptional recursively makes all fields in a schema optional, useful +// for allowing PATCH operation on just some fields. +func makeAllOptional(s *schema.Schema) { + if s.Required != nil { + s.Required = []string{} + } + + if s.Items != nil { + makeAllOptional(s.Items) + } + + for _, props := range []map[string]*schema.Schema{ + s.Properties, + s.PatternProperties, + } { + for _, v := range props { + makeAllOptional(v) + } + } +} + +// AutoPatch generates HTTP PATCH operations for any resource which has a +// GET & PUT but no pre-existing PATCH operation. Generated PATCH operations +// will call GET, apply either `application/merge-patch+json` or +// `application/json-patch+json` patches, then call PUT with the updated +// resource. This method is called automatically on server start-up but can +// be called manually (e.g. for tests) and is idempotent. +func (r *Router) AutoPatch() { + for _, resource := range r.allResources() { + var get *Operation + var put *Operation + hasPatch := false + var kind reflect.Kind = 0 + + for _, op := range resource.operations { + switch op.method { + case http.MethodGet: + get = op + case http.MethodPut: + put = op + _, reqDef := put.requestForContentType("application/json") + if reqDef.model != nil { + kind = reqDef.model.Kind() + if kind == reflect.Ptr { + kind = reqDef.model.Elem().Kind() + } + } + case http.MethodPatch: + hasPatch = true + } + } + + // We need a GET and PUT, but also an object (not array) to patch. + if get != nil && put != nil && !hasPatch && kind == reflect.Struct { + generatePatch(resource, get, put) + } + } +} + +// copyHeaders copies all headers from one header object into another, useful +// for creating a new request with headers that match an existing request. +func copyHeaders(from, to http.Header) { + for k, values := range from { + for _, v := range values { + to.Add(k, v) + } + } +} + +// generatePatch is called for each resource which needs a PATCH operation to +// be added. it registers and provides a handler for this new operation. +func generatePatch(resource *Resource, get *Operation, put *Operation) { + _, reqDef := put.requestForContentType("application/json") + + s, _ := schema.Generate(reqDef.model) + makeAllOptional(s) + + // Guess a name for this patch operation based on the model. + name := "" + if reqDef.model.Kind() == reflect.Struct { + name = reqDef.model.Name() + } + if reqDef.model.Kind() == reflect.Ptr { + name = reqDef.model.Elem().Name() + } + + // Augment the response list with ones we may return from the PATCH. + responses := append([]Response{}, put.responses...) + for _, code := range []int{ + http.StatusNotModified, + http.StatusBadRequest, + http.StatusUnprocessableEntity, + http.StatusUnsupportedMediaType, + } { + found := false + for _, resp := range responses { + if resp.status == code { + found = true + break + } + } + if !found { + responses = append(responses, NewResponse(code, http.StatusText(code)).Model(&ErrorModel{})) + } + } + + // Manually register the operation so it shows up in the generated OpenAPI. + resource.operations = append(resource.operations, &Operation{ + resource: resource, + method: http.MethodPatch, + id: "patch-" + name, + summary: "Patch " + name, + params: get.params, + requests: map[string]*request{ + "application/merge-patch+json": { + override: true, + schema: s, + model: reqDef.model, + }, + "application/json-patch+json": { + override: true, + schema: jsonPatchSchema, + model: jsonPatchType, + }, + }, + responses: responses, + }) + + // Manually register the handler with the router. This bypasses the normal + // Huma API since this is easier and we are just calling the other pre-existing + // operations. + resource.router.mux.Patch(resource.path, func(w http.ResponseWriter, r *http.Request) { + ctx := ContextFromRequest(w, r) + + patchData, err := ioutil.ReadAll(r.Body) + if err != nil { + ctx.WriteError(http.StatusBadRequest, "Unable to read request body", err) + return + } + + // Perform the get! + origReq, err := http.NewRequest(http.MethodGet, r.URL.Path, nil) + if err != nil { + ctx.WriteError(http.StatusBadRequest, "Unable to get resource", err) + return + } + copyHeaders(r.Header, origReq.Header) + origReq.Header.Set("Accept", "application/json") + origReq.Header.Set("Accept-Encoding", "") + + // Conditional request headers will be used on the write side, so ignore + // them on the read. + origReq.Header.Del("If-Match") + origReq.Header.Del("If-None-Match") + origReq.Header.Del("If-Modified-Since") + origReq.Header.Del("If-Unmodified-Since") + + origWriter := httptest.NewRecorder() + resource.router.ServeHTTP(origWriter, origReq) + + if origWriter.Code >= 300 { + // This represents an error on the GET side. + copyHeaders(origWriter.Header(), w.Header()) + w.WriteHeader(origWriter.Code) + w.Write(origWriter.Body.Bytes()) + return + } + + // Patch the data! + var patched []byte + switch strings.Split(r.Header.Get("Content-Type"), ";")[0] { + case "application/json-patch+json": + patch, err := jsonpatch.DecodePatch(patchData) + if err != nil { + ctx.WriteError(http.StatusUnprocessableEntity, "Unable to decode patch", err) + return + } + patched, err = patch.Apply(origWriter.Body.Bytes()) + if err != nil { + ctx.WriteError(http.StatusUnprocessableEntity, "Unable to apply patch", err) + return + } + case "application/merge-patch+json", "application/json", "": + // Assume most cases are merge-patch. + patched, err = jsonpatch.MergePatch(origWriter.Body.Bytes(), patchData) + if err != nil { + ctx.WriteError(http.StatusUnprocessableEntity, "Unable to apply patch", err) + return + } + default: + // A content type we explicitly do not support was passed. + ctx.WriteError(http.StatusUnsupportedMediaType, "Content type should be one of application/merge-patch+json or application/json-patch+json") + return + } + + if bytes.Compare(patched, origWriter.Body.Bytes()) == 0 { + ctx.WriteHeader(http.StatusNotModified) + return + } + + // Write the updated data back to the server! + putReq, err := http.NewRequest(http.MethodPut, r.URL.Path, bytes.NewReader(patched)) + if err != nil { + ctx.WriteError(http.StatusInternalServerError, "Unable to put modified resource", err) + } + copyHeaders(r.Header, putReq.Header) + + h := putReq.Header + if h.Get("If-Match") == "" && h.Get("If-None-Match") == "" && h.Get("If-Unmodified-Since") == "" && h.Get("If-Modified-Since") == "" { + // No conditional headers have been set on the request. Can we set one? + // If we have an ETag or last modified time then we can set a corresponding + // conditional request header to prevent overwriting someone else's + // changes between when we did our GET and are doing our PUT. + // Distributed write failures will result in a 412 Precondition Failed. + oh := origWriter.Header() + if etag := oh.Get("ETag"); etag != "" { + h.Set("If-Match", etag) + } else if modified := oh.Get("Last-Modified"); modified != "" { + h.Set("If-Unmodified-Since", modified) + } + } + + resource.router.ServeHTTP(w, putReq) + }) +} diff --git a/patch_test.go b/patch_test.go new file mode 100644 index 00000000..fe978b69 --- /dev/null +++ b/patch_test.go @@ -0,0 +1,172 @@ +package huma + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type SaleModel struct { + Location string `json:"location"` + Count int `json:"count"` +} + +func (m SaleModel) String() string { + return fmt.Sprintf("%s%d", m.Location, m.Count) +} + +type ThingModel struct { + ID string `json:"id"` + Price float32 `json:"price,omitempty"` + Sales []SaleModel `json:"sales,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +func (m ThingModel) ETag() string { + return fmt.Sprintf("%s%v%v%v", m.ID, m.Price, m.Sales, m.Tags) +} + +type ThingIDParam struct { + ThingID string `path:"thing-id"` +} + +func TestPatch(t *testing.T) { + db := map[string]*ThingModel{ + "test": { + ID: "test", + Price: 1.00, + Sales: []SaleModel{ + {Location: "US", Count: 123}, + {Location: "EU", Count: 456}, + }, + }, + } + + app := newTestRouter() + + things := app.Resource("/things/{thing-id}") + + // Create the necessary GET/PUT + things.Get("get-thing", "docs", + NewResponse(http.StatusOK, "OK").Headers("ETag").Model(&ThingModel{}), + NewResponse(http.StatusNotFound, "Not Found"), + NewResponse(http.StatusPreconditionFailed, "Failed"), + ).Run(func(ctx Context, input struct { + ThingIDParam + }) { + t := db[input.ThingID] + if t == nil { + ctx.WriteError(http.StatusNotFound, "Not found") + return + } + ctx.Header().Set("ETag", t.ETag()) + ctx.WriteModel(http.StatusOK, t) + }) + + things.Put("put-thing", "docs", + NewResponse(http.StatusOK, "OK").Headers("ETag").Model(&ThingModel{}), + NewResponse(http.StatusPreconditionFailed, "Precondition failed").Model(&ErrorModel{}), + ).Run(func(ctx Context, input struct { + ThingIDParam + Body ThingModel + IfMatch []string `header:"If-Match" doc:"Succeeds if the server's resource matches one of the passed values."` + }) { + if len(input.IfMatch) > 0 { + found := false + if existing := db[input.ThingID]; existing != nil { + for _, possible := range input.IfMatch { + if possible == existing.ETag() { + found = true + break + } + } + } + if !found { + ctx.WriteError(http.StatusPreconditionFailed, "ETag does not match") + return + } + } else { + // Since the GET returns an ETag, and the auto-patch feature should always + // use it when available, we can fail the test if we ever get here. + t.Fatal("No If-Match header set during PUT") + } + db[input.ThingID] = &input.Body + ctx.Header().Set("ETag", db[input.ThingID].ETag()) + ctx.WriteModel(http.StatusOK, db[input.ThingID]) + }) + + // Merge Patch Test + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`{"price": 1.23}`)) + req.Header.Set("Content-Type", "application/merge-patch+json") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + assert.Equal(t, "test1.23[US123 EU456][]", w.Result().Header.Get("ETag")) + + // Same change results in a 304 (patches are idempotent) + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`{"price": 1.23}`)) + req.Header.Set("Content-Type", "application/merge-patch+json") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotModified, w.Code, w.Body.String()) + + // New change but with wrong manual ETag, should fail! + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`{"price": 4.56}`)) + req.Header.Set("Content-Type", "application/merge-patch+json") + req.Header.Set("If-Match", "abc123") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusPreconditionFailed, w.Code, w.Body.String()) + + // Correct manual ETag should pass! + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`{"price": 4.56}`)) + req.Header.Set("Content-Type", "application/merge-patch+json") + req.Header.Set("If-Match", "test1.23[US123 EU456][]") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + assert.Equal(t, "test4.56[US123 EU456][]", w.Result().Header.Get("ETag")) + + // Merge Patch: invalid + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`{`)) + req.Header.Set("Content-Type", "application/merge-patch+json") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnprocessableEntity, w.Code, w.Body.String()) + + // JSON Patch Test + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`[ + {"op": "add", "path": "/tags", "value": ["b"]}, + {"op": "add", "path": "/tags/0", "value": "a"} + ]`)) + req.Header.Set("Content-Type", "application/json-patch+json") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + assert.Equal(t, "test4.56[US123 EU456][a b]", w.Result().Header.Get("ETag")) + + // JSON Patch: bad JSON + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`[`)) + req.Header.Set("Content-Type", "application/json-patch+json") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnprocessableEntity, w.Code, w.Body.String()) + + // JSON Patch: invalid patch + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`[{"op": "unsupported"}]`)) + req.Header.Set("Content-Type", "application/json-patch+json") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnprocessableEntity, w.Code, w.Body.String()) + + // Bad content type + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPatch, "/things/test", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/unsupported-content-type") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnsupportedMediaType, w.Code, w.Body.String()) +} diff --git a/resolver.go b/resolver.go index 9f81e65f..0cc3d2b1 100644 --- a/resolver.go +++ b/resolver.go @@ -170,7 +170,7 @@ func parseParamValue(ctx Context, location string, name string, typ reflect.Type return pv } -func setFields(ctx *hcontext, req *http.Request, input reflect.Value, t reflect.Type) { +func setFields(ctx *hcontext, req *http.Request, input reflect.Value, t reflect.Type, ct string, reqDef *request) { if t.Kind() == reflect.Ptr { t = t.Elem() } @@ -189,11 +189,11 @@ func setFields(ctx *hcontext, req *http.Request, input reflect.Value, t reflect. if f.Anonymous { // Embedded struct - setFields(ctx, req, inField, f.Type) + setFields(ctx, req, inField, f.Type, ct, reqDef) continue } - if _, ok := f.Tag.Lookup(locationBody); ok || f.Name == strings.Title(locationBody) { + if fct, ok := f.Tag.Lookup(locationBody); (ok && (fct == ct || fct == "true")) || f.Name == strings.Title(locationBody) { // Special case: body field is a reader for streaming if f.Type == readerType { inField.Set(reflect.ValueOf(req.Body)) @@ -237,8 +237,8 @@ func setFields(ctx *hcontext, req *http.Request, input reflect.Value, t reflect. continue } - if ctx.op.requestSchema != nil && ctx.op.requestSchema.HasValidation() { - if !validAgainstSchema(ctx, locationBody+".", ctx.op.requestSchema, data) { + if reqDef.schema != nil && reqDef.schema.HasValidation() { + if !validAgainstSchema(ctx, locationBody+".", reqDef.schema, data) { continue } } diff --git a/router.go b/router.go index 9ee76177..3295bab9 100644 --- a/router.go +++ b/router.go @@ -71,6 +71,9 @@ type Router struct { // Information for creating non-relative links & schema refs. urlPrefix string disableSchemaProperty bool + + // Turn off auto-generation of HTTP PATCH operations + disableAutoPatch bool } // OpenAPI returns an OpenAPI 3 representation of the API, which can be @@ -323,6 +326,11 @@ func replaceRef(schema map[string]interface{}, from, to string) { // Set up the docs & OpenAPI routes. func (r *Router) setupDocs() { + if !r.disableAutoPatch { + // Generate PATCH methods before generating the OpenAPI or docs. + r.AutoPatch() + } + // Precompute the OpenAPI document once on startup and then serve the cached // version of it. spec := r.OpenAPI() @@ -474,6 +482,12 @@ func (r *Router) DisableSchemaProperty() { r.disableSchemaProperty = true } +// DisableAutoPatch disables the automatic generation of HTTP PATCH operations +// whenever a GET/PUT combo exists without a pre-existing PATCH. +func (r *Router) DisableAutoPatch() { + r.disableAutoPatch = true +} + const ( DefaultDocsSuffix = "docs" DefaultSchemasSuffix = "schemas" diff --git a/router_test.go b/router_test.go index 642cf2b1..521ec785 100644 --- a/router_test.go +++ b/router_test.go @@ -546,3 +546,62 @@ func TestRoundTrip(t *testing.T) { assert.Equal(t, http.StatusOK, w.Result().StatusCode) } + +func TestRequestContentTypes(t *testing.T) { + app := newTestRouter() + app.DisableSchemaProperty() + resource := app.Resource("/") + + type ThingV1 struct { + First string `json:"first"` + Last string `json:"last"` + } + + type ThingV2 struct { + Name string `json:"name"` + } + + resource.Post("post-root", "docs", + NewResponse(http.StatusOK, "").Model(&ThingV2{}), + ).Run(func(ctx Context, input struct { + BodyV2 *ThingV2 `body:"application/thingv2+json"` + BodyV1 *ThingV1 `body:"application/thingv1+json"` + }) { + var thing *ThingV2 + if input.BodyV2 != nil { + thing = input.BodyV2 + } else if input.BodyV1 != nil { + thing = &ThingV2{ + Name: input.BodyV1.First + " " + input.BodyV1.Last, + } + } + ctx.WriteModel(http.StatusOK, thing) + }) + + // Version 1 should work + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"first": "one", "last": "two"}`)) + req.Header.Set("Content-Type", "application/thingv1+json") + app.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + assert.JSONEq(t, `{"name": "one two"}`, w.Body.String()) + + // Version 2 (explicit) should work + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"name": "one two"}`)) + req.Header.Set("Content-Type", "application/thingv2+json") + app.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + assert.JSONEq(t, `{"name": "one two"}`, w.Body.String()) + + // Version 2 (implicit, missing content type) should work by selecting the + // first defined body (ThingV2). + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"name": "one two"}`)) + app.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + assert.JSONEq(t, `{"name": "one two"}`, w.Body.String()) +}