wand/reflect.go

299 lines
6.1 KiB
Go
Raw Normal View History

2025-08-18 14:24:31 +02:00
package wand
import (
"github.com/iancoleman/strcase"
"reflect"
"strconv"
"io"
)
type packedKind[T any] interface {
Kind() reflect.Kind
Elem() T
}
type field struct {
name string
typ reflect.Type
acceptsMultiple bool
}
var (
readerType = reflect.TypeFor[io.Reader]()
writerType = reflect.TypeFor[io.Writer]()
)
func pack(v reflect.Value, t reflect.Type) reflect.Value {
if v.Type() == t {
return v
}
if t.Kind() == reflect.Pointer {
pv := pack(v, t.Elem())
p := reflect.New(t.Elem())
p.Elem().Set(pv)
return p
}
iv := pack(v, t.Elem())
s := reflect.MakeSlice(t, 1, 1)
s.Index(0).Set(iv)
return s
}
func unpack[T packedKind[T]](p T) T {
switch p.Kind() {
case reflect.Pointer,
reflect.Slice:
return unpack(p.Elem())
default:
return p
}
}
func isReader(t reflect.Type) bool {
return unpack(t) == readerType
}
func isWriter(t reflect.Type) bool {
return unpack(t) == writerType
}
func isStruct(t reflect.Type) bool {
t = unpack(t)
return t.Kind() == reflect.Struct
}
func canScan(t reflect.Type, s string) bool {
switch t.Kind() {
case reflect.Bool:
_, err := strconv.ParseBool(s)
return err == nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
_, err := strconv.ParseInt(s, 10, int(t.Size())*8)
return err == nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
_, err := strconv.ParseUint(s, 10, int(t.Size())*8)
return err == nil
case reflect.Float32, reflect.Float64:
_, err := strconv.ParseFloat(s, int(t.Size())*8)
return err == nil
case reflect.String:
return true
default:
return false
}
}
func scan(t reflect.Type, s string) any {
p := reflect.New(t)
switch t.Kind() {
case reflect.Bool:
v, _ := strconv.ParseBool(s)
p.Elem().Set(reflect.ValueOf(v).Convert(t))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, _ := strconv.ParseInt(s, 10, int(t.Size())*8)
p.Elem().Set(reflect.ValueOf(v).Convert(t))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v, _ := strconv.ParseUint(s, 10, int(t.Size())*8)
p.Elem().Set(reflect.ValueOf(v).Convert(t))
case reflect.Float32, reflect.Float64:
v, _ := strconv.ParseFloat(s, int(t.Size())*8)
p.Elem().Set(reflect.ValueOf(v).Convert(t))
default:
p.Elem().Set(reflect.ValueOf(s).Convert(t))
}
return p.Elem().Interface()
}
func fields(s ...reflect.Type) []field {
if len(s) == 0 {
return nil
}
var (
anonFields []field
plainFields []field
)
for i := 0; i < s[0].NumField(); i++ {
sf := s[0].Field(i)
sft := sf.Type
am := acceptsMultiple(sft)
sft = unpack(sft)
sfn := sf.Name
sfn = strcase.ToKebab(sfn)
switch sft.Kind() {
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.String:
plainFields = append(plainFields, field{name: sfn, typ: sft, acceptsMultiple: am})
case reflect.Interface:
if sft.NumMethod() == 0 {
plainFields = append(plainFields, field{name: sfn, typ: sft, acceptsMultiple: am})
}
case reflect.Struct:
sff := fields(sft)
if sf.Anonymous {
anonFields = append(anonFields, sff...)
} else {
for i := range sff {
sff[i].name = sfn + "-" + sff[i].name
sff[i].acceptsMultiple = sff[i].acceptsMultiple || am
}
plainFields = append(plainFields, sff...)
}
}
}
mf := make(map[string]field)
for _, fi := range anonFields {
mf[fi.name] = fi
}
for _, fi := range plainFields {
mf[fi.name] = fi
}
var f []field
for _, fi := range mf {
f = append(f, fi)
}
return append(f, fields(s[1:]...)...)
}
func boolFields(f []field) []field {
var b []field
for _, fi := range f {
if fi.typ.Kind() == reflect.Bool {
b = append(b, fi)
}
}
return b
}
func mapFields(impl any) map[string][]field {
v := reflect.ValueOf(impl)
t := v.Type()
t = unpack(t)
s := structParameters(t)
f := fields(s...)
mf := make(map[string][]field)
for _, fi := range f {
mf[fi.name] = append(mf[fi.name], fi)
}
return mf
}
func filterParameters(t reflect.Type, f func(reflect.Type) bool) []reflect.Type {
var s []reflect.Type
for i := 0; i < t.NumIn(); i++ {
p := t.In(i)
p = unpack(p)
if f(p) {
s = append(s, p)
}
}
return s
}
func positionalParameters(t reflect.Type) []reflect.Type {
return filterParameters(t, func(p reflect.Type) bool {
return p.Kind() != reflect.Struct
})
}
func ioParameters(p []reflect.Type) ([]int, []int) {
var (
reader []int
writer []int
)
for i, pi := range p {
switch {
case isReader(pi):
reader = append(reader, i)
case isWriter(pi):
writer = append(writer, i)
}
}
return reader, writer
}
func structParameters(t reflect.Type) []reflect.Type {
return filterParameters(t, func(p reflect.Type) bool {
return p.Kind() == reflect.Struct
})
}
func compatibleTypes(t ...reflect.Type) bool {
if len(t) < 2 {
return true
}
t0 := t[0]
t1 := t[1]
switch t0.Kind() {
case reflect.Bool:
return t1.Kind() == reflect.Bool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch t1.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
default:
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
switch t1.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
default:
return false
}
case reflect.Float32, reflect.Float64:
switch t1.Kind() {
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
case reflect.String:
return t1.Kind() == reflect.String
case reflect.Interface:
return t1.Kind() == reflect.Interface && t0.NumMethod() == 0 && t1.NumMethod() == 0
default:
return false
}
}
func acceptsMultiple(t reflect.Type) bool {
if t.Kind() == reflect.Slice {
return true
}
switch t.Kind() {
case reflect.Pointer, reflect.Slice:
return acceptsMultiple(t.Elem())
default:
return false
}
}