From b0ff605b252d40c638ad447320dba2d69620af2c Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Sun, 31 Aug 2025 00:40:47 +0200 Subject: [PATCH] field binding --- field.go | 392 +++++++++++++++++++++++++++++++++++++++-------- field_test.go | 409 +++++++++++++++++++++++++++++++++++++++++++++++++- lib.go | 29 ++-- notes.txt | 10 +- scalar.go | 8 +- type.go | 77 +++++++--- 6 files changed, 822 insertions(+), 103 deletions(-) diff --git a/field.go b/field.go index 7834d3b..4dfbf4a 100644 --- a/field.go +++ b/field.go @@ -5,12 +5,48 @@ import ( "github.com/iancoleman/strcase" "reflect" "unicode" + "strings" ) +func pathString(f Field) string { + return strings.Join(f.path, ":") +} + +func nameFromPath(p []string) string { + var pp []string + for _, pi := range p { + pp = append(pp, strcase.ToKebab(pi)) + } + + return strings.Join(pp, "-") +} + func exported(name string) bool { return unicode.IsUpper([]rune(name)[0]) } +func filterFields(predicate func(Field) bool, f []Field) ([]Field, []Field) { + var yes, no []Field + for _, fi := range f { + if predicate(fi) { + yes = append(yes, fi) + continue + } + + no = append(no, fi) + } + + return yes, no +} + +func fieldHasCircRef(f Field) bool { + return hasCircularReference(reflect.ValueOf(f.Value())) +} + +func hasPath(f Field) bool { + return len(f.path) > 0 +} + func fields(t reflect.Type) []Field { t = unpackType(t, pointer) list := t.Kind() == reflect.Slice @@ -65,62 +101,41 @@ func fields(t reflect.Type) []Field { return f } -func fieldFromValue(name string, v reflect.Value) Field { +func fieldFromValue(v reflect.Value) Field { var fi Field - fi.name = strcase.ToKebab(name) - fi.path = []string{name} fi.value = v.Interface() fi.isBool = v.Kind() == reflect.Bool return fi } -func scalarMapFields(v reflect.Value) []Field { - var f []Field - for _, key := range v.MapKeys() { - value := v.MapIndex(key) - name := key.Interface().(string) - if value.Kind() == reflect.Slice { - for i := 0; i < value.Len(); i++ { - fi := fieldFromValue(name, value.Index(i)) - fi.free = true - f = append(f, fi) - } - } else { - fi := fieldFromValue(name, value) - fi.free = true - f = append(f, fi) - } - } - - return f -} - func prependFieldName(name string, f []Field) { for i := range f { + if f[i].name == "" { + f[i].name = strcase.ToKebab(name) + f[i].path = []string{name} + continue + } + f[i].name = fmt.Sprintf("%s-%s", strcase.ToKebab(name), f[i].name) f[i].path = append([]string{name}, f[i].path...) } } -func listFieldValues(fieldName string, l reflect.Value) []Field { +func freeFields(f []Field) { + for i := range f { + f[i].free = true + } +} + +func scalarMapFields(v reflect.Value) []Field { var f []Field - for i := 0; i < l.Len(); i++ { - item := l.Index(i) - item = unpackValue(item, pointer|iface|anytype) - switch { - case isScalar(item.Type()): - f = append(f, fieldFromValue(fieldName, item)) - case item.Kind() == reflect.Slice: - f = append(f, listFieldValues(fieldName, item)...) - case isScalarMap(item.Type()): - mf := fieldValues(item) - prependFieldName(fieldName, mf) - f = append(f, mf...) - case item.Kind() == reflect.Struct: - sf := fieldValues(item) - prependFieldName(fieldName, sf) - f = append(f, sf...) - } + for _, key := range v.MapKeys() { + name := key.Interface().(string) + value := v.MapIndex(key) + fk := fieldValues(value) + prependFieldName(name, fk) + freeFields(fk) + f = append(f, fk...) } return f @@ -129,6 +144,11 @@ func listFieldValues(fieldName string, l reflect.Value) []Field { func fieldValues(v reflect.Value) []Field { var f []Field v = unpackValue(v, pointer|anytype|iface) + t := v.Type() + if isScalar(t) { + return []Field{fieldFromValue(v)} + } + if v.Kind() == reflect.Slice { for i := 0; i < v.Len(); i++ { f = append(f, fieldValues(v.Index(i))...) @@ -137,7 +157,6 @@ func fieldValues(v reflect.Value) []Field { return f } - t := v.Type() if isScalarMap(t) { return scalarMapFields(v) } @@ -155,30 +174,234 @@ func fieldValues(v reflect.Value) []Field { vfi := v.Field(i) vfi = unpackValue(vfi, pointer|iface|anytype) switch { - case isScalar(vfi.Type()): - f = append(f, fieldFromValue(tfi.Name, vfi)) - case vfi.Kind() == reflect.Slice: - f = append(f, listFieldValues(tfi.Name, vfi)...) - case isScalarMap(vfi.Type()): - mf := fieldValues(vfi) - prependFieldName(tfi.Name, mf) - f = append(f, mf...) - case vfi.Kind() == reflect.Struct: - sf := fieldValues(vfi) - if !tfi.Anonymous { - prependFieldName(tfi.Name, sf) - } - - f = append(f, sf...) + case vfi.Kind() == reflect.Struct && tfi.Anonymous: + ff := fieldValues(vfi) + f = append(f, ff...) + default: + ff := fieldValues(vfi) + prependFieldName(tfi.Name, ff) + f = append(f, ff...) } } return f } +func takeFieldValues(f []Field) []any { + var v []any + for _, fi := range f { + v = append(v, fi.Value()) + } + + return v +} + +func bindScalarField(receiver reflect.Value, values []Field) bool { + v := takeFieldValues(values) + if !receiver.CanSet() { + return bindScalar(receiver, v) + } + + rv, ok := allocate(receiver.Type(), len(v)) + if !ok { + return false + } + + if ok = bindScalar(rv, v); ok { + receiver.Set(rv) + } + + return ok +} + +func bindListField(receiver reflect.Value, values []Field) bool { + if receiver.Len() < len(values) && !receiver.CanSet() { + return false + } + + if receiver.Len() < len(values) { + newList, ok := allocate(receiver.Type(), len(values)) + if !ok { + return false + } + + reflect.Copy(newList, receiver) + receiver.Set(newList) + } + + for i := range values { + if !bindField(receiver.Index(i), values[i:i+1]) { + return false + } + } + + return true +} + +func bindMapField(receiver reflect.Value, values []Field) bool { + for _, v := range values { + if len(v.path) > 1 { + return false + } + } + + if receiver.IsZero() && !receiver.CanSet() { + return false + } + + var key string + fp, nfp := filterFields(hasPath, values) + if len(fp) > 0 { + key = fp[0].path[0] + } else { + key = nfp[0].name + } + + v := takeFieldValues(values) + t := receiver.Type() + kt := t.Key() + vt := t.Elem() + kv, ok := bindScalarCreate(kt, []any{key}) + if !ok { + return false + } + + vv, ok := bindScalarCreate(vt, v) + if !ok { + return false + } + + if receiver.IsZero() { + rv, ok := allocate(receiver.Type(), 1) + if !ok { + return false + } + + receiver.Set(rv) + } + + receiver.SetMapIndex(kv, vv) + return true +} + +func trimNameAndPath(name string, values []Field) []Field { + name = strcase.ToKebab(name) + v := make([]Field, len(values)) + copy(v, values) + for i := range v { + if len(v[i].path) > 0 { + v[i].path = v[i].path[1:] + } + + if v[i].name == name { + v[i].name = "" + } + + if strings.HasPrefix(v[i].name, fmt.Sprintf("%s-", name)) { + v[i].name = v[i].name[len(name) + 1:] + } + } + + return v +} + +func bindStructField(receiver reflect.Value, values []Field) bool { + var name, pathName string + fp, nfp := filterFields(hasPath, values) + if len(fp) > 0 { + pathName = fp[0].path[0] + } + + if len(nfp) > 0 { + name = nfp[0].name + } + + t := receiver.Type() + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + if sf.Anonymous { + continue + } + + if sf.Name == pathName { + values = trimNameAndPath(pathName, values) + return bindField(receiver.Field(i), values) + } + + sfn := strcase.ToKebab(sf.Name) + if name == sfn || + strings.HasPrefix(name, fmt.Sprintf("%s-", sfn)) { + values = trimNameAndPath(sfn, values) + return bindField(receiver.Field(i), values) + } + } + + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + if !sf.Anonymous { + continue + } + + if bindField(receiver.Field(i), values) { + return true + } + } + + return false +} + +func bindField(receiver reflect.Value, values []Field) bool { + if values[0].name == "" && len(values[0].path) == 0 { + return bindScalarField(receiver, values) + } + + listReceiver := unpackValue(receiver, pointer|iface|anytype) + if listReceiver.Kind() == reflect.Slice { + return bindListField(listReceiver, values) + } + + fieldReceiver := unpackValue(receiver, pointer|slice|iface) + if isScalarMap(fieldReceiver.Type()) { + return bindMapField(fieldReceiver, values) + } + + if fieldReceiver.Kind() == reflect.Struct { + return bindStructField(fieldReceiver, values) + } + + if !receiver.CanSet() { + return false + } + + t := receiver.Type() + ut := unpackType(t, pointer) + if ut.Kind() == reflect.Slice || + isScalarMap(ut) || + ut.Kind() == reflect.Struct { + l := 1 + if ut.Kind() == reflect.Slice { + l = len(values) + } + + rv, ok := allocate(t, l) + if !ok { + return false + } + + if !bindField(rv, values) { + return false + } + + receiver.Set(rv) + return true + } + + return false +} + func fieldsReflect[T any]() []Field { t := reflect.TypeFor[T]() - if hasCircularType(nil, t) { + if hasCircularType(t) { return nil } @@ -187,9 +410,58 @@ func fieldsReflect[T any]() []Field { func fieldValuesReflect(structure any) []Field { v := reflect.ValueOf(structure) - if hasCircularReference(nil, v) { + if hasCircularReference(v) { return nil } return fieldValues(v) } + +func groupFields(f []Field) [][]Field { + withPath, withoutPath := filterFields(hasPath, f) + paths := make(map[string][]Field) + for _, ff := range withPath { + ps := pathString(ff) + paths[ps] = append(paths[ps], ff) + } + + names := make(map[string][]Field) + for _, ff := range withoutPath { + names[ff.name] = append(names[ff.name], ff) + } + + var groups [][]Field + for _, group := range paths { + nfp := nameFromPath(group[0].path) + group = append(group, names[nfp]...) + delete(names, nfp) + groups = append(groups, group) + } + + for _, group := range names { + groups = append(groups, group) + } + + return groups +} + +func bindFieldsReflect(structure any, values []Field) []Field { + receiver := reflect.ValueOf(structure) + if hasCircularReference(receiver) { + return values + } + + if !acceptsFields(receiver.Type()) { + return values + } + + unmatched, try := filterFields(fieldHasCircRef, values) + groups := groupFields(try) + for _, g := range groups { + if !bindField(receiver, g) { + unmatched = append(unmatched, g...) + } + } + + return unmatched +} diff --git a/field_test.go b/field_test.go index 355908a..a9dd0d2 100644 --- a/field_test.go +++ b/field_test.go @@ -128,7 +128,7 @@ func TestField(t *testing.T) { if len(f) != 2 || f[0].Name() != "foo" || f[0].Value() != 21 || !f[0].Free() || f[1].Name() != "bar" || f[1].Value() != 42 || !f[1].Free() { - t.Fatal(notation.Println(f)) + t.Fatal(notation.Sprint(f)) } }) @@ -151,14 +151,14 @@ func TestField(t *testing.T) { f[1].Name() != "foo" || f[1].Value() != 36 || !f[1].Free() || f[2].Name() != "bar" || f[2].Value() != 42 || !f[2].Free() || f[3].Name() != "bar" || f[3].Value() != 72 || !f[3].Free() { - t.Fatal(notation.Println(f)) + t.Fatal(notation.Sprint(f)) } }) t.Run("not a struct", func(t *testing.T) { v := []int{21, 42, 84} f := bind.FieldValues(v) - if len(f) != 0 { + if len(f) != 3 || f[0].Value() != 21 || f[1].Value() != 42 || f[2].Value() != 84 { t.Fatal() } }) @@ -168,10 +168,11 @@ func TestField(t *testing.T) { Foo int bar string } + v := s{Foo: 42, bar: "baz"} f := bind.FieldValues(v) if len(f) != 1 || f[0].Name() != "foo" { - t.Fatal() + t.Fatal(notation.Sprint(f)) } }) @@ -275,4 +276,404 @@ func TestField(t *testing.T) { } }) }) + + t.Run("bind fields", func(t *testing.T) { + t.Run("no circular receiver", func(t *testing.T) { + type s struct{Foo *s; Bar int} + var v s + if len(bind.BindFields(&v, bind.NamedValue("bar", 42))) != 1 { + t.Fatal() + } + }) + + t.Run("no circular valeu", func(t *testing.T) { + type s struct{Foo int} + type p *p + var v p + if len(bind.BindFields(&v, bind.NamedValue("foo", 42))) != 1 { + t.Fatal() + } + }) + + t.Run("set by name", func(t *testing.T) { + type s struct{FooBar int} + var v s + u := bind.BindFields(&v, bind.NamedValue("foo-bar", 42)) + if len(u) != 0 || v.FooBar != 42 { + t.Fatal(notation.Sprint(u), notation.Sprint(v)) + } + }) + + t.Run("set by path", func(t *testing.T) { + type s struct{FooBar int} + var v s + u := bind.BindFields(&v, bind.ValueByPath([]string{"FooBar"}, 42)) + if len(u) != 0 || v.FooBar != 42 { + t.Fatal(notation.Sprint(u), notation.Sprint(v)) + } + }) + + t.Run("fail to bind", func(t *testing.T) { + type s struct{Foo int; Bar int} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 21), + bind.NamedValue("baz", 42), + ) + + if len(u) != 1 || u[0].Name() != "baz" { + t.Fatal() + } + }) + + t.Run("bind list", func(t *testing.T) { + type s struct{Foo []int} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 21), + bind.NamedValue("foo", 42), + bind.NamedValue("foo", 84), + ) + + if len(u) != 0 || v.Foo[0] != 21 || v.Foo[1] != 42 || v.Foo[2] != 84 { + t.Fatal(notation.Sprint(u), notation.Sprint(v)) + } + }) + + t.Run("bind list of structs", func(t *testing.T) { + type s struct{Foo []struct{Bar int}} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo-bar", 21), + bind.NamedValue("foo-bar", 42), + bind.NamedValue("foo-bar", 84), + ) + + if len(u) != 0 || len(v.Foo) != 3 || v.Foo[0].Bar != 21 || v.Foo[1].Bar != 42 || v.Foo[2].Bar != 84 { + t.Fatal() + } + }) + + t.Run("bind list in list", func(t *testing.T) { + type s struct{Foo []struct{Bar []int}} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo-bar", 21), + bind.NamedValue("foo-bar", 42), + bind.NamedValue("foo-bar", 84), + ) + + if len(u) != 0 || len(v.Foo) != 3 || + len(v.Foo[0].Bar) != 1 || v.Foo[0].Bar[0] != 21 || + len(v.Foo[1].Bar) != 1 || v.Foo[1].Bar[0] != 42 || + len(v.Foo[2].Bar) != 1 || v.Foo[2].Bar[0] != 84 { + t.Fatal() + } + }) + + t.Run("list receiver", func(t *testing.T) { + var l []struct{Foo int} + u := bind.BindFields( + &l, + bind.NamedValue("foo", 21), + bind.NamedValue("foo", 42), + bind.NamedValue("foo", 84), + ) + + if len(u) != 0 || len(l) != 3 || l[0].Foo != 21 || l[1].Foo != 42 || l[2].Foo != 84 { + t.Fatal() + } + }) + + t.Run("list short and cannot be set", func(t *testing.T) { + type ( + s0 struct{Bar int} + s1 struct{Foo []s0} + ) + + v := s1{[]s0{{1}, {2}}} + u := bind.BindFields( + v, + bind.NamedValue("foo-bar", 21), + bind.NamedValue("foo-bar", 42), + bind.NamedValue("foo-bar", 84), + ) + + if len(u) != 3 || len(v.Foo) != 2 || v.Foo[0].Bar != 1 || v.Foo[1].Bar != 2 { + t.Fatal(notation.Sprint(u), notation.Sprint(v)) + } + }) + + t.Run("list short and gets reset", func(t *testing.T) { + type ( + s0 struct{Bar int} + s1 struct{Foo []s0} + ) + + v := s1{[]s0{{1}, {2}}} + u := bind.BindFields( + &v, + bind.NamedValue("foo-bar", 21), + bind.NamedValue("foo-bar", 42), + bind.NamedValue("foo-bar", 84), + ) + + if len(u) != 0 || len(v.Foo) != 3 || v.Foo[0].Bar != 21 || v.Foo[1].Bar != 42 || v.Foo[2].Bar != 84 { + t.Fatal(notation.Sprint(u), notation.Sprint(v)) + } + }) + + t.Run("list has invalid type", func(t *testing.T) { + type ( + s0 struct{Bar chan int} + s1 struct{Foo []s0} + ) + + v := s1{[]s0{{nil}, {nil}}} + u := bind.BindFields( + &v, + bind.NamedValue("foo-bar", 21), + bind.NamedValue("foo-bar", 42), + bind.NamedValue("foo-bar", 84), + ) + + if len(u) != 3 { + t.Fatal(notation.Sprint(u), notation.Sprint(v)) + } + }) + + t.Run("bind scalar map", func(t *testing.T) { + m := make(map[string]int) + u := bind.BindFields(m, bind.NamedValue("foo-bar", 21), bind.NamedValue("baz-qux", 42)) + if len(u) != 0 || len(m) != 2 || m["foo-bar"] != 21 || m["baz-qux"] != 42 { + t.Fatal() + } + }) + + t.Run("scalar map key conversion", func(t *testing.T) { + type key string + m := make(map[key]int) + u := bind.BindFields(m, bind.NamedValue("foo", 42)) + if len(u) != 0 || len(m) != 1 || m["foo"] != 42 { + t.Fatal() + } + }) + + t.Run("scalar map pointer key", func(t *testing.T) { + m := make(map[*string]int) + u := bind.BindFields(m, bind.NamedValue("foo", 42)) + if len(u) != 0 || len(m) != 1 { + t.Fatal() + } + + for key, value := range m { + if *key != "foo" || value != 42 { + t.Fatal() + } + } + }) + + t.Run("scalar map list value", func(t *testing.T) { + m := make(map[string][]int) + u := bind.BindFields(m, bind.NamedValue("foo", 21), bind.NamedValue("foo", 42), bind.NamedValue("foo", 84)) + if len(u) != 0 || len(m) != 1 || !slices.Equal(m["foo"], []int{21, 42, 84}) { + t.Fatal(notation.Sprint(u), notation.Sprint(m)) + } + }) + + t.Run("scalar map pointer value", func(t *testing.T) { + m := make(map[string]*int) + u := bind.BindFields(m, bind.NamedValue("foo", 42)) + if len(u) != 0 || len(m) != 1 { + t.Fatal() + } + + for key, value := range m { + if key != "foo" || *value != 42 { + t.Fatal() + } + } + }) + + t.Run("scalar map list pointer value", func(t *testing.T) { + m := make(map[string]*[]int) + u := bind.BindFields(m, bind.NamedValue("foo", 21), bind.NamedValue("foo", 42), bind.NamedValue("foo", 84)) + if len(u) != 0 || len(m) != 1 || !slices.Equal(*m["foo"], []int{21, 42, 84}) { + t.Fatal() + } + }) + + t.Run("allocate scalar map", func(t *testing.T) { + type s struct{Foo map[string]int} + var v s + u := bind.BindFields(&v, bind.NamedValue("foo-bar", 42)) + if len(u) != 0 || len(v.Foo) != 1 || v.Foo["bar"] != 42 { + t.Fatal() + } + }) + + t.Run("scalar map addressing via path", func(t *testing.T) { + type s struct{Foo map[string]int} + var v s + u := bind.BindFields(&v, bind.ValueByPath([]string{"Foo", "Bar"}, 42)) + if len(u) != 0 || len(v.Foo) != 1 || v.Foo["Bar"] != 42 { + t.Fatal() + } + }) + + t.Run("scalar map and too long field path", func(t *testing.T) { + m := make(map[string]int) + u := bind.BindFields(m, bind.ValueByPath([]string{"foo", "bar"}, 42)) + if len(u) != 1 { + t.Fatal() + } + }) + + t.Run("scalar map cannot be set", func(t *testing.T) { + type s struct{Foo map[string]int} + var v s + u := bind.BindFields(v, bind.NamedValue("foo-bar", 42)) + if len(u) != 1 { + t.Fatal() + } + }) + + t.Run("scalar map with wrong value type", func(t *testing.T) { + m := make(map[string]int) + u := bind.BindFields(m, bind.NamedValue("foo", "bar")) + if len(u) != 1 { + t.Fatal() + } + }) + + t.Run("struct fields", func(t *testing.T) { + type s struct{Foo int; Bar struct { Baz string }} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 42), + bind.NamedValue("bar-baz", "qux"), + ) + + if len(u) != 0 || v.Foo != 42 || v.Bar.Baz != "qux" { + t.Fatal() + } + }) + + t.Run("non-existing field", func(t *testing.T) { + type s struct{Foo int; Bar struct { Baz string }} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 42), + bind.NamedValue("bar-qux", "qux"), + ) + + if len(u) != 1 || u[0].Name() != "bar-qux" || v.Foo != 42 || v.Bar.Baz != "" { + t.Fatal() + } + }) + + t.Run("too many fields", func(t *testing.T) { + type s struct{Foo int; Bar struct { Baz string }} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 42), + bind.NamedValue("bar-baz", "qux"), + bind.NamedValue("bar-baz", "quux"), + ) + + if len(u) != 2 || u[0].Name() != "bar-baz" || u[1].Name() != "bar-baz" || v.Foo != 42 { + t.Fatal() + } + }) + + t.Run("pointer fields", func(t *testing.T) { + type s struct{Foo *int; Bar *struct { Baz *string }; Qux *[]struct{Quux string}} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 42), + bind.NamedValue("bar-baz", "qux"), + bind.NamedValue("qux-quux", "corge"), + ) + + if len(u) != 0 || *v.Foo != 42 || *v.Bar.Baz != "qux" || len(*v.Qux) != 1 || (*v.Qux)[0].Quux != "corge" { + t.Fatal() + } + }) + + t.Run("unsupported pointer fields", func(t *testing.T) { + type s struct{Foo *int; Bar *chan int} + var v s + u := bind.BindFields( + &v, + bind.NamedValue("foo", 42), + bind.NamedValue("bar", "qux"), + bind.NamedValue("bar-baz", "corge"), + ) + + if len(u) != 2 || u[0].Name() != "bar" || u[1].Name() != "bar-baz" || *v.Foo != 42 { + t.Fatal() + } + }) + + t.Run("struct fields by path", func(t *testing.T) { + type s struct{Foo int; Bar struct { Baz string }} + var v s + u := bind.BindFields( + &v, + bind.ValueByPath([]string{"Foo"}, 42), + bind.ValueByPath([]string{"Bar", "Baz"}, "qux"), + ) + + if len(u) != 0 || v.Foo != 42 || v.Bar.Baz != "qux" { + t.Fatal() + } + }) + + t.Run("cannot set field", func(t *testing.T) { + type s struct{Foo int} + var v s + u := bind.BindFields(v, bind.NamedValue("foo", 42)) + if len(u) != 1 { + t.Fatal() + } + }) + + t.Run("struct with anonymous field", func(t *testing.T) { + type ( + s0 struct{Foo int} + s1 struct{s0} + ) + + var v s1 + u := bind.BindFields(&v, bind.NamedValue("foo", 42)) + if len(u) != 0 || v.Foo != 42 { + t.Fatal() + } + }) + + t.Run("receiver cannot be set", func(t *testing.T) { + type s struct{Foo *struct{Bar int}} + var v s + u := bind.BindFields(v, bind.NamedValue("foo-bar", 42)) + if len(u) != 1 { + t.Fatal() + } + }) + + t.Run("receiver not supported", func(t *testing.T) { + v := make(chan int) + u := bind.BindFields(&v, bind.NamedValue("foo-bar", 42)) + if len(u) != 1 { + t.Fatal() + } + }) + }) } diff --git a/lib.go b/lib.go index 82e9844..eaa2c88 100644 --- a/lib.go +++ b/lib.go @@ -5,10 +5,10 @@ package bind type Field struct { - name string path []string - isBool bool + name string list bool + isBool bool free bool value any } @@ -22,11 +22,11 @@ func BindScalarCreate[T any](value ...any) (T, bool) { return bindScalarCreateReflect[T](value) } -func FieldValue(path []string, value any) Field { +func ValueByPath(path []string, value any) Field { return Field{path: path, value: value} } -func FieldValueByName(name string, value any) Field { +func NamedValue(name string, value any) Field { return Field{name: name, value: value} } @@ -36,26 +36,37 @@ func (f Field) Path() []string { return p } -func (f Field) Name() string { return f.name } +func (f Field) Name() string { + if f.name != "" || len(f.path) == 0 { + return f.name + } + + return nameFromPath(f.path) +} + func (f Field) List() bool { return f.list } func (f Field) Bool() bool { return f.isBool } func (f Field) Free() bool { return f.free } func (f Field) Value() any { return f.value } +// it does not return fields with free keys, however, this should be obvious +// non-struct and non-named map values return unnamed fields func Fields[T any]() []Field { return fieldsReflect[T]() } +// the list and bool flags are not set because it is not possible if they are defined by the root type func FieldValues(structure any) []Field { return fieldValuesReflect(structure) } -func BindFields(structure any, values []Field) []Field { - return nil +func BindFields(structure any, values ...Field) []Field { + return bindFieldsReflect(structure, values) } -func BindFieldsCreate[T any](values []Field) (any, []Field) { - return nil, nil +func BindFieldsCreate[T any](values ...Field) (T, []Field) { + var t T + return t, nil } func AcceptsScalar[T any]() bool { diff --git a/notes.txt b/notes.txt index 1908028..8b347f2 100644 --- a/notes.txt +++ b/notes.txt @@ -1,3 +1,7 @@ -maps with string keys and scalar, or list of scalar, values can be supported -fields being set by name, can be converted first to use path -add time and duration support. Can there be any other important quasi-scalars? +add time and duration support +don't change the input value when bind returns false +track down the cases when reflect can panic +test: +- repeated bindings +- cases of allocations +- preallocated and unallocated list sizes diff --git a/scalar.go b/scalar.go index 761bbae..7c5664b 100644 --- a/scalar.go +++ b/scalar.go @@ -65,12 +65,12 @@ func bindScalarCreate(t reflect.Type, values []any) (reflect.Value, bool) { func bindScalarReflect(receiver any, values []any) bool { v := reflect.ValueOf(receiver) - if hasCircularReference(nil, v) { + if hasCircularReference(v) { return false } for _, vi := range values { - if hasCircularReference(nil, reflect.ValueOf(vi)) { + if hasCircularReference(reflect.ValueOf(vi)) { return false } } @@ -80,13 +80,13 @@ func bindScalarReflect(receiver any, values []any) bool { func bindScalarCreateReflect[T any](values []any) (T, bool) { t := reflect.TypeFor[T]() - if hasCircularType(nil, t) { + if hasCircularType(t) { var tt T return tt, false } for _, vi := range values { - if hasCircularReference(nil, reflect.ValueOf(vi)) { + if hasCircularReference(reflect.ValueOf(vi)) { var tt T return tt, false } diff --git a/type.go b/type.go index 7f282f5..bd517a4 100644 --- a/type.go +++ b/type.go @@ -11,6 +11,8 @@ const ( iface ) +var typeCirc = make(map[reflect.Type]bool) + func (f unpackFlag) has(v unpackFlag) bool { return f&v > 0 } @@ -25,7 +27,7 @@ func setVisited[T comparable](visited map[T]bool, k T) map[T]bool { return s } -func hasCircularType(visited map[reflect.Type]bool, t reflect.Type) bool { +func checkHasCircularType(visited map[reflect.Type]bool, t reflect.Type) bool { if visited[t] { return true } @@ -33,11 +35,11 @@ func hasCircularType(visited map[reflect.Type]bool, t reflect.Type) bool { switch t.Kind() { case reflect.Pointer, reflect.Slice: visited = setVisited(visited, t) - return hasCircularType(visited, t.Elem()) + return checkHasCircularType(visited, t.Elem()) case reflect.Struct: visited = setVisited(visited, t) for i := 0; i < t.NumField(); i++ { - if hasCircularType(visited, t.Field(i).Type) { + if checkHasCircularType(visited, t.Field(i).Type) { return true } } @@ -48,8 +50,22 @@ func hasCircularType(visited map[reflect.Type]bool, t reflect.Type) bool { } } -func hasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { - if hasCircularType(nil, v.Type()) { +func hasCircularType(t reflect.Type) bool { + if has, cached := typeCirc[t]; cached { + return has + } + + has := checkHasCircularType(nil, t) + typeCirc[t] = has + return has +} + +func checkHasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { + if !v.IsValid() { + return false + } + + if hasCircularType(v.Type()) { return true } @@ -61,7 +77,7 @@ func hasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { } visited = setVisited(visited, v.Pointer()) - return hasCircularReference(visited, v.Elem()) + return checkHasCircularReference(visited, v.Elem()) case reflect.Slice: p := v.Pointer() if visited[p] { @@ -70,7 +86,7 @@ func hasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { visited = setVisited(visited, v.Pointer()) for i := 0; i < v.Len(); i++ { - if hasCircularReference(visited, v.Index(i)) { + if checkHasCircularReference(visited, v.Index(i)) { return true } } @@ -81,10 +97,10 @@ func hasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { return false } - return hasCircularReference(visited, v.Elem()) + return checkHasCircularReference(visited, v.Elem()) case reflect.Struct: for i := 0; i < v.NumField(); i++ { - if hasCircularReference(visited, v.Field(i)) { + if checkHasCircularReference(visited, v.Field(i)) { return true } } @@ -95,6 +111,10 @@ func hasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { return false } +func hasCircularReference(v reflect.Value) bool { + return checkHasCircularReference(nil, v) +} + func isAny(t reflect.Type) bool { return t.Kind() == reflect.Interface && t.NumMethod() == 0 } @@ -153,6 +173,10 @@ func unpackType(t reflect.Type, unpack unpackFlag) reflect.Type { } func unpackValue(v reflect.Value, unpack unpackFlag) reflect.Value { + if v.IsZero() { + return v + } + if unpack.has(pointer) && v.Kind() == reflect.Pointer { return unpackValue(v.Elem(), unpack) } @@ -161,11 +185,11 @@ func unpackValue(v reflect.Value, unpack unpackFlag) reflect.Value { return unpackValue(v.Index(0), unpack) } - if unpack.has(anytype) && isAny(v.Type()) && !v.IsNil() { + if unpack.has(anytype) && isAny(v.Type()) && !v.IsZero() && !v.IsNil() { return unpackValue(v.Elem(), unpack) } - if unpack.has(iface) && isInterface(v.Type()) && !v.IsNil() { + if unpack.has(iface) && isInterface(v.Type()) && !v.IsZero() && !v.IsNil() { return unpackValue(v.Elem(), unpack) } @@ -196,7 +220,7 @@ func bindable(t reflect.Type) bool { } func acceptsScalarChecked(t reflect.Type) bool { - if hasCircularType(nil, t) { + if hasCircularType(t) { return false } @@ -204,7 +228,7 @@ func acceptsScalarChecked(t reflect.Type) bool { } func acceptsFieldsChecked(t reflect.Type) bool { - if hasCircularType(nil, t) { + if hasCircularType(t) { return false } @@ -212,7 +236,7 @@ func acceptsFieldsChecked(t reflect.Type) bool { } func acceptsListChecked(t reflect.Type) bool { - if hasCircularType(nil, t) { + if hasCircularType(t) { return false } @@ -220,7 +244,7 @@ func acceptsListChecked(t reflect.Type) bool { } func bindableChecked(t reflect.Type) bool { - if hasCircularType(nil, t) { + if hasCircularType(t) { return false } @@ -249,11 +273,15 @@ func allocate(t reflect.Type, len int) (reflect.Value, bool) { return reflect.Zero(t), false } - if isAny(t) || isScalar(t) || t.Kind() == reflect.Struct || isScalarMap(t) { + if isAny(t) || isScalar(t) || t.Kind() == reflect.Struct { p := reflect.New(t) return p.Elem(), len == 1 } + if isScalarMap(t) { + return reflect.MakeMap(t), len == 1 + } + if t.Kind() == reflect.Slice { l := reflect.MakeSlice(t, len, len) for i := 0; i < len; i++ { @@ -268,13 +296,16 @@ func allocate(t reflect.Type, len int) (reflect.Value, bool) { return l, true } - // must be pointer - e, ok := allocate(t.Elem(), len) - if !ok { - return reflect.Zero(t), false + if t.Kind() == reflect.Pointer { + e, ok := allocate(t.Elem(), len) + if !ok { + return reflect.Zero(t), false + } + + p := reflect.New(t.Elem()) + p.Elem().Set(e) + return p, true } - p := reflect.New(t.Elem()) - p.Elem().Set(e) - return p, true + return reflect.Zero(t), false }