docreflect/generate/generate.go

562 lines
10 KiB
Go

package generate
import (
"fmt"
"go/ast"
"go/build"
"go/doc"
"go/parser"
"go/token"
"golang.org/x/mod/modfile"
"io"
"io/fs"
"os"
"path"
"runtime"
"slices"
"sort"
"strconv"
"strings"
)
type options struct {
wd string
goroot string
gomod string
modules map[string]string
}
func getGoroot() string {
gr := os.Getenv("GOROOT")
if gr != "" {
return gr
}
return runtime.GOROOT()
}
func modCache() string {
mc := os.Getenv("GOMODCACHE")
if mc != "" {
return mc
}
gp := os.Getenv("GOPATH")
if gp == "" {
gp = path.Join(os.Getenv("HOME"), "go")
}
mc = path.Join(gp, "pkg/mod")
return mc
}
func findGoMod(dir string) (string, bool) {
p := path.Join(dir, "go.mod")
f, err := os.Stat(p)
if err == nil && !f.IsDir() {
return p, true
}
if dir == "/" {
return "", false
}
return findGoMod(path.Dir(dir))
}
func readGomod(wd string) (string, map[string]string) {
p, ok := findGoMod(wd)
if !ok {
return "", nil
}
d, err := os.ReadFile(p)
if err != nil {
return "", nil
}
m, err := modfile.Parse(p, d, nil)
if err != nil {
return "", nil
}
mc := modCache()
dirs := make(map[string]string)
for _, dep := range m.Require {
pth := dep.Mod.String()
moduleDir := path.Join(mc, pth)
dirs[pth] = moduleDir
}
name := m.Module.Mod.String()
dirs[name] = path.Dir(p)
return name, dirs
}
func initOptions() options {
wd, _ := os.Getwd()
gr := getGoroot()
gomod, modules := readGomod(wd)
return options{
wd: wd,
goroot: gr,
gomod: gomod,
modules: modules,
}
}
func cleanPaths(gopath []string) []string {
var c []string
for _, pi := range gopath {
pi = path.Clean(pi)
c = append(c, pi)
}
return c
}
func splitGopath(p string) (string, string) {
parts := strings.Split(p, "/")
last := len(parts) - 1
parts, lastPart := parts[:last], parts[last]
symbolParts := strings.Split(lastPart, ".")
pkg := symbolParts[0]
return strings.Join(append(parts, pkg), "/"), strings.Join(symbolParts[1:], ".")
}
func packagePaths(p []string) []string {
var pp []string
m := make(map[string]bool)
for _, pi := range p {
ppi, _ := splitGopath(pi)
if m[ppi] {
continue
}
pp = append(pp, ppi)
m[ppi] = true
}
return pp
}
func collectGoDirs(o options) []string {
var dirs []string
if o.goroot != "" {
dirs = append(dirs, path.Join(o.goroot, "src"))
dirs = append(dirs, path.Join(o.goroot, "src", "cmd"))
}
return dirs
}
func importPackages(o options, godirs, paths []string) ([]*build.Package, error) {
var pkgs []*build.Package
for _, p := range paths {
var found bool
for mod, modDir := range o.modules {
if !strings.HasPrefix(p, strings.Split(mod, "@")[0]) {
continue
}
pkg, err := build.Import(p, modDir, build.ImportComment)
if err != nil || pkg == nil {
continue
}
pkgs = append(pkgs, pkg)
found = true
break
}
if found {
continue
}
for _, d := range godirs {
pkg, err := build.Import(p, d, build.ImportComment)
if err != nil || pkg == nil {
continue
}
pkgs = append(pkgs, pkg)
found = true
break
}
if !found {
return nil, fmt.Errorf("failed to import package for %s", p)
}
}
return pkgs, nil
}
func parserInclude(pkg *build.Package) func(fs.FileInfo) bool {
return func(file fs.FileInfo) bool {
for _, fn := range pkg.GoFiles {
if fn == file.Name() {
return true
}
}
return false
}
}
func parsePackages(pkgs []*build.Package) (map[string][]*ast.Package, error) {
ppkgs := make(map[string][]*ast.Package)
for _, p := range pkgs {
fset := token.NewFileSet()
pm, err := parser.ParseDir(fset, p.Dir, parserInclude(p), parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("failed to parse package %s: %w", p.Name, err)
}
for _, pp := range pm {
if pp != nil {
ppkgs[p.ImportPath] = append(ppkgs[p.ImportPath], pp)
}
}
}
return ppkgs, nil
}
func fixDocPackage(p doc.Package) doc.Package {
for _, t := range p.Types {
p.Consts = append(p.Consts, t.Consts...)
p.Vars = append(p.Vars, t.Vars...)
p.Funcs = append(p.Funcs, t.Funcs...)
}
return p
}
func merge(m ...map[string]string) map[string]string {
if len(m) == 1 {
return m[0]
}
mm := make(map[string]string)
for key, value := range m[0] {
mm[key] = value
}
mr := merge(m[1:]...)
for key, value := range mr {
mm[key] = value
}
return mm
}
func funcParams(f *doc.Func) string {
var paramNames []string
if f.Decl != nil && f.Decl.Type != nil && f.Decl.Type.Params != nil {
for _, p := range f.Decl.Type.Params.List {
for _, n := range p.Names {
if n != nil {
paramNames = append(paramNames, n.Name)
}
}
}
}
return fmt.Sprintf("func(%s)", strings.Join(paramNames, ", "))
}
func funcDocs(f *doc.Func) string {
return fmt.Sprintf("%s\n%s", f.Doc, funcParams(f))
}
func symbolDocs(pkg *doc.Package, gopath string) (string, bool) {
_, s := splitGopath(gopath)
symbol := strings.Split(s, ".")
if len(symbol) == 1 {
for _, c := range pkg.Consts {
if c != nil && slices.Contains(c.Names, symbol[0]) {
return c.Doc, true
}
}
for _, v := range pkg.Vars {
if v != nil && slices.Contains(v.Names, symbol[0]) {
return v.Doc, true
}
}
for _, f := range pkg.Funcs {
if f == nil || f.Name != symbol[0] {
continue
}
return funcDocs(f), true
}
}
for _, t := range pkg.Types {
if t == nil || t.Name != symbol[0] {
continue
}
if len(symbol) == 1 {
return t.Doc, true
}
if t.Decl == nil || len(t.Decl.Specs) != 1 {
continue
}
ts, ok := t.Decl.Specs[0].(*ast.TypeSpec)
if !ok {
continue
}
str, ok := ts.Type.(*ast.StructType)
if !ok {
continue
}
fieldPath := symbol[1:]
for len(fieldPath) > 0 {
var foundField bool
if str.Fields != nil {
for _, f := range str.Fields.List {
for _, fn := range f.Names {
if fn != nil {
if fn.Name != fieldPath[0] {
continue
}
foundField = true
if len(fieldPath) == 1 {
if f.Doc == nil {
return "", true
}
return f.Doc.Text(), true
}
fstr, ok := f.Type.(*ast.StructType)
if !ok {
return "", false
}
str = fstr
fieldPath = fieldPath[1:]
break
}
}
if foundField {
break
}
}
}
if !foundField {
break
}
}
if len(symbol) != 2 {
return "", false
}
methodName := symbol[1]
for _, m := range t.Methods {
if m.Name != methodName {
continue
}
return funcDocs(m), true
}
}
return "", false
}
func symbolPath(packagePath string, name ...string) string {
return fmt.Sprintf("%s.%s", packagePath, strings.Join(name, "."))
}
func valueDocs(packagePath string, v []*doc.Value) map[string]string {
d := make(map[string]string)
for _, vi := range v {
if vi != nil {
for _, n := range vi.Names {
d[symbolPath(packagePath, n)] = vi.Doc
}
}
}
return d
}
func takeFieldDocs(docs map[string]string, packagePath string, prefix []string, f *ast.Field) {
for _, fn := range f.Names {
if fn == nil {
continue
}
docs[symbolPath(packagePath, append(prefix, fn.Name)...)] = f.Doc.Text()
if fst, ok := f.Type.(*ast.StructType); ok {
if fst.Fields != nil {
for _, fi := range fst.Fields.List {
if fi == nil {
continue
}
takeFieldDocs(docs, packagePath, append(prefix, fn.Name), fi)
}
}
}
}
}
func packageDocs(pkg *doc.Package) map[string]string {
d := make(map[string]string)
d[pkg.ImportPath] = pkg.Doc
d = merge(d, valueDocs(pkg.ImportPath, pkg.Consts))
d = merge(d, valueDocs(pkg.ImportPath, pkg.Vars))
for _, t := range pkg.Types {
if t != nil {
d[symbolPath(pkg.ImportPath, t.Name)] = t.Doc
}
if t.Decl != nil && len(t.Decl.Specs) == 1 {
if ts, ok := t.Decl.Specs[0].(*ast.TypeSpec); ok {
if str, ok := ts.Type.(*ast.StructType); ok && str.Fields != nil {
for _, f := range str.Fields.List {
if f == nil {
continue
}
takeFieldDocs(d, pkg.ImportPath, []string{t.Name}, f)
}
}
}
}
for _, m := range t.Methods {
if m != nil {
doc := funcDocs(m)
d[symbolPath(pkg.ImportPath, t.Name, m.Name)] = doc
}
}
}
for _, f := range pkg.Funcs {
if f != nil {
doc := funcDocs(f)
d[symbolPath(pkg.ImportPath, f.Name)] = doc
}
}
return d
}
func takeDocs(pkgs map[string][]*ast.Package, gopaths []string) (map[string]string, error) {
dm := make(map[string]string)
for _, gp := range gopaths {
pp, _ := splitGopath(gp)
isPackage := pp == gp
for _, pkg := range pkgs[pp] {
dpkg := doc.New(pkg, pp, doc.AllDecls|doc.PreserveAST)
*dpkg = fixDocPackage(*dpkg)
if isPackage {
pd := packageDocs(dpkg)
dm = merge(dm, pd)
continue
}
sd, ok := symbolDocs(dpkg, gp)
if !ok {
return nil, fmt.Errorf("symbol not found: %s", gp)
}
dm[gp] = sd
}
}
return dm, nil
}
func generate(o options, gopaths ...string) (map[string]string, error) {
gopaths = cleanPaths(gopaths)
ppaths := packagePaths(gopaths)
dirs := collectGoDirs(o)
pkgs, err := importPackages(o, dirs, ppaths)
if err != nil {
return nil, err
}
ppkgs, err := parsePackages(pkgs)
if err != nil {
return nil, err
}
return takeDocs(ppkgs, gopaths)
}
func format(w io.Writer, pname string, docs map[string]string) error {
var err error
printf := func(f string, a ...any) {
if err != nil {
return
}
_, err = fmt.Fprintf(w, f, a...)
}
println := func(a ...any) {
if err != nil {
return
}
_, err = fmt.Fprintln(w, a...)
}
printf("package %s\n", pname)
println("import \"code.squareroundforest.org/arpio/docreflect\"")
println("func init() {")
var paths []string
for path := range docs {
paths = append(paths, path)
}
sort.Strings(paths)
for _, path := range paths {
doc := docs[path]
path, doc = strconv.Quote(path), strconv.Quote(doc)
printf("docreflect.Register(%s, %s)\n", path, doc)
}
println("}")
return err
}
// GenerateRegistry generates a Go code file to the output, including a package init function that
// will register the documentation of the declarations specified by their gopath.
func GenerateRegistry(w io.Writer, outputPackageName string, gopath ...string) error {
o := initOptions()
d, err := generate(o, gopath...)
if err != nil {
return err
}
if err := format(w, outputPackageName, d); err != nil {
return err
}
return nil
}
// TODO:
// - type foo = bar
// - type foo bar
// - type (...)
// - pointers in struct fields