wand/command.go

237 lines
4.7 KiB
Go
Raw Normal View History

2025-08-18 14:24:31 +02:00
package wand
import (
"errors"
"fmt"
"reflect"
"slices"
)
func command(name string, impl any, subcmds ...Cmd) Cmd {
return Cmd{
name: name,
impl: impl,
subcommands: subcmds,
}
}
func wrap(impl any) Cmd {
cmd, ok := impl.(Cmd)
if ok {
return cmd
}
return Command("", impl)
}
func validateFields(f []field) error {
mf := make(map[string]field)
for _, fi := range f {
if ef, ok := mf[fi.name]; ok && !compatibleTypes(fi.typ, ef.typ) {
return fmt.Errorf("duplicate fields with different types: %s", fi.name)
}
mf[fi.name] = fi
}
return nil
}
func validateParameter(t reflect.Type) error {
switch t.Kind() {
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.String:
return nil
case reflect.Pointer,
reflect.Slice:
t = unpack(t)
return validateParameter(t)
case reflect.Interface:
if t.NumMethod() > 0 {
return errors.New("'non-empty' interface parameter")
}
return nil
default:
return fmt.Errorf("unsupported parameter type: %v", t)
}
}
func validatePositional(t reflect.Type, min, max int) error {
p := positionalParameters(t)
ior, iow := ioParameters(p)
if len(ior) > 1 || len(iow) > 1 {
return errors.New("only zero or one reader and zero or one writer parameters is supported")
}
for i, pi := range p {
if slices.Contains(ior, i) || slices.Contains(iow, i) {
continue
}
if err := validateParameter(pi); err != nil {
return err
}
}
last := t.NumIn()-1
lastVariadic := t.IsVariadic() &&
!isStruct(t.In(last)) &&
!slices.Contains(ior, last) &&
!slices.Contains(iow, last)
fixedPositional := len(p) - len(ior) - len(iow)
if lastVariadic {
fixedPositional--
}
if min > 0 && min < fixedPositional {
return fmt.Errorf(
"minimum positional defined as %d but the implementation expects minimum %d fixed parameters",
min,
fixedPositional,
)
}
if min > 0 && min > fixedPositional && !lastVariadic {
return fmt.Errorf(
"minimum positional defined as %d but the implementation has only %d fixed parameters and no variadic parameter",
min,
fixedPositional,
)
}
if max > 0 && max < fixedPositional {
return fmt.Errorf(
"maximum positional defined as %d but the implementation expects minimum %d fixed parameters",
max,
fixedPositional,
)
}
if min > 0 && max > 0 && min > max {
return fmt.Errorf(
"minimum positional defined as larger then the maxmimum positional: %d > %d",
min,
max,
)
}
return nil
}
func validateImpl(cmd Cmd) error {
v := reflect.ValueOf(cmd.impl)
v = unpack(v)
t := v.Type()
if t.Kind() != reflect.Func {
return errors.New("command implementation not a function")
}
s := structParameters(t)
f := fields(s...)
if err := validateFields(f); err != nil {
return err
}
if err := validatePositional(t, cmd.minPositional, cmd.maxPositional); err != nil {
return err
}
return nil
}
func validateShortForms(cmd Cmd) error {
mf := mapFields(cmd.impl)
ms := make(map[string]string)
if len(cmd.shortForms)%2 != 0 {
return fmt.Errorf(
"undefined option short form: %s", cmd.shortForms[len(cmd.shortForms)-1],
)
}
for i := 0; i < len(cmd.shortForms); i += 2 {
fn := cmd.shortForms[i]
sf := cmd.shortForms[i+1]
if _, ok := mf[fn]; !ok {
return fmt.Errorf("undefined field: %s", fn)
}
if len(sf) != 1 && (sf[0] < 'a' || sf[0] > 'z') {
return fmt.Errorf("invalid short form: %s", sf)
}
if _, ok := mf[sf]; ok {
return fmt.Errorf("short form shadowing field name: %s", sf)
}
if lf, ok := ms[sf]; ok && lf != fn {
return fmt.Errorf("ambigous short form: %s", sf)
}
ms[sf] = fn
}
return nil
}
func validateCommand(cmd Cmd) error {
if cmd.isHelp {
return nil
}
if cmd.impl != nil {
if err := validateImpl(cmd); err != nil {
return fmt.Errorf("%s: %w", cmd.name, err)
}
}
if cmd.impl == nil && len(cmd.subcommands) == 0 {
return fmt.Errorf("empty command category: %s", cmd.name)
}
if cmd.impl != nil {
if err := validateShortForms(cmd); err != nil {
return fmt.Errorf("%s: %w", cmd.name, err)
}
}
var hasDefault bool
names := make(map[string]bool)
for _, s := range cmd.subcommands {
if s.name == "" {
return fmt.Errorf("unnamed subcommand of: %s", cmd.name)
}
if names[s.name] {
return fmt.Errorf("subcommand name conflict: %s", s.name)
}
names[s.name] = true
if err := validateCommand(s); err != nil {
return fmt.Errorf("%s: %w", s.name, err)
}
if s.isDefault && hasDefault {
return fmt.Errorf("multiple default subcommands for: %s", cmd.name)
}
if s.isDefault {
hasDefault = true
}
}
return nil
}