wand/tools/execwand.go
2025-09-01 04:10:35 +02:00

309 lines
6.1 KiB
Go

package tools
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"hash/fnv"
"io"
"os"
"path"
"sort"
"strings"
"unicode"
)
type ExecOptions struct {
NoCache bool
ClearCache bool
CacheDir string
Import []string
InlineImport []string
}
func commandReader(in io.Reader) func() ([]string, error) {
var (
yieldErr error
currentArg []rune
args []string
escapeOne, escapePartial, escapeFull bool
)
buf := bufio.NewReader(in)
return func() ([]string, error) {
if yieldErr != nil {
return nil, yieldErr
}
for {
r, _, err := buf.ReadRune()
if errors.Is(err, io.EOF) {
if len(currentArg) > 0 {
args = append(args, string(currentArg))
currentArg = nil
}
yield := args
args = nil
yieldErr = err
return yield, nil
}
if err != nil {
yieldErr = fmt.Errorf("failed reading from input: %w", err)
return nil, yieldErr
}
if r == unicode.ReplacementChar {
if len(currentArg) > 0 {
yieldErr = errors.New("broken unicode stream")
return nil, yieldErr
}
continue
}
if escapeFull {
if r == '\'' {
escapeFull = false
args = append(args, string(currentArg))
currentArg = nil
continue
}
currentArg = append(currentArg, r)
continue
}
if escapeOne {
escapeOne = false
currentArg = append(currentArg, r)
continue
}
if escapePartial {
if escapeOne {
escapeOne = false
currentArg = append(currentArg, r)
continue
}
if r == '\\' {
escapeOne = true
continue
}
if r == '"' {
escapePartial = false
args = append(args, string(currentArg))
currentArg = nil
continue
}
currentArg = append(currentArg, r)
continue
}
if r == '\n' {
if len(currentArg) > 0 {
args = append(args, string(currentArg))
currentArg = nil
}
yield := args
args = nil
return yield, nil
}
if unicode.IsSpace(r) {
if len(currentArg) > 0 {
args = append(args, string(currentArg))
currentArg = nil
}
continue
}
switch r {
case '\\':
escapeOne = true
case '"':
escapePartial = true
case '\'':
escapeFull = true
default:
currentArg = append(currentArg, r)
}
}
}
}
func hash(expression string, imports, inlineImports []string) (string, error) {
h := fnv.New128()
h.Write([]byte(expression))
allImports := append(imports, inlineImports...)
sort.Strings(allImports)
for _, i := range allImports {
h.Write([]byte(i))
}
buf := bytes.NewBuffer(nil)
b64 := base64.NewEncoder(base64.RawURLEncoding, buf)
if _, err := b64.Write(h.Sum(nil)); err != nil {
return "", fmt.Errorf("failed to encode expression: %w", err)
}
if err := b64.Close(); err != nil {
return "", fmt.Errorf("failed to complete encoding of expression: %w", err)
}
return strings.TrimPrefix(buf.String(), "_"), nil
}
func printGoFile(fn string, expression string, imports []string, inlineImports []string) error {
f, err := os.Create(fn)
if err != nil {
return err
}
defer f.Close()
fprintf := func(format string, args ...any) {
if err != nil {
return
}
_, err = fmt.Fprintf(f, format, args...)
}
fprintf("package main\n")
for _, i := range imports {
fprintf("import \"%s\"\n", i)
}
for _, i := range inlineImports {
fprintf("import . \"%s\"\n", i)
}
fprintf("import \"code.squareroundforest.org/arpio/wand\"\n")
fprintf("func main() {\n")
fprintf("wand.Exec(%s)\n", expression)
fprintf("}")
return err
}
func execWand(o ExecOptions, args []string) error {
expression, args := args[0], args[1:]
commandHash, err := hash(expression, o.Import, o.InlineImport)
if err != nil {
return err
}
cacheDir := o.CacheDir
if cacheDir == "" {
cacheDir = path.Join(os.Getenv("HOME"), ".wand")
}
commandDir := path.Join(cacheDir, commandHash)
if o.NoCache {
commandDir = path.Join(cacheDir, "tmp", commandHash)
}
if o.NoCache || o.ClearCache {
if err := os.RemoveAll(commandDir); err != nil {
return fmt.Errorf("failed to clear cache: %w", err)
}
}
if err := os.MkdirAll(commandDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to ensure cache directory: %w", err)
}
wd, err := os.Getwd()
if err != nil {
return fmt.Errorf("error identifying current directory: %w", err)
}
goGet := func(pkg string) error {
if err := execInternal("go get", pkg); err != nil {
return fmt.Errorf("failed to get go module: %w", err)
}
return nil
}
if err := os.Chdir(commandDir); err != nil {
return fmt.Errorf("failed to switch to temporary directory: %w", err)
}
defer os.Chdir(wd)
gomodPath := path.Join(commandDir, "go.mod")
if _, err := os.Stat(gomodPath); err != nil {
if err := execInternal("go mod init", commandHash); err != nil {
return fmt.Errorf("failed to initialize temporary module: %w", err)
}
for _, pkg := range o.Import {
if err := goGet(pkg); err != nil {
return err
}
}
for _, pkg := range o.InlineImport {
if err := goGet(pkg); err != nil {
return err
}
}
if err := goGet("code.squareroundforest.org/arpio/wand"); err != nil {
return err
}
}
goFile := path.Join(commandDir, fmt.Sprintf("%s.go", commandHash))
if _, err := os.Stat(goFile); err != nil {
if err := printGoFile(goFile, expression, o.Import, o.InlineImport); err != nil {
return fmt.Errorf("failed to create temporary go file: %w", err)
}
}
if err := execTransparent("go run", append([]string{commandDir}, args...)...); err != nil {
return err
}
if o.NoCache {
if err := os.RemoveAll(commandDir); err != nil {
return fmt.Errorf("failed to clean cache: %w", err)
}
}
return nil
}
func readExec(o ExecOptions, stdin io.Reader) error {
readCommand := commandReader(stdin)
for {
args, err := readCommand()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}
if err := execWand(o, args); err != nil {
fmt.Fprintln(os.Stderr, err)
}
}
}
func Exec(o ExecOptions, stdin io.Reader, args ...string) error {
if len(args) == 0 {
return readExec(o, stdin)
}
return execWand(o, args)
}