diff --git a/field.go b/field.go new file mode 100644 index 0000000..7834d3b --- /dev/null +++ b/field.go @@ -0,0 +1,195 @@ +package bind + +import ( + "fmt" + "github.com/iancoleman/strcase" + "reflect" + "unicode" +) + +func exported(name string) bool { + return unicode.IsUpper([]rune(name)[0]) +} + +func fields(t reflect.Type) []Field { + t = unpackType(t, pointer) + list := t.Kind() == reflect.Slice + t = unpackType(t, pointer|slice) + if t.Kind() != reflect.Struct { + return nil + } + + var f []Field + for i := 0; i < t.NumField(); i++ { + tfi := t.Field(i) + if !exported(tfi.Name) && !tfi.Anonymous { + continue + } + + if acceptsScalar(tfi.Type) { + var fi Field + ft := unpackType(tfi.Type, pointer|slice) + fi.isBool = ft.Kind() == reflect.Bool + fi.list = acceptsList(tfi.Type) + fi.name = strcase.ToKebab(tfi.Name) + fi.path = []string{tfi.Name} + f = append(f, fi) + continue + } + + ffi := fields(tfi.Type) + if !tfi.Anonymous { + for i := range ffi { + ffi[i].name = fmt.Sprintf( + "%s-%s", + strcase.ToKebab(tfi.Name), + ffi[i].name, + ) + + ffi[i].path = append( + []string{tfi.Name}, + ffi[i].path..., + ) + } + } + + f = append(f, ffi...) + } + + if list { + for i := range f { + f[i].list = true + } + } + + return f +} + +func fieldFromValue(name string, v reflect.Value) Field { + var fi Field + fi.name = strcase.ToKebab(name) + fi.path = []string{name} + fi.value = v.Interface() + fi.isBool = v.Kind() == reflect.Bool + return fi +} + +func scalarMapFields(v reflect.Value) []Field { + var f []Field + for _, key := range v.MapKeys() { + value := v.MapIndex(key) + name := key.Interface().(string) + if value.Kind() == reflect.Slice { + for i := 0; i < value.Len(); i++ { + fi := fieldFromValue(name, value.Index(i)) + fi.free = true + f = append(f, fi) + } + } else { + fi := fieldFromValue(name, value) + fi.free = true + f = append(f, fi) + } + } + + return f +} + +func prependFieldName(name string, f []Field) { + for i := range f { + f[i].name = fmt.Sprintf("%s-%s", strcase.ToKebab(name), f[i].name) + f[i].path = append([]string{name}, f[i].path...) + } +} + +func listFieldValues(fieldName string, l reflect.Value) []Field { + var f []Field + for i := 0; i < l.Len(); i++ { + item := l.Index(i) + item = unpackValue(item, pointer|iface|anytype) + switch { + case isScalar(item.Type()): + f = append(f, fieldFromValue(fieldName, item)) + case item.Kind() == reflect.Slice: + f = append(f, listFieldValues(fieldName, item)...) + case isScalarMap(item.Type()): + mf := fieldValues(item) + prependFieldName(fieldName, mf) + f = append(f, mf...) + case item.Kind() == reflect.Struct: + sf := fieldValues(item) + prependFieldName(fieldName, sf) + f = append(f, sf...) + } + } + + return f +} + +func fieldValues(v reflect.Value) []Field { + var f []Field + v = unpackValue(v, pointer|anytype|iface) + if v.Kind() == reflect.Slice { + for i := 0; i < v.Len(); i++ { + f = append(f, fieldValues(v.Index(i))...) + } + + return f + } + + t := v.Type() + if isScalarMap(t) { + return scalarMapFields(v) + } + + if t.Kind() != reflect.Struct { + return nil + } + + for i := 0; i < t.NumField(); i++ { + tfi := t.Field(i) + if !exported(tfi.Name) && !tfi.Anonymous { + continue + } + + vfi := v.Field(i) + vfi = unpackValue(vfi, pointer|iface|anytype) + switch { + case isScalar(vfi.Type()): + f = append(f, fieldFromValue(tfi.Name, vfi)) + case vfi.Kind() == reflect.Slice: + f = append(f, listFieldValues(tfi.Name, vfi)...) + case isScalarMap(vfi.Type()): + mf := fieldValues(vfi) + prependFieldName(tfi.Name, mf) + f = append(f, mf...) + case vfi.Kind() == reflect.Struct: + sf := fieldValues(vfi) + if !tfi.Anonymous { + prependFieldName(tfi.Name, sf) + } + + f = append(f, sf...) + } + } + + return f +} + +func fieldsReflect[T any]() []Field { + t := reflect.TypeFor[T]() + if hasCircularType(nil, t) { + return nil + } + + return fields(t) +} + +func fieldValuesReflect(structure any) []Field { + v := reflect.ValueOf(structure) + if hasCircularReference(nil, v) { + return nil + } + + return fieldValues(v) +} diff --git a/field_test.go b/field_test.go new file mode 100644 index 0000000..355908a --- /dev/null +++ b/field_test.go @@ -0,0 +1,278 @@ +package bind_test + +import ( + "code.squareroundforest.org/arpio/bind" + "code.squareroundforest.org/arpio/notation" + "slices" + "sort" + "testing" +) + +func TestField(t *testing.T) { + t.Run("fields", func(t *testing.T) { + type s0 struct { + FieldOne int + } + + type s1 struct { + Foo int + } + + type s2 struct { + s0 + foo int + FooBar int + Baz s1 + Qux *s1 + Chan chan int + Que bool + Hola []int + } + + f := bind.Fields[s2]() + if len(f) != 6 { + t.Fatal(notation.Sprintwt(f)) + } + + m := make(map[string]bind.Field) + for _, fi := range f { + m[fi.Name()] = fi + } + + fieldOne := m["field-one"] + if !slices.Equal(fieldOne.Path(), []string{"FieldOne"}) { + t.Fatal(fieldOne.Name()) + } + + fooBar := m["foo-bar"] + if !slices.Equal(fooBar.Path(), []string{"FooBar"}) { + t.Fatal(fooBar.Name()) + } + + bazFoo := m["baz-foo"] + if !slices.Equal(bazFoo.Path(), []string{"Baz", "Foo"}) { + t.Fatal(bazFoo.Name()) + } + + quxFoo := m["qux-foo"] + if !slices.Equal(quxFoo.Path(), []string{"Qux", "Foo"}) { + t.Fatal(quxFoo.Name()) + } + + que := m["que"] + if !slices.Equal(que.Path(), []string{"Que"}) || !que.Bool() { + t.Fatal(que.Name()) + } + + hola := m["hola"] + if !slices.Equal(hola.Path(), []string{"Hola"}) || !hola.List() { + t.Fatal(hola.Name()) + } + + t.Run("cannot have fields", func(t *testing.T) { + type i []int + f := bind.Fields[i]() + if len(f) != 0 { + t.Fatal() + } + }) + + t.Run("has circular type", func(t *testing.T) { + type s struct{ Foo *s } + if len(bind.Fields[s]()) != 0 { + t.Fatal() + } + }) + + t.Run("list", func(t *testing.T) { + type s struct{ Foo int } + f := bind.Fields[[]s]() + if len(f) != 1 || !f[0].List() { + t.Fatal() + } + }) + }) + + t.Run("field values", func(t *testing.T) { + t.Run("has circular reference", func(t *testing.T) { + type s struct{ Foo any } + var v s + v.Foo = v + if len(bind.FieldValues(v)) != 0 { + t.Fatal() + } + }) + + t.Run("slice", func(t *testing.T) { + type s struct{ Foo int } + f := bind.FieldValues([]s{{21}, {42}}) + if len(f) != 2 || f[0].Value() != 21 || f[1].Value() != 42 { + t.Fatal() + } + }) + + t.Run("scalar map", func(t *testing.T) { + f := bind.FieldValues(map[string]int{"foo": 21, "bar": 42}) + sort.Slice(f, func(i, j int) bool { + if f[i].Name() > f[j].Name() { + return true + } + + if f[i].Value().(int) < f[j].Value().(int) { + return true + } + + return false + }) + + if len(f) != 2 || + f[0].Name() != "foo" || f[0].Value() != 21 || !f[0].Free() || + f[1].Name() != "bar" || f[1].Value() != 42 || !f[1].Free() { + t.Fatal(notation.Println(f)) + } + }) + + t.Run("scalar map with list values", func(t *testing.T) { + f := bind.FieldValues(map[string][]int{"foo": []int{21, 36}, "bar": []int{42, 72}}) + sort.Slice(f, func(i, j int) bool { + if f[i].Name() > f[j].Name() { + return true + } + + if f[i].Value().(int) < f[j].Value().(int) { + return true + } + + return false + }) + + if len(f) != 4 || + f[0].Name() != "foo" || f[0].Value() != 21 || !f[0].Free() || + f[1].Name() != "foo" || f[1].Value() != 36 || !f[1].Free() || + f[2].Name() != "bar" || f[2].Value() != 42 || !f[2].Free() || + f[3].Name() != "bar" || f[3].Value() != 72 || !f[3].Free() { + t.Fatal(notation.Println(f)) + } + }) + + t.Run("not a struct", func(t *testing.T) { + v := []int{21, 42, 84} + f := bind.FieldValues(v) + if len(f) != 0 { + t.Fatal() + } + }) + + t.Run("not exported field", func(t *testing.T) { + type s struct { + Foo int + bar string + } + v := s{Foo: 42, bar: "baz"} + f := bind.FieldValues(v) + if len(f) != 1 || f[0].Name() != "foo" { + t.Fatal() + } + }) + + t.Run("scalar fields", func(t *testing.T) { + type s struct { + Foo int + Bar bool + } + v := s{Foo: 42, Bar: true} + f := bind.FieldValues(v) + if len(f) != 2 || + f[0].Name() != "foo" || !slices.Equal(f[0].Path(), []string{"Foo"}) || f[0].Value() != 42 || f[0].Bool() || + f[1].Name() != "bar" || !slices.Equal(f[1].Path(), []string{"Bar"}) || f[1].Value() != true || !f[1.].Bool() { + t.Fatal() + } + }) + + t.Run("list field", func(t *testing.T) { + type s struct{ Foo []int } + v := s{Foo: []int{21, 42, 84}} + f := bind.FieldValues(v) + if len(f) != 3 || + f[0].Name() != "foo" || f[0].Value() != 21 || + f[1].Name() != "foo" || f[1].Value() != 42 || + f[2].Name() != "foo" || f[2].Value() != 84 { + t.Fatal(notation.Sprintwt(f)) + } + }) + + t.Run("list of lists", func(t *testing.T) { + type s struct{ Foo [][]int } + v := s{Foo: [][]int{{21}, {42}}} + f := bind.FieldValues(v) + if len(f) != 2 || + f[0].Name() != "foo" || f[0].Value() != 21 || + f[1].Name() != "foo" || f[1].Value() != 42 { + t.Fatal() + } + }) + + t.Run("list of scalar maps", func(t *testing.T) { + type s struct{ Foo any } + v := s{Foo: []any{map[string]int{"foo": 42}}} + f := bind.FieldValues(v) + if len(f) != 1 || f[0].Name() != "foo-foo" || f[0].Value() != 42 { + t.Fatal(notation.Sprintwt(f)) + } + }) + + t.Run("list of structs", func(t *testing.T) { + type s struct{ Foo any } + v := s{Foo: []any{s{Foo: 21}, s{Foo: 42}}} + f := bind.FieldValues(v) + if len(f) != 2 || + f[0].Name() != "foo-foo" || f[0].Value() != 21 || + f[1].Name() != "foo-foo" || f[1].Value() != 42 { + t.Fatal() + } + }) + + t.Run("scalar map field", func(t *testing.T) { + type s struct{ Foo map[string]int } + v := s{Foo: map[string]int{"foo": 42}} + f := bind.FieldValues(v) + if len(f) != 1 || f[0].Name() != "foo-foo" || f[0].Value() != 42 { + t.Fatal() + } + }) + + t.Run("anonymous field", func(t *testing.T) { + type s0 struct{ Foo int } + type s1 struct { + s0 + Bar int + } + var v s1 + v.Foo = 21 + v.Bar = 42 + f := bind.FieldValues(v) + if len(f) != 2 || + f[0].Name() != "foo" || f[0].Value() != 21 || + f[1].Name() != "bar" || f[1].Value() != 42 { + t.Fatal() + } + }) + + t.Run("child struct", func(t *testing.T) { + type S0 struct{ Foo int } + type s1 struct { + S0 S0 + Bar int + } + var v s1 + v.S0.Foo = 21 + v.Bar = 42 + f := bind.FieldValues(v) + if len(f) != 2 || + f[0].Name() != "s-0-foo" || f[0].Value() != 21 || + f[1].Name() != "bar" || f[1].Value() != 42 { + t.Fatal(notation.Sprintwt(f)) + } + }) + }) +} diff --git a/go.mod b/go.mod index e7d23b1..278d772 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module code.squareroundforest.org/arpio/bind go 1.25.0 + +require ( + code.squareroundforest.org/arpio/notation v0.0.0-20250826181910-5140794b16b2 // indirect + github.com/iancoleman/strcase v0.3.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..385621e --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +code.squareroundforest.org/arpio/notation v0.0.0-20250826181910-5140794b16b2 h1:S4mjQHL70CuzFg1AGkr0o0d+4M+ZWM0sbnlYq6f0b3I= +code.squareroundforest.org/arpio/notation v0.0.0-20250826181910-5140794b16b2/go.mod h1:ait4Fvg9o0+bq5hlxi9dAcPL5a+/sr33qsZPNpToMLY= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= diff --git a/lib.go b/lib.go index 878aa22..82e9844 100644 --- a/lib.go +++ b/lib.go @@ -1,11 +1,16 @@ // provides more flexible and permissive ways of setting values than reflect.Value.Set +// circular type structures not supported +// it handles scalar fields +// primary use cases by design are: command line options, environment variables, ini file fields, URL query parameters, HTTP form values package bind type Field struct { - name string - path []string - list bool - value any + name string + path []string + isBool bool + list bool + free bool + value any } // the receiver must be addressable @@ -33,14 +38,16 @@ func (f Field) Path() []string { func (f Field) Name() string { return f.name } func (f Field) List() bool { return f.list } +func (f Field) Bool() bool { return f.isBool } +func (f Field) Free() bool { return f.free } func (f Field) Value() any { return f.value } func Fields[T any]() []Field { - return nil + return fieldsReflect[T]() } func FieldValues(structure any) []Field { - return nil + return fieldValuesReflect(structure) } func BindFields(structure any, values []Field) []Field { diff --git a/notes.txt b/notes.txt new file mode 100644 index 0000000..1908028 --- /dev/null +++ b/notes.txt @@ -0,0 +1,3 @@ +maps with string keys and scalar, or list of scalar, values can be supported +fields being set by name, can be converted first to use path +add time and duration support. Can there be any other important quasi-scalars? diff --git a/scalar.go b/scalar.go index 9942e06..761bbae 100644 --- a/scalar.go +++ b/scalar.go @@ -64,14 +64,38 @@ func bindScalarCreate(t reflect.Type, values []any) (reflect.Value, bool) { } func bindScalarReflect(receiver any, values []any) bool { - return bindScalar(reflect.ValueOf(receiver), values) + v := reflect.ValueOf(receiver) + if hasCircularReference(nil, v) { + return false + } + + for _, vi := range values { + if hasCircularReference(nil, reflect.ValueOf(vi)) { + return false + } + } + + return bindScalar(v, values) } func bindScalarCreateReflect[T any](values []any) (T, bool) { - v, ok := bindScalarCreate(reflect.TypeFor[T](), values) + t := reflect.TypeFor[T]() + if hasCircularType(nil, t) { + var tt T + return tt, false + } + + for _, vi := range values { + if hasCircularReference(nil, reflect.ValueOf(vi)) { + var tt T + return tt, false + } + } + + v, ok := bindScalarCreate(t, values) if !ok { - var v T - return v, false + var tt T + return tt, false } return v.Interface().(T), true diff --git a/scalar_test.go b/scalar_test.go index dc41ef6..50312da 100644 --- a/scalar_test.go +++ b/scalar_test.go @@ -146,6 +146,16 @@ func TestScalar(t *testing.T) { t.Fatal() } }) + + t.Run("value has circular reference", func(t *testing.T) { + var r any + var v any + p := &v + *p = p + if bind.BindScalar(&r, p) { + t.Fatal() + } + }) }) t.Run("bind scalar with create", func(t *testing.T) { @@ -156,7 +166,7 @@ func TestScalar(t *testing.T) { }) t.Run("does not accept scalar", func(t *testing.T) { - if _, ok := bind.BindScalarCreate[struct{Foo int}]("42"); ok { + if _, ok := bind.BindScalarCreate[struct{ Foo int }]("42"); ok { t.Fatal() } }) @@ -202,5 +212,21 @@ func TestScalar(t *testing.T) { t.Fatal() } }) + + t.Run("receiver has circular type", func(t *testing.T) { + type s []s + if _, ok := bind.BindScalarCreate[s]("foo"); ok { + t.Fatal() + } + }) + + t.Run("value has circular reference", func(t *testing.T) { + var v any + p := &v + *p = p + if _, ok := bind.BindScalarCreate[any](p); ok { + t.Fatal() + } + }) }) } diff --git a/scan.go b/scan.go index 8f4e679..d6d7e2a 100644 --- a/scan.go +++ b/scan.go @@ -69,6 +69,7 @@ func scanString(t reflect.Type, s string) (any, bool) { } func scan(t reflect.Type, v any) (any, bool) { + // time and duration support if vv, ok := scanConvert(t, v); ok { return vv, true } diff --git a/scan_test.go b/scan_test.go index db0cdc0..5c3e51e 100644 --- a/scan_test.go +++ b/scan_test.go @@ -1,8 +1,8 @@ package bind_test import ( - "testing" "code.squareroundforest.org/arpio/bind" + "testing" ) func TestScan(t *testing.T) { diff --git a/struct.go b/struct.go deleted file mode 100644 index 92a8c5c..0000000 --- a/struct.go +++ /dev/null @@ -1 +0,0 @@ -package bind diff --git a/type.go b/type.go index 36f2344..7f282f5 100644 --- a/type.go +++ b/type.go @@ -15,6 +15,86 @@ func (f unpackFlag) has(v unpackFlag) bool { return f&v > 0 } +func setVisited[T comparable](visited map[T]bool, k T) map[T]bool { + s := make(map[T]bool) + for v := range visited { + s[v] = true + } + + s[k] = true + return s +} + +func hasCircularType(visited map[reflect.Type]bool, t reflect.Type) bool { + if visited[t] { + return true + } + + switch t.Kind() { + case reflect.Pointer, reflect.Slice: + visited = setVisited(visited, t) + return hasCircularType(visited, t.Elem()) + case reflect.Struct: + visited = setVisited(visited, t) + for i := 0; i < t.NumField(); i++ { + if hasCircularType(visited, t.Field(i).Type) { + return true + } + } + + return false + default: + return false + } +} + +func hasCircularReference(visited map[uintptr]bool, v reflect.Value) bool { + if hasCircularType(nil, v.Type()) { + return true + } + + switch v.Kind() { + case reflect.Pointer: + p := v.Pointer() + if visited[p] { + return true + } + + visited = setVisited(visited, v.Pointer()) + return hasCircularReference(visited, v.Elem()) + case reflect.Slice: + p := v.Pointer() + if visited[p] { + return true + } + + visited = setVisited(visited, v.Pointer()) + for i := 0; i < v.Len(); i++ { + if hasCircularReference(visited, v.Index(i)) { + return true + } + } + + return false + case reflect.Interface: + if v.IsNil() { + return false + } + + return hasCircularReference(visited, v.Elem()) + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + if hasCircularReference(visited, v.Field(i)) { + return true + } + } + + return false + } + + return false +} + func isAny(t reflect.Type) bool { return t.Kind() == reflect.Interface && t.NumMethod() == 0 } @@ -115,20 +195,52 @@ func bindable(t reflect.Type) bool { return acceptsScalar(t) || acceptsFields(t) } +func acceptsScalarChecked(t reflect.Type) bool { + if hasCircularType(nil, t) { + return false + } + + return acceptsScalar(t) +} + +func acceptsFieldsChecked(t reflect.Type) bool { + if hasCircularType(nil, t) { + return false + } + + return acceptsFields(t) +} + +func acceptsListChecked(t reflect.Type) bool { + if hasCircularType(nil, t) { + return false + } + + return acceptsList(t) +} + +func bindableChecked(t reflect.Type) bool { + if hasCircularType(nil, t) { + return false + } + + return bindable(t) +} + func acceptsScalarReflect[T any]() bool { - return acceptsScalar(reflect.TypeFor[T]()) + return acceptsScalarChecked(reflect.TypeFor[T]()) } func acceptsFieldsReflect[T any]() bool { - return acceptsFields(reflect.TypeFor[T]()) + return acceptsFieldsChecked(reflect.TypeFor[T]()) } func acceptsListReflect[T any]() bool { - return acceptsList(reflect.TypeFor[T]()) + return acceptsListChecked(reflect.TypeFor[T]()) } func bindableReflect[T any]() bool { - return bindable(reflect.TypeFor[T]()) + return bindableChecked(reflect.TypeFor[T]()) } // expected to be called with types that can pass the bindable check diff --git a/type_test.go b/type_test.go index 7fbb9e6..5219465 100644 --- a/type_test.go +++ b/type_test.go @@ -275,4 +275,94 @@ func TestTypeChecks(t *testing.T) { } }) }) + + t.Run("circular type", func(t *testing.T) { + t.Run("via pointer", func(t *testing.T) { + type p *p + if bind.Bindable[p]() { + t.Fatal() + } + }) + + t.Run("via slice", func(t *testing.T) { + type s []s + if bind.Bindable[s]() { + t.Fatal() + } + }) + + t.Run("pointer via struct field", func(t *testing.T) { + type s struct{ f *s } + if bind.Bindable[s]() { + t.Fatal() + } + }) + + t.Run("slice via struct field", func(t *testing.T) { + type s struct{ f []s } + if bind.Bindable[s]() { + t.Fatal() + } + }) + }) + + t.Run("circular reference", func(t *testing.T) { + t.Run("via pointer", func(t *testing.T) { + p := new(any) + *p = p + if bind.BindScalar(p, "42") { + t.Fatal() + } + }) + + t.Run("via slice", func(t *testing.T) { + s := make([]any, 1) + s[0] = s + if bind.BindScalar(s, "42") { + t.Fatal() + } + }) + + t.Run("via interface and pointer", func(t *testing.T) { + var v any + v = &v + if bind.BindScalar(v, "42") { + t.Fatal() + } + }) + + t.Run("via struct field and pointer", func(t *testing.T) { + type s struct{ F *s } + var v s + v.F = &v + if len(bind.FieldValues(v)) > 0 { + t.Fatal() + } + }) + + t.Run("via struct field and slice", func(t *testing.T) { + type s struct{ F []s } + var v s + v.F = []s{v} + if len(bind.FieldValues(v)) > 0 { + t.Fatal() + } + }) + + t.Run("via struct field and interface", func(t *testing.T) { + var s struct{ F any } + s.F = s + if len(bind.FieldValues(s)) > 0 { + t.Fatal() + } + }) + + t.Run("value with circular type", func(t *testing.T) { + type p *p + var v p + if len(bind.FieldValues(v)) > 0 { + t.Fatal() + } + }) + }) }