diff --git a/com.go b/com.go index 86ade70..934bc98 100644 --- a/com.go +++ b/com.go @@ -12,6 +12,7 @@ package commandeer import ( + "errors" "flag" "fmt" "net" @@ -42,18 +43,39 @@ import ( // field. It should be a single ascii character. This will only be used if the // Flagger is also a PFlagger. func Flags(flags Flagger, main interface{}) error { + _, err := FlagsMap(flags, main) + return err +} + +// FlagsMap is like Flags, but can also handle maps. Since map values aren't +// addressable, though, this means it has to have a followup pass where it +// takes any updates that came from values being set, and copy them back into +// the map. It returns a function that will do this. +func FlagsMap(flags Flagger, main interface{}) (cleanup func(), err error) { typ := reflect.TypeOf(main) - if typ.Kind() != reflect.Ptr { - return fmt.Errorf("value must be pointer to struct, but is %s", typ.Kind()) - } + var mainVal reflect.Value + var mainTyp reflect.Type + tracker := newFlagTracker(flags) - mainVal := reflect.ValueOf(main).Elem() - mainTyp := mainVal.Type() - if mainTyp.Kind() != reflect.Struct { - return fmt.Errorf("value must be pointer to struct, but is pointer to %s", typ.Kind()) + switch typ.Kind() { + case reflect.Map: + // you don't need to pointer-to a map, because maps are already + // pointer-like. + return tracker.cleanup, setMapFlags(tracker, main, "") + case reflect.Ptr: + mainVal = reflect.ValueOf(main).Elem() + mainTyp = mainVal.Type() + default: + return nil, fmt.Errorf("value must be map or pointer to struct, but is %s", typ.Kind()) } - return setFlags(newFlagTracker(flags), main, "") + switch mainTyp.Kind() { + case reflect.Struct: + return tracker.cleanup, setStructFlags(tracker, main, "") + case reflect.Map: + return tracker.cleanup, setMapFlags(tracker, main, "") + } + return nil, fmt.Errorf("value must be pointer to struct, but is pointer to %s", typ.Kind()) } // Run runs "main" which must be a pointer to a struct which implements the @@ -66,7 +88,7 @@ func Run(main interface{}) error { // RunArgs is similar to Run, but the caller must specify their own flag set and // args to be parsed by that flag set. func RunArgs(flags Flagger, main interface{}, args []string) error { - err := Flags(flags, main) + cleanup, err := FlagsMap(flags, main) if err != nil { return fmt.Errorf("calling Flags: %v", err) } @@ -74,14 +96,157 @@ func RunArgs(flags Flagger, main interface{}, args []string) error { if err != nil { return fmt.Errorf("parsing flags: %v", err) } + if cleanup != nil { + cleanup() + } if main, ok := main.(Runner); ok { return main.Run() } - return fmt.Errorf("called 'Run' with something which doesn't implement the 'Run() error' method.") + return fmt.Errorf("called 'Run' with something which doesn't implement the 'Run() error' method") } -func setFlags(flags *flagTracker, main interface{}, prefix string) error { +func setFlag(flags *flagTracker, f reflect.Value, typ reflect.Type, ft *reflect.StructField, flagName string, shorthand string) error { + // first check supported concrete types + switch f.Interface().(type) { + case time.Duration: + p := f.Addr().Interface().(*time.Duration) + flags.duration(p, flagName, shorthand, time.Duration(f.Int()), flagHelp(ft)) + return nil + case net.IPMask: + if !flags.pflag { + return fmt.Errorf("cannot support net.IPMask field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*net.IPMask) + flags.ipMask(p, flagName, shorthand, *p, flagHelp(ft)) + return nil + case net.IPNet: + if !flags.pflag { + return fmt.Errorf("cannot support net.IPNet field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*net.IPNet) + flags.ipNet(p, flagName, shorthand, *p, flagHelp(ft)) + return nil + case net.IP: + if !flags.pflag { + return fmt.Errorf("cannot support net.IP field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*net.IP) + flags.ip(p, flagName, shorthand, *p, flagHelp(ft)) + return nil + case []net.IP: + if !flags.pflag { + return fmt.Errorf("cannot support []net.IP field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*[]net.IP) + flags.ipSlice(p, flagName, shorthand, *p, flagHelp(ft)) + return nil + } + + // now check basic kinds + switch typ.Kind() { + case reflect.String: + p := f.Addr().Interface().(*string) + flags.string(p, flagName, shorthand, f.String(), flagHelp(ft)) + case reflect.Bool: + p := f.Addr().Interface().(*bool) + flags.bool(p, flagName, shorthand, f.Bool(), flagHelp(ft)) + case reflect.Int: + p := f.Addr().Interface().(*int) + val := int(f.Int()) + flags.int(p, flagName, shorthand, val, flagHelp(ft)) + case reflect.Int64: + p := f.Addr().Interface().(*int64) + flags.int64(p, flagName, shorthand, f.Int(), flagHelp(ft)) + case reflect.Float64: + p := f.Addr().Interface().(*float64) + flags.float64(p, flagName, shorthand, f.Float(), flagHelp(ft)) + case reflect.Uint: + p := f.Addr().Interface().(*uint) + val := uint(f.Uint()) + flags.uint(p, flagName, shorthand, val, flagHelp(ft)) + case reflect.Uint64: + p := f.Addr().Interface().(*uint64) + flags.uint64(p, flagName, shorthand, f.Uint(), flagHelp(ft)) + case reflect.Slice: + if !flags.pflag { + return fmt.Errorf("cannot support slice field at '%v' with stdlib flag pkg", flagName) + } + switch ft.Type.Elem().Kind() { + case reflect.String: + p := f.Addr().Interface().(*[]string) + flags.stringSlice(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Bool: + p := f.Addr().Interface().(*[]bool) + flags.boolSlice(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Int: + p := f.Addr().Interface().(*[]int) + flags.intSlice(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Uint: + p := f.Addr().Interface().(*[]uint) + flags.uintSlice(p, flagName, shorthand, *p, flagHelp(ft)) + default: + return fmt.Errorf("encountered unsupported slice type/kind: %#v at %s", f, flagName) + } + case reflect.Float32: + if !flags.pflag { + return fmt.Errorf("cannot support float32 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*float32) + flags.float32(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Int16: + if !flags.pflag { + return fmt.Errorf("cannot support int16 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*int16) + flags.int16(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Int32: + if !flags.pflag { + return fmt.Errorf("cannot support int32 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*int32) + flags.int32(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Uint16: + if !flags.pflag { + return fmt.Errorf("cannot support uint16 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*uint16) + flags.uint16(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Uint32: + if !flags.pflag { + return fmt.Errorf("cannot support uint32 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*uint32) + flags.uint32(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Uint8: + if !flags.pflag { + return fmt.Errorf("cannot support uint8 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*uint8) + flags.uint8(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Int8: + if !flags.pflag { + return fmt.Errorf("cannot support int8 field at '%v' with stdlib flag pkg", flagName) + } + p := f.Addr().Interface().(*int8) + flags.int8(p, flagName, shorthand, *p, flagHelp(ft)) + case reflect.Struct: + err := setStructFlags(flags, f.Addr().Interface(), flagName) + if err != nil { + return err + } + case reflect.Map: + err := setMapFlags(flags, f.Interface(), flagName) + if err != nil { + return err + } + default: + return fmt.Errorf("encountered unsupported field type/kind: %#v at %s", f, flagName) + } + return nil +} + +func setStructFlags(flags *flagTracker, main interface{}, prefix string) error { // TODO add tracking of flag names to ensure no duplicates mainVal := reflect.ValueOf(main).Elem() mainTyp := mainVal.Type() @@ -92,7 +257,7 @@ func setFlags(flags *flagTracker, main interface{}, prefix string) error { if ft.PkgPath != "" { continue // this field is unexported } - flagName := flagName(ft) + flagName := structFlagName(ft) if flagName == "-" || flagName == "" { continue // explicitly ignored } @@ -103,152 +268,75 @@ func setFlags(flags *flagTracker, main interface{}, prefix string) error { if prefix != "" { flagName = prefix + "." + flagName } + err = setFlag(flags, f, ft.Type, &ft, flagName, shorthand) + if err != nil { + return err + } + } + return nil +} - // first check supported concrete types - switch f.Interface().(type) { - case time.Duration: - p := f.Addr().Interface().(*time.Duration) - flags.duration(p, flagName, shorthand, time.Duration(f.Int()), flagHelp(ft)) - continue - case net.IPMask: - if !flags.pflag { - return fmt.Errorf("cannot support net.IPMask field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*net.IPMask) - flags.ipMask(p, flagName, shorthand, *p, flagHelp(ft)) - continue - case net.IPNet: - if !flags.pflag { - return fmt.Errorf("cannot support net.IPNet field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*net.IPNet) - flags.ipNet(p, flagName, shorthand, *p, flagHelp(ft)) - continue - case net.IP: - if !flags.pflag { - return fmt.Errorf("cannot support net.IP field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*net.IP) - flags.ip(p, flagName, shorthand, *p, flagHelp(ft)) - continue - case []net.IP: - if !flags.pflag { - return fmt.Errorf("cannot support []net.IP field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*[]net.IP) - flags.ipSlice(p, flagName, shorthand, *p, flagHelp(ft)) - continue +var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem() + +func stringifyStringValue(v reflect.Value) string { + return v.Interface().(string) +} + +var zeroValue reflect.Value + +func stringifyStringerValue(v reflect.Value) string { + fn := v.MethodByName("String") + if fn != zeroValue { + vals := fn.Call(nil) + return vals[0].Interface().(string) + } + return "" +} + +func setMapFlags(flags *flagTracker, main interface{}, prefix string) error { + mainVal := reflect.ValueOf(main) + mainTyp := mainVal.Type() + keyTyp := mainTyp.Key() + var stringify func(reflect.Value) string + switch { + case keyTyp.Kind() == reflect.String: + stringify = stringifyStringValue + case keyTyp.Implements(stringerType): + stringify = stringifyStringerValue + default: + if prefix != "" { + return fmt.Errorf("map keys must be strings or implement fmt.Stringer at %s", prefix) } + return errors.New("map keys must be strings or implement fmt.Stringer") + } - // now check basic kinds - switch ft.Type.Kind() { - case reflect.String: - p := f.Addr().Interface().(*string) - flags.string(p, flagName, shorthand, f.String(), flagHelp(ft)) - case reflect.Bool: - p := f.Addr().Interface().(*bool) - flags.bool(p, flagName, shorthand, f.Bool(), flagHelp(ft)) - case reflect.Int: - p := f.Addr().Interface().(*int) - val := int(f.Int()) - flags.int(p, flagName, shorthand, val, flagHelp(ft)) - case reflect.Int64: - p := f.Addr().Interface().(*int64) - flags.int64(p, flagName, shorthand, f.Int(), flagHelp(ft)) - case reflect.Float64: - p := f.Addr().Interface().(*float64) - flags.float64(p, flagName, shorthand, f.Float(), flagHelp(ft)) - case reflect.Uint: - p := f.Addr().Interface().(*uint) - val := uint(f.Uint()) - flags.uint(p, flagName, shorthand, val, flagHelp(ft)) - case reflect.Uint64: - p := f.Addr().Interface().(*uint64) - flags.uint64(p, flagName, shorthand, f.Uint(), flagHelp(ft)) - case reflect.Slice: - if !flags.pflag { - return fmt.Errorf("cannot support slice field at '%v' with stdlib flag pkg.", flagName) - } - switch ft.Type.Elem().Kind() { - case reflect.String: - p := f.Addr().Interface().(*[]string) - flags.stringSlice(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Bool: - p := f.Addr().Interface().(*[]bool) - flags.boolSlice(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Int: - p := f.Addr().Interface().(*[]int) - flags.intSlice(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Uint: - p := f.Addr().Interface().(*[]uint) - flags.uintSlice(p, flagName, shorthand, *p, flagHelp(ft)) - default: - return fmt.Errorf("encountered unsupported slice type/kind: %#v at %s", f, prefix) - } - case reflect.Float32: - if !flags.pflag { - return fmt.Errorf("cannot support float32 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*float32) - flags.float32(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Int16: - if !flags.pflag { - return fmt.Errorf("cannot support int16 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*int16) - flags.int16(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Int32: - if !flags.pflag { - return fmt.Errorf("cannot support int32 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*int32) - flags.int32(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Uint16: - if !flags.pflag { - return fmt.Errorf("cannot support uint16 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*uint16) - flags.uint16(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Uint32: - if !flags.pflag { - return fmt.Errorf("cannot support uint32 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*uint32) - flags.uint32(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Uint8: - if !flags.pflag { - return fmt.Errorf("cannot support uint8 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*uint8) - flags.uint8(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Int8: - if !flags.pflag { - return fmt.Errorf("cannot support int8 field at '%v' with stdlib flag pkg.", flagName) - } - p := f.Addr().Interface().(*int8) - flags.int8(p, flagName, shorthand, *p, flagHelp(ft)) - case reflect.Struct: - var newprefix string - if prefix != "" { - newprefix = prefix + "." + flagName - } else { - newprefix = flagName - } - err := setFlags(flags, f.Addr().Interface(), newprefix) - if err != nil { - return err - } - default: - return fmt.Errorf("encountered unsupported field type/kind: %#v at %s", f, prefix) + // we'll be needing the map's target type + mapTyp := mainTyp.Elem() + + for _, key := range mainVal.MapKeys() { + entry := mapEntry{mapValue: mainVal, mapKey: key} + flagName := stringify(key) + if prefix != "" { + flagName = prefix + "." + flagName + } + val := mainVal.MapIndex(key) + entry.assignable = reflect.New(mapTyp) + // set its default value + entry.assignable.Elem().Set(val) + flags.mapped[flagName] = entry + err := setFlag(flags, entry.assignable.Elem(), mapTyp, nil, flagName, "") + + if err != nil { + return err } } return nil } -// flagName finds a field's flag name. It first looks for a "flag" tag, then +// structFlagName finds a field's flag name. It first looks for a "flag" tag, then // tries to use the "json" tag, and final falls back to using the name of the // field after running it through "downcaseAndDash". -func flagName(field reflect.StructField) (flagname string) { +func structFlagName(field reflect.StructField) (flagname string) { var ok bool if flagname, ok = field.Tag.Lookup("flag"); ok { return flagname @@ -289,13 +377,25 @@ func downcaseAndDash(input string) string { } // flagHelp gets the help text from a field's tag or returns an empty string. -func flagHelp(field reflect.StructField) (flaghelp string) { +func flagHelp(field *reflect.StructField) (flaghelp string) { + if field == nil { + return "" + } if flaghelp, ok := field.Tag.Lookup("help"); ok { return flaghelp } return "" } +// we can't get the address of an object which is a map entry, so we don't. +// instead, we stash the map, and the key, and a copy of the value which +// *is* addressable, and then when we're done, we can copy values back. +type mapEntry struct { + mapValue reflect.Value + mapKey reflect.Value + assignable reflect.Value +} + // flagTracker has methods for managing the set up of flags - it will utilize // pflag methods if flagger is a PFlagger, and set up short flags as well. type flagTracker struct { @@ -303,6 +403,14 @@ type flagTracker struct { pflagger PFlagger pflag bool shorts map[rune]struct{} + mapped map[string]mapEntry +} + +// cleanup performs the reassignment of things from the mapped list +func (f *flagTracker) cleanup() { + for _, v := range f.mapped { + v.mapValue.SetMapIndex(v.mapKey, v.assignable.Elem()) + } } // newFlagTracker sets up a flagTracker based on a flagger. @@ -312,6 +420,7 @@ func newFlagTracker(flagger Flagger) *flagTracker { shorts: map[rune]struct{}{ 'h': {}, // "h" is always used for help, so we can't set it. }, + mapped: make(map[string]mapEntry), } fTr.pflagger, fTr.pflag = flagger.(PFlagger) return fTr diff --git a/com_test.go b/com_test.go index 9a1233f..c07c50b 100644 --- a/com_test.go +++ b/com_test.go @@ -24,6 +24,22 @@ func TestZeroStruct(t *testing.T) { } } +func TestSetMap(t *testing.T) { + fs := pflag.NewFlagSet("myset", pflag.ContinueOnError) + mm := &test.MyMain{AMap: map[test.Stringable]int{1: 2}, Thing: "Don't Error"} + err := RunArgs(fs, mm, []string{"--a-map.1=3"}) + if err != nil { + t.Fatalf("parsing map assignment: %v", err) + } + if mm.AMap[1] != 3 { + t.Fatalf("map assignment didn't work: %#v", mm.AMap) + } + err = fs.Parse([]string{"-h"}) + if err != nil && err != pflag.ErrHelp { + t.Fatalf("parsing help flag: %v", err) + } +} + func TestNonStruct(t *testing.T) { var a int = 4 err := Run(&a) @@ -127,11 +143,11 @@ func TestRun(t *testing.T) { }, { main: &NonRunner{}, - err: "called 'Run' with something which doesn't implement the 'Run() error' method.", + err: "called 'Run' with something which doesn't implement the 'Run() error' method", }, { main: test.MyMain{}, - err: "calling Flags: value must be pointer to struct, but is struct", + err: "calling Flags: value must be map or pointer to struct, but is struct", }, } for i, tst := range tests { @@ -334,6 +350,21 @@ func TestRunMyMain(t *testing.T) { t.Fatalf("couldn't lookup 'subthing.a-bool'") } + if f := flags.Lookup("subthing.sub.b-bool"); f != nil { + if f.DefValue != "true" { + t.Fatalf("'subthing.sub.b-bool' not defined properly, got '%v'", f.DefValue) + } + } else { + t.Fatalf("couldn't lookup 'subthing.sub.b-bool'") + } + + if f := flags.Lookup("a-map.1"); f != nil { + if f.DefValue != "1" { + t.Fatalf("'a-map.1' not defined properly, got '%v'", f.DefValue) + } + } else { + t.Fatalf("couldn't lookup 'a-map.foo'") + } } func TestRunSimpleMain(t *testing.T) { diff --git a/test/test.go b/test/test.go index 154837b..3ab4fec 100644 --- a/test/test.go +++ b/test/test.go @@ -6,6 +6,12 @@ import ( "time" ) +type Stringable int + +func (s Stringable) String() string { + return fmt.Sprintf("%d", s) +} + // MyMain defines a variety of different field types and exercises various // different tags. type MyMain struct { @@ -38,6 +44,8 @@ type MyMain struct { AIntSlice []int AUintSlice []uint + AMap map[Stringable]int + SubThing SubThing `flag:"subthing"` } @@ -69,20 +77,37 @@ func NewMyMain() *MyMain { AIntSlice: []int{9, -8, 7}, AUintSlice: []uint{7, 8, 9}, + AMap: map[Stringable]int{ + 1: 1, + 2: 2, + }, + SubThing: SubThing{ SubBool: true, + Sub: SubberThing{ + SubBool2: true, + }, }, } } // SubThing exists to test nested structs. type SubThing struct { - SubBool bool `flag:"a-bool" help:"nested boolean flag"` + SubBool bool `flag:"a-bool" help:"nested boolean flag"` + Sub SubberThing `help:"further nested struct"` +} + +// SubberThing exists to test even more nested structs. +type SubberThing struct { + SubBool2 bool `flag:"b-bool" help:"more nested boolean flag"` } // Run implements the Runner interface. func (m *MyMain) Run() error { - return fmt.Errorf("mymain error") + if m == nil || m.Thing != "Don't Error" { + return fmt.Errorf("mymain error") + } + return nil } type SimpleMain struct {