diff --git a/registry.go b/registry.go index 6ee69153..55bb0b8f 100644 --- a/registry.go +++ b/registry.go @@ -59,13 +59,23 @@ func DefaultSchemaNamer(t reflect.Type, hint string) string { return name } +type mapRegistryOption func(*mapRegistry) + +// WithPrimitiveTypeReuse allows primitive types to be reused. +func WithPrimitiveTypeReuse() mapRegistryOption { + return func(r *mapRegistry) { + r.reusePrimitive = true + } +} + type mapRegistry struct { - prefix string - schemas map[string]*Schema - types map[string]reflect.Type - seen map[reflect.Type]bool - namer func(reflect.Type, string) string - aliases map[reflect.Type]reflect.Type + prefix string + schemas map[string]*Schema + types map[string]reflect.Type + seen map[reflect.Type]bool + namer func(reflect.Type, string) string + aliases map[reflect.Type]reflect.Type + reusePrimitive bool } func (r *mapRegistry) Schema(t reflect.Type, allowRef bool, hint string) *Schema { @@ -83,6 +93,13 @@ func (r *mapRegistry) Schema(t reflect.Type, allowRef bool, hint string) *Schema } getsRef := t.Kind() == reflect.Struct + + isNamedPrimitive := t.PkgPath() != "" && t.Name() != "" && t.Kind() != reflect.Struct + if isNamedPrimitive && r.reusePrimitive { + // Special case: named primitive-based types are reused. + getsRef = true + } + if t == timeType { // Special case: time.Time is always a string. getsRef = false @@ -162,13 +179,19 @@ func (r *mapRegistry) RegisterTypeAlias(t reflect.Type, alias reflect.Type) { // NewMapRegistry creates a new registry that stores schemas in a map and // returns references to them using the given prefix. -func NewMapRegistry(prefix string, namer func(t reflect.Type, hint string) string) Registry { - return &mapRegistry{ +func NewMapRegistry(prefix string, namer func(t reflect.Type, hint string) string, opts ...mapRegistryOption) Registry { + r := mapRegistry{ prefix: prefix, - schemas: map[string]*Schema{}, - types: map[string]reflect.Type{}, - seen: map[reflect.Type]bool{}, - aliases: map[reflect.Type]reflect.Type{}, + schemas: make(map[string]*Schema), + types: make(map[string]reflect.Type), + seen: make(map[reflect.Type]bool), namer: namer, + aliases: make(map[reflect.Type]reflect.Type), } + + for _, opt := range opts { + opt(&r) + } + + return &r } diff --git a/registry_test.go b/registry_test.go index 8fd014fc..297b7b85 100644 --- a/registry_test.go +++ b/registry_test.go @@ -73,3 +73,48 @@ func TestSchemaAlias(t *testing.T) { schemaWithString := registry.Schema(reflect.TypeOf(StructWithString{}), false, "") assert.Equal(t, schemaWithString, schemaWithContainer) } + +func TestReusePrimitiveType(t *testing.T) { + type ( + CustomHeader string + + firstRequest struct { + Header CustomHeader `json:"header" description:"A custom header"` + } + + secondRequest struct { + AnotherHeader CustomHeader `json:"another_header" description:"Another custom header"` + } + ) + + // Default settings + registry := NewMapRegistry("#/components/schemas", DefaultSchemaNamer) + + first := SchemaFromType(registry, reflect.TypeOf(firstRequest{})) + second := SchemaFromType(registry, reflect.TypeOf(secondRequest{})) + + if first.Properties["header"].Ref != "" { + t.Errorf("Expected header to be defined inline, but got a ref: %s", first.Properties["header"].Ref) + } + if second.Properties["another_header"].Ref != "" { + t.Errorf("Expected another_header to be defined inline, but got a ref: %s", second.Properties["another_header"].Ref) + } + + // Reusing primitive types enabled + registry = NewMapRegistry("#/components/schemas", DefaultSchemaNamer, WithPrimitiveTypeReuse()) + + first = SchemaFromType(registry, reflect.TypeOf(firstRequest{})) + second = SchemaFromType(registry, reflect.TypeOf(secondRequest{})) + + if first.Properties["header"].Ref == "" { + t.Errorf("Expected header to use a ref, but it's defined inline") + } + if second.Properties["another_header"].Ref == "" { + t.Errorf("Expected another_header to use a ref, but it's defined inline") + } + + if first.Properties["header"].Ref != second.Properties["another_header"].Ref { + t.Errorf("Expected both properties to use the same ref, but got %s and %s", + first.Properties["header"].Ref, second.Properties["another_header"].Ref) + } +}