From 7c66a3e382d12b7f28eb83cd71736737392cb0a6 Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Sun, 22 Nov 2020 23:22:20 +0100 Subject: [PATCH] handle cyclic references --- debug_test.go | 2 +- notation.go | 21 +++++++++- reflect.go | 110 +++++++++++++++++++++++++++++++++++-------------- sprint_test.go | 75 +++++++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 32 deletions(-) diff --git a/debug_test.go b/debug_test.go index 4fc042c..5f6ccb4 100644 --- a/debug_test.go +++ b/debug_test.go @@ -9,7 +9,7 @@ import ( func TestDebugNode(t *testing.T) { const expect = `"foobarbaz"` o := "foobarbaz" - n := reflectValue(none, reflect.ValueOf(o)) + n := reflectValue(none, &pending{values: make(map[valueKey]nodeRef)}, reflect.ValueOf(o)) s := fmt.Sprint(n) if s != expect { t.Fatalf( diff --git a/notation.go b/notation.go index c266a21..8d96fd6 100644 --- a/notation.go +++ b/notation.go @@ -26,6 +26,20 @@ type wrapLen struct { first, max, last int } +type valueKey struct { + typ reflect.Type + ptr uintptr +} + +type nodeRef struct { + id, refCount int +} + +type pending struct { + values map[valueKey]nodeRef + idCounter int +} + type node struct { len int wrapLen wrapLen @@ -151,7 +165,12 @@ func fprintValues(w io.Writer, o opts, v []interface{}) (int, error) { continue } - n := reflectValue(o, reflect.ValueOf(vi)) + n := reflectValue( + o, + &pending{values: make(map[valueKey]nodeRef)}, + reflect.ValueOf(vi), + ) + if o&wrap != 0 { n = nodeLen(tab, n) n = wrapNode(tab, cols0, cols0, cols1, n) diff --git a/reflect.go b/reflect.go index cc467cd..e3241d2 100644 --- a/reflect.go +++ b/reflect.go @@ -60,7 +60,7 @@ func reflectNil(o opts, groupUnnamedType bool, r reflect.Value) node { return nodeOf(reflectType(rt), "(nil)") } -func reflectItems(o opts, prefix string, r reflect.Value) node { +func reflectItems(o opts, p *pending, prefix string, r reflect.Value) node { typ := r.Type() var items wrapper if typ.Elem().Name() == "uint8" { @@ -77,7 +77,7 @@ func reflectItems(o opts, prefix string, r reflect.Value) node { for i := 0; i < r.Len(); i++ { items.items = append( items.items, - reflectValue(itemOpts, r.Index(i)), + reflectValue(itemOpts, p, r.Index(i)), ) } } @@ -101,8 +101,8 @@ func reflectHidden(o opts, hidden string, r reflect.Value) node { return reflectType(r.Type()) } -func reflectArray(o opts, r reflect.Value) node { - return reflectItems(o, fmt.Sprintf("[%d]", r.Len()), r) +func reflectArray(o opts, p *pending, r reflect.Value) node { + return reflectItems(o, p, fmt.Sprintf("[%d]", r.Len()), r) } func reflectChan(o opts, r reflect.Value) node { @@ -113,12 +113,12 @@ func reflectFunc(o opts, r reflect.Value) node { return reflectHidden(o, "func()", r) } -func reflectInterface(o opts, r reflect.Value) node { +func reflectInterface(o opts, p *pending, r reflect.Value) node { if r.IsNil() { return reflectNil(o, false, r) } - e := reflectValue(o, r.Elem()) + e := reflectValue(o, p, r.Elem()) if _, t, _ := withType(o); !t { return e } @@ -131,7 +131,7 @@ func reflectInterface(o opts, r reflect.Value) node { ) } -func reflectMap(o opts, r reflect.Value) node { +func reflectMap(o opts, p *pending, r reflect.Value) node { if r.IsNil() { return reflectNil(o, true, r) } @@ -149,7 +149,7 @@ func reflectMap(o opts, r reflect.Value) node { sn := make(map[string]node) for _, key := range keys { var b bytes.Buffer - nk := reflectValue(itemOpts, key) + nk := reflectValue(itemOpts, p, key) nkeys = append(nkeys, nk) wr := writer{w: &b} fprint(&wr, 0, nk) @@ -169,7 +169,7 @@ func reflectMap(o opts, r reflect.Value) node { nodeOf( sn[skey], ": ", - reflectValue(itemOpts, r.MapIndex(sv[skey])), + reflectValue(itemOpts, p, r.MapIndex(sv[skey])), ), ) } @@ -181,12 +181,12 @@ func reflectMap(o opts, r reflect.Value) node { return nodeOf(reflectType(r.Type()), "{", items, "}") } -func reflectPointer(o opts, r reflect.Value) node { +func reflectPointer(o opts, p *pending, r reflect.Value) node { if r.IsNil() { return reflectNil(o, true, r) } - e := reflectValue(o, r.Elem()) + e := reflectValue(o, p, r.Elem()) if _, t, _ := withType(o); !t { return e } @@ -194,12 +194,12 @@ func reflectPointer(o opts, r reflect.Value) node { return nodeOf("*", e) } -func reflectList(o opts, r reflect.Value) node { +func reflectList(o opts, p *pending, r reflect.Value) node { if r.IsNil() { return reflectNil(o, true, r) } - return reflectItems(o, "[]", r) + return reflectItems(o, p, "[]", r) } func reflectString(o opts, r reflect.Value) node { @@ -248,7 +248,7 @@ func reflectString(o opts, r reflect.Value) node { return nodeOf(tn, "(", wrapper{items: []node{n}}, ")") } -func reflectStruct(o opts, r reflect.Value) node { +func reflectStruct(o opts, p *pending, r reflect.Value) node { wr := wrapper{sep: ", ", suffix: ","} fieldOpts := o | skipTypes @@ -262,6 +262,7 @@ func reflectStruct(o opts, r reflect.Value) node { ": ", reflectValue( fieldOpts, + p, r.FieldByName(name), ), ), @@ -287,17 +288,64 @@ func reflectUnsafePointer(o opts, r reflect.Value) node { return nodeOf(reflectType(r.Type()), "(pointer)") } -func reflectValue(o opts, r reflect.Value) node { +func checkPending(p *pending, r reflect.Value) (applyRef func(node) node, ref node, isPending bool) { + applyRef = func(n node) node { return n } + switch r.Kind() { + case reflect.Slice, reflect.Map: + case reflect.Ptr: + if r.IsNil() { + return + } + default: + return + } + + var nr nodeRef + key := valueKey{typ: r.Type(), ptr: r.Pointer()} + nr, isPending = p.values[key] + if isPending { + nr.refCount++ + p.values[key] = nr + ref = nodeOf("r", nr.id) + return + } + + nr = nodeRef{id: p.idCounter} + p.idCounter++ + p.values[key] = nr + applyRef = func(n node) node { + nr = p.values[key] + if nr.refCount > 0 { + n.parts = append( + []interface{}{"r", nr.id, "="}, + n.parts..., + ) + } + + delete(p.values, key) + return n + } + + return +} + +func reflectValue(o opts, p *pending, r reflect.Value) node { + applyRef, ref, isPending := checkPending(p, r) + if isPending { + return ref + } + + var n node switch r.Kind() { case reflect.Bool: - return reflectPrimitive(o, r, r.Bool(), "bool") + n = reflectPrimitive(o, r, r.Bool(), "bool") case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return reflectPrimitive(o, r, r.Int(), "int") + n = reflectPrimitive(o, r, r.Int(), "int") case reflect.Uint, reflect.Uint8, @@ -305,30 +353,32 @@ func reflectValue(o opts, r reflect.Value) node { reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return reflectPrimitive(o, r, r.Uint()) + n = reflectPrimitive(o, r, r.Uint()) case reflect.Float32, reflect.Float64: - return reflectPrimitive(o, r, r.Float()) + n = reflectPrimitive(o, r, r.Float()) case reflect.Complex64, reflect.Complex128: - return reflectPrimitive(o, r, r.Complex()) + n = reflectPrimitive(o, r, r.Complex()) case reflect.Array: - return reflectArray(o, r) + n = reflectArray(o, p, r) case reflect.Chan: - return reflectChan(o, r) + n = reflectChan(o, r) case reflect.Func: - return reflectFunc(o, r) + n = reflectFunc(o, r) case reflect.Interface: - return reflectInterface(o, r) + n = reflectInterface(o, p, r) case reflect.Map: - return reflectMap(o, r) + n = reflectMap(o, p, r) case reflect.Ptr: - return reflectPointer(o, r) + n = reflectPointer(o, p, r) case reflect.Slice: - return reflectList(o, r) + n = reflectList(o, p, r) case reflect.String: - return reflectString(o, r) + n = reflectString(o, r) case reflect.UnsafePointer: - return reflectUnsafePointer(o, r) + n = reflectUnsafePointer(o, r) default: - return reflectStruct(o, r) + n = reflectStruct(o, p, r) } + + return applyRef(n) } diff --git a/sprint_test.go b/sprint_test.go index f7f29c6..924715a 100644 --- a/sprint_test.go +++ b/sprint_test.go @@ -1001,3 +1001,78 @@ func TestSingleLongString(t *testing.T) { } }) } + +func TestCyclicReferences(t *testing.T) { + t.Run("slice", func(t *testing.T) { + const expect = `r0=[]{r0}` + l := []interface{}{"foo"} + l[0] = l + s := Sprint(l) + if s != expect { + t.Fatalf("expected: %s, got: %s", expect, s) + } + }) + + t.Run("map", func(t *testing.T) { + const expect = `r0=map{"foo": r0}` + m := map[string]interface{}{"foo": "bar"} + m["foo"] = m + s := Sprint(m) + if s != expect { + t.Fatalf("expected: %s, got: %s", expect, s) + } + }) + + t.Run("pointer", func(t *testing.T) { + const expect = `r0=r0` + p := new(interface{}) + *p = p + s := Sprint(p) + if s != expect { + t.Fatalf("expected: %s, got: %s", expect, s) + } + }) + + t.Run("multiple refs", func(t *testing.T) { + const expect = `{f0: r1={f0: r2={f0: nil, f1: r1, f2: r2}, f1: nil, f2: nil}, f1: nil, f2: nil}` + type typ struct{ f0, f1, f2 *typ } + v0 := new(typ) + v1 := new(typ) + v2 := new(typ) + v0.f0 = v1 + v1.f0 = v2 + v2.f1 = v1 + v2.f2 = v2 + s := Sprintw(v0) + if s != expect { + t.Fatalf("expected: %s, got: %s", expect, s) + } + }) + + t.Run("multiple refs, different subtrees", func(t *testing.T) { + const expect = `{ + f0: r1={f0: r2={f0: nil, f1: r1, f2: r2}, f1: nil, f2: nil}, + f1: r3={f0: r4={f0: nil, f1: r3, f2: r4}, f1: nil, f2: nil}, + f2: nil, +}` + + type typ struct{ f0, f1, f2 *typ } + v0 := new(typ) + v11 := new(typ) + v12 := new(typ) + v21 := new(typ) + v22 := new(typ) + v0.f0 = v11 + v11.f0 = v12 + v12.f1 = v11 + v12.f2 = v12 + v0.f1 = v21 + v21.f0 = v22 + v22.f1 = v21 + v22.f2 = v22 + s := Sprintw(v0) + if s != expect { + t.Fatalf("expected: %s, got: %s", expect, s) + } + }) +}