package wand import ( "fmt" "github.com/iancoleman/strcase" "io" "reflect" "slices" "strconv" "strings" ) type packedKind[T any] interface { Kind() reflect.Kind Elem() T } type field struct { name string path []string typ reflect.Type acceptsMultiple bool } var ( readerType = reflect.TypeFor[io.Reader]() writerType = reflect.TypeFor[io.Writer]() ) func pack(v reflect.Value, t reflect.Type) reflect.Value { if v.Type() == t { return v } if t.Kind() == reflect.Pointer { pv := pack(v, t.Elem()) p := reflect.New(t.Elem()) p.Elem().Set(pv) return p } iv := pack(v, t.Elem()) s := reflect.MakeSlice(t, 1, 1) s.Index(0).Set(iv) return s } func unpack[T packedKind[T]](p T, kinds ...reflect.Kind) T { if len(kinds) == 0 { kinds = []reflect.Kind{reflect.Pointer, reflect.Slice} } if slices.Contains(kinds, p.Kind()) { return unpack(p.Elem(), kinds...) } return p } func isReader(t reflect.Type) bool { return unpack(t) == readerType } func isWriter(t reflect.Type) bool { return unpack(t) == writerType } func isStruct(t reflect.Type) bool { t = unpack(t) return t.Kind() == reflect.Struct } func parseInt(s string, byteSize int) (int64, error) { bitSize := byteSize * 8 switch { case strings.HasPrefix(s, "0b"): return strconv.ParseInt(s[2:], 2, bitSize) case strings.HasPrefix(s, "0x"): return strconv.ParseInt(s[2:], 16, bitSize) case strings.HasPrefix(s, "0"): return strconv.ParseInt(s[1:], 8, bitSize) default: return strconv.ParseInt(s, 10, bitSize) } } func parseUint(s string, byteSize int) (uint64, error) { bitSize := byteSize * 8 switch { case strings.HasPrefix(s, "0b"): return strconv.ParseUint(s[2:], 2, bitSize) case strings.HasPrefix(s, "0x"): return strconv.ParseUint(s[2:], 16, bitSize) case strings.HasPrefix(s, "0"): return strconv.ParseUint(s[1:], 8, bitSize) default: return strconv.ParseUint(s, 10, bitSize) } } func canScan(t reflect.Type, s string) bool { switch t.Kind() { case reflect.Bool: _, err := strconv.ParseBool(s) return err == nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: _, err := parseInt(s, int(t.Size())) return err == nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: _, err := parseUint(s, int(t.Size())) return err == nil case reflect.Float32, reflect.Float64: _, err := strconv.ParseFloat(s, int(t.Size())*8) return err == nil case reflect.String: return true default: return false } } func scan(t reflect.Type, s string) any { p := reflect.New(t) switch t.Kind() { case reflect.Bool: v, _ := strconv.ParseBool(s) p.Elem().Set(reflect.ValueOf(v).Convert(t)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: v, _ := parseInt(s, int(t.Size())) p.Elem().Set(reflect.ValueOf(v).Convert(t)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: v, _ := parseUint(s, int(t.Size())) p.Elem().Set(reflect.ValueOf(v).Convert(t)) case reflect.Float32, reflect.Float64: v, _ := strconv.ParseFloat(s, int(t.Size())*8) p.Elem().Set(reflect.ValueOf(v).Convert(t)) default: p.Elem().Set(reflect.ValueOf(s).Convert(t)) } return p.Elem().Interface() } func fieldsChecked(visited map[reflect.Type]bool, s ...reflect.Type) ([]field, error) { if len(s) == 0 { return nil, nil } var ( anonFields []field plainFields []field ) for i := 0; i < s[0].NumField(); i++ { sf := s[0].Field(i) sft := sf.Type am := acceptsMultiple(sft) sft = unpack(sft) sfn := sf.Name sfn = strcase.ToKebab(sfn) switch sft.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: plainFields = append(plainFields, field{ name: sfn, path: []string{sf.Name}, typ: sft, acceptsMultiple: am, }) case reflect.Interface: if sft.NumMethod() == 0 { plainFields = append(plainFields, field{ name: sfn, path: []string{sf.Name}, typ: sft, acceptsMultiple: am, }) } case reflect.Struct: if visited[sft] { return nil, fmt.Errorf("circular type definitions not allowed: %s", sft.Name()) } if visited == nil { visited = make(map[reflect.Type]bool) } visited[sft] = true sff, err := fieldsChecked(visited, sft) if err != nil { return nil, err } if sf.Anonymous { anonFields = append(anonFields, sff...) } else { for i := range sff { sff[i].name = sfn + "-" + sff[i].name sff[i].path = append([]string{sf.Name}, sff[i].path...) sff[i].acceptsMultiple = sff[i].acceptsMultiple || am } plainFields = append(plainFields, sff...) } } } mf := make(map[string]field) for _, fi := range anonFields { mf[fi.name] = fi } for _, fi := range plainFields { mf[fi.name] = fi } var f []field for _, fi := range mf { f = append(f, fi) } ff, err := fieldsChecked(visited, s[1:]...) if err != nil { return nil, err } return append(f, ff...), nil } func fields(s ...reflect.Type) []field { f, _ := fieldsChecked(nil, s...) return f } func boolFields(f []field) []field { var b []field for _, fi := range f { if fi.typ.Kind() == reflect.Bool { b = append(b, fi) } } return b } func mapFields(impl any) map[string][]field { v := reflect.ValueOf(impl) t := v.Type() t = unpack(t) s := structParameters(t) f := fields(s...) mf := make(map[string][]field) for _, fi := range f { mf[fi.name] = append(mf[fi.name], fi) } return mf } func filterParameters(t reflect.Type, f func(reflect.Type) bool) []reflect.Type { var s []reflect.Type for i := 0; i < t.NumIn(); i++ { p := t.In(i) p = unpack(p) if f(p) { s = append(s, p) } } return s } func positionalParameters(t reflect.Type) []reflect.Type { return filterParameters(t, func(p reflect.Type) bool { return p.Kind() != reflect.Struct }) } func ioParameters(p []reflect.Type) ([]int, []int) { var ( reader []int writer []int ) for i, pi := range p { switch { case isReader(pi): reader = append(reader, i) case isWriter(pi): writer = append(writer, i) } } return reader, writer } func structParameters(t reflect.Type) []reflect.Type { return filterParameters(t, func(p reflect.Type) bool { return p.Kind() == reflect.Struct }) } func compatibleTypes(t ...reflect.Type) bool { if len(t) < 2 { return true } t0 := t[0] t1 := t[1] switch t0.Kind() { case reflect.Bool: return t1.Kind() == reflect.Bool case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: switch t1.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return true default: return false } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: switch t1.Kind() { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return true default: return false } case reflect.Float32, reflect.Float64: switch t1.Kind() { case reflect.Float32, reflect.Float64: return true default: return false } case reflect.String: return t1.Kind() == reflect.String case reflect.Interface: return t1.Kind() == reflect.Interface && t0.NumMethod() == 0 && t1.NumMethod() == 0 default: return false } } func acceptsMultiple(t reflect.Type) bool { if t.Kind() == reflect.Slice { return true } switch t.Kind() { case reflect.Pointer, reflect.Slice: return acceptsMultiple(t.Elem()) default: return false } }