From af83b13d7c92aae06f6be7bee5c985cce8e4a8de Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Sat, 9 Aug 2025 20:39:45 +0200 Subject: [PATCH] fix struct field docs --- docs.go | 78 +++++++++++++++++++++++++- docs_test.go | 18 ++++++ generate/generate.go | 49 ++++++++++++---- generate/generate_test.go | 7 ++- internal/tests/src/testpackage/test.go | 9 ++- 5 files changed, 142 insertions(+), 19 deletions(-) diff --git a/docs.go b/docs.go index 10d8a3c..710445f 100644 --- a/docs.go +++ b/docs.go @@ -13,11 +13,24 @@ import ( "strings" ) +type pointer[T any] interface { + Kind() reflect.Kind + Elem() T +} + var ( registry = make(map[string]string) funcParamsExp = regexp.MustCompile("^func[(]([a-zA-Z_][a-zA-Z0-9_]+(, [a-zA-Z_][a-zA-Z0-9_]+)*)?[)]$") ) +func unpack[T pointer[T]](p T) T { + if p.Kind() != reflect.Pointer { + return p + } + + return unpack(p.Elem()) +} + func docs(p string) string { return registry[p] } @@ -91,6 +104,7 @@ func Docs(gopath string) string { // Function returns the documentation for a package level function. func Function(v reflect.Value) string { + v = unpack(v) if v.Kind() != reflect.Func { return "" } @@ -100,6 +114,7 @@ func Function(v reflect.Value) string { // FunctionParams returns the list of the parameter names of a package level function. func FunctionParams(v reflect.Value) []string { + v = unpack(v) if v.Kind() != reflect.Func { return nil } @@ -110,27 +125,83 @@ func FunctionParams(v reflect.Value) []string { // Type returns the docuemntation for a package level type. func Type(t reflect.Type) string { + t = unpack(t) p := fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) return docs(p) } +func structField(st reflect.Type, name string) (reflect.StructField, bool) { + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + if f.Name == name { + return f, true + } + } + + return reflect.StructField{}, false +} + // Field returns the docuemntation for a struct field. func Field(t reflect.Type, fieldPath ...string) string { if len(fieldPath) == 0 { return "" } + t = unpack(t) if t.Kind() != reflect.Struct { return "" } - p := strings.Join(append([]string{t.PkgPath(), t.Name()}, fieldPath...), ".") - println(p) - return docs(p) + f, found := structField(t, fieldPath[0]) + if !found { + return "" + } + + if len(fieldPath) == 1 { + println("returning docs", strings.Join([]string{t.PkgPath(), t.Name(), fieldPath[0]}, ".")) + return docs(strings.Join([]string{t.PkgPath(), t.Name(), fieldPath[0]}, ".")) + } + + ft := unpack(f.Type) + if ft.Kind() != reflect.Struct { + return "" + } + + if ft.Name() != "" { + return Field(ft, fieldPath[1:]...) + } + + st := ft + path := fieldPath[1:] + for { + f, found = structField(st, path[0]) + if !found { + return "" + } + + if len(path) == 1 { + return docs(strings.Join(append([]string{t.PkgPath(), t.Name()}, fieldPath...), ".")) + } + + path = path[1:] + if len(path) == 0 { + return "" + } + + st = unpack(f.Type) + if st.Kind() != reflect.Struct { + return "" + } + + if st.Name() != "" { + return Field(st, path...) + } + } } // Method returns the documentation for a type method. func Method(t reflect.Type, name string) string { + t = unpack(t) if t.Kind() != reflect.Struct { return "" } @@ -141,6 +212,7 @@ func Method(t reflect.Type, name string) string { // MethodParams returns the list of the parameter names of a type method. func MethodParams(t reflect.Type, name string) []string { + t = unpack(t) if t.Kind() != reflect.Struct { return nil } diff --git a/docs_test.go b/docs_test.go index 345eec6..d072400 100644 --- a/docs_test.go +++ b/docs_test.go @@ -96,6 +96,24 @@ func Test(t *testing.T) { } }) + t.Run("field in inline struct type", func(t *testing.T) { + s := &testpackage.ExportedType{} + typ := reflect.TypeOf(s) + d := docreflect.Field(typ, "Bar", "Baz") + if !strings.Contains(d, "Baz is another field") { + t.Fatal(d) + } + }) + + t.Run("field in another type", func(t *testing.T) { + s := &testpackage.ExportedType{} + typ := reflect.TypeOf(s) + d := docreflect.Field(typ, "Baz", "Foo") + if !strings.Contains(d, "Foo is a field in ExportedType2") { + t.Fatal(d) + } + }) + t.Run("method", func(t *testing.T) { s := testpackage.ExportedType{} typ := reflect.TypeOf(s) diff --git a/generate/generate.go b/generate/generate.go index 812b52d..0d10139 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -320,7 +320,7 @@ func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { fieldPath := symbol[1:] for len(fieldPath) > 0 { - var found bool + var foundField bool if str.Fields != nil { for _, f := range str.Fields.List { for _, fn := range f.Names { @@ -329,6 +329,7 @@ func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { continue } + foundField = true if len(fieldPath) == 1 { if f.Doc == nil { return "", true @@ -339,19 +340,22 @@ func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { fstr, ok := f.Type.(*ast.StructType) if !ok { - break + return "", false } - found = true str = fstr fieldPath = fieldPath[1:] break } } + + if foundField { + break + } } } - if !found { + if !foundField { break } } @@ -390,6 +394,27 @@ func valueDocs(packagePath string, v []*doc.Value) map[string]string { return d } +func takeFieldDocs(docs map[string]string, packagePath string, prefix []string, f *ast.Field) { + for _, fn := range f.Names { + if fn == nil { + continue + } + + docs[symbolPath(packagePath, append(prefix, fn.Name)...)] = f.Doc.Text() + if fst, ok := f.Type.(*ast.StructType); ok { + if fst.Fields != nil { + for _, fi := range fst.Fields.List { + if fi == nil { + continue + } + + takeFieldDocs(docs, packagePath, append(prefix, fn.Name), fi) + } + } + } + } +} + func packageDocs(pkg *doc.Package) map[string]string { d := make(map[string]string) d[pkg.ImportPath] = pkg.Doc @@ -404,17 +429,11 @@ func packageDocs(pkg *doc.Package) map[string]string { if ts, ok := t.Decl.Specs[0].(*ast.TypeSpec); ok { if str, ok := ts.Type.(*ast.StructType); ok && str.Fields != nil { for _, f := range str.Fields.List { - if f == nil || f.Doc == nil { + if f == nil { continue } - for _, fn := range f.Names { - if fn == nil { - continue - } - - d[symbolPath(pkg.ImportPath, t.Name, fn.Name)] = f.Doc.Text() - } + takeFieldDocs(d, pkg.ImportPath, []string{t.Name}, f) } } } @@ -534,3 +553,9 @@ func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) e return nil } + +// TODO: +// - type foo = bar +// - type foo bar +// - type (...) +// - pointers in struct fields diff --git a/generate/generate_test.go b/generate/generate_test.go index d191bc3..04f40d8 100644 --- a/generate/generate_test.go +++ b/generate/generate_test.go @@ -63,9 +63,9 @@ func testGenerate(check map[string]string, errstr string, o options, gopath ...s if failed { t.Log("failed to get the right documentation") t.Log("expected matches:") - t.Log(notation.Sprint(check)) + t.Log(notation.Sprintw(check)) t.Log("documetation got:") - t.Log(notation.Sprint(d)) + t.Log(notation.Sprintw(d)) t.Fatal() } } @@ -121,7 +121,8 @@ func TestGenerate(t *testing.T) { testGenerate( check( packagePath, "Package testpackage is a test package", - fmt.Sprintf("%s.%s", packagePath, "ExportedType"), "ExportedType has docs", + fmt.Sprintf("%s.ExportedType", packagePath), "ExportedType has docs", + fmt.Sprintf("%s.ExportedType.Bar.Baz", packagePath), "Baz is another field", ), "", o, diff --git a/internal/tests/src/testpackage/test.go b/internal/tests/src/testpackage/test.go index 00f5546..a2a9a1b 100644 --- a/internal/tests/src/testpackage/test.go +++ b/internal/tests/src/testpackage/test.go @@ -19,11 +19,18 @@ type ExportedType struct { // Baz is another field Baz int } + + // Baz is a field of type *ExportedType2 + Baz *ExportedType2 } type Foo = ExportedType -type ExportedType2 struct{} +type ExportedType2 struct{ + + // Foo is a field in ExportedType2 + Foo int +} // C1 is a const const C1 = 42