package tools import ( "bufio" "bytes" "encoding/base64" "errors" "fmt" "hash/fnv" "io" "os" "path" "sort" "strings" "unicode" ) const selfPkg = "code.squareroundforest.org/arpio/wand" type ExecOptions struct { NoCache bool ClearCache bool CacheDir string Import []string ReplaceModule []string } func commandReader(in io.Reader) func() ([]string, error) { var ( yieldErr error currentArg []rune args []string escapeOne, escapePartial, escapeFull bool ) buf := bufio.NewReader(in) return func() ([]string, error) { if yieldErr != nil { return nil, yieldErr } for { r, _, err := buf.ReadRune() if errors.Is(err, io.EOF) { if len(currentArg) > 0 { args = append(args, string(currentArg)) currentArg = nil } yield := args args = nil yieldErr = err return yield, nil } if err != nil { yieldErr = fmt.Errorf("failed reading from input: %w", err) return nil, yieldErr } if r == unicode.ReplacementChar { if len(currentArg) > 0 { yieldErr = errors.New("broken unicode stream") return nil, yieldErr } continue } if escapeFull { if r == '\'' { escapeFull = false args = append(args, string(currentArg)) currentArg = nil continue } currentArg = append(currentArg, r) continue } if escapeOne { escapeOne = false currentArg = append(currentArg, r) continue } if escapePartial { if r == '\\' { escapeOne = true continue } if r == '"' { escapePartial = false args = append(args, string(currentArg)) currentArg = nil continue } currentArg = append(currentArg, r) continue } if r == '\n' { if len(currentArg) > 0 { args = append(args, string(currentArg)) currentArg = nil } yield := args args = nil return yield, nil } if unicode.IsSpace(r) { if len(currentArg) > 0 { args = append(args, string(currentArg)) currentArg = nil } continue } switch r { case '\\': escapeOne = true case '"': escapePartial = true case '\'': escapeFull = true default: currentArg = append(currentArg, r) } } } } func hash(expression string, imports []string) (string, error) { h := fnv.New128() h.Write([]byte(expression)) sort.Strings(imports) for _, i := range imports { h.Write([]byte(i)) } buf := bytes.NewBuffer(nil) b64 := base64.NewEncoder(base64.RawURLEncoding, buf) if _, err := b64.Write(h.Sum(nil)); err != nil { return "", fmt.Errorf("failed to encode expression: %w", err) } if err := b64.Close(); err != nil { return "", fmt.Errorf("failed to complete encoding of expression: %w", err) } return strings.TrimPrefix(buf.String(), "_"), nil } func printGoFile(fn string, expression string, imports []string) error { f, err := os.Create(fn) if err != nil { return err } defer f.Close() fprintf := func(format string, args ...any) { if err != nil { return } _, err = fmt.Fprintf(f, format, args...) } fprintf("package main\n") for _, i := range imports { ii := strings.Split(i, "=") p := ii[0] if len(ii) > 1 { p = strings.Join(ii[1:], "=") } if len(ii) == 1 { fprintf("import \"%s\"\n", strings.Split(p, "@")[0]) } else { fprintf("import %s \"%s\"\n", ii[0], strings.Split(p, "@")[0]) } } fprintf("import \"%s\"\n", selfPkg) fprintf("func main() {\n") fprintf("wand.Exec(%s)\n", expression) fprintf("}") return err } func execWand(o ExecOptions, stdin io.Reader, stdout, stderr io.Writer, args []string) error { expression, args := args[0], args[1:] commandHash, err := hash(expression, o.Import) if err != nil { return err } cacheDir := o.CacheDir if cacheDir == "" { cacheDir = path.Join(os.Getenv("HOME"), ".wand") } commandDir := path.Join(cacheDir, commandHash) if o.NoCache { commandDir = path.Join(cacheDir, "tmp", commandHash) } if o.NoCache || o.ClearCache { if err := os.RemoveAll(commandDir); err != nil { return fmt.Errorf("failed to clear cache: %w", err) } } if err := os.MkdirAll(commandDir, os.ModePerm); err != nil { return fmt.Errorf("failed to ensure cache directory: %w", err) } wd, err := os.Getwd() if err != nil { return fmt.Errorf("error identifying current directory: %w", err) } goGet := func(pkg string) error { if err := execInternal("go get", pkg); err != nil { return fmt.Errorf("failed to get go module: %w", err) } return nil } goReplace := func(replace string) error { if err := execInternal("go mod edit -replace", replace); err != nil { return fmt.Errorf("failed to get go module: %w", err) } return nil } if err := os.Chdir(commandDir); err != nil { return fmt.Errorf("failed to switch to temporary directory: %w", err) } defer os.Chdir(wd) gomodPath := path.Join(commandDir, "go.mod") if _, err := os.Stat(gomodPath); err != nil { if err := execInternal("go mod init", commandHash); err != nil { return fmt.Errorf("failed to initialize temporary module: %w", err) } var hasSelf bool for _, pkg := range o.Import { pp := strings.Split(pkg, "=") p := pp[0] if len(pp) > 1 { p = strings.Join(pp[1:], "=") } if err := goGet(p); err != nil { return err } if p == selfPkg { hasSelf = true } } if !hasSelf { if err := goGet(selfPkg); err != nil { return err } } for _, replace := range o.ReplaceModule { if err := goReplace(replace); err != nil { return err } } } goFile := path.Join(commandDir, fmt.Sprintf("%s.go", commandHash)) if _, err := os.Stat(goFile); err != nil { if err := printGoFile(goFile, expression, o.Import); err != nil { return fmt.Errorf("failed to create temporary go file: %w", err) } } if err := execc(stdin, stdout, stderr, "go run", append([]string{commandDir}, args...), nil); err != nil { return err } if o.NoCache { if err := os.RemoveAll(commandDir); err != nil { return fmt.Errorf("failed to clean cache: %w", err) } } return nil } func readExec(o ExecOptions, stdin io.Reader, stdout, stderr io.Writer) error { readCommand := commandReader(stdin) for { args, err := readCommand() if errors.Is(err, io.EOF) { return nil } if err != nil { return err } if len(args) == 0 { continue } if err := execWand(o, stdin, stdout, stderr, args); err != nil { fmt.Fprintln(stderr, err) } } } func execInput(o ExecOptions, stdin io.Reader, stdout, stderr io.Writer, args ...string) error { if len(args) == 0 { return readExec(o, stdin, stdout, stderr) } return execWand(o, stdin, stdout, stderr, args) } func Exec(o ExecOptions, stdin io.Reader, stdout io.Writer, args ...string) error { return execInput(o, stdin, stdout, os.Stderr, args...) }