diff --git a/scan.go b/scan.go index e42bbda..0456cfc 100644 --- a/scan.go +++ b/scan.go @@ -3,6 +3,7 @@ package bind import ( "errors" "fmt" + "math" "reflect" "strconv" "strings" @@ -53,27 +54,71 @@ func parseUint(s string, byteSize int) (uint64, error) { return intParse(strconv.ParseUint, s, byteSize) } +func floatToTime(f float64) time.Time { + s := int64(f) + ns := int64((f - float64(s)) * 1_000_000_000) + return time.Unix(s, ns) +} + func scanConvert(t reflect.Type, v any) (any, bool) { if isAny(t) { return v, true } r := reflect.ValueOf(v) - if !r.CanConvert(t) { + if r.CanConvert(t) { + return r.Convert(t).Interface(), true + } + + if t != reflect.TypeFor[time.Time]() { return nil, false } - return r.Convert(t).Interface(), true + fi := reflect.TypeFor[float64]() + if !r.CanConvert(fi) { + return nil, false + } + + f := r.Convert(fi).Interface().(float64) + return floatToTime(f), true } func parseTime(s string) (any, error) { + const errMsg = "failed to parse time" for _, l := range timeLayouts { if t, err := time.Parse(l, s); err == nil { return t, nil } } - return time.Time{}, errors.New("failed to parse time") + ss := strings.Split(s, ".") + if len(ss) > 2 { + return nil, errors.New(errMsg) + } + + if len(ss) == 1 { + i, err := parseInt(ss[0], 8) + if err != nil { + return nil, errors.New(errMsg) + } + + return time.Unix(i, 0), nil + } + + sec, err := parseInt(ss[0], 8) + if err != nil { + return nil, errors.New(errMsg) + } + + dec := strings.TrimLeft(ss[1], "0") + i, err := parseInt(dec, 8) + if err != nil { + return nil, errors.New(errMsg) + } + + zeros := len(ss[1]) - len(dec) + nanosec := int64(float64(i) * math.Pow10(9-zeros-len(dec))) + return time.Unix(sec, nanosec), nil } func scanString(t reflect.Type, s string) (any, bool) { @@ -117,7 +162,6 @@ func scanString(t reflect.Type, s string) (any, bool) { } func scan(t reflect.Type, v any) (any, bool) { - // time and duration support if vv, ok := scanConvert(t, v); ok { return vv, true } diff --git a/scan_test.go b/scan_test.go index d52af5d..0696f53 100644 --- a/scan_test.go +++ b/scan_test.go @@ -2,6 +2,7 @@ package bind_test import ( "code.squareroundforest.org/arpio/bind" + "fmt" "testing" "time" ) @@ -145,6 +146,71 @@ func TestScan(t *testing.T) { } }) + t.Run("time from int", func(t *testing.T) { + const timestamp = 1757097011 + var tim time.Time + if !bind.BindScalar(&tim, timestamp) || !tim.Equal(time.Unix(timestamp, 0)) { + t.Fatal() + } + }) + + t.Run("time from float", func(t *testing.T) { + const ( + seconds = 1757097011 + nanoseconds = 33000 + timestamp = float64(1757097011.000033) + ) + + var tim time.Time + if !bind.BindScalar(&tim, timestamp) || tim.Unix() != seconds { + t.Fatal(tim, tim.Unix(), tim.UnixNano()) + } + }) + + t.Run("time from int string", func(t *testing.T) { + const seconds = 1757097011 + var tim time.Time + if !bind.BindScalar(&tim, fmt.Sprint(seconds)) || tim.Unix() != seconds { + t.Fatal() + } + }) + + t.Run("time from float string", func(t *testing.T) { + const ( + seconds = 1757097011 + nanoseconds = 33000 + s = "1757097011.000033" + ) + + var tim time.Time + if !bind.BindScalar(&tim, s) || tim.Unix() != seconds || tim.Nanosecond() != nanoseconds { + t.Fatal(tim.Unix(), tim.Nanosecond()) + } + }) + + t.Run("time from invalid float string", func(t *testing.T) { + t.Run("multiple decimal points", func(t *testing.T) { + var tim time.Time + if bind.BindScalar(&tim, "24.235.23") { + t.Fatal() + } + }) + + t.Run("invalid round part", func(t *testing.T) { + var tim time.Time + if bind.BindScalar(&tim, "89s.01") { + t.Fatal() + } + }) + + t.Run("invalid fractional part", func(t *testing.T) { + var tim time.Time + if bind.BindScalar(&tim, "89.01s") { + t.Fatal() + } + }) + }) + t.Run("time invalid", func(t *testing.T) { var tim time.Time if bind.BindScalar(&tim, "foo") {