diff --git a/schema.go b/schema.go index 4f664b7b..33bce4a9 100644 --- a/schema.go +++ b/schema.go @@ -658,6 +658,8 @@ type SchemaTransformer interface { // schema := huma.SchemaFromType(registry, reflect.TypeOf(MyType{})) func SchemaFromType(r Registry, t reflect.Type) *Schema { s := schemaFromType(r, t) + t = deref(t) + // Transform generated schema if type implements SchemaTransformer v := reflect.New(t).Interface() if st, ok := v.(SchemaTransformer); ok { diff --git a/schema_test.go b/schema_test.go index 693170cd..e717c4a4 100644 --- a/schema_test.go +++ b/schema_test.go @@ -73,6 +73,17 @@ func (t *TypedArrayWithCustomDesc) TransformSchema(r huma.Registry, s *huma.Sche return s } +var _ huma.SchemaTransformer = (*CustomSchemaPtr)(nil) + +type CustomSchemaPtr struct { + Value string `json:"value"` +} + +func (c *CustomSchemaPtr) TransformSchema(r huma.Registry, s *huma.Schema) *huma.Schema { + s.Description = "custom description" + return s +} + type TypedStringWithCustomLength string func (c TypedStringWithCustomLength) Schema(r huma.Registry) *huma.Schema { @@ -1005,7 +1016,9 @@ func TestSchema(t *testing.T) { "properties":{ "value":{ "format":"int64", - "type": ["integer", "null"] + "type": ["integer", "null"], + "minimum":1, + "maximum":10 } }, "required":["value"], @@ -1058,6 +1071,21 @@ func TestSchema(t *testing.T) { "type":"object" }`, }, + { + name: "schema-transformer-for-ptr", + input: &CustomSchemaPtr{}, + expected: ` { + "additionalProperties":false, + "description":"custom description", + "properties":{ + "value":{ + "type":"string" + } + }, + "required":["value"], + "type":"object" + }`, + }, } for _, c := range cases { @@ -1098,7 +1126,6 @@ func TestSchemaOld(t *testing.T) { r := huma.NewMapRegistry("#/components/schemas/", huma.DefaultSchemaNamer) s := r.Schema(reflect.TypeOf(GreetingInput{}), false, "") - // fmt.Printf("%+v\n", s) assert.Equal(t, "object", s.Type) assert.Len(t, s.Properties, 1) assert.Equal(t, "string", s.Properties["ID"].Type) @@ -1115,9 +1142,6 @@ func TestSchemaOld(t *testing.T) { }, }, &res) assert.Empty(t, res.Errors) - - // b, _ := json.MarshalIndent(r.Map(), "", " ") - // fmt.Println(string(b)) } func TestSchemaGenericNaming(t *testing.T) {