diff --git a/humacli/humacli.go b/humacli/humacli.go index 2e648710..2ffa5a8f 100644 --- a/humacli/humacli.go +++ b/humacli/humacli.go @@ -14,6 +14,7 @@ import ( "github.com/danielgtaylor/huma/v2/casing" "github.com/spf13/cobra" + "github.com/spf13/pflag" ) func deref(t reflect.Type) reflect.Type { @@ -89,6 +90,128 @@ type cli[Options any] struct { stop func() } +// getStringValue returns a string value respecting precedence: CLI arg > ENV var > default +func getStringValue(flags *pflag.FlagSet, flagName, envValue string, hasEnv bool) string { + if flags.Changed(flagName) { + // CLI arg provided + value, _ := flags.GetString(flagName) + return value + } else if hasEnv { + // Environment variable provided + return envValue + } + // Default value from flag + value, _ := flags.GetString(flagName) + return value +} + +// getIntValue returns an int-like value respecting precedence: CLI arg > ENV var > default +// It also handles duration types. +func getIntValue(flags *pflag.FlagSet, flagName, envValue string, hasEnv bool, isDuration bool) any { + if flags.Changed(flagName) { + // CLI arg provided + if isDuration { + value, _ := flags.GetDuration(flagName) + return value + } + value, _ := flags.GetInt64(flagName) + return value + } else if hasEnv { + // Environment variable provided + if isDuration { + value, err := time.ParseDuration(envValue) + if err == nil { + return value + } + // If parsing fails, fall back to default + } else { + value, err := strconv.ParseInt(envValue, 10, 64) + if err == nil { + return value + } + // If parsing fails, fall back to default + } + } + + // Default value from flag + if isDuration { + value, _ := flags.GetDuration(flagName) + return value + } + value, _ := flags.GetInt64(flagName) + return value +} + +// getBoolValue returns a boolean value respecting precedence: CLI arg > ENV var > default +func getBoolValue(flags *pflag.FlagSet, flagName, envValue string, hasEnv bool) bool { + if flags.Changed(flagName) { + // CLI arg provided + value, _ := flags.GetBool(flagName) + return value + } else if hasEnv { + // Environment variable provided + value, err := strconv.ParseBool(envValue) + if err == nil { + return value + } + // If parsing fails, fall back to default + } + // Default value from flag + value, _ := flags.GetBool(flagName) + return value +} + +// getEnvName converts a flag name to the corresponding environment variable name +func getEnvName(flagName string) string { + name := strings.ReplaceAll(flagName, "-", "_") + name = strings.ReplaceAll(name, ".", "_") + return "SERVICE_" + strings.ToUpper(name) +} + +// getValueFromType uses the appropriate getter based on the field type +// and returns the value respecting precedence rules. +func getValueFromType(flags *pflag.FlagSet, flagName string, fieldType reflect.Type) (any, bool) { + // Check environment variables + envName := getEnvName(flagName) + envValue, hasEnv := os.LookupEnv(envName) + + // Determine the appropriate getter based on type + switch deref(fieldType).Kind() { + case reflect.String: + return getStringValue(flags, flagName, envValue, hasEnv), true + case reflect.Int, reflect.Int64: + isDuration := fieldType == durationType + rawValue := getIntValue(flags, flagName, envValue, hasEnv, isDuration) + return reflect.ValueOf(rawValue).Convert(deref(fieldType)).Interface(), true + case reflect.Bool: + return getBoolValue(flags, flagName, envValue, hasEnv), true + default: + return nil, false + } +} + +// getValueFromFlagOrEnv retrieves a value from either environment variables or flags +// based on the field type. It handles converting strings to the appropriate types. +// CLI args take precedence over environment variables. +func getValueFromFlagOrEnv(flags *pflag.FlagSet, opt option, fieldType reflect.Type) reflect.Value { + // Get the value based on type + value, ok := getValueFromType(flags, opt.name, fieldType) + if !ok { + // This shouldn't happen if setupOptions validates types properly + panic(fmt.Sprintf("unsupported type for option %s: %s", opt.name, fieldType.String())) + } + + // Create and return proper value + fv := reflect.ValueOf(value) + if fieldType.Kind() == reflect.Ptr { + ptr := reflect.New(fv.Type()) + ptr.Elem().Set(fv) + fv = ptr + } + + return fv +} + func (c *cli[Options]) Run() { var o Options @@ -100,32 +223,19 @@ func (c *cli[Options]) Run() { for _, opt := range c.optInfo { f := v for _, i := range opt.path { - f = f.Field(i) - } - var fv reflect.Value - switch deref(opt.typ).Kind() { - case reflect.String: - s, _ := flags.GetString(opt.name) - fv = reflect.ValueOf(s) - case reflect.Int, reflect.Int64: - var i any - if opt.typ == durationType { - i, _ = flags.GetDuration(opt.name) - } else { - i, _ = flags.GetInt64(opt.name) + // Check if f is a pointer and dereference it before calling Field + if f.Kind() == reflect.Ptr { + // Initialize nil pointers + if f.IsNil() { + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() } - fv = reflect.ValueOf(i).Convert(deref(opt.typ)) - case reflect.Bool: - b, _ := flags.GetBool(opt.name) - fv = reflect.ValueOf(b) - } - - if opt.typ.Kind() == reflect.Ptr { - ptr := reflect.New(fv.Type()) - ptr.Elem().Set(fv) - fv = ptr + f = f.Field(i) } + // Get field value from flag or environment variable + fv := getValueFromFlagOrEnv(flags, opt, opt.typ) f.Set(fv) } @@ -156,8 +266,58 @@ func (c *cli[O]) OnStop(fn func()) { c.stop = fn } -func (c *cli[O]) setupOptions(t reflect.Type, path []int) { - var err error +// registerOption registers an option with the CLI, handling common tasks like +// parsing default values, setting up flags, and storing option metadata. +func (c *cli[O]) registerOption(flags *pflag.FlagSet, field reflect.StructField, currentPath []int, name, defaultValue string) error { + fieldType := deref(field.Type) + + // Store option metadata regardless of type + c.optInfo = append(c.optInfo, option{name, field.Type, currentPath}) + + // Type-specific flag setup and default parsing + switch fieldType.Kind() { + case reflect.String: + flags.StringP(name, field.Tag.Get("short"), defaultValue, field.Tag.Get("doc")) + case reflect.Int, reflect.Int64: + var def int64 + if defaultValue != "" { + if fieldType == durationType { + t, err := time.ParseDuration(defaultValue) + if err != nil { + return fmt.Errorf("failed to parse duration for field %s: %w", field.Name, err) + } + def = int64(t) + } else { + var err error + def, err = strconv.ParseInt(defaultValue, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse int for field %s: %w", field.Name, err) + } + } + } + if fieldType == durationType { + flags.DurationP(name, field.Tag.Get("short"), time.Duration(def), field.Tag.Get("doc")) + } else { + flags.Int64P(name, field.Tag.Get("short"), def, field.Tag.Get("doc")) + } + case reflect.Bool: + var def bool + if defaultValue != "" { + var err error + def, err = strconv.ParseBool(defaultValue) + if err != nil { + return fmt.Errorf("failed to parse bool for field %q: %w", field.Name, err) + } + } + flags.BoolP(name, field.Tag.Get("short"), def, field.Tag.Get("doc")) + default: + return fmt.Errorf("unsupported option type for field %q: %q", field.Name, field.Type.Kind().String()) + } + + return nil +} + +func (c *cli[O]) setupOptions(t reflect.Type, path []int, prefix string) error { flags := c.root.PersistentFlags() for i := 0; i < t.NumField(); i++ { field := t.Field(i) @@ -175,7 +335,7 @@ func (c *cli[O]) setupOptions(t reflect.Type, path []int) { fieldType := deref(field.Type) if field.Anonymous { // Embedded struct. This enables composition from e.g. company defaults. - c.setupOptions(fieldType, currentPath) + c.setupOptions(fieldType, currentPath, prefix) continue } @@ -184,7 +344,13 @@ func (c *cli[O]) setupOptions(t reflect.Type, path []int) { name = casing.Kebab(field.Name) } - envName := "SERVICE_" + casing.Snake(name, strings.ToUpper) + // Apply prefix for nested fields + if prefix != "" { + name = prefix + "." + name + } + + // Convert dotted names to snake case with underscores for env vars + envName := "SERVICE_" + casing.Snake(strings.ReplaceAll(name, ".", "_"), strings.ToUpper) defaultValue := field.Tag.Get("default") if v, ok := os.LookupEnv(envName); ok { // Env vars will override the default value, which is used to document @@ -192,42 +358,31 @@ func (c *cli[O]) setupOptions(t reflect.Type, path []int) { defaultValue = v } - c.optInfo = append(c.optInfo, option{name, field.Type, currentPath}) switch fieldType.Kind() { - case reflect.String: - flags.StringP(name, field.Tag.Get("short"), defaultValue, field.Tag.Get("doc")) - case reflect.Int, reflect.Int64: - var def int64 - if defaultValue != "" { - if fieldType == durationType { - var t time.Duration - t, err = time.ParseDuration(defaultValue) - def = int64(t) - } else { - def, err = strconv.ParseInt(defaultValue, 10, 64) - } - if err != nil { - panic(err) - } + case reflect.String, reflect.Int, reflect.Int64, reflect.Bool: + if err := c.registerOption(flags, field, currentPath, name, defaultValue); err != nil { + return fmt.Errorf("failed to register option %q: %w", field.Name, err) } - if fieldType == durationType { - flags.DurationP(name, field.Tag.Get("short"), time.Duration(def), field.Tag.Get("doc")) - } else { - flags.Int64P(name, field.Tag.Get("short"), def, field.Tag.Get("doc")) + case reflect.Struct: + // For nested structs, recurse and pass the current name as a prefix + if err := c.setupOptions(fieldType, currentPath, name); err != nil { + return fmt.Errorf("failed to setup options for field %q: %w", field.Name, err) } - case reflect.Bool: - var def bool - if defaultValue != "" { - def, err = strconv.ParseBool(defaultValue) - if err != nil { - panic(err) + case reflect.Ptr: + // If it's a pointer to a struct, handle it like a struct after dereferencing + if fieldType.Kind() == reflect.Struct { + if err := c.setupOptions(fieldType, currentPath, name); err != nil { + return fmt.Errorf("failed to setup options for field %q: %w", field.Name, err) } + } else { + return fmt.Errorf("unsupported option type for field %q: pointer to %q", field.Name, fieldType.Kind().String()) } - flags.BoolP(name, field.Tag.Get("short"), def, field.Tag.Get("doc")) default: - panic("Unsupported option type: " + field.Type.Kind().String()) + return fmt.Errorf("unsupported option type for field %q: %q", field.Name, field.Type.Kind().String()) } } + + return nil } // New creates a new CLI. The `onParsed` callback is called after the command @@ -278,7 +433,9 @@ func New[O any](onParsed func(Hooks, *O)) CLI { } var o O - c.setupOptions(reflect.TypeOf(o), []int{}) + if err := c.setupOptions(reflect.TypeOf(o), []int{}, ""); err != nil { + panic(err) + } c.root.Run = func(cmd *cobra.Command, args []string) { done := make(chan struct{}, 1) diff --git a/humacli/humacli_test.go b/humacli/humacli_test.go index 1f27b318..e8ced9e6 100644 --- a/humacli/humacli_test.go +++ b/humacli/humacli_test.go @@ -249,3 +249,71 @@ func TestCLIBadDefaults(t *testing.T) { humacli.New(func(hooks humacli.Hooks, options *OptionsInt) {}) }) } + +func TestCLINestedOptions(t *testing.T) { + type OptionsA struct { + One int `name:"one"` + } + + type OptionsB struct { + Two int `name:"two"` + APtr *OptionsA `name:"a-ptr"` + ADirect OptionsA `name:"a-direct"` + } + + t.Run("cli", func(t *testing.T) { + cli := humacli.New(func(hooks humacli.Hooks, options *OptionsB) { + assert.Equal(t, 1, options.APtr.One) + assert.Equal(t, 2, options.ADirect.One) + assert.Equal(t, 3, options.Two) + hooks.OnStart(func() {}) + }) + + cli.Root().SetArgs([]string{ + "--a-ptr.one", "1", + "--a-direct.one", "2", + "--two", "3", + }) + cli.Run() + }) + + t.Run("env", func(t *testing.T) { + cli := humacli.New(func(hooks humacli.Hooks, options *OptionsB) { + assert.Equal(t, 4, options.APtr.One) + assert.Equal(t, 5, options.ADirect.One) + assert.Equal(t, 6, options.Two) + hooks.OnStart(func() {}) + }) + + t.Setenv("SERVICE_A_PTR_ONE", "4") + t.Setenv("SERVICE_A_DIRECT_ONE", "5") + t.Setenv("SERVICE_TWO", "6") + + cli.Root().SetArgs([]string{}) + cli.Run() + }) +} + +func TestCLIPriority(t *testing.T) { + type Options struct { + WithEnv int `name:"with-env"` + WithFlag int `name:"with-flag"` + WithBoth int `name:"with-both"` + } + + cli := humacli.New(func(hooks humacli.Hooks, options *Options) { + assert.Equal(t, 1, options.WithEnv) + assert.Equal(t, 20, options.WithFlag) + assert.Equal(t, 30, options.WithBoth) + hooks.OnStart(func() {}) + }) + + t.Setenv("SERVICE_WITH_ENV", "1") + t.Setenv("SERVICE_WITH_BOTH", "3") + + cli.Root().SetArgs([]string{ + "--with-flag", "20", + "--with-both", "30", + }) + cli.Run() +}