121 lines
2.3 KiB
Go
121 lines
2.3 KiB
Go
package wand
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
type testCase struct {
|
|
impl any
|
|
stdin string
|
|
conf string
|
|
mergeConfTyped []Config
|
|
mergeConf []string
|
|
env string
|
|
command string
|
|
contains bool
|
|
}
|
|
|
|
func testExec(test testCase, err string, expect ...string) func(*testing.T) {
|
|
return func(t *testing.T) {
|
|
var exitCode int
|
|
exit := func(code int) { exitCode = code }
|
|
|
|
var stdin io.Reader
|
|
if test.stdin == "" {
|
|
stdin = bytes.NewBuffer(nil)
|
|
} else {
|
|
stdin = bytes.NewBufferString(test.stdin)
|
|
}
|
|
|
|
stdout := bytes.NewBuffer(nil)
|
|
stderr := bytes.NewBuffer(nil)
|
|
cmd := wrap(test.impl)
|
|
e := strings.Split(test.env, ";")
|
|
a := strings.Split(test.command, " ")
|
|
|
|
zeroOrOne := func(b ...bool) bool {
|
|
var one bool
|
|
for _, bi := range b {
|
|
if bi && one {
|
|
return false
|
|
}
|
|
|
|
if bi {
|
|
one = true
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
if !zeroOrOne(test.conf != "", len(test.mergeConf) > 0, len(test.mergeConfTyped) > 0) {
|
|
t.Fatal("test error: conflicting test config")
|
|
}
|
|
|
|
var conf Config
|
|
if test.conf != "" {
|
|
conf = Config{test: test.conf}
|
|
} else if len(test.mergeConf) > 0 {
|
|
var c []Config
|
|
for _, cs := range test.mergeConf {
|
|
c = append(c, Config{test: cs})
|
|
}
|
|
|
|
conf = MergeConfig(c...)
|
|
} else if len(test.mergeConfTyped) > 0 {
|
|
conf = MergeConfig(test.mergeConfTyped...)
|
|
}
|
|
|
|
exec(stdin, stdout, stderr, exit, cmd, conf, e, a)
|
|
if exitCode != 0 && err == "" {
|
|
t.Fatal("non-zero exit code:", stderr.String())
|
|
}
|
|
|
|
if err != "" && exitCode == 0 {
|
|
t.Fatal("failed to fail")
|
|
}
|
|
|
|
if err != "" && !strings.Contains(stderr.String(), err) {
|
|
t.Fatal("expected error not received:", stderr.String())
|
|
}
|
|
|
|
if exitCode != 0 {
|
|
return
|
|
}
|
|
|
|
var expstr []string
|
|
for _, e := range expect {
|
|
expstr = append(expstr, fmt.Sprint(e))
|
|
}
|
|
|
|
output := stdout.String()
|
|
if output[len(output)-1] != '\n' {
|
|
output = output + "\n"
|
|
}
|
|
|
|
checkOutput := func() bool {
|
|
return output == strings.Join(expstr, "\n")+"\n"
|
|
}
|
|
|
|
if test.contains {
|
|
checkOutput = func() bool {
|
|
for _, e := range expstr {
|
|
if !strings.Contains(output, e) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
}
|
|
|
|
if !checkOutput() {
|
|
t.Fatal("unexpected output:", output)
|
|
}
|
|
}
|
|
}
|