package tools import ( "bufio" "bytes" "encoding/base64" "errors" "fmt" "hash/fnv" "io" "os" "path" "sort" "strings" "unicode" ) type ExecOptions struct { NoCache bool ClearCache bool CacheDir string Import []string InlineImport []string } func commandReader(in io.Reader) func() ([]string, error) { var ( yieldErr error currentArg []rune args []string escapeOne, escapePartial, escapeFull bool ) buf := bufio.NewReader(in) return func() ([]string, error) { if yieldErr != nil { return nil, yieldErr } for { r, _, err := buf.ReadRune() if errors.Is(err, io.EOF) { if len(currentArg) > 0 { args = append(args, string(currentArg)) currentArg = nil } yield := args args = nil yieldErr = err return yield, nil } if err != nil { yieldErr = fmt.Errorf("failed reading from input: %w", err) return nil, yieldErr } if r == unicode.ReplacementChar { if len(currentArg) > 0 { yieldErr = errors.New("broken unicode stream") return nil, yieldErr } continue } if escapeFull { if r == '\'' { escapeFull = false args = append(args, string(currentArg)) currentArg = nil continue } currentArg = append(currentArg, r) continue } if escapeOne { escapeOne = false currentArg = append(currentArg, r) continue } if escapePartial { if escapeOne { escapeOne = false currentArg = append(currentArg, r) continue } if r == '\\' { escapeOne = true continue } if r == '"' { escapePartial = false args = append(args, string(currentArg)) currentArg = nil continue } currentArg = append(currentArg, r) continue } if r == '\n' { if len(currentArg) > 0 { args = append(args, string(currentArg)) currentArg = nil } yield := args args = nil return yield, nil } if unicode.IsSpace(r) { if len(currentArg) > 0 { args = append(args, string(currentArg)) currentArg = nil } continue } switch r { case '\\': escapeOne = true case '"': escapePartial = true case '\'': escapeFull = true default: currentArg = append(currentArg, r) } } } } func hash(expression string, imports, inlineImports []string) (string, error) { h := fnv.New128() h.Write([]byte(expression)) allImports := append(imports, inlineImports...) sort.Strings(allImports) for _, i := range allImports { h.Write([]byte(i)) } buf := bytes.NewBuffer(nil) b64 := base64.NewEncoder(base64.RawURLEncoding, buf) if _, err := b64.Write(h.Sum(nil)); err != nil { return "", fmt.Errorf("failed to encode expression: %w", err) } if err := b64.Close(); err != nil { return "", fmt.Errorf("failed to complete encoding of expression: %w", err) } return strings.TrimPrefix(buf.String(), "_"), nil } func printGoFile(fn string, expression string, imports []string, inlineImports []string) error { f, err := os.Create(fn) if err != nil { return err } defer f.Close() fprintf := func(format string, args ...any) { if err != nil { return } _, err = fmt.Fprintf(f, format, args...) } fprintf("package main\n") for _, i := range imports { fprintf("import \"%s\"\n", strings.Split(i, "@")[0]) } for _, i := range inlineImports { fprintf("import . \"%s\"\n", strings.Split(i, "@")[0]) } fprintf("import \"code.squareroundforest.org/arpio/wand\"\n") fprintf("func main() {\n") fprintf("wand.Exec(%s)\n", expression) fprintf("}") return err } func execWand(o ExecOptions, args []string) error { expression, args := args[0], args[1:] commandHash, err := hash(expression, o.Import, o.InlineImport) if err != nil { return err } cacheDir := o.CacheDir if cacheDir == "" { cacheDir = path.Join(os.Getenv("HOME"), ".wand") } commandDir := path.Join(cacheDir, commandHash) if o.NoCache { commandDir = path.Join(cacheDir, "tmp", commandHash) } if o.NoCache || o.ClearCache { if err := os.RemoveAll(commandDir); err != nil { return fmt.Errorf("failed to clear cache: %w", err) } } if err := os.MkdirAll(commandDir, os.ModePerm); err != nil { return fmt.Errorf("failed to ensure cache directory: %w", err) } wd, err := os.Getwd() if err != nil { return fmt.Errorf("error identifying current directory: %w", err) } goGet := func(pkg string) error { if err := execInternal("go get", pkg); err != nil { return fmt.Errorf("failed to get go module: %w", err) } return nil } if err := os.Chdir(commandDir); err != nil { return fmt.Errorf("failed to switch to temporary directory: %w", err) } defer os.Chdir(wd) gomodPath := path.Join(commandDir, "go.mod") if _, err := os.Stat(gomodPath); err != nil { if err := execInternal("go mod init", commandHash); err != nil { return fmt.Errorf("failed to initialize temporary module: %w", err) } for _, pkg := range o.Import { if err := goGet(pkg); err != nil { return err } } for _, pkg := range o.InlineImport { if err := goGet(pkg); err != nil { return err } } if err := goGet("code.squareroundforest.org/arpio/wand"); err != nil { return err } } goFile := path.Join(commandDir, fmt.Sprintf("%s.go", commandHash)) if _, err := os.Stat(goFile); err != nil { if err := printGoFile(goFile, expression, o.Import, o.InlineImport); err != nil { return fmt.Errorf("failed to create temporary go file: %w", err) } } if err := execTransparent("go run", append([]string{commandDir}, args...)...); err != nil { return err } if o.NoCache { if err := os.RemoveAll(commandDir); err != nil { return fmt.Errorf("failed to clean cache: %w", err) } } return nil } func readExec(o ExecOptions, stdin io.Reader) error { readCommand := commandReader(stdin) for { args, err := readCommand() if errors.Is(err, io.EOF) { return nil } if err != nil { return err } if len(args) == 0 { continue } if err := execWand(o, args); err != nil { fmt.Fprintln(os.Stderr, err) } } } func Exec(o ExecOptions, stdin io.Reader, args ...string) error { if len(args) == 0 { return readExec(o, stdin) } return execWand(o, args) }