wand/command.go

340 lines
7.0 KiB
Go
Raw Normal View History

2025-08-18 14:24:31 +02:00
package wand
import (
"errors"
"fmt"
"reflect"
2025-08-26 03:21:35 +02:00
"regexp"
2025-09-01 02:07:48 +02:00
"code.squareroundforest.org/arpio/bind"
2025-08-18 14:24:31 +02:00
)
2025-08-26 03:21:35 +02:00
var commandNameExpression = regexp.MustCompile("^[a-zA-Z_][a-zA-Z_0-9]*$")
2025-08-18 14:24:31 +02:00
func wrap(impl any) Cmd {
cmd, ok := impl.(Cmd)
if ok {
return cmd
}
return Command("", impl)
}
2025-09-01 02:07:48 +02:00
func validateFields(f []bind.Field, conf Config) error {
2025-08-24 01:45:25 +02:00
hasConfigFromOption := hasConfigFromOption(conf)
2025-09-01 02:07:48 +02:00
mf := make(map[string]bind.Field)
2025-08-18 14:24:31 +02:00
for _, fi := range f {
2025-09-01 02:07:48 +02:00
if ef, ok := mf[fi.Name()]; ok && !compatibleTypes(fi.Type(), ef.Type()) {
return fmt.Errorf("duplicate fields with different types: %s", fi.Name())
2025-08-18 14:24:31 +02:00
}
2025-09-01 02:07:48 +02:00
if hasConfigFromOption && fi.Name() == "config" {
2025-08-24 01:45:25 +02:00
return errors.New("option reserved for config file shadowed by struct field")
}
2025-09-01 02:07:48 +02:00
mf[fi.Name()] = fi
2025-08-18 14:24:31 +02:00
}
return nil
}
2025-09-01 02:07:48 +02:00
func validatePositional(p []reflect.Type, variadic bool, min, max int) error {
fixedPositional := len(p)
if variadic {
2025-08-18 14:24:31 +02:00
fixedPositional--
}
if min > 0 && min < fixedPositional {
return fmt.Errorf(
2025-09-01 02:07:48 +02:00
"minimum positional arguments defined as %d but the implementation expects minimum %d fixed parameters",
2025-08-18 14:24:31 +02:00
min,
fixedPositional,
)
}
2025-09-01 02:07:48 +02:00
if min > 0 && min > fixedPositional && !variadic {
2025-08-18 14:24:31 +02:00
return fmt.Errorf(
2025-09-01 02:07:48 +02:00
"minimum positional arguments defined as %d but the implementation has only %d fixed parameters and no variadic parameter",
2025-08-18 14:24:31 +02:00
min,
fixedPositional,
)
}
if max > 0 && max < fixedPositional {
return fmt.Errorf(
2025-09-01 02:07:48 +02:00
"maximum positional arguments defined as %d but the implementation expects minimum %d fixed parameters",
2025-08-18 14:24:31 +02:00
max,
fixedPositional,
)
}
if min > 0 && max > 0 && min > max {
return fmt.Errorf(
2025-09-01 02:07:48 +02:00
"minimum positional arguments defined as larger then the maxmimum positional: %d > %d",
2025-08-18 14:24:31 +02:00
min,
max,
)
}
return nil
}
2025-09-01 02:07:48 +02:00
func validateIOParameters(ior, iow []reflect.Type) error {
if len(ior) > 1 || len(iow) > 1 {
return errors.New("only zero or one reader and zero or one writer parameter is supported")
}
return nil
}
2025-08-24 01:45:25 +02:00
func validateImpl(cmd Cmd, conf Config) error {
2025-09-01 02:07:48 +02:00
if !isFunc(cmd.impl) {
return errors.New("command implementation must be a function or a pointer to a function")
2025-08-18 14:24:31 +02:00
}
2025-09-01 02:07:48 +02:00
p := parameters(cmd.impl)
for _, pi := range p {
if !isReader(pi) && !isWriter(pi) && !bindable(pi) {
return fmt.Errorf("unsupported parameter type: %s", pi.Name())
}
2025-08-24 01:45:25 +02:00
}
2025-09-01 02:07:48 +02:00
f := fields(cmd.impl)
2025-08-24 01:45:25 +02:00
if err := validateFields(f, conf); err != nil {
2025-08-18 14:24:31 +02:00
return err
}
2025-09-01 02:07:48 +02:00
pos, variadic := positional(cmd.impl)
if err := validatePositional(pos, variadic, cmd.minPositional, cmd.maxPositional); err != nil {
2025-08-18 14:24:31 +02:00
return err
}
2025-09-01 02:07:48 +02:00
ior, iow := ioParameters(cmd.impl)
if err := validateIOParameters(ior, iow); err != nil {
return err
2025-08-18 14:24:31 +02:00
}
return nil
}
2025-09-01 02:07:48 +02:00
func validateCommandTree(cmd Cmd, conf Config) error {
2025-08-18 14:24:31 +02:00
if cmd.isHelp {
return nil
}
2025-08-26 03:21:35 +02:00
if cmd.version != "" {
return nil
}
2025-08-26 14:12:18 +02:00
if cmd.impl == nil && !cmd.group {
return fmt.Errorf("command does not have an implementation: %s", cmd.name)
}
2025-08-18 14:24:31 +02:00
if cmd.impl == nil && len(cmd.subcommands) == 0 {
return fmt.Errorf("empty command category: %s", cmd.name)
}
if cmd.impl != nil {
2025-09-01 02:07:48 +02:00
if err := validateImpl(cmd, conf); err != nil {
2025-08-18 14:24:31 +02:00
return fmt.Errorf("%s: %w", cmd.name, err)
}
}
var hasDefault bool
names := make(map[string]bool)
for _, s := range cmd.subcommands {
2025-09-01 02:07:48 +02:00
if !commandNameExpression.MatchString(s.name) {
return fmt.Errorf("command name is not a valid symbol: '%s'", cmd.name)
2025-08-18 14:24:31 +02:00
}
if names[s.name] {
return fmt.Errorf("subcommand name conflict: %s", s.name)
}
names[s.name] = true
2025-08-24 01:45:25 +02:00
if s.isDefault && cmd.impl != nil {
return fmt.Errorf(
2025-09-01 02:07:48 +02:00
"default subcommand defined for a command that has an explicit implementation: %s, %s",
2025-08-24 01:45:25 +02:00
cmd.name,
s.name,
)
}
2025-08-18 14:24:31 +02:00
if s.isDefault && hasDefault {
return fmt.Errorf("multiple default subcommands for: %s", cmd.name)
}
if s.isDefault {
hasDefault = true
}
2025-09-01 02:07:48 +02:00
if err := validateCommandTree(s, conf); err != nil {
return fmt.Errorf("%s: %w", s.name, err)
}
2025-08-18 14:24:31 +02:00
}
return nil
}
2025-08-24 01:45:25 +02:00
2025-09-01 02:07:48 +02:00
func checkShortFormDefinition(existing map[string]string, short, long string) error {
e, ok := existing[short]
if !ok {
return nil
}
if e == long {
return nil
}
return fmt.Errorf(
"using the same short form for different options is not allowed: %s->%s, %s->%s",
short, long, short, e,
)
}
func collectMappedShortForms(to, from map[string]string) (map[string]string, error) {
for s, l := range from {
if err := checkShortFormDefinition(to, s, l); err != nil {
return nil, err
}
if to == nil {
to = make(map[string]string)
}
to[s] = l
}
return to, nil
}
func validateShortFormsTree(cmd Cmd) (map[string]string, map[string]string, error) {
var mapped, unmapped map[string]string
2025-08-24 01:45:25 +02:00
for _, sc := range cmd.subcommands {
2025-09-01 02:07:48 +02:00
m, um, err := validateShortFormsTree(sc)
if err != nil {
return nil, nil, err
}
if mapped, err = collectMappedShortForms(mapped, m); err != nil {
return nil, nil, err
}
if unmapped, err = collectMappedShortForms(unmapped, um); err != nil {
return nil, nil, err
}
}
if len(cmd.shortForms) % 2 != 0 {
return nil, nil, fmt.Errorf("unassigned short form: %s", cmd.shortForms[len(cmd.shortForms) - 1])
2025-08-24 01:45:25 +02:00
}
2025-09-01 02:07:48 +02:00
mf := mapFields(cmd.impl)
2025-08-24 01:45:25 +02:00
for i := 0; i < len(cmd.shortForms); i += 2 {
2025-09-01 02:07:48 +02:00
s, l := cmd.shortForms[i], cmd.shortForms[i + 1]
r := []rune(s)
if len(r) != 1 || r[0] < 'a' || r[0] > 'z' {
return nil, nil, fmt.Errorf("invalid short form: %s", s)
}
if err := checkShortFormDefinition(mapped, s, l); err != nil {
return nil, nil, err
}
if err := checkShortFormDefinition(unmapped, s, l); err != nil {
return nil, nil, err
}
_, hasField := mf[l]
_, isMapped := mapped[s]
if !hasField && !isMapped {
if unmapped == nil {
unmapped = make(map[string]string)
}
unmapped[s] = l
continue
}
if mapped == nil {
mapped = make(map[string]string)
}
delete(unmapped, s)
mapped[s] = l
2025-08-24 01:45:25 +02:00
}
2025-09-01 02:07:48 +02:00
return mapped, unmapped, nil
}
func validateShortForms(cmd Cmd) error {
_, um, err := validateShortFormsTree(cmd)
if err != nil {
return err
}
if len(um) != 0 {
return errors.New("unmapped short forms")
}
return nil
2025-08-24 01:45:25 +02:00
}
func validateCommand(cmd Cmd, conf Config) error {
2025-09-01 02:07:48 +02:00
if err := validateCommandTree(cmd, conf); err != nil {
2025-08-24 01:45:25 +02:00
return err
}
2025-09-01 02:07:48 +02:00
if err := validateShortForms(cmd); err != nil {
return err
2025-08-24 01:45:25 +02:00
}
return nil
}
2025-09-01 02:07:48 +02:00
func insertHelpOption(names []string) []string {
for _, n := range names {
if n == "help" {
return names
}
}
return append(names, "help")
}
func insertHelpShortForm(shortForms []string) []string {
for _, sf := range shortForms {
if sf == "h" {
return shortForms
}
}
return append(shortForms, "h")
}
func boolOptions(cmd Cmd) []string {
f := fields(cmd.impl)
b := boolFields(f)
var n []string
for _, fi := range b {
n = append(n, fi.Name())
}
n = insertHelpOption(n)
sfm := make(map[string][]string)
for i := 0; i < len(cmd.shortForms); i += 2 {
s, l := cmd.shortForms[i], cmd.shortForms[i+1]
sfm[l] = append(sfm[l], s)
}
var sf []string
for _, ni := range n {
if sn, ok := sfm[ni]; ok {
sf = append(sf, sn...)
}
}
sf = insertHelpShortForm(sf)
return append(n, sf...)
}