// Package generate provides a generator to generate go code from go docs that registers doc entries // for use with the docreflect package. package generate import ( "fmt" "go/ast" "go/build" "go/doc" "go/parser" "go/token" "golang.org/x/mod/modfile" "io" "io/fs" "os" "path" "runtime" "slices" "sort" "strconv" "strings" ) type options struct { wd string goroot string gomod string modules map[string]string } func getGoroot() string { gr := os.Getenv("GOROOT") if gr != "" { return gr } return runtime.GOROOT() } func modCache() string { mc := os.Getenv("GOMODCACHE") if mc != "" { return mc } gp := os.Getenv("GOPATH") if gp == "" { gp = path.Join(os.Getenv("HOME"), "go") } mc = path.Join(gp, "pkg/mod") return mc } func findGoMod(dir string) (string, bool) { p := path.Join(dir, "go.mod") f, err := os.Stat(p) if err == nil && !f.IsDir() { return p, true } if dir == "/" { return "", false } return findGoMod(path.Dir(dir)) } func readGomod(wd string) (string, map[string]string) { p, ok := findGoMod(wd) if !ok { return "", nil } d, err := os.ReadFile(p) if err != nil { return "", nil } m, err := modfile.Parse(p, d, nil) if err != nil { return "", nil } var dirs map[string]string mc := modCache() for _, dep := range m.Require { p := dep.Mod.String() dirs = set(dirs, p, path.Join(mc, p)) } name := m.Module.Mod.String() dirs[name] = path.Dir(p) return name, dirs } func initOptions() options { wd, _ := os.Getwd() gr := getGoroot() gomod, modules := readGomod(wd) return options{ wd: wd, goroot: gr, gomod: gomod, modules: modules, } } func cleanPaths(gopath []string) []string { var c []string for _, pi := range gopath { pi = path.Clean(pi) c = append(c, pi) } return c } func splitGopath(p string) (string, string) { parts := strings.Split(p, "/") last := len(parts) - 1 parts, lastPart := parts[:last], parts[last] symbolParts := strings.Split(lastPart, ".") pkg := symbolParts[0] return strings.Join(append(parts, pkg), "/"), strings.Join(symbolParts[1:], ".") } func packagePaths(p []string) []string { var ( pp []string m map[string]bool ) for _, pi := range p { ppi, _ := splitGopath(pi) if m[ppi] { continue } pp = append(pp, ppi) m = set(m, ppi, true) } return pp } func collectGoDirs(o options) []string { var dirs []string if o.goroot != "" { dirs = append(dirs, path.Join(o.goroot, "src")) dirs = append(dirs, path.Join(o.goroot, "src", "cmd")) } return dirs } func importPackages(o options, godirs, paths []string) ([]*build.Package, error) { var pkgs []*build.Package for _, p := range paths { var found bool for mod, modDir := range o.modules { if !strings.HasPrefix(p, strings.Split(mod, "@")[0]) { continue } pkg, err := build.Import(p, modDir, build.ImportComment) if err != nil || pkg == nil { continue } pkgs = append(pkgs, pkg) found = true break } if found { continue } for _, d := range godirs { pkg, err := build.Import(p, d, build.ImportComment) if err != nil || pkg == nil { continue } pkgs = append(pkgs, pkg) found = true break } if !found { return nil, fmt.Errorf("failed to import package for %s", p) } } return pkgs, nil } func parserInclude(pkg *build.Package) func(fs.FileInfo) bool { return func(file fs.FileInfo) bool { for _, fn := range pkg.GoFiles { if fn == file.Name() { return true } } return false } } func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) { var ppkgs map[string][]*ast.Package for _, p := range pkgs { fset := token.NewFileSet() pm, err := parser.ParseDir(fset, p.Dir, parserInclude(p), parser.ParseComments) if err != nil { return nil, fmt.Errorf("failed to parse package %s: %w", p.Name, err) } for _, pp := range pm { if pp == nil { continue } ppkgs = set(ppkgs, p.ImportPath, append(ppkgs[p.ImportPath], pp)) } } return ppkgs, nil } func fixDocPackage(p doc.Package) doc.Package { for _, t := range p.Types { p.Consts = append(p.Consts, t.Consts...) p.Vars = append(p.Vars, t.Vars...) p.Funcs = append(p.Funcs, t.Funcs...) } return p } func set[K comparable, V any](m map[K]V, key K, value V) map[K]V { if m == nil { m = make(map[K]V) } m[key] = value return m } func merge[K comparable, V any](m ...map[K]V) map[K]V { if len(m) == 1 { return m[0] } var mm map[K]V for key, value := range m[0] { mm = set(mm, key, value) } mr := merge(m[1:]...) for key, value := range mr { mm = set(mm, key, value) } return mm } func unpack(e ast.Expr) ast.Expr { p, ok := e.(*ast.StarExpr) if !ok { return e } return unpack(p.X) } func funcParams(f *doc.Func) string { if f.Decl == nil || f.Decl.Type == nil || f.Decl.Type.Params == nil { return "func()" } var paramNames []string for _, p := range f.Decl.Type.Params.List { for _, n := range p.Names { if n == nil { continue } paramNames = append(paramNames, n.Name) } } return fmt.Sprintf("func(%s)", strings.Join(paramNames, ", ")) } func funcDocs(f *doc.Func) string { return fmt.Sprintf("%s\n%s", f.Doc, funcParams(f)) } func findFieldDocs(str *ast.StructType, fieldPath []string) (string, bool) { if len(fieldPath) == 0 || str.Fields == nil { return "", false } for _, f := range str.Fields.List { var found bool for _, fn := range f.Names { if fn != nil && fn.Name == fieldPath[0] { found = true break } } if !found { continue } if len(fieldPath) == 1 { if f.Doc == nil { return "", true } return f.Doc.Text(), true } te := unpack(f.Type) fstr, ok := te.(*ast.StructType) if !ok { return "", false } return findFieldDocs(fstr, fieldPath[1:]) } return "", false } func structFieldDocs(t *doc.Type, fieldPath []string) (string, bool) { if t.Decl == nil || len(t.Decl.Specs) != 1 { return "", false } ts, ok := t.Decl.Specs[0].(*ast.TypeSpec) if !ok { return "", false } te := unpack(ts.Type) str, ok := te.(*ast.StructType) if !ok { return "", false } return findFieldDocs(str, fieldPath) } func typeMethodDocs(t *doc.Type, name string) (string, bool) { for _, m := range t.Methods { if m.Name != name { continue } return funcDocs(m), true } return "", false } func symbolDocs(pkg *doc.Package, gopath string) (string, bool) { _, s := splitGopath(gopath) symbol := strings.Split(s, ".") if len(symbol) == 1 { for _, c := range pkg.Consts { if c == nil || !slices.Contains(c.Names, symbol[0]) { continue } return c.Doc, true } for _, v := range pkg.Vars { if v == nil || !slices.Contains(v.Names, symbol[0]) { continue } return v.Doc, true } for _, f := range pkg.Funcs { if f == nil || f.Name != symbol[0] { continue } return funcDocs(f), true } } for _, t := range pkg.Types { if t == nil || t.Name != symbol[0] { continue } if len(symbol) == 1 { return t.Doc, true } if d, ok := structFieldDocs(t, symbol[1:]); ok { return d, true } if len(symbol) != 2 { return "", false } if d, ok := typeMethodDocs(t, symbol[1]); ok { return d, ok } } return "", false } func symbolPath(packagePath string, name ...string) string { return fmt.Sprintf("%s.%s", packagePath, strings.Join(name, ".")) } func valueDocs(packagePath string, v []*doc.Value) map[string]string { var d map[string]string for _, vi := range v { if vi != nil { for _, n := range vi.Names { d = set(d, symbolPath(packagePath, n), vi.Doc) } } } return d } func takeFieldDocs(packagePath string, prefix []string, f *ast.Field) map[string]string { var docs map[string]string for _, fn := range f.Names { if fn == nil { continue } docs = set(docs, symbolPath(packagePath, append(prefix, fn.Name)...), f.Doc.Text()) te := unpack(f.Type) fst, ok := te.(*ast.StructType) if !ok || fst.Fields == nil { continue } for _, fi := range fst.Fields.List { if fi == nil { continue } docs = merge(docs, takeFieldDocs(packagePath, append(prefix, fn.Name), fi)) } } return docs } func structFieldsDocs(importPath string, t *doc.Type) map[string]string { if t.Decl == nil || len(t.Decl.Specs) != 1 { return nil } ts, ok := t.Decl.Specs[0].(*ast.TypeSpec) if !ok { return nil } te := unpack(ts.Type) str, ok := te.(*ast.StructType) if !ok || str.Fields == nil { return nil } var d map[string]string for _, f := range str.Fields.List { if f == nil { continue } d = merge(d, takeFieldDocs(importPath, []string{t.Name}, f)) } return d } func methodDocs(importPath string, t *doc.Type) map[string]string { var d map[string]string for _, m := range t.Methods { if m != nil { d = set(d, symbolPath(importPath, t.Name, m.Name), funcDocs(m)) } } return d } func typeDocs(importPath string, types []*doc.Type) map[string]string { var d map[string]string for _, t := range types { if t == nil { continue } d = set(d, symbolPath(importPath, t.Name), t.Doc) d = merge(d, structFieldsDocs(importPath, t)) d = merge(d, methodDocs(importPath, t)) } return d } func packageFuncDocs(importPath string, funcs []*doc.Func) map[string]string { var d map[string]string for _, f := range funcs { if f != nil { d = set(d, symbolPath(importPath, f.Name), funcDocs(f)) } } return d } func packageDocs(pkg *doc.Package) map[string]string { return merge( map[string]string{pkg.ImportPath: pkg.Doc}, valueDocs(pkg.ImportPath, pkg.Consts), valueDocs(pkg.ImportPath, pkg.Vars), typeDocs(pkg.ImportPath, pkg.Types), packageFuncDocs(pkg.ImportPath, pkg.Funcs), ) } func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) { var dm map[string]string for _, gp := range gopaths { pp, _ := splitGopath(gp) isPackage := pp == gp for _, pkg := range pkgs[pp] { dpkg := doc.New(pkg, pp, doc.AllDecls|doc.PreserveAST) *dpkg = fixDocPackage(*dpkg) if isPackage { dm = merge(dm, packageDocs(dpkg)) continue } sd, ok := symbolDocs(dpkg, gp) if !ok { return nil, fmt.Errorf("symbol not found: %s", gp) } dm = set(dm, gp, sd) } } return dm, nil } func generate(o options, gopaths ...string) (map[string]string, error) { gopaths = cleanPaths(gopaths) ppaths := packagePaths(gopaths) dirs := collectGoDirs(o) pkgs, err := importPackages(o, dirs, ppaths) if err != nil { return nil, err } ppkgs, err := parsePackages(pkgs) if err != nil { return nil, err } return takeDocs(ppkgs, gopaths) } func format(w io.Writer, pname string, docs map[string]string) error { var err error printf := func(f string, a ...any) { if err != nil { return } _, err = fmt.Fprintf(w, f, a...) } println := func(a ...any) { if err != nil { return } _, err = fmt.Fprintln(w, a...) } printf("package %s\n", pname) println("import \"code.squareroundforest.org/arpio/docreflect\"") println("func init() {") var paths []string for path := range docs { paths = append(paths, path) } sort.Strings(paths) for _, path := range paths { doc := docs[path] path, doc = strconv.Quote(path), strconv.Quote(doc) printf("docreflect.Register(%s, %s)\n", path, doc) } printf("}") return err } // GenerateRegistry generates a Go code file to the output, including a package init function that // will register the documentation of the declarations specified by their gopath. // // The gopath argument accepts any number of package, package level symbol, of struct field paths. // It is recommended to use package paths unless special circumstances. // // Some important gotchas to keep in mind, GenerateRegistry does not resolve type references like // type aliases, or type definitions based on named types, and it doesn't follow import paths. func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) error { o := initOptions() d, err := generate(o, gopath...) if err != nil { return err } if err := format(w, outputPackageName, d); err != nil { return err } return nil }