wand/reflect.go

367 lines
7.6 KiB
Go
Raw Normal View History

2025-08-18 14:24:31 +02:00
package wand
import (
2025-08-24 01:45:25 +02:00
"fmt"
2025-08-18 14:24:31 +02:00
"github.com/iancoleman/strcase"
2025-08-24 01:45:25 +02:00
"io"
2025-08-18 14:24:31 +02:00
"reflect"
2025-08-24 04:46:54 +02:00
"slices"
2025-08-18 14:24:31 +02:00
"strconv"
2025-08-24 01:45:25 +02:00
"strings"
2025-08-18 14:24:31 +02:00
)
type packedKind[T any] interface {
Kind() reflect.Kind
Elem() T
}
type field struct {
name string
2025-08-24 01:45:25 +02:00
path []string
2025-08-18 14:24:31 +02:00
typ reflect.Type
acceptsMultiple bool
}
var (
readerType = reflect.TypeFor[io.Reader]()
writerType = reflect.TypeFor[io.Writer]()
)
func pack(v reflect.Value, t reflect.Type) reflect.Value {
if v.Type() == t {
return v
}
if t.Kind() == reflect.Pointer {
pv := pack(v, t.Elem())
p := reflect.New(t.Elem())
p.Elem().Set(pv)
return p
}
iv := pack(v, t.Elem())
s := reflect.MakeSlice(t, 1, 1)
s.Index(0).Set(iv)
return s
}
2025-08-24 04:46:54 +02:00
func unpack[T packedKind[T]](p T, kinds ...reflect.Kind) T {
if len(kinds) == 0 {
kinds = []reflect.Kind{reflect.Pointer, reflect.Slice}
2025-08-18 14:24:31 +02:00
}
2025-08-24 04:46:54 +02:00
if slices.Contains(kinds, p.Kind()) {
return unpack(p.Elem(), kinds...)
}
return p
2025-08-18 14:24:31 +02:00
}
func isReader(t reflect.Type) bool {
return unpack(t) == readerType
}
func isWriter(t reflect.Type) bool {
return unpack(t) == writerType
}
func isStruct(t reflect.Type) bool {
t = unpack(t)
return t.Kind() == reflect.Struct
}
2025-08-24 01:45:25 +02:00
func parseInt(s string, byteSize int) (int64, error) {
bitSize := byteSize * 8
switch {
case strings.HasPrefix(s, "0b"):
return strconv.ParseInt(s[2:], 2, bitSize)
case strings.HasPrefix(s, "0x"):
return strconv.ParseInt(s[2:], 16, bitSize)
case strings.HasPrefix(s, "0"):
return strconv.ParseInt(s[1:], 8, bitSize)
default:
return strconv.ParseInt(s[2:], 2, byteSize*8)
}
}
func parseUint(s string, byteSize int) (uint64, error) {
bitSize := byteSize * 8
switch {
case strings.HasPrefix(s, "0b"):
return strconv.ParseUint(s[2:], 2, bitSize)
case strings.HasPrefix(s, "0x"):
return strconv.ParseUint(s[2:], 16, bitSize)
case strings.HasPrefix(s, "0"):
return strconv.ParseUint(s[1:], 8, bitSize)
default:
return strconv.ParseUint(s[2:], 2, bitSize)
}
}
2025-08-18 14:24:31 +02:00
func canScan(t reflect.Type, s string) bool {
switch t.Kind() {
case reflect.Bool:
_, err := strconv.ParseBool(s)
return err == nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
2025-08-24 01:45:25 +02:00
_, err := parseInt(s, int(t.Size()))
2025-08-18 14:24:31 +02:00
return err == nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
2025-08-24 01:45:25 +02:00
_, err := parseUint(s, int(t.Size()))
2025-08-18 14:24:31 +02:00
return err == nil
case reflect.Float32, reflect.Float64:
_, err := strconv.ParseFloat(s, int(t.Size())*8)
return err == nil
case reflect.String:
return true
default:
return false
}
}
func scan(t reflect.Type, s string) any {
p := reflect.New(t)
switch t.Kind() {
case reflect.Bool:
v, _ := strconv.ParseBool(s)
p.Elem().Set(reflect.ValueOf(v).Convert(t))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
2025-08-24 01:45:25 +02:00
v, _ := parseInt(s, int(t.Size()))
2025-08-18 14:24:31 +02:00
p.Elem().Set(reflect.ValueOf(v).Convert(t))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
2025-08-24 01:45:25 +02:00
v, _ := parseUint(s, int(t.Size()))
2025-08-18 14:24:31 +02:00
p.Elem().Set(reflect.ValueOf(v).Convert(t))
case reflect.Float32, reflect.Float64:
v, _ := strconv.ParseFloat(s, int(t.Size())*8)
p.Elem().Set(reflect.ValueOf(v).Convert(t))
default:
p.Elem().Set(reflect.ValueOf(s).Convert(t))
}
return p.Elem().Interface()
}
2025-08-24 01:45:25 +02:00
func fieldsChecked(visited map[reflect.Type]bool, s ...reflect.Type) ([]field, error) {
2025-08-18 14:24:31 +02:00
if len(s) == 0 {
2025-08-24 01:45:25 +02:00
return nil, nil
2025-08-18 14:24:31 +02:00
}
var (
anonFields []field
plainFields []field
)
for i := 0; i < s[0].NumField(); i++ {
sf := s[0].Field(i)
sft := sf.Type
am := acceptsMultiple(sft)
sft = unpack(sft)
sfn := sf.Name
sfn = strcase.ToKebab(sfn)
switch sft.Kind() {
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.String:
2025-08-24 01:45:25 +02:00
plainFields = append(plainFields, field{
name: sfn,
path: []string{sf.Name},
typ: sft,
acceptsMultiple: am,
})
2025-08-18 14:24:31 +02:00
case reflect.Interface:
if sft.NumMethod() == 0 {
2025-08-24 01:45:25 +02:00
plainFields = append(plainFields, field{
name: sfn,
path: []string{sf.Name},
typ: sft,
acceptsMultiple: am,
})
2025-08-18 14:24:31 +02:00
}
case reflect.Struct:
2025-08-24 01:45:25 +02:00
if visited[sft] {
return nil, fmt.Errorf("circular type definitions not allowed: %s", sft.Name())
}
if visited == nil {
visited = make(map[reflect.Type]bool)
}
visited[sft] = true
sff, err := fieldsChecked(visited, sft)
if err != nil {
return nil, err
}
2025-08-18 14:24:31 +02:00
if sf.Anonymous {
anonFields = append(anonFields, sff...)
} else {
for i := range sff {
sff[i].name = sfn + "-" + sff[i].name
2025-08-24 01:45:25 +02:00
sff[i].path = append([]string{sf.Name}, sff[i].path...)
2025-08-18 14:24:31 +02:00
sff[i].acceptsMultiple = sff[i].acceptsMultiple || am
}
plainFields = append(plainFields, sff...)
}
}
}
mf := make(map[string]field)
for _, fi := range anonFields {
mf[fi.name] = fi
}
for _, fi := range plainFields {
mf[fi.name] = fi
}
var f []field
for _, fi := range mf {
f = append(f, fi)
}
2025-08-24 01:45:25 +02:00
ff, err := fieldsChecked(visited, s[1:]...)
if err != nil {
return nil, err
}
return append(f, ff...), nil
}
func fields(s ...reflect.Type) []field {
f, _ := fieldsChecked(nil, s...)
return f
2025-08-18 14:24:31 +02:00
}
func boolFields(f []field) []field {
var b []field
for _, fi := range f {
if fi.typ.Kind() == reflect.Bool {
b = append(b, fi)
}
}
return b
}
func mapFields(impl any) map[string][]field {
v := reflect.ValueOf(impl)
t := v.Type()
t = unpack(t)
s := structParameters(t)
f := fields(s...)
mf := make(map[string][]field)
for _, fi := range f {
mf[fi.name] = append(mf[fi.name], fi)
}
return mf
}
func filterParameters(t reflect.Type, f func(reflect.Type) bool) []reflect.Type {
var s []reflect.Type
for i := 0; i < t.NumIn(); i++ {
p := t.In(i)
p = unpack(p)
if f(p) {
s = append(s, p)
}
}
return s
}
func positionalParameters(t reflect.Type) []reflect.Type {
return filterParameters(t, func(p reflect.Type) bool {
return p.Kind() != reflect.Struct
})
}
func ioParameters(p []reflect.Type) ([]int, []int) {
var (
reader []int
writer []int
)
for i, pi := range p {
switch {
case isReader(pi):
reader = append(reader, i)
case isWriter(pi):
writer = append(writer, i)
}
}
return reader, writer
}
func structParameters(t reflect.Type) []reflect.Type {
return filterParameters(t, func(p reflect.Type) bool {
return p.Kind() == reflect.Struct
})
}
func compatibleTypes(t ...reflect.Type) bool {
if len(t) < 2 {
return true
}
t0 := t[0]
t1 := t[1]
switch t0.Kind() {
case reflect.Bool:
return t1.Kind() == reflect.Bool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch t1.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
default:
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
switch t1.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
default:
return false
}
case reflect.Float32, reflect.Float64:
switch t1.Kind() {
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
case reflect.String:
return t1.Kind() == reflect.String
case reflect.Interface:
return t1.Kind() == reflect.Interface && t0.NumMethod() == 0 && t1.NumMethod() == 0
default:
return false
}
}
func acceptsMultiple(t reflect.Type) bool {
if t.Kind() == reflect.Slice {
return true
}
switch t.Kind() {
case reflect.Pointer, reflect.Slice:
return acceptsMultiple(t.Elem())
default:
return false
}
}