package wand import ( "errors" "fmt" "reflect" "slices" ) func command(name string, impl any, subcmds ...Cmd) Cmd { return Cmd{ name: name, impl: impl, subcommands: subcmds, } } func wrap(impl any) Cmd { cmd, ok := impl.(Cmd) if ok { return cmd } return Command("", impl) } func validateFields(f []field) error { mf := make(map[string]field) for _, fi := range f { if ef, ok := mf[fi.name]; ok && !compatibleTypes(fi.typ, ef.typ) { return fmt.Errorf("duplicate fields with different types: %s", fi.name) } mf[fi.name] = fi } return nil } func validateParameter(t reflect.Type) error { switch t.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: return nil case reflect.Pointer, reflect.Slice: t = unpack(t) return validateParameter(t) case reflect.Interface: if t.NumMethod() > 0 { return errors.New("'non-empty' interface parameter") } return nil default: return fmt.Errorf("unsupported parameter type: %v", t) } } func validatePositional(t reflect.Type, min, max int) error { p := positionalParameters(t) ior, iow := ioParameters(p) if len(ior) > 1 || len(iow) > 1 { return errors.New("only zero or one reader and zero or one writer parameters is supported") } for i, pi := range p { if slices.Contains(ior, i) || slices.Contains(iow, i) { continue } if err := validateParameter(pi); err != nil { return err } } last := t.NumIn()-1 lastVariadic := t.IsVariadic() && !isStruct(t.In(last)) && !slices.Contains(ior, last) && !slices.Contains(iow, last) fixedPositional := len(p) - len(ior) - len(iow) if lastVariadic { fixedPositional-- } if min > 0 && min < fixedPositional { return fmt.Errorf( "minimum positional defined as %d but the implementation expects minimum %d fixed parameters", min, fixedPositional, ) } if min > 0 && min > fixedPositional && !lastVariadic { return fmt.Errorf( "minimum positional defined as %d but the implementation has only %d fixed parameters and no variadic parameter", min, fixedPositional, ) } if max > 0 && max < fixedPositional { return fmt.Errorf( "maximum positional defined as %d but the implementation expects minimum %d fixed parameters", max, fixedPositional, ) } if min > 0 && max > 0 && min > max { return fmt.Errorf( "minimum positional defined as larger then the maxmimum positional: %d > %d", min, max, ) } return nil } func validateImpl(cmd Cmd) error { v := reflect.ValueOf(cmd.impl) v = unpack(v) t := v.Type() if t.Kind() != reflect.Func { return errors.New("command implementation not a function") } s := structParameters(t) f := fields(s...) if err := validateFields(f); err != nil { return err } if err := validatePositional(t, cmd.minPositional, cmd.maxPositional); err != nil { return err } return nil } func validateShortForms(cmd Cmd) error { mf := mapFields(cmd.impl) ms := make(map[string]string) if len(cmd.shortForms)%2 != 0 { return fmt.Errorf( "undefined option short form: %s", cmd.shortForms[len(cmd.shortForms)-1], ) } for i := 0; i < len(cmd.shortForms); i += 2 { fn := cmd.shortForms[i] sf := cmd.shortForms[i+1] if _, ok := mf[fn]; !ok { return fmt.Errorf("undefined field: %s", fn) } if len(sf) != 1 && (sf[0] < 'a' || sf[0] > 'z') { return fmt.Errorf("invalid short form: %s", sf) } if _, ok := mf[sf]; ok { return fmt.Errorf("short form shadowing field name: %s", sf) } if lf, ok := ms[sf]; ok && lf != fn { return fmt.Errorf("ambigous short form: %s", sf) } ms[sf] = fn } return nil } func validateCommand(cmd Cmd) error { if cmd.isHelp { return nil } if cmd.impl != nil { if err := validateImpl(cmd); err != nil { return fmt.Errorf("%s: %w", cmd.name, err) } } if cmd.impl == nil && len(cmd.subcommands) == 0 { return fmt.Errorf("empty command category: %s", cmd.name) } if cmd.impl != nil { if err := validateShortForms(cmd); err != nil { return fmt.Errorf("%s: %w", cmd.name, err) } } var hasDefault bool names := make(map[string]bool) for _, s := range cmd.subcommands { if s.name == "" { return fmt.Errorf("unnamed subcommand of: %s", cmd.name) } if names[s.name] { return fmt.Errorf("subcommand name conflict: %s", s.name) } names[s.name] = true if err := validateCommand(s); err != nil { return fmt.Errorf("%s: %w", s.name, err) } if s.isDefault && hasDefault { return fmt.Errorf("multiple default subcommands for: %s", cmd.name) } if s.isDefault { hasDefault = true } } return nil }