From 5c8853f7aee51da76b3ecd0ecf0a5fffda33ba5e Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Sun, 17 Aug 2025 15:59:01 +0200 Subject: [PATCH] testing generate --- .gitignore | 2 +- Makefile | 24 +- generate/generate.go | 366 +++++++++++++++---------- generate/generate_test.go | 74 ++++- internal/tests/src/testpackage/test.go | 27 ++ 5 files changed, 330 insertions(+), 163 deletions(-) diff --git a/.gitignore b/.gitignore index bbec715..904be13 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ .bin -testdocs.go +testdocs_test.go .cover diff --git a/Makefile b/Makefile index f755c41..0021c04 100644 --- a/Makefile +++ b/Makefile @@ -1,27 +1,29 @@ -SOURCES = $(shell find . -name "*.go") +SOURCES = $(shell find . -name "*.go" | grep -v testdocs_test.go) PREFIX ?= ~/bin default: build -build: $(SOURCES) .bin +libdocreflect: $(SOURCES) go build . - go build ./generate - go build -o .bin/docreflect ./cmd/docreflect -.bin/docreflect: build +libgenerate: $(SOURCES) + go build ./generate .bin: mkdir -p .bin -check: $(SOURCES) .bin/docreflect - .bin/docreflect generate docreflect_test code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage > testdocs_test.go - go test -count 1 . ./generate - rm -f testdocs_test.go +.bin/docreflect: $(SOURCES) .bin + go build -o .bin/docreflect ./cmd/docreflect -.cover: $(SOURCES) .bin/docreflect +build: libdocreflect libgenerate .bin/docreflect + +testdocs_test.go: $(SOURCES) .bin/docreflect .bin/docreflect generate docreflect_test code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage > testdocs_test.go + +.cover: $(SOURCES) testdocs_test.go go test -count 1 -coverprofile .cover . ./generate - rm -f testdocs_test.go + +check: .cover cover: .cover go tool cover -func .cover diff --git a/generate/generate.go b/generate/generate.go index 0d10139..5c8e5bf 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -1,3 +1,5 @@ +// Package generate provides a generator to generate go code from go docs that registers doc entries +// for use with the docreflect package. package generate import ( @@ -80,12 +82,11 @@ func readGomod(wd string) (string, map[string]string) { return "", nil } + var dirs map[string]string mc := modCache() - dirs := make(map[string]string) for _, dep := range m.Require { - pth := dep.Mod.String() - moduleDir := path.Join(mc, pth) - dirs[pth] = moduleDir + p := dep.Mod.String() + dirs = set(dirs, p, path.Join(mc, p)) } name := m.Module.Mod.String() @@ -125,8 +126,11 @@ func splitGopath(p string) (string, string) { } func packagePaths(p []string) []string { - var pp []string - m := make(map[string]bool) + var ( + pp []string + m map[string]bool + ) + for _, pi := range p { ppi, _ := splitGopath(pi) if m[ppi] { @@ -134,7 +138,7 @@ func packagePaths(p []string) []string { } pp = append(pp, ppi) - m[ppi] = true + m = set(m, ppi, true) } return pp @@ -205,7 +209,7 @@ func parserInclude(pkg *build.Package) func(fs.FileInfo) bool { } func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) { - ppkgs := make(map[string][]*ast.Package) + var ppkgs map[string][]*ast.Package for _, p := range pkgs { fset := token.NewFileSet() pm, err := parser.ParseDir(fset, p.Dir, parserInclude(p), parser.ParseComments) @@ -214,9 +218,11 @@ func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) { } for _, pp := range pm { - if pp != nil { - ppkgs[p.ImportPath] = append(ppkgs[p.ImportPath], pp) + if pp == nil { + continue } + + ppkgs = set(ppkgs, p.ImportPath, append(ppkgs[p.ImportPath], pp)) } } @@ -233,33 +239,55 @@ func fixDocPackage(p doc.Package) doc.Package { return p } -func merge(m ...map[string]string) map[string]string { +func set[K comparable, V any](m map[K]V, key K, value V) map[K]V { + if m == nil { + m = make(map[K]V) + } + + m[key] = value + return m +} + +func merge[K comparable, V any](m ...map[K]V) map[K]V { if len(m) == 1 { return m[0] } - mm := make(map[string]string) + var mm map[K]V for key, value := range m[0] { - mm[key] = value + mm = set(mm, key, value) } mr := merge(m[1:]...) for key, value := range mr { - mm[key] = value + mm = set(mm, key, value) } return mm } +func unpack(e ast.Expr) ast.Expr { + p, ok := e.(*ast.StarExpr) + if !ok { + return e + } + + return unpack(p.X) +} + func funcParams(f *doc.Func) string { + if f.Decl == nil || f.Decl.Type == nil || f.Decl.Type.Params == nil { + return "func()" + } + var paramNames []string - if f.Decl != nil && f.Decl.Type != nil && f.Decl.Type.Params != nil { - for _, p := range f.Decl.Type.Params.List { - for _, n := range p.Names { - if n != nil { - paramNames = append(paramNames, n.Name) - } + for _, p := range f.Decl.Type.Params.List { + for _, n := range p.Names { + if n == nil { + continue } + + paramNames = append(paramNames, n.Name) } } @@ -270,20 +298,93 @@ func funcDocs(f *doc.Func) string { return fmt.Sprintf("%s\n%s", f.Doc, funcParams(f)) } +func findFieldDocs(str *ast.StructType, fieldPath []string) (string, bool) { + if len(fieldPath) == 0 || str.Fields == nil { + return "", false + } + + for _, f := range str.Fields.List { + var found bool + for _, fn := range f.Names { + if fn != nil && fn.Name == fieldPath[0] { + found = true + break + } + } + + if !found { + continue + } + + if len(fieldPath) == 1 { + if f.Doc == nil { + return "", true + } + + return f.Doc.Text(), true + } + + te := unpack(f.Type) + fstr, ok := te.(*ast.StructType) + if !ok { + return "", false + } + + return findFieldDocs(fstr, fieldPath[1:]) + } + + return "", false +} + +func structFieldDocs(t *doc.Type, fieldPath []string) (string, bool) { + if t.Decl == nil || len(t.Decl.Specs) != 1 { + return "", false + } + + ts, ok := t.Decl.Specs[0].(*ast.TypeSpec) + if !ok { + return "", false + } + + te := unpack(ts.Type) + str, ok := te.(*ast.StructType) + if !ok { + return "", false + } + + return findFieldDocs(str, fieldPath) +} + +func typeMethodDocs(t *doc.Type, name string) (string, bool) { + for _, m := range t.Methods { + if m.Name != name { + continue + } + + return funcDocs(m), true + } + + return "", false +} + func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { _, s := splitGopath(gopath) symbol := strings.Split(s, ".") if len(symbol) == 1 { for _, c := range pkg.Consts { - if c != nil && slices.Contains(c.Names, symbol[0]) { - return c.Doc, true + if c == nil || !slices.Contains(c.Names, symbol[0]) { + continue } + + return c.Doc, true } for _, v := range pkg.Vars { - if v != nil && slices.Contains(v.Names, symbol[0]) { - return v.Doc, true + if v == nil || !slices.Contains(v.Names, symbol[0]) { + continue } + + return v.Doc, true } for _, f := range pkg.Funcs { @@ -304,73 +405,16 @@ func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { return t.Doc, true } - if t.Decl == nil || len(t.Decl.Specs) != 1 { - continue - } - - ts, ok := t.Decl.Specs[0].(*ast.TypeSpec) - if !ok { - continue - } - - str, ok := ts.Type.(*ast.StructType) - if !ok { - continue - } - - fieldPath := symbol[1:] - for len(fieldPath) > 0 { - var foundField bool - if str.Fields != nil { - for _, f := range str.Fields.List { - for _, fn := range f.Names { - if fn != nil { - if fn.Name != fieldPath[0] { - continue - } - - foundField = true - if len(fieldPath) == 1 { - if f.Doc == nil { - return "", true - } - - return f.Doc.Text(), true - } - - fstr, ok := f.Type.(*ast.StructType) - if !ok { - return "", false - } - - str = fstr - fieldPath = fieldPath[1:] - break - } - } - - if foundField { - break - } - } - } - - if !foundField { - break - } + if d, ok := structFieldDocs(t, symbol[1:]); ok { + return d, true } if len(symbol) != 2 { return "", false } - methodName := symbol[1] - for _, m := range t.Methods { - if m.Name != methodName { - continue - } - - return funcDocs(m), true + if d, ok := typeMethodDocs(t, symbol[1]); ok { + return d, ok } } @@ -382,11 +426,11 @@ func symbolPath(packagePath string, name ...string) string { } func valueDocs(packagePath string, v []*doc.Value) map[string]string { - d := make(map[string]string) + var d map[string]string for _, vi := range v { if vi != nil { for _, n := range vi.Names { - d[symbolPath(packagePath, n)] = vi.Doc + d = set(d, symbolPath(packagePath, n), vi.Doc) } } } @@ -394,71 +438,110 @@ func valueDocs(packagePath string, v []*doc.Value) map[string]string { return d } -func takeFieldDocs(docs map[string]string, packagePath string, prefix []string, f *ast.Field) { +func takeFieldDocs(packagePath string, prefix []string, f *ast.Field) map[string]string { + var docs map[string]string for _, fn := range f.Names { if fn == nil { continue } - docs[symbolPath(packagePath, append(prefix, fn.Name)...)] = f.Doc.Text() - if fst, ok := f.Type.(*ast.StructType); ok { - if fst.Fields != nil { - for _, fi := range fst.Fields.List { - if fi == nil { - continue - } + docs = set(docs, symbolPath(packagePath, append(prefix, fn.Name)...), f.Doc.Text()) + te := unpack(f.Type) + fst, ok := te.(*ast.StructType) + if !ok || fst.Fields == nil { + continue + } - takeFieldDocs(docs, packagePath, append(prefix, fn.Name), fi) - } + for _, fi := range fst.Fields.List { + if fi == nil { + continue } + + docs = merge(docs, takeFieldDocs(packagePath, append(prefix, fn.Name), fi)) } } + + return docs } -func packageDocs(pkg *doc.Package) map[string]string { - d := make(map[string]string) - d[pkg.ImportPath] = pkg.Doc - d = merge(d, valueDocs(pkg.ImportPath, pkg.Consts)) - d = merge(d, valueDocs(pkg.ImportPath, pkg.Vars)) - for _, t := range pkg.Types { - if t != nil { - d[symbolPath(pkg.ImportPath, t.Name)] = t.Doc - } - - if t.Decl != nil && len(t.Decl.Specs) == 1 { - if ts, ok := t.Decl.Specs[0].(*ast.TypeSpec); ok { - if str, ok := ts.Type.(*ast.StructType); ok && str.Fields != nil { - for _, f := range str.Fields.List { - if f == nil { - continue - } - - takeFieldDocs(d, pkg.ImportPath, []string{t.Name}, f) - } - } - } - } - - for _, m := range t.Methods { - if m != nil { - doc := funcDocs(m) - d[symbolPath(pkg.ImportPath, t.Name, m.Name)] = doc - } - } +func structFieldsDocs(importPath string, t *doc.Type) map[string]string { + if t.Decl == nil || len(t.Decl.Specs) != 1 { + return nil } - for _, f := range pkg.Funcs { - if f != nil { - doc := funcDocs(f) - d[symbolPath(pkg.ImportPath, f.Name)] = doc + ts, ok := t.Decl.Specs[0].(*ast.TypeSpec) + if !ok { + return nil + } + + te := unpack(ts.Type) + str, ok := te.(*ast.StructType) + if !ok || str.Fields == nil { + return nil + } + + var d map[string]string + for _, f := range str.Fields.List { + if f == nil { + continue + } + + d = merge(d, takeFieldDocs(importPath, []string{t.Name}, f)) + } + + return d +} + +func methodDocs(importPath string, t *doc.Type) map[string]string { + var d map[string]string + for _, m := range t.Methods { + if m != nil { + d = set(d, symbolPath(importPath, t.Name, m.Name), funcDocs(m)) } } return d } +func typeDocs(importPath string, types []*doc.Type) map[string]string { + var d map[string]string + for _, t := range types { + if t == nil { + continue + } + + d = set(d, symbolPath(importPath, t.Name), t.Doc) + d = merge(d, structFieldsDocs(importPath, t)) + d = merge(d, methodDocs(importPath, t)) + + } + + return d +} + +func packageFuncDocs(importPath string, funcs []*doc.Func) map[string]string { + var d map[string]string + for _, f := range funcs { + if f != nil { + d = set(d, symbolPath(importPath, f.Name), funcDocs(f)) + } + } + + return d +} + +func packageDocs(pkg *doc.Package) map[string]string { + return merge( + map[string]string{pkg.ImportPath: pkg.Doc}, + valueDocs(pkg.ImportPath, pkg.Consts), + valueDocs(pkg.ImportPath, pkg.Vars), + typeDocs(pkg.ImportPath, pkg.Types), + packageFuncDocs(pkg.ImportPath, pkg.Funcs), + ) +} + func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) { - dm := make(map[string]string) + var dm map[string]string for _, gp := range gopaths { pp, _ := splitGopath(gp) isPackage := pp == gp @@ -466,8 +549,7 @@ func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]stri dpkg := doc.New(pkg, pp, doc.AllDecls|doc.PreserveAST) *dpkg = fixDocPackage(*dpkg) if isPackage { - pd := packageDocs(dpkg) - dm = merge(dm, pd) + dm = merge(dm, packageDocs(dpkg)) continue } @@ -476,7 +558,7 @@ func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]stri return nil, fmt.Errorf("symbol not found: %s", gp) } - dm[gp] = sd + dm = set(dm, gp, sd) } } @@ -534,12 +616,18 @@ func format(w io.Writer, pname string, docs map[string]string) error { printf("docreflect.Register(%s, %s)\n", path, doc) } - println("}") + printf("}") return err } // GenerateRegistry generates a Go code file to the output, including a package init function that // will register the documentation of the declarations specified by their gopath. +// +// The gopath argument accepts any number of package, package level symbol, of struct field paths. +// It is recommended to use package paths unless special circumstances. +// +// Some important gotchas to keep in mind, GenerateRegistry does not resolve type references like +// type aliases, or type definitions based on named types, and it doesn't follow import paths. func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) error { o := initOptions() d, err := generate(o, gopath...) @@ -553,9 +641,3 @@ func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) e return nil } - -// TODO: -// - type foo = bar -// - type foo bar -// - type (...) -// - pointers in struct fields diff --git a/generate/generate_test.go b/generate/generate_test.go index 04f40d8..2650038 100644 --- a/generate/generate_test.go +++ b/generate/generate_test.go @@ -87,7 +87,7 @@ func TestGenerate(t *testing.T) { }, } - t.Run("with modules", func(t *testing.T) { + t.Run("generate", func(t *testing.T) { t.Run( "stdlib", testGenerate(check("strings.Join", "Join concatenates", "strings.Join", "func(elems, sep)"), "", o, "strings.Join"), @@ -243,6 +243,68 @@ func TestGenerate(t *testing.T) { mainFunc := "code.squareroundforest.org/arpio/docreflect/internal/tests/src/command.main" t.Run("main package", testGenerate(check(mainFunc, "main func"), "", o, mainFunc)) + + anum := fmt.Sprintf("%s.A", packagePath) + bnum := fmt.Sprintf("%s.B", packagePath) + cnum := fmt.Sprintf("%s.C", packagePath) + t.Run( + "grouped types", + testGenerate( + check(anum, "A is a number", bnum, "B is another number", cnum, "C is a third number"), + "", + o, + packagePath, + ), + ) + + t.Run( + "grouped types as symbol", + testGenerate(check(anum, "A is a number"), "", o, anum), + ) + + pfoo := fmt.Sprintf("%s.ExportedType.PFoo", packagePath) + barPBaz := fmt.Sprintf("%s.ExportedType.Bar.PBaz", packagePath) + t.Run( + "pointer field", + testGenerate( + check(pfoo, "PFoo is a pointer field", barPBaz, "PBaz is another pointer field"), + "", + o, + packagePath, + ), + ) + + t.Run( + "pointer field as symbol", + testGenerate( + check(pfoo, "PFoo is a pointer field", barPBaz, "PBaz is another pointer field"), + "", + o, + pfoo, + barPBaz, + ), + ) + + quxQuux := fmt.Sprintf("%s.ExportedType.Qux.Quux", packagePath) + t.Run( + "anonymous pointer struct field", + testGenerate( + check(quxQuux, "Quux is a number"), + "", + o, + packagePath, + ), + ) + + t.Run( + "anonymous pointer struct field as symbol", + testGenerate( + check(quxQuux, "Quux is a number"), + "", + o, + quxQuux, + ), + ) }) t.Run("errors", func(t *testing.T) { @@ -258,7 +320,7 @@ func TestGenerate(t *testing.T) { nil, "symbol", o, - "code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage.ExportedType.Qux", + "code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage.ExportedType.Corge", ), ) @@ -331,11 +393,6 @@ func TestGenerate(t *testing.T) { } func TestFormat(t *testing.T) { - // header - // import - // few items - // escaping - b := bytes.NewBuffer(nil) d := map[string]string{ "foo": "bar", @@ -352,8 +409,7 @@ import "code.squareroundforest.org/arpio/docreflect" func init() { docreflect.Register("baz", "qux") docreflect.Register("foo", "bar") -} -` { +}` { t.Fatal() } } diff --git a/internal/tests/src/testpackage/test.go b/internal/tests/src/testpackage/test.go index a2a9a1b..e9d8f50 100644 --- a/internal/tests/src/testpackage/test.go +++ b/internal/tests/src/testpackage/test.go @@ -13,25 +13,52 @@ type ExportedType struct { // Foo is a field Foo int + // PFoo is a pointer field + PFoo *int + // Bar is an inline struct type expression Bar struct { // Baz is another field Baz int + + // PBaz is another pointer field + PBaz *int } // Baz is a field of type *ExportedType2 Baz *ExportedType2 + + // Qux is a pointer struct field + Qux *struct{ + + // Quux is a number + Quux int + } } type Foo = ExportedType +type Bar ExportedType + type ExportedType2 struct{ // Foo is a field in ExportedType2 Foo int } +type ( + + // A is a number + A int + + // B is another number + B int + + // C is a third number + C int +) + // C1 is a const const C1 = 42