这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 213 additions & 56 deletions humacli/humacli.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/danielgtaylor/huma/v2/casing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

func deref(t reflect.Type) reflect.Type {
Expand Down Expand Up @@ -89,6 +90,128 @@ type cli[Options any] struct {
stop func()
}

// getStringValue returns a string value respecting precedence: CLI arg > ENV var > default
func getStringValue(flags *pflag.FlagSet, flagName, envValue string, hasEnv bool) string {
if flags.Changed(flagName) {
// CLI arg provided
value, _ := flags.GetString(flagName)
return value
} else if hasEnv {
// Environment variable provided
return envValue
}
// Default value from flag
value, _ := flags.GetString(flagName)
return value
}

// getIntValue returns an int-like value respecting precedence: CLI arg > ENV var > default
// It also handles duration types.
func getIntValue(flags *pflag.FlagSet, flagName, envValue string, hasEnv bool, isDuration bool) any {
if flags.Changed(flagName) {
// CLI arg provided
if isDuration {
value, _ := flags.GetDuration(flagName)
return value
}
value, _ := flags.GetInt64(flagName)
return value
} else if hasEnv {
// Environment variable provided
if isDuration {
value, err := time.ParseDuration(envValue)
if err == nil {
return value
}
// If parsing fails, fall back to default
} else {
value, err := strconv.ParseInt(envValue, 10, 64)
if err == nil {
return value
}
// If parsing fails, fall back to default
}
}

// Default value from flag
if isDuration {
value, _ := flags.GetDuration(flagName)
return value
}
value, _ := flags.GetInt64(flagName)
return value
}

// getBoolValue returns a boolean value respecting precedence: CLI arg > ENV var > default
func getBoolValue(flags *pflag.FlagSet, flagName, envValue string, hasEnv bool) bool {
if flags.Changed(flagName) {
// CLI arg provided
value, _ := flags.GetBool(flagName)
return value
} else if hasEnv {
// Environment variable provided
value, err := strconv.ParseBool(envValue)
if err == nil {
return value
}
// If parsing fails, fall back to default
}
// Default value from flag
value, _ := flags.GetBool(flagName)
return value
}

// getEnvName converts a flag name to the corresponding environment variable name
func getEnvName(flagName string) string {
name := strings.ReplaceAll(flagName, "-", "_")
name = strings.ReplaceAll(name, ".", "_")
return "SERVICE_" + strings.ToUpper(name)
}

// getValueFromType uses the appropriate getter based on the field type
// and returns the value respecting precedence rules.
func getValueFromType(flags *pflag.FlagSet, flagName string, fieldType reflect.Type) (any, bool) {
// Check environment variables
envName := getEnvName(flagName)
envValue, hasEnv := os.LookupEnv(envName)

// Determine the appropriate getter based on type
switch deref(fieldType).Kind() {
case reflect.String:
return getStringValue(flags, flagName, envValue, hasEnv), true
case reflect.Int, reflect.Int64:
isDuration := fieldType == durationType
rawValue := getIntValue(flags, flagName, envValue, hasEnv, isDuration)
return reflect.ValueOf(rawValue).Convert(deref(fieldType)).Interface(), true
case reflect.Bool:
return getBoolValue(flags, flagName, envValue, hasEnv), true
default:
return nil, false
}
}

// getValueFromFlagOrEnv retrieves a value from either environment variables or flags
// based on the field type. It handles converting strings to the appropriate types.
// CLI args take precedence over environment variables.
func getValueFromFlagOrEnv(flags *pflag.FlagSet, opt option, fieldType reflect.Type) reflect.Value {
// Get the value based on type
value, ok := getValueFromType(flags, opt.name, fieldType)
if !ok {
// This shouldn't happen if setupOptions validates types properly
panic(fmt.Sprintf("unsupported type for option %s: %s", opt.name, fieldType.String()))
}

// Create and return proper value
fv := reflect.ValueOf(value)
if fieldType.Kind() == reflect.Ptr {
ptr := reflect.New(fv.Type())
ptr.Elem().Set(fv)
fv = ptr
}

return fv
}

func (c *cli[Options]) Run() {
var o Options

Expand All @@ -100,32 +223,19 @@ func (c *cli[Options]) Run() {
for _, opt := range c.optInfo {
f := v
for _, i := range opt.path {
f = f.Field(i)
}
var fv reflect.Value
switch deref(opt.typ).Kind() {
case reflect.String:
s, _ := flags.GetString(opt.name)
fv = reflect.ValueOf(s)
case reflect.Int, reflect.Int64:
var i any
if opt.typ == durationType {
i, _ = flags.GetDuration(opt.name)
} else {
i, _ = flags.GetInt64(opt.name)
// Check if f is a pointer and dereference it before calling Field
if f.Kind() == reflect.Ptr {
// Initialize nil pointers
if f.IsNil() {
f.Set(reflect.New(f.Type().Elem()))
}
f = f.Elem()
}
fv = reflect.ValueOf(i).Convert(deref(opt.typ))
case reflect.Bool:
b, _ := flags.GetBool(opt.name)
fv = reflect.ValueOf(b)
}

if opt.typ.Kind() == reflect.Ptr {
ptr := reflect.New(fv.Type())
ptr.Elem().Set(fv)
fv = ptr
f = f.Field(i)
}

// Get field value from flag or environment variable
fv := getValueFromFlagOrEnv(flags, opt, opt.typ)
f.Set(fv)
}

Expand Down Expand Up @@ -156,8 +266,58 @@ func (c *cli[O]) OnStop(fn func()) {
c.stop = fn
}

func (c *cli[O]) setupOptions(t reflect.Type, path []int) {
var err error
// registerOption registers an option with the CLI, handling common tasks like
// parsing default values, setting up flags, and storing option metadata.
func (c *cli[O]) registerOption(flags *pflag.FlagSet, field reflect.StructField, currentPath []int, name, defaultValue string) error {
fieldType := deref(field.Type)

// Store option metadata regardless of type
c.optInfo = append(c.optInfo, option{name, field.Type, currentPath})

// Type-specific flag setup and default parsing
switch fieldType.Kind() {
case reflect.String:
flags.StringP(name, field.Tag.Get("short"), defaultValue, field.Tag.Get("doc"))
case reflect.Int, reflect.Int64:
var def int64
if defaultValue != "" {
if fieldType == durationType {
t, err := time.ParseDuration(defaultValue)
if err != nil {
return fmt.Errorf("failed to parse duration for field %s: %w", field.Name, err)
}
def = int64(t)
} else {
var err error
def, err = strconv.ParseInt(defaultValue, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse int for field %s: %w", field.Name, err)
}
}
}
if fieldType == durationType {
flags.DurationP(name, field.Tag.Get("short"), time.Duration(def), field.Tag.Get("doc"))
} else {
flags.Int64P(name, field.Tag.Get("short"), def, field.Tag.Get("doc"))
}
case reflect.Bool:
var def bool
if defaultValue != "" {
var err error
def, err = strconv.ParseBool(defaultValue)
if err != nil {
return fmt.Errorf("failed to parse bool for field %q: %w", field.Name, err)
}
}
flags.BoolP(name, field.Tag.Get("short"), def, field.Tag.Get("doc"))
default:
return fmt.Errorf("unsupported option type for field %q: %q", field.Name, field.Type.Kind().String())
}

return nil
}

func (c *cli[O]) setupOptions(t reflect.Type, path []int, prefix string) error {
flags := c.root.PersistentFlags()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
Expand All @@ -175,7 +335,7 @@ func (c *cli[O]) setupOptions(t reflect.Type, path []int) {
fieldType := deref(field.Type)
if field.Anonymous {
// Embedded struct. This enables composition from e.g. company defaults.
c.setupOptions(fieldType, currentPath)
c.setupOptions(fieldType, currentPath, prefix)
continue
}

Expand All @@ -184,50 +344,45 @@ func (c *cli[O]) setupOptions(t reflect.Type, path []int) {
name = casing.Kebab(field.Name)
}

envName := "SERVICE_" + casing.Snake(name, strings.ToUpper)
// Apply prefix for nested fields
if prefix != "" {
name = prefix + "." + name
}

// Convert dotted names to snake case with underscores for env vars
envName := "SERVICE_" + casing.Snake(strings.ReplaceAll(name, ".", "_"), strings.ToUpper)
defaultValue := field.Tag.Get("default")
if v, ok := os.LookupEnv(envName); ok {
// Env vars will override the default value, which is used to document
// what the value is if no options are passed.
defaultValue = v
}

c.optInfo = append(c.optInfo, option{name, field.Type, currentPath})
switch fieldType.Kind() {
case reflect.String:
flags.StringP(name, field.Tag.Get("short"), defaultValue, field.Tag.Get("doc"))
case reflect.Int, reflect.Int64:
var def int64
if defaultValue != "" {
if fieldType == durationType {
var t time.Duration
t, err = time.ParseDuration(defaultValue)
def = int64(t)
} else {
def, err = strconv.ParseInt(defaultValue, 10, 64)
}
if err != nil {
panic(err)
}
case reflect.String, reflect.Int, reflect.Int64, reflect.Bool:
if err := c.registerOption(flags, field, currentPath, name, defaultValue); err != nil {
return fmt.Errorf("failed to register option %q: %w", field.Name, err)
}
if fieldType == durationType {
flags.DurationP(name, field.Tag.Get("short"), time.Duration(def), field.Tag.Get("doc"))
} else {
flags.Int64P(name, field.Tag.Get("short"), def, field.Tag.Get("doc"))
case reflect.Struct:
// For nested structs, recurse and pass the current name as a prefix
if err := c.setupOptions(fieldType, currentPath, name); err != nil {
return fmt.Errorf("failed to setup options for field %q: %w", field.Name, err)
}
case reflect.Bool:
var def bool
if defaultValue != "" {
def, err = strconv.ParseBool(defaultValue)
if err != nil {
panic(err)
case reflect.Ptr:
// If it's a pointer to a struct, handle it like a struct after dereferencing
if fieldType.Kind() == reflect.Struct {
if err := c.setupOptions(fieldType, currentPath, name); err != nil {
return fmt.Errorf("failed to setup options for field %q: %w", field.Name, err)
}
} else {
return fmt.Errorf("unsupported option type for field %q: pointer to %q", field.Name, fieldType.Kind().String())
}
flags.BoolP(name, field.Tag.Get("short"), def, field.Tag.Get("doc"))
default:
panic("Unsupported option type: " + field.Type.Kind().String())
return fmt.Errorf("unsupported option type for field %q: %q", field.Name, field.Type.Kind().String())
}
}

return nil
}

// New creates a new CLI. The `onParsed` callback is called after the command
Expand Down Expand Up @@ -278,7 +433,9 @@ func New[O any](onParsed func(Hooks, *O)) CLI {
}

var o O
c.setupOptions(reflect.TypeOf(o), []int{})
if err := c.setupOptions(reflect.TypeOf(o), []int{}, ""); err != nil {
panic(err)
}

c.root.Run = func(cmd *cobra.Command, args []string) {
done := make(chan struct{}, 1)
Expand Down
Loading