testing generate

This commit is contained in:
Arpad Ryszka 2025-08-17 15:59:01 +02:00
parent af83b13d7c
commit 5c8853f7ae
5 changed files with 330 additions and 163 deletions

2
.gitignore vendored
View File

@ -1,3 +1,3 @@
.bin .bin
testdocs.go testdocs_test.go
.cover .cover

View File

@ -1,27 +1,29 @@
SOURCES = $(shell find . -name "*.go") SOURCES = $(shell find . -name "*.go" | grep -v testdocs_test.go)
PREFIX ?= ~/bin PREFIX ?= ~/bin
default: build default: build
build: $(SOURCES) .bin libdocreflect: $(SOURCES)
go build . go build .
go build ./generate
go build -o .bin/docreflect ./cmd/docreflect
.bin/docreflect: build libgenerate: $(SOURCES)
go build ./generate
.bin: .bin:
mkdir -p .bin mkdir -p .bin
check: $(SOURCES) .bin/docreflect .bin/docreflect: $(SOURCES) .bin
.bin/docreflect generate docreflect_test code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage > testdocs_test.go go build -o .bin/docreflect ./cmd/docreflect
go test -count 1 . ./generate
rm -f testdocs_test.go
.cover: $(SOURCES) .bin/docreflect build: libdocreflect libgenerate .bin/docreflect
testdocs_test.go: $(SOURCES) .bin/docreflect
.bin/docreflect generate docreflect_test code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage > testdocs_test.go .bin/docreflect generate docreflect_test code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage > testdocs_test.go
.cover: $(SOURCES) testdocs_test.go
go test -count 1 -coverprofile .cover . ./generate go test -count 1 -coverprofile .cover . ./generate
rm -f testdocs_test.go
check: .cover
cover: .cover cover: .cover
go tool cover -func .cover go tool cover -func .cover

View File

