wand/tools/execwand.go

327 lines
6.5 KiB
Go

package tools
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"hash/fnv"
"io"
"os"
"path"
"sort"
"strings"
"unicode"
)
const selfPkg = "code.squareroundforest.org/arpio/wand"
func commandReader(in io.Reader) func() ([]string, error) {
var (
yieldErr error
currentArg []rune
args []string
escapeOne, escapePartial, escapeFull bool
)
buf := bufio.NewReader(in)
return func() ([]string, error) {
if yieldErr != nil {
return nil, yieldErr
}
for {
r, _, err := buf.ReadRune()
if errors.Is(err, io.EOF) {
if len(currentArg) > 0 {
args = append(args, string(currentArg))
currentArg = nil
}
yield := args
args = nil
yieldErr = err
return yield, nil
}
if err != nil {
yieldErr = fmt.Errorf("failed reading from input: %w", err)
return nil, yieldErr
}
if r == unicode.ReplacementChar {
if len(currentArg) > 0 {
yieldErr = errors.New("broken unicode stream")
return nil, yieldErr
}
continue
}
if escapeFull {
if r == '\'' {
escapeFull = false
args = append(args, string(currentArg))
currentArg = nil
continue
}
currentArg = append(currentArg, r)
continue
}
if escapeOne {
escapeOne = false
currentArg = append(currentArg, r)
continue
}
if escapePartial {
if r == '\\' {
escapeOne = true
continue
}
if r == '"' {
escapePartial = false
args = append(args, string(currentArg))
currentArg = nil
continue
}
currentArg = append(currentArg, r)
continue
}
if r == '\n' {
if len(currentArg) > 0 {
args = append(args, string(currentArg))
currentArg = nil
}
yield := args
args = nil
return yield, nil
}
if unicode.IsSpace(r) {
if len(currentArg) > 0 {
args = append(args, string(currentArg))
currentArg = nil
}
continue
}
switch r {
case '\\':
escapeOne = true
case '"':
escapePartial = true
case '\'':
escapeFull = true
default:
currentArg = append(currentArg, r)
}
}
}
}
func hash(expression string, imports []string) (string, error) {
h := fnv.New128()
h.Write([]byte(expression))
sort.Strings(imports)
for _, i := range imports {
h.Write([]byte(i))
}
buf := bytes.NewBuffer(nil)
b64 := base64.NewEncoder(base64.RawURLEncoding, buf)
if _, err := b64.Write(h.Sum(nil)); err != nil {
return "", fmt.Errorf("failed to encode expression: %w", err)
}
if err := b64.Close(); err != nil {
return "", fmt.Errorf("failed to complete encoding of expression: %w", err)
}
return strings.TrimPrefix(buf.String(), "_"), nil
}
func printGoFile(fn string, expression string, imports []string) error {
f, err := os.Create(fn)
if err != nil {
return err
}
defer f.Close()
fprintf := func(format string, args ...any) {
if err != nil {
return
}
_, err = fmt.Fprintf(f, format, args...)
}
fprintf("package main\n")
for _, i := range imports {
ii := strings.Split(i, "=")
p := ii[0]
if len(ii) > 1 {
p = strings.Join(ii[1:], "=")
}
if len(ii) == 1 {
fprintf("import \"%s\"\n", strings.Split(p, "@")[0])
} else {
fprintf("import %s \"%s\"\n", ii[0], strings.Split(p, "@")[0])
}
}
fprintf("import \"%s\"\n", selfPkg)
fprintf("func main() {\n")
fprintf("wand.Exec(%s)\n", expression)
fprintf("}")
return err
}
func execWand(o ExecOptions, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
expression, args := args[0], args[1:]
commandHash, err := hash(expression, o.Import)
if err != nil {
return err
}
cacheDir := o.CacheDir
if cacheDir == "" {
cacheDir = path.Join(os.Getenv("HOME"), ".wand")
}
commandDir := path.Join(cacheDir, commandHash)
if o.NoCache {
commandDir = path.Join(cacheDir, "tmp", commandHash)
}
if o.NoCache || o.ClearCache {
if err := os.RemoveAll(commandDir); err != nil {
return fmt.Errorf("failed to clear cache: %w", err)
}
}
if err := os.MkdirAll(commandDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to ensure cache directory: %w", err)
}
wd, err := os.Getwd()
if err != nil {
return fmt.Errorf("error identifying current directory: %w", err)
}
goGet := func(pkg string) error {
if err := execInternal("go get", pkg); err != nil {
return fmt.Errorf("failed to get go module: %w", err)
}
return nil
}
goReplace := func(replace string) error {
if err := execInternal("go mod edit -replace", replace); err != nil {
return fmt.Errorf("failed to get go module: %w", err)
}
return nil
}
if err := os.Chdir(commandDir); err != nil {
return fmt.Errorf("failed to switch to temporary directory: %w", err)
}
defer os.Chdir(wd)
gomodPath := path.Join(commandDir, "go.mod")
if _, err := os.Stat(gomodPath); err != nil {
if err := execInternal("go mod init", commandHash); err != nil {
return fmt.Errorf("failed to initialize temporary module: %w", err)
}
var hasSelf bool
for _, pkg := range o.Import {
pp := strings.Split(pkg, "=")
p := pp[0]
if len(pp) > 1 {
p = strings.Join(pp[1:], "=")
}
if err := goGet(p); err != nil {
return err
}
if p == selfPkg {
hasSelf = true
}
}
if !hasSelf {
if err := goGet(selfPkg); err != nil {
return err
}
}
for _, replace := range o.ReplaceModule {
if err := goReplace(replace); err != nil {
return err
}
}
}
goFile := path.Join(commandDir, fmt.Sprintf("%s.go", commandHash))
if _, err := os.Stat(goFile); err != nil {
if err := printGoFile(goFile, expression, o.Import); err != nil {
return fmt.Errorf("failed to create temporary go file: %w", err)
}
}
if err := execc(stdin, stdout, stderr, "go run", append([]string{commandDir}, args...), nil); err != nil {
return err
}
if o.NoCache {
if err := os.RemoveAll(commandDir); err != nil {
return fmt.Errorf("failed to clean cache: %w", err)
}
}
return nil
}
func readExec(o ExecOptions, stdin io.Reader, stdout, stderr io.Writer) error {
readCommand := commandReader(stdin)
for {
args, err := readCommand()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}
if len(args) == 0 {
continue
}
if err := execWand(o, stdin, stdout, stderr, args); err != nil {
fmt.Fprintln(stderr, err)
}
}
}
func execInput(o ExecOptions, stdin io.Reader, stdout, stderr io.Writer, args ...string) error {
if len(args) == 0 {
return readExec(o, stdin, stdout, stderr)
}
return execWand(o, stdin, stdout, stderr, args)
}