From 40bd1879753efc3c100f7ac8426391b83e40799b Mon Sep 17 00:00:00 2001 From: Arpad Ryszka Date: Tue, 13 Jan 2026 23:28:46 +0100 Subject: [PATCH] support main package docs --- cmd/docreflect/main.go | 2 +- generate/lib.go | 48 ++++++++++++++++++++++++++++++++++++------ generate/lib_test.go | 17 +++++++++++++++ 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/cmd/docreflect/main.go b/cmd/docreflect/main.go index 09403ba..5fe4e61 100644 --- a/cmd/docreflect/main.go +++ b/cmd/docreflect/main.go @@ -13,7 +13,7 @@ func main() { } packageName, args := args[0], args[1:] - if err := generate.GenerateRegistry(os.Stdout, packageName, args...); err != nil { + if err := generate.GenerateRegistry(generate.Options{}, os.Stdout, packageName, args...); err != nil { log.Fatalln(err) } } diff --git a/generate/lib.go b/generate/lib.go index 75b152f..7f5919f 100644 --- a/generate/lib.go +++ b/generate/lib.go @@ -27,10 +27,19 @@ Generated with https://code.squareroundforest.org/arpio/docreflect ` +// Options contains options for the generator. +type Options struct { + + // Main indicates that the docs for the symbols will be lookded up as part of the main package of an + // executable. + Main bool +} + type options struct { wd string goroot string modules map[string]string + isMain bool } func getGoroot() string { @@ -544,7 +553,24 @@ func packageDocs(pkg *doc.Package) map[string]string { ) } -func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) { +func replacePath(p, prefix, replace string) string { + if !strings.HasPrefix(p, prefix) { + return p + } + + return fmt.Sprintf("%s%s", replace, p[len(prefix):]) +} + +func replacePaths(m map[string]string, prefix, replace string) map[string]string { + var mm map[string]string + for p, d := range m { + mm = set(mm, replacePath(p, prefix, replace), d) + } + + return mm +} + +func takeDocs(o options, pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) { var dm map[string]string for _, gp := range gopaths { pp, _ := splitGopath(gp) @@ -553,7 +579,12 @@ func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]stri dpkg := doc.New(pkg, pp, doc.AllDecls|doc.PreserveAST) *dpkg = fixDocPackage(*dpkg) if isPackage { - dm = merge(dm, packageDocs(dpkg)) + pd := packageDocs(dpkg) + if o.isMain { + pd = replacePaths(pd, pp, "main") + } + + dm = merge(dm, pd) continue } @@ -562,6 +593,10 @@ func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]stri return nil, fmt.Errorf("symbol not found: %s", gp) } + if o.isMain { + gp = replacePath(gp, pp, "main") + } + dm = set(dm, gp, sd) } } @@ -583,7 +618,7 @@ func generate(o options, gopaths ...string) (map[string]string, error) { return nil, err } - return takeDocs(ppkgs, gopaths) + return takeDocs(o, ppkgs, gopaths) } func format(w io.Writer, pname string, docs map[string]string) error { @@ -633,9 +668,10 @@ func format(w io.Writer, pname string, docs map[string]string) error { // // Some important gotchas to keep in mind, GenerateRegistry does not resolve type references like // type aliases, or type definitions based on named types, and it doesn't follow import paths. -func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) error { - o := initOptions() - d, err := generate(o, gopath...) +func GenerateRegistry(o Options, w io.Writer, outputPackageName string, gopath ...string) error { + oo := initOptions() + oo.isMain = o.Main + d, err := generate(oo, gopath...) if err != nil { return err } diff --git a/generate/lib_test.go b/generate/lib_test.go index c9b6fcf..27c0136 100644 --- a/generate/lib_test.go +++ b/generate/lib_test.go @@ -86,6 +86,9 @@ func TestGenerate(t *testing.T) { }, } + oMain := o + oMain.isMain = true + t.Run("generate", func(t *testing.T) { t.Run( "stdlib", @@ -129,6 +132,20 @@ func TestGenerate(t *testing.T) { ), ) + t.Run( + "main", + testGenerate( + check( + "main", "Package testpackage is a test package", + "main.ExportedType", "ExportedType has docs", + "main.ExportedType.Bar.Baz", "Baz is another field", + ), + "", + oMain, + packagePath, + ), + ) + t.Run("symbol", func(t *testing.T) { t.Run("type", testGenerate( check(