1
0

fix function param arity

This commit is contained in:
Arpad Ryszka 2025-09-04 15:27:30 +02:00
parent d26ecc663a
commit afd2706372
4 changed files with 30 additions and 28 deletions

View File

@ -30,7 +30,6 @@ Generated with https://code.squareroundforest.org/arpio/docreflect
type options struct { type options struct {
wd string wd string
goroot string goroot string
gomod string
modules map[string]string modules map[string]string
} }
@ -72,20 +71,20 @@ func findGoMod(dir string) (string, bool) {
return findGoMod(path.Dir(dir)) return findGoMod(path.Dir(dir))
} }
func readGomod(wd string) (string, map[string]string) { func readGomod(wd string) map[string]string {
p, ok := findGoMod(wd) p, ok := findGoMod(wd)
if !ok { if !ok {
return "", nil return nil
} }
d, err := os.ReadFile(p) d, err := os.ReadFile(p)
if err != nil { if err != nil {
return "", nil return nil
} }
m, err := modfile.Parse(p, d, nil) m, err := modfile.Parse(p, d, nil)
if err != nil { if err != nil {
return "", nil return nil
} }
var dirs map[string]string var dirs map[string]string
@ -97,17 +96,16 @@ func readGomod(wd string) (string, map[string]string) {
name := m.Module.Mod.String() name := m.Module.Mod.String()
dirs[name] = path.Dir(p) dirs[name] = path.Dir(p)
return name, dirs return dirs
} }
func initOptions() options { func initOptions() options {
wd, _ := os.Getwd() wd, _ := os.Getwd()
gr := getGoroot() gr := getGoroot()
gomod, modules := readGomod(wd) modules := readGomod(wd)
return options{ return options{
wd: wd, wd: wd,
goroot: gr, goroot: gr,
gomod: gomod,
modules: modules, modules: modules,
} }
} }

View File

@ -77,7 +77,6 @@ func TestGenerate(t *testing.T) {
o := options{ o := options{
wd: wd, wd: wd,
goroot: runtime.GOROOT(), goroot: runtime.GOROOT(),
gomod: "code.squareroundforest.org/arpio/docreflect",
modules: map[string]string{ modules: map[string]string{
"golang.org/x/mod@v0.27.0": path.Join(h, "go", "pkg/mod", "golang.org/x/mod@v0.27.0"), "golang.org/x/mod@v0.27.0": path.Join(h, "go", "pkg/mod", "golang.org/x/mod@v0.27.0"),
"code.squareroundforest.org/arpio/notation@v0.0.0-20241225183158-af3bd591a174": path.Join( "code.squareroundforest.org/arpio/notation@v0.0.0-20241225183158-af3bd591a174": path.Join(
@ -370,10 +369,6 @@ func TestGenerate(t *testing.T) {
t.Fatal("goroot") t.Fatal("goroot")
} }
if o.gomod != "code.squareroundforest.org/arpio/docreflect" {
t.Fatal("gomod")
}
for _, module := range []string{ for _, module := range []string{
"code.squareroundforest.org/arpio/notation", "code.squareroundforest.org/arpio/notation",
"golang.org/x/mod", "golang.org/x/mod",

32
lib.go
View File

@ -20,7 +20,7 @@ type pointer[T any] interface {
var ( var (
registry = make(map[string]string) registry = make(map[string]string)
funcParamsExp = regexp.MustCompile("^func[(]([a-zA-Z_][a-zA-Z0-9_]+(, [a-zA-Z_][a-zA-Z0-9_]+)*)?[)]$") funcParamsExp = regexp.MustCompile("^func[(]([a-zA-Z_][a-zA-Z0-9_]*(, [a-zA-Z_][a-zA-Z0-9_]*)*)?[)]$")
) )
func unpack[T pointer[T]](p T) T { func unpack[T pointer[T]](p T) T {
@ -66,6 +66,7 @@ func functionPath(v reflect.Value) string {
} }
func splitDocs(d string) (string, string) { func splitDocs(d string) (string, string) {
d = strings.TrimSpace(d)
parts := strings.Split(d, "\n") parts := strings.Split(d, "\n")
last := len(parts) - 1 last := len(parts) - 1
lastPart := parts[last] lastPart := parts[last]
@ -109,7 +110,6 @@ func Docs(gopath string) string {
// Function returns the documentation for a package level function. // Function returns the documentation for a package level function.
func Function(v reflect.Value) string { func Function(v reflect.Value) string {
v = unpack(v)
if v.Kind() != reflect.Func { if v.Kind() != reflect.Func {
return "" return ""
} }
@ -119,13 +119,17 @@ func Function(v reflect.Value) string {
// FunctionParams returns the list of the parameter names of a package level function. // FunctionParams returns the list of the parameter names of a package level function.
func FunctionParams(v reflect.Value) []string { func FunctionParams(v reflect.Value) []string {
v = unpack(v)
if v.Kind() != reflect.Func { if v.Kind() != reflect.Func {
return nil return nil
} }
d := docs(functionPath(v)) d := docs(functionPath(v))
return functionParams(d) params := functionParams(d)
if len(params) != v.Type().NumIn() {
return make([]string, v.Type().NumIn())
}
return params
} }
// Type returns the docuemntation for a package level type. // Type returns the docuemntation for a package level type.
@ -134,7 +138,6 @@ func Type(t reflect.Type) string {
return "" return ""
} }
t = unpack(t)
p := fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) p := fmt.Sprintf("%s.%s", t.PkgPath(), t.Name())
return Docs(p) return Docs(p)
} }
@ -216,11 +219,6 @@ func Method(t reflect.Type, name string) string {
return "" return ""
} }
t = unpack(t)
if t.Kind() != reflect.Struct {
return ""
}
p := fmt.Sprintf("%s.%s.%s", t.PkgPath(), t.Name(), name) p := fmt.Sprintf("%s.%s.%s", t.PkgPath(), t.Name(), name)
return Docs(p) return Docs(p)
} }
@ -231,8 +229,18 @@ func MethodParams(t reflect.Type, name string) []string {
return nil return nil
} }
t = unpack(t) m, ok := t.MethodByName(name)
if !ok {
return nil
}
p := fmt.Sprintf("%s.%s.%s", t.PkgPath(), t.Name(), name) p := fmt.Sprintf("%s.%s.%s", t.PkgPath(), t.Name(), name)
d := docs(p) d := docs(p)
return functionParams(d) params := functionParams(d)
expectedParams := m.Type.NumIn() - 1
if len(params) != expectedParams {
return make([]string, expectedParams)
}
return params
} }

View File

@ -3,6 +3,7 @@ package docreflect_test
import ( import (
"code.squareroundforest.org/arpio/docreflect" "code.squareroundforest.org/arpio/docreflect"
"code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage" "code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage"
"code.squareroundforest.org/arpio/notation"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -137,16 +138,16 @@ func Test(t *testing.T) {
typ := reflect.TypeOf(s) typ := reflect.TypeOf(s)
d := docreflect.Method(typ, "Method") d := docreflect.Method(typ, "Method")
if !strings.Contains(d, "Method is a method of ExportedType") { if !strings.Contains(d, "Method is a method of ExportedType") {
t.Fatal() t.Fatal("docs", d)
} }
p := docreflect.MethodParams(typ, "Method") p := docreflect.MethodParams(typ, "Method")
if len(p) != 3 { if len(p) != 3 {
t.Fatal() t.Fatal("length", notation.Sprint(p))
} }
if p[0] != "p1" || p[1] != "p2" || p[2] != "p3" { if p[0] != "p1" || p[1] != "p2" || p[2] != "p3" {
t.Fatal() t.Fatal("values", notation.Sprint(p))
} }
}) })