@ -1,3 +1,5 @@
// Package generate provides a generator to generate go code from go docs that registers doc entries
// for use with the docreflect package.
package generate package generate
import ( import (
@ -80,12 +82,11 @@ func readGomod(wd string) (string, map[string]string) {
return "", nil return "", nil
} }
var dirs map[string]string
mc := modCache() mc := modCache()
dirs := make(map[string]string)
for _, dep := range m.Require { for _, dep := range m.Require {
pth := dep.Mod.String() p := dep.Mod.String()
moduleDir := path.Join(mc, pth) dirs = set(dirs, p, path.Join(mc, p))
dirs[pth] = moduleDir
} }
name := m.Module.Mod.String() name := m.Module.Mod.String()
@ -125,8 +126,11 @@ func splitGopath(p string) (string, string) {
} }
func packagePaths(p []string) []string { func packagePaths(p []string) []string {
var pp []string var (
m := make(map[string]bool) pp []string
m map[string]bool
)
for _, pi := range p { for _, pi := range p {
ppi, _ := splitGopath(pi) ppi, _ := splitGopath(pi)
if m[ppi] { if m[ppi] {
@ -134,7 +138,7 @@ func packagePaths(p []string) []string {
} }
pp = append(pp, ppi) pp = append(pp, ppi)
m[ppi] = true m = set(m, ppi, true)
} }
return pp return pp
@ -205,7 +209,7 @@ func parserInclude(pkg *build.Package) func(fs.FileInfo) bool {
} }
func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) { func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) {
ppkgs := make(map[string][]*ast.Package) var ppkgs map[string][]*ast.Package
for _, p := range pkgs { for _, p := range pkgs {
fset := token.NewFileSet() fset := token.NewFileSet()
pm, err := parser.ParseDir(fset, p.Dir, parserInclude(p), parser.ParseComments) pm, err := parser.ParseDir(fset, p.Dir, parserInclude(p), parser.ParseComments)
@ -214,9 +218,11 @@ func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) {
} }
for _, pp := range pm { for _, pp := range pm {
if pp != nil { if pp == nil {
ppkgs[p.ImportPath] = append(ppkgs[p.ImportPath], pp) continue
} }
ppkgs = set(ppkgs, p.ImportPath, append(ppkgs[p.ImportPath], pp))
} }
} }
@ -233,35 +239,57 @@ func fixDocPackage(p doc.Package) doc.Package {
return p return p
} }
func merge(m ...map[string]string) map[string]string { func set[K comparable, V any](m map[K]V, key K, value V) map[K]V {
if m == nil {
m = make(map[K]V)
}
m[key] = value
return m
}
func merge[K comparable, V any](m ...map[K]V) map[K]V {
if len(m) == 1 { if len(m) == 1 {
return m[0] return m[0]
} }
mm := make(map[string]string) var mm map[K]V
for key, value := range m[0] { for key, value := range m[0] {
mm[key] = value mm = set(mm, key, value)
} }
mr := merge(m[1:]...) mr := merge(m[1:]...)
for key, value := range mr { for key, value := range mr {
mm[key] = value mm = set(mm, key, value)
} }
return mm return mm
} }
func unpack(e ast.Expr) ast.Expr {
p, ok := e.(*ast.StarExpr)
if !ok {
return e
}
return unpack(p.X)
}
func funcParams(f *doc.Func) string { func funcParams(f *doc.Func) string {
if f.Decl == nil || f.Decl.Type == nil || f.Decl.Type.Params == nil {
return "func()"
}
var paramNames []string var paramNames []string
if f.Decl != nil && f.Decl.Type != nil && f.Decl.Type.Params != nil {
for _, p := range f.Decl.Type.Params.List { for _, p := range f.Decl.Type.Params.List {
for _, n := range p.Names { for _, n := range p.Names {
if n != nil { if n == nil {
continue
}
paramNames = append(paramNames, n.Name) paramNames = append(paramNames, n.Name)
} }
} }
}
}
return fmt.Sprintf("func(%s)", strings.Join(paramNames, ", ")) return fmt.Sprintf("func(%s)", strings.Join(paramNames, ", "))
} }
@ -270,20 +298,93 @@ func funcDocs(f *doc.Func) string {
return fmt.Sprintf("%s\n%s", f.Doc, funcParams(f)) return fmt.Sprintf("%s\n%s", f.Doc, funcParams(f))
} }
func findFieldDocs(str *ast.StructType, fieldPath []string) (string, bool) {
if len(fieldPath) == 0 || str.Fields == nil {
return "", false
}
for _, f := range str.Fields.List {
var found bool
for _, fn := range f.Names {
if fn != nil && fn.Name == fieldPath[0] {
found = true
break
}
}
if !found {
continue
}
if len(fieldPath) == 1 {
if f.Doc == nil {
return "", true
}
return f.Doc.Text(), true
}
te := unpack(f.Type)
fstr, ok := te.(*ast.StructType)
if !ok {
return "", false
}
return findFieldDocs(fstr, fieldPath[1:])
}
return "", false
}
func structFieldDocs(t *doc.Type, fieldPath []string) (string, bool) {
if t.Decl == nil || len(t.Decl.Specs) != 1 {
return "", false
}
ts, ok := t.Decl.Specs[0].(*ast.TypeSpec)
if !ok {
return "", false
}
te := unpack(ts.Type)
str, ok := te.(*ast.StructType)
if !ok {
return "", false
}
return findFieldDocs(str, fieldPath)
}
func typeMethodDocs(t *doc.Type, name string) (string, bool) {
for _, m := range t.Methods {
if m.Name != name {
continue
}
return funcDocs(m), true
}
return "", false
}
func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { func symbolDocs(pkg *doc.Package, gopath string) (string, bool) {
_, s := splitGopath(gopath) _, s := splitGopath(gopath)
symbol := strings.Split(s, ".") symbol := strings.Split(s, ".")
if len(symbol) == 1 { if len(symbol) == 1 {
for _, c := range pkg.Consts { for _, c := range pkg.Consts {
if c != nil && slices.Contains(c.Names, symbol[0]) { if c == nil || !slices.Contains(c.Names, symbol[0]) {
return c.Doc, true continue
} }
return c.Doc, true
} }
for _, v := range pkg.Vars { for _, v := range pkg.Vars {
if v != nil && slices.Contains(v.Names, symbol[0]) { if v == nil || !slices.Contains(v.Names, symbol[0]) {
return v.Doc, true continue
} }
return v.Doc, true
} }
for _, f := range pkg.Funcs { for _, f := range pkg.Funcs {
@ -304,73 +405,16 @@ func symbolDocs(pkg *doc.Package, gopath string) (string, bool) {
return t.Doc, true return t.Doc, true
} }
if t.Decl == nil || len(t.Decl.Specs) != 1 { if d, ok := structFieldDocs(t, symbol[1:]); ok {
continue return d, true
}
ts, ok := t.Decl.Specs[0].(*ast.TypeSpec)
if !ok {
continue
}
str, ok := ts.Type.(*ast.StructType)
if !ok {
continue
}
fieldPath := symbol[1:]
for len(fieldPath) > 0 {
var foundField bool
if str.Fields != nil {
for _, f := range str.Fields.List {
for _, fn := range f.Names {
if fn != nil {
if fn.Name != fieldPath[0] {
continue
}
foundField = true
if len(fieldPath) == 1 {
if f.Doc == nil {
return "", true
}
return f.Doc.Text(), true
}
fstr, ok := f.Type.(*ast.StructType)
if !ok {
return "", false
}
str = fstr
fieldPath = fieldPath[1:]
break
}
}
if foundField {
break
}
}
}
if !foundField {
break
}
} }
if len(symbol) != 2 { if len(symbol) != 2 {
return "", false return "", false
} }
methodName := symbol[1] if d, ok := typeMethodDocs(t, symbol[1]); ok {
for _, m := range t.Methods { return d, ok
if m.Name != methodName {
continue
}
return funcDocs(m), true
} }
} }
@ -382,11 +426,11 @@ func symbolPath(packagePath string, name ...string) string {
} }
func valueDocs(packagePath string, v []*doc.Value) map[string]string { func valueDocs(packagePath string, v []*doc.Value) map[string]string {
d := make(map[string]string) var d map[string]string
for _, vi := range v { for _, vi := range v {
if vi != nil { if vi != nil {
for _, n := range vi.Names { for _, n := range vi.Names {
d[symbolPath(packagePath, n)] = vi.Doc d = set(d, symbolPath(packagePath, n), vi.Doc)
} }
} }
} }
@ -394,71 +438,110 @@ func valueDocs(packagePath string, v []*doc.Value) map[string]string {
return d return d
} }
func takeFieldDocs(docs map[string]string, packagePath string, prefix []string, f *ast.Field) { func takeFieldDocs(packagePath string, prefix []string, f *ast.Field) map[string]string {
var docs map[string]string
for _, fn := range f.Names { for _, fn := range f.Names {
if fn == nil { if fn == nil {
continue continue
} }
docs[symbolPath(packagePath, append(prefix, fn.Name)...)] = f.Doc.Text() docs = set(docs, symbolPath(packagePath, append(prefix, fn.Name)...), f.Doc.Text())
if fst, ok := f.Type.(*ast.StructType); ok { te := unpack(f.Type)
if fst.Fields != nil { fst, ok := te.(*ast.StructType)
if !ok || fst.Fields == nil {
continue
}
for _, fi := range fst.Fields.List { for _, fi := range fst.Fields.List {
if fi == nil { if fi == nil {
continue continue
} }
takeFieldDocs(docs, packagePath, append(prefix, fn.Name), fi) docs = merge(docs, takeFieldDocs(packagePath, append(prefix, fn.Name), fi))
}
}
} }
} }
return docs
} }
func packageDocs(pkg *doc.Package) map[string]string { func structFieldsDocs(importPath string, t *doc.Type) map[string]string {
d := make(map[string]string) if t.Decl == nil || len(t.Decl.Specs) != 1 {
d[pkg.ImportPath] = pkg.Doc return nil
d = merge(d, valueDocs(pkg.ImportPath, pkg.Consts))
d = merge(d, valueDocs(pkg.ImportPath, pkg.Vars))
for _, t := range pkg.Types {
if t != nil {
d[symbolPath(pkg.ImportPath, t.Name)] = t.Doc
} }
if t.Decl != nil && len(t.Decl.Specs) == 1 { ts, ok := t.Decl.Specs[0].(*ast.TypeSpec)
if ts, ok := t.Decl.Specs[0].(*ast.TypeSpec); ok { if !ok {
if str, ok := ts.Type.(*ast.StructType); ok && str.Fields != nil { return nil
}
te := unpack(ts.Type)
str, ok := te.(*ast.StructType)
if !ok || str.Fields == nil {
return nil
}
var d map[string]string
for _, f := range str.Fields.List { for _, f := range str.Fields.List {
if f == nil { if f == nil {
continue continue
} }
takeFieldDocs(d, pkg.ImportPath, []string{t.Name}, f) d = merge(d, takeFieldDocs(importPath, []string{t.Name}, f))
}
}
}
} }
return d
}
func methodDocs(importPath string, t *doc.Type) map[string]string {
var d map[string]string
for _, m := range t.Methods { for _, m := range t.Methods {
if m != nil { if m != nil {
doc := funcDocs(m) d = set(d, symbolPath(importPath, t.Name, m.Name), funcDocs(m))
d[symbolPath(pkg.ImportPath, t.Name, m.Name)] = doc
}
}
}
for _, f := range pkg.Funcs {
if f != nil {
doc := funcDocs(f)
d[symbolPath(pkg.ImportPath, f.Name)] = doc
} }
} }
return d return d
} }
func typeDocs(importPath string, types []*doc.Type) map[string]string {
var d map[string]string
for _, t := range types {
if t == nil {
continue
}
d = set(d, symbolPath(importPath, t.Name), t.Doc)
d = merge(d, structFieldsDocs(importPath, t))
d = merge(d, methodDocs(importPath, t))
}
return d
}
func packageFuncDocs(importPath string, funcs []*doc.Func) map[string]string {
var d map[string]string
for _, f := range funcs {
if f != nil {
d = set(d, symbolPath(importPath, f.Name), funcDocs(f))
}
}
return d
}
func packageDocs(pkg *doc.Package) map[string]string {
return merge(
map[string]string{pkg.ImportPath: pkg.Doc},
valueDocs(pkg.ImportPath, pkg.Consts),
valueDocs(pkg.ImportPath, pkg.Vars),
typeDocs(pkg.ImportPath, pkg.Types),
packageFuncDocs(pkg.ImportPath, pkg.Funcs),
)
}
func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) { func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) {
dm := make(map[string]string) var dm map[string]string
for _, gp := range gopaths { for _, gp := range gopaths {
pp, _ := splitGopath(gp) pp, _ := splitGopath(gp)
isPackage := pp == gp isPackage := pp == gp
@ -466,8 +549,7 @@ func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]stri
dpkg := doc.New(pkg, pp, doc.AllDecls|doc.PreserveAST) dpkg := doc.New(pkg, pp, doc.AllDecls|doc.PreserveAST)
*dpkg = fixDocPackage(*dpkg) *dpkg = fixDocPackage(*dpkg)
if isPackage { if isPackage {
pd := packageDocs(dpkg) dm = merge(dm, packageDocs(dpkg))
dm = merge(dm, pd)
continue continue
} }
@ -476,7 +558,7 @@ func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]stri
return nil, fmt.Errorf("symbol not found: %s", gp) return nil, fmt.Errorf("symbol not found: %s", gp)
} }
dm[gp] = sd dm = set(dm, gp, sd)
} }
} }
@ -534,12 +616,18 @@ func format(w io.Writer, pname string, docs map[string]string) error {
printf("docreflect.Register(%s, %s)\n", path, doc) printf("docreflect.Register(%s, %s)\n", path, doc)
} }
println("}") printf("}")
return err return err
} }
// GenerateRegistry generates a Go code file to the output, including a package init function that // GenerateRegistry generates a Go code file to the output, including a package init function that
// will register the documentation of the declarations specified by their gopath. // will register the documentation of the declarations specified by their gopath.
//
// The gopath argument accepts any number of package, package level symbol, of struct field paths.
// It is recommended to use package paths unless special circumstances.
//
// Some important gotchas to keep in mind, GenerateRegistry does not resolve type references like
// type aliases, or type definitions based on named types, and it doesn't follow import paths.
func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) error { func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) error {
o := initOptions() o := initOptions()
d, err := generate(o, gopath...) d, err := generate(o, gopath...)
@ -553,9 +641,3 @@ func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) e
return nil return nil
} }
// TODO:
// - type foo = bar
// - type foo bar
// - type (...)
// - pointers in struct fields

