1
0

fix function param arity

This commit is contained in:
Arpad Ryszka 2025-09-04 15:27:30 +02:00
parent d26ecc663a
commit afd2706372
4 changed files with 30 additions and 28 deletions

View File

@ -30,7 +30,6 @@ Generated with https://code.squareroundforest.org/arpio/docreflect
type options struct {
wd string
goroot string
gomod string
modules map[string]string
}
@ -72,20 +71,20 @@ func findGoMod(dir string) (string, bool) {
return findGoMod(path.Dir(dir))
}
func readGomod(wd string) (string, map[string]string) {
func readGomod(wd string) map[string]string {
p, ok := findGoMod(wd)
if !ok {
return "", nil
return nil
}
d, err := os.ReadFile(p)
if err != nil {
return "", nil
return nil
}
m, err := modfile.Parse(p, d, nil)
if err != nil {
return "", nil
return nil
}
var dirs map[string]string
@ -97,17 +96,16 @@ func readGomod(wd string) (string, map[string]string) {
name := m.Module.Mod.String()
dirs[name] = path.Dir(p)
return name, dirs
return dirs
}
func initOptions() options {
wd, _ := os.Getwd()
gr := getGoroot()
gomod, modules := readGomod(wd)
modules := readGomod(wd)
return options{
wd: wd,
goroot: gr,
gomod: gomod,
modules: modules,
}
}

View File

@ -77,7 +77,6 @@ func TestGenerate(t *testing.T) {
o := options{
wd: wd,
goroot: runtime.GOROOT(),
gomod: "code.squareroundforest.org/arpio/docreflect",
modules: map[string]string{
"golang.org/x/mod@v0.27.0": path.Join(h, "go", "pkg/mod", "golang.org/x/mod@v0.27.0"),
"code.squareroundforest.org/arpio/notation@v0.0.0-20241225183158-af3bd591a174": path.Join(
@ -370,10 +369,6 @@ func TestGenerate(t *testing.T) {
t.Fatal("goroot")
}
if o.gomod != "code.squareroundforest.org/arpio/docreflect" {
t.Fatal("gomod")
}
for _, module := range []string{
"code.squareroundforest.org/arpio/notation",
"golang.org/x/mod",

32
lib.go
View File

@ -20,7 +20,7 @@ type pointer[T any] interface {
var (
registry = make(map[string]string)
funcParamsExp = regexp.MustCompile("^func[(]([a-zA-Z_][a-zA-Z0-9_]+(, [a-zA-Z_][a-zA-Z0-9_]+)*)?[)]$")
funcParamsExp = regexp.MustCompile("^func[(]([a-zA-Z_][a-zA-Z0-9_]*(, [a-zA-Z_][a-zA-Z0-9_]*)*)?[)]$")
)
func unpack[T pointer[T]](p T) T {
@ -66,6 +66,7 @@ func functionPath(v reflect.Value) string {
}
func splitDocs(d string) (string, string) {
d = strings.TrimSpace(d)
parts := strings.Split(d, "\n")
last := len(parts) - 1
lastPart := parts[last]
@ -109,7 +110,6 @@ func Docs(gopath string) string {
// Function returns the documentation for a package level function.
func Function(v reflect.Value) string {
v = unpack(v)
if v.Kind() != reflect.Func {
return ""
}
@ -119,13 +119,17 @@ func Function(v reflect.Value) string {
// FunctionParams returns the list of the parameter names of a package level function.
func FunctionParams(v reflect.Value) []string {
v = unpack(v)
if v.Kind() != reflect.Func {
return nil
}
d := docs(functionPath(v))
return functionParams(d)
params := functionParams(d)
if len(params) != v.Type().NumIn() {
return make([]string, v.Type().NumIn())
}
return params
}
// Type returns the docuemntation for a package level type.
@ -134,7 +138,6 @@ func Type(t reflect.Type) string {
return ""
}
t = unpack(t)
p := fmt.Sprintf("%s.%s", t.PkgPath(), t.Name())
return Docs(p)
}
@ -216,11 +219,6 @@ func Method(t reflect.Type, name string) string {
return ""
}
t = unpack(t)
if t.Kind() != reflect.Struct {
return ""
}
p := fmt.Sprintf("%s.%s.%s", t.PkgPath(), t.Name(), name)
return Docs(p)
}
@ -231,8 +229,18 @@ func MethodParams(t reflect.Type, name string) []string {
return nil
}
t = unpack(t)
m, ok := t.MethodByName(name)
if !ok {
return nil
}
p := fmt.Sprintf("%s.%s.%s", t.PkgPath(), t.Name(), name)
d := docs(p)
return functionParams(d)
params := functionParams(d)
expectedParams := m.Type.NumIn() - 1
if len(params) != expectedParams {
return make([]string, expectedParams)
}
return params
}

View File

@ -3,6 +3,7 @@ package docreflect_test
import (
"code.squareroundforest.org/arpio/docreflect"
"code.squareroundforest.org/arpio/docreflect/internal/tests/src/testpackage"
"code.squareroundforest.org/arpio/notation"
"reflect"
"strings"
"testing"
@ -137,16 +138,16 @@ func Test(t *testing.T) {
typ := reflect.TypeOf(s)
d := docreflect.Method(typ, "Method")
if !strings.Contains(d, "Method is a method of ExportedType") {
t.Fatal()
t.Fatal("docs", d)
}
p := docreflect.MethodParams(typ, "Method")
if len(p) != 3 {
t.Fatal()
t.Fatal("length", notation.Sprint(p))
}
if p[0] != "p1" || p[1] != "p2" || p[2] != "p3" {
t.Fatal()
t.Fatal("values", notation.Sprint(p))
}
})