diff --git a/huma.go b/huma.go index 32e075f8..21d8725f 100644 --- a/huma.go +++ b/huma.go @@ -16,6 +16,7 @@ import ( "io" "net" "net/http" + "net/url" "reflect" "regexp" "slices" @@ -1127,6 +1128,16 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) f.Set(reflect.ValueOf(t)) pv = value break + // Special case: url.URL + } else if f.Type() == urlType { + u, err := url.Parse(value) + if err != nil { + res.Add(pb, value, "invalid url.URL value") + return + } + f.Set(reflect.ValueOf(*u)) + pv = value + break } // Last resort: use the `encoding.TextUnmarshaler` interface. diff --git a/huma_test.go b/huma_test.go index 435df780..8f11487b 100644 --- a/huma_test.go +++ b/huma_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" + "net/url" "os" "strings" "testing" @@ -352,6 +353,7 @@ func TestFeatures(t *testing.T) { QueryDefault float32 `query:"def" default:"135" example:"5"` QueryBefore time.Time `query:"before"` QueryDate time.Time `query:"date" timeFormat:"2006-01-02"` + QueryURL url.URL `query:"url"` QueryUint uint32 `query:"uint"` QueryBool bool `query:"bool"` QueryStrings []string `query:"strings"` @@ -383,6 +385,7 @@ func TestFeatures(t *testing.T) { assert.InDelta(t, 135, input.QueryDefault, 0) assert.True(t, input.QueryBefore.Equal(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC))) assert.True(t, input.QueryDate.Equal(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))) + assert.Equal(t, url.URL{Scheme: "http", Host: "foo.com", Path: "/bar"}, input.QueryURL) assert.EqualValues(t, 1, input.QueryUint) assert.True(t, input.QueryBool) assert.Equal(t, []string{"foo", "bar"}, input.QueryStrings) @@ -410,10 +413,10 @@ func TestFeatures(t *testing.T) { assert.Equal(t, "Some docs", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[0].Description) // `http.Cookie` should be treated as a string. - assert.Equal(t, "string", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[28].Schema.Type) + assert.Equal(t, "string", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[29].Schema.Type) }, Method: http.MethodGet, - URL: "/test-params/foo/123/fba4f46b-4539-4d19-8e3f-a0e629a243b5?string=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01&uint=1&bool=true&strings=foo,bar&ints=2,3&ints8=4,5&ints16=4,5&ints32=4,5&ints64=4,5&uints=1,2&uints16=10,15&uints32=10,15&uints64=10,15&floats32=2.2,2.3&floats64=3.2,3.3&exploded=foo&exploded=bar", + URL: "/test-params/foo/123/fba4f46b-4539-4d19-8e3f-a0e629a243b5?string=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01&url=http%3A%2F%2Ffoo.com%2Fbar&uint=1&bool=true&strings=foo,bar&ints=2,3&ints8=4,5&ints16=4,5&ints32=4,5&ints64=4,5&uints=1,2&uints16=10,15&uints32=10,15&uints64=10,15&floats32=2.2,2.3&floats64=3.2,3.3&exploded=foo&exploded=bar", Headers: map[string]string{ "string": "baz", "int": "789", @@ -434,6 +437,7 @@ func TestFeatures(t *testing.T) { QueryFloat float32 `query:"float"` QueryBefore time.Time `query:"before"` QueryDate time.Time `query:"date" timeFormat:"2006-01-02"` + QueryURL url.URL `query:"url"` QueryUint uint32 `query:"uint"` QueryBool bool `query:"bool"` QueryInts []int `query:"ints"` @@ -454,7 +458,7 @@ func TestFeatures(t *testing.T) { }) }, Method: http.MethodGet, - URL: "/test-params/bad/not-a-uuid?int=bad&float=bad&before=bad&date=bad&uint=bad&bool=bad&ints=bad&ints8=bad&ints16=bad&ints32=bad&ints64=bad&uints=bad&uints16=bad&uints32=bad&uints64=bad&floats32=bad&floats64=bad", + URL: "/test-params/bad/not-a-uuid?int=bad&float=bad&before=bad&date=bad&url=:&uint=bad&bool=bad&ints=bad&ints8=bad&ints16=bad&ints32=bad&ints64=bad&uints=bad&uints16=bad&uints32=bad&uints64=bad&floats32=bad&floats64=bad", Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { assert.Equal(t, http.StatusUnprocessableEntity, resp.Code) @@ -462,6 +466,7 @@ func TestFeatures(t *testing.T) { assert.Contains(t, resp.Body.String(), "invalid value: invalid UUID length: 10") assert.Contains(t, resp.Body.String(), "invalid float") assert.Contains(t, resp.Body.String(), "invalid date/time") + assert.Contains(t, resp.Body.String(), "invalid url.URL") assert.Contains(t, resp.Body.String(), "invalid bool") assert.Contains(t, resp.Body.String(), "required query parameter is missing") assert.Contains(t, resp.Body.String(), "required header parameter is missing")