View File

@ -87,7 +87,7 @@ func TestGenerate(t *testing.T) {
}, },
} }
t.Run("with modules", func(t *testing.T) { t.Run("generate", func(t *testing.T) {
t.Run( t.Run(
"stdlib", "stdlib",
testGenerate(check("strings.Join", "Join concatenates", "strings.Join", "func(elems, sep)"), "", o, "strings.Join"), testGenerate(check("strings.Join", "Join concatenates", "strings.Join", "func(elems, sep)"), "", o, "strings.Join"),
@ -243,6 +243,68 @@ func TestGenerate(t *testing.T) {
mainFunc := "code.squareroundforest.org/arpio/docreflect/internal/tests/src/command.main" mainFunc := "code.squareroundforest.org/arpio/docreflect/internal/tests/src/command.main"
t.Run("main package", testGenerate(check(mainFunc, "main func"), "", o, mainFunc)) t.Run("main package", testGenerate(check(mainFunc, "main func"), "", o, mainFunc))
anum := fmt.Sprintf("%s.A", packagePath)
bnum := fmt.Sprintf("%s.B", packagePath)
cnum := fmt.Sprintf("%s.C", packagePath)
t.Run(
"grouped types",
testGenerate(
check(anum, "A is a number", bnum, "B is another number", cnum, "C is a third number"),
"",
o,
packagePath,
),
)
t.Run(
"grouped types as symbol",
testGenerate(check(anum, "A is a number"), "", o, anum),
)
pfoo := fmt.Sprintf("%s.ExportedType.PFoo", packagePath)
barPBaz := fmt.Sprintf("%s.ExportedType.Bar.PBaz", packagePath)
t.Run(
"pointer field",
testGenerate(
check(pfoo, "PFoo is a pointer field", barPBaz, "PBaz is another pointer field"),
"",
o,
packagePath,
),
)
t.Run(
"pointer field as symbol",
testGenerate(
check(pfoo, "PFoo is a pointer field", barPBaz, "PBaz is another pointer field"),
"",
o,
pfoo,
barPBaz,
),
)
quxQuux := fmt.Sprintf("%s.ExportedType.Qux.Quux", packagePath)
t.Run(
"anonymous pointer struct field",
testGenerate(
check(quxQuux, "Quux is a number"),
"",
o,
packagePath,
),
)
t.Run(
"anonymous pointer struct field as symbol",
testGenerate(
check(quxQuux, "Quux is a number"),
"",
o,
quxQuux,
),
)
}) })
t.Run("errors", func(t *testing.T) { t.Run("errors", func(t *testing.T) {
@ -258,7 +320,7 @@ func TestGenerate(t *testing.T) {
nil, nil,
"symbol", "symbol",
o, o,
"code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage.ExportedType.Qux", "code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage.ExportedType.Corge",
), ),
) )
@ -331,11 +393,6 @@ func TestGenerate(t *testing.T) {
} }
func TestFormat(t *testing.T) { func TestFormat(t *testing.T) {
// header
// import
// few items
// escaping
b := bytes.NewBuffer(nil) b := bytes.NewBuffer(nil)
d := map[string]string{ d := map[string]string{
"foo": "bar", "foo": "bar",
@ -352,8 +409,7 @@ import "code.squareroundforest.org/arpio/docreflect"
func init() { func init() {
docreflect.Register("baz", "qux") docreflect.Register("baz", "qux")
docreflect.Register("foo", "bar") docreflect.Register("foo", "bar")
} }` {
` {
t.Fatal() t.Fatal()
} }
} }

View File

@ -13,25 +13,52 @@ type ExportedType struct {
// Foo is a field // Foo is a field
Foo int Foo int
// PFoo is a pointer field
PFoo *int
// Bar is an inline struct type expression // Bar is an inline struct type expression
Bar struct { Bar struct {
// Baz is another field // Baz is another field
Baz int Baz int
// PBaz is another pointer field
PBaz *int
} }
// Baz is a field of type *ExportedType2 // Baz is a field of type *ExportedType2
Baz *ExportedType2 Baz *ExportedType2
// Qux is a pointer struct field
Qux *struct{
// Quux is a number
Quux int
}
} }
type Foo = ExportedType type Foo = ExportedType
type Bar ExportedType
type ExportedType2 struct{ type ExportedType2 struct{
// Foo is a field in ExportedType2 // Foo is a field in ExportedType2
Foo int Foo int
} }
type (
// A is a number
A int
// B is another number
B int
// C is a third number
C int
)
// C1 is a const // C1 is a const
const C1 = 42 const C1 = 42