package tools import ( "bytes" "code.squareroundforest.org/arpio/docreflect/generate" "encoding/base64" "errors" "fmt" "hash/fnv" "io" "os" "os/exec" "path" "strings" ) type ExecOptions struct { NoCache bool ClearCache bool CacheDir string } func execc(stdin io.Reader, stdout, stderr io.Writer, command string, args []string, env []string) error { c := strings.Split(command, " ") cmd := exec.Command(c[0], append(c[1:], args...)...) cmd.Env = append(os.Environ(), env...) cmd.Stdin = stdin cmd.Stdout = stdout cmd.Stderr = stderr return cmd.Run() } func execCommandDir(out io.Writer, commandDir string, env ...string) error { stderr := bytes.NewBuffer(nil) if err := execc(nil, out, stderr, "go run", []string{commandDir}, env); err != nil { io.Copy(os.Stderr, stderr) return err } return nil } func execInternal(command string, args ...string) error { stdout := bytes.NewBuffer(nil) stderr := bytes.NewBuffer(nil) if err := execc(nil, stdout, stderr, command, args, nil); err != nil { io.Copy(os.Stderr, stdout) io.Copy(os.Stderr, stderr) return err } return nil } func execTransparent(command string, args ...string) error { return execc(os.Stdin, os.Stdout, os.Stderr, command, args, nil) } func Docreflect(out io.Writer, packageName string, gopaths ...string) error { return generate.GenerateRegistry(out, packageName, gopaths...) } func Man(out io.Writer, commandDir string) error { return execCommandDir(out, commandDir, "wandgenerate=man") } func Markdown(out io.Writer, commandDir string) error { return execCommandDir(out, commandDir, "wandgenerate=markdown") } func splitFunction(function string) (pkg string, expression string, err error) { parts := strings.Split(function, "/") gopath := parts[:len(parts)-1] sparts := strings.Split(parts[len(parts)-1], ".") if len(sparts) == 1 && len(gopath) > 1 { err = errors.New("function cannot be identified") return } if len(sparts) == 1 { expression = sparts[0] } else { expression = parts[len(parts)-1] pkg = strings.Join(append(gopath, sparts[0]), "/") } return } func functionHash(function string) (string, error) { h := fnv.New128() h.Write([]byte(function)) buf := bytes.NewBuffer(nil) b64 := base64.NewEncoder(base64.RawURLEncoding, buf) if _, err := b64.Write(h.Sum(nil)); err != nil { return "", fmt.Errorf("failed to encode function: %w", err) } if err := b64.Close(); err != nil { return "", fmt.Errorf("failed to complete encoding of function: %w", err) } return strings.TrimPrefix(buf.String(), "_"), nil } func findGomod(wd string) (string, bool) { gomodDir := wd for { gomodPath := path.Join(gomodDir, "go.mod") f, err := os.Stat(gomodPath) if err == nil && !f.IsDir() { return gomodPath, true } if gomodDir == "/" { return "", false } gomodDir = path.Dir(gomodDir) } } func copyGomod(mn, dst, src string) error { srcf, err := os.Open(src) if err != nil { return fmt.Errorf("failed to open file: %s; %w", src, err) } defer srcf.Close() dstf, err := os.Create(dst) if err != nil { return fmt.Errorf("failed to create file: %s; %w", dst, err) } defer dstf.Close() b, err := io.ReadAll(srcf) if err != nil { return fmt.Errorf("failed to read go.mod file %s: %w", src, err) } s := string(b) ss := strings.Split(s, "\n") for i := range ss { if strings.HasPrefix(ss[i], "module ") { ss[i] = fmt.Sprintf("module %s", mn) break } } if _, err := dstf.Write([]byte(strings.Join(ss, "\n"))); err != nil { return fmt.Errorf("failed to write go.mod file %s: %w", dst, err) } return nil } func printFile(fn string, pkg, expression string) error { f, err := os.Create(fn) if err != nil { return err } defer f.Close() fprintf := func(format string, args ...any) { if err != nil { return } _, err = fmt.Fprintf(f, format, args...) } fprintf("package main\n") if pkg != "" { fprintf("import \"%s\"\n", pkg) } fprintf("import \"code.squareroundforest.org/arpio/wand\"\n") fprintf("func main() {\n") fprintf("wand.Exec(%s)\n", expression) fprintf("}") return err } func Exec(o ExecOptions, function string, args ...string) error { pkg, expression, err := splitFunction(function) if err != nil { return err } functionHash, err := functionHash(function) if err != nil { return err } cacheDir := o.CacheDir if cacheDir == "" { cacheDir = path.Join(os.Getenv("HOME"), ".wand") } functionDir := path.Join(cacheDir, functionHash) if o.NoCache { functionDir = path.Join(cacheDir, "tmp", functionHash) } if o.NoCache || o.ClearCache { if err := os.RemoveAll(functionDir); err != nil { return fmt.Errorf("failed to clean cache: %w", err) } } if err := os.MkdirAll(functionDir, os.ModePerm); err != nil { return fmt.Errorf("failed to ensure cache directory: %w", err) } wd, err := os.Getwd() if err != nil { return fmt.Errorf("error identifying current directory: %w", err) } goGet := func(pkg string) error { if err := execInternal("go get", pkg); err != nil { return fmt.Errorf("failed to get go module: %w", err) } return nil } if err := os.Chdir(functionDir); err != nil { return fmt.Errorf("failed to switch to temporary directory: %w", err) } defer os.Chdir(wd) gomodPath, hasGomod := findGomod(wd) if hasGomod { if err := copyGomod(functionHash, path.Join(functionDir, "go.mod"), gomodPath); err != nil { return err } } else { if err := execInternal("go mod init", functionHash); err != nil { return fmt.Errorf("failed to initialize temporary module: %w", err) } } // non-robust way of avoiding importing standard library packages: if strings.Contains(pkg, ".") { if err := goGet(pkg); err != nil { return err } } if err := goGet("code.squareroundforest.org/arpio/wand"); err != nil { return err } goFile := path.Join(functionDir, fmt.Sprintf("%s.go", functionHash)) if _, err := os.Stat(goFile); err != nil { if err := printFile(goFile, pkg, expression); err != nil { return fmt.Errorf("failed to create temporary go file: %w", err) } } if err := execTransparent("go run", append([]string{functionDir}, args...)...); err != nil { return err } if o.NoCache { if err := os.RemoveAll(functionDir); err != nil { return fmt.Errorf("failed to clean cache: %w", err) } } return nil }