diff --git a/ranges/daterange.go b/ranges/daterange.go new file mode 100644 index 00000000..382242f8 --- /dev/null +++ b/ranges/daterange.go @@ -0,0 +1,104 @@ +package ranges + +import ( + "database/sql/driver" + "errors" + "fmt" + "time" + + "github.com/lib/pq" +) + +func isTimeZero(t time.Time) bool { + return t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 && t.Nanosecond() == 0 +} + +// DateRange represents a range between two dates where the lower is inclusive +// and the upper exclusive. +type DateRange struct { + Lower time.Time + Upper time.Time +} + +// Scan implements the sql.Scanner interface +func (r *DateRange) Scan(val interface{}) error { + var ( + err error + minb, maxb []byte + ) + + if val == nil { + return errors.New("cannot scan NULL into *DateRange") + } + minb, maxb, err = readDiscreteTimeRange(val.([]byte)) + if err != nil { + return errors.New("could not scan date range: " + err.Error()) + } + + if len(minb) == 0 { + r.Lower = time.Time{} + } else { + r.Lower, err = pq.ParseTimestamp(nil, string(minb)) + if err != nil { + return errors.New("could not parse lower date:" + err.Error()) + } + if !isTimeZero(r.Lower) { + return errors.New("time component of lower date is not zero") + } + } + + if len(maxb) == 0 { + r.Upper = time.Time{} + } else { + r.Upper, err = pq.ParseTimestamp(nil, string(maxb)) + if err != nil { + return errors.New("could not parse upper date:" + err.Error()) + } + if !isTimeZero(r.Upper) { + return errors.New("time component of upper date is not zero") + } + } + + return nil +} + +// IsLowerInfinity returns whether the lower value is negative infinity +func (r DateRange) IsLowerInfinity() bool { + return r.Lower.IsZero() +} + +// IsUpperInfinity returns whether the upper value is positive infinity +func (r DateRange) IsUpperInfinity() bool { + return r.Upper.IsZero() +} + +// Value implements the driver.Value interface +func (r DateRange) Value() (driver.Value, error) { + if !isTimeZero(r.Lower) { + return nil, errors.New("time component of lower date is not zero") + } + if !isTimeZero(r.Upper) { + return nil, errors.New("time component of upper date is not zero") + } + if r.Lower.After(r.Upper) { + return nil, errors.New("lower date is after upper date") + } + return []byte(r.String()), nil +} + +// Returns the date range as a string where the dates are formatted according +// to ISO8601 +func (r DateRange) String() string { + var ( + open = '(' + lower, upper string + ) + if !r.Lower.IsZero() { + lower = r.Lower.Format("2006-01-02") + open = '[' + } + if !r.Upper.IsZero() { + upper = r.Upper.Format("2006-01-02") + } + return fmt.Sprintf("%c%s,%s)", open, lower, upper) +} diff --git a/ranges/daterange_test.go b/ranges/daterange_test.go new file mode 100644 index 00000000..c1665789 --- /dev/null +++ b/ranges/daterange_test.go @@ -0,0 +1,63 @@ +package ranges + +import ( + "testing" + "time" +) + +func TestDateRangeScan(t *testing.T) { + test := func(input string, lowers, uppers string) { + r := DateRange{} + if err := r.Scan([]byte(input)); err != nil { + t.Fatalf("unexpected error: " + err.Error()) + } + lower, _ := time.Parse("2006-01-02", lowers) + upper, _ := time.Parse("2006-01-02", uppers) + if !r.Lower.Equal(lower) { + t.Errorf("expected lower date '%v', got '%v'", lower, r.Lower) + } + if !r.Upper.Equal(upper) { + t.Errorf("expected upper date '%v', got '%v'", upper, r.Upper) + } + } + + test("[2000-01-01,2017-05-09)", "2000-01-01", "2017-05-09") + test("[2000-01-01,)", "2000-01-01", "0001-01-01") + test("[,2000-01-01)", "0001-01-01", "2000-01-01") +} + +func TestDateRangeString(t *testing.T) { + test := func(lowers, uppers string, expect string) { + var lower, upper time.Time + if lowers != "" { + lower, _ = time.Parse("2006-01-02", lowers) + } + if uppers != "" { + upper, _ = time.Parse("2006-01-02", uppers) + } + if s := (DateRange{lower, upper}).String(); s != expect { + t.Errorf("expected '%s', got '%s'", expect, s) + } + } + + test("2001-06-02", "2007-05-04", "[2001-06-02,2007-05-04)") + test("2001-06-02", "", "[2001-06-02,)") + test("", "2001-06-02", "(,2001-06-02)") + test("", "", "(,)") +} + +func TestDateRangeValueError(t *testing.T) { + expectError := func(lowers, uppers string) { + lower, _ := time.Parse("2006-01-02 15:04:05", lowers) + upper, _ := time.Parse("2006-01-02 15:04:05", uppers) + r := DateRange{lower, upper} + if _, err := r.Value(); err == nil { + t.Errorf("expected an error for '%s' but did not get one", r.String()) + } + } + + expectError("2001-01-02 00:00:00", "2001-01-01 00:00:00") + expectError("2001-02-01 00:00:00", "2001-01-01 00:00:00") + expectError("2001-02-01 12:00:03", "2001-01-01 00:00:00") + expectError("2001-02-01 00:00:00", "2001-01-01 13:00:00") +} diff --git a/ranges/float64range.go b/ranges/float64range.go new file mode 100644 index 00000000..763e049c --- /dev/null +++ b/ranges/float64range.go @@ -0,0 +1,61 @@ +package ranges + +import ( + "database/sql/driver" + "fmt" + "strconv" +) + +// Float64Range represents a range between two float64 values +type Float64Range struct { + Lower float64 + LowerInclusive bool + Upper float64 + UpperInclusive bool +} + +// Scan implements the sql.Scanner interface +func (r *Float64Range) Scan(val interface{}) error { + if val == nil { + r.Lower = 0 + r.LowerInclusive = false + r.Upper = 0 + r.UpperInclusive = false + return nil + } + lowerIn, upperIn, lower, upper, err := readRange(val.([]byte)) + if err != nil { + return err + } + r.Lower, err = strconv.ParseFloat(string(lower), 64) + if err != nil { + return err + } + r.Upper, err = strconv.ParseFloat(string(upper), 64) + if err != nil { + return err + } + r.LowerInclusive = lowerIn + r.UpperInclusive = upperIn + return nil +} + +// Value implements the driver.Valuer interface +func (r Float64Range) Value() (driver.Value, error) { + return []byte(r.String()), nil +} + +// String returns a string representation of this range +func (r Float64Range) String() string { + var ( + open = "(" + close = ")" + ) + if r.LowerInclusive { + open = "[" + } + if r.UpperInclusive { + close = "]" + } + return fmt.Sprintf("%s%f,%f%s", open, r.Lower, r.Upper, close) +} diff --git a/ranges/float64range_test.go b/ranges/float64range_test.go new file mode 100644 index 00000000..edbcf559 --- /dev/null +++ b/ranges/float64range_test.go @@ -0,0 +1,18 @@ +package ranges + +import ( + "testing" +) + +func TestFloat64RangeString(t *testing.T) { + test := func(lower, upper float64, lowerIn, upperIn bool, expect string) { + s := Float64Range{lower, lowerIn, upper, upperIn}.String() + if s != expect { + t.Errorf("expected '%s', got '%s'", expect, s) + } + } + + test(-1.0, 2.1, false, true, "(-1.000000,2.100000]") + test(9.99, 0.01, true, true, "[9.990000,0.010000]") + test(80.0, 90.0, false, false, "(80.000000,90.000000)") +} diff --git a/ranges/int32range.go b/ranges/int32range.go new file mode 100644 index 00000000..a78e9422 --- /dev/null +++ b/ranges/int32range.go @@ -0,0 +1,41 @@ +package ranges + +import ( + "database/sql/driver" + "errors" + "fmt" +) + +// Int32Range represents a range between two int32 values. The lower value is +// inclusive and the upper is exclusive. +type Int32Range struct { + Lower int32 + Upper int32 +} + +// Scan implements the sql.Scanner interface +func (r *Int32Range) Scan(val interface{}) error { + if val == nil { + return errors.New("cannot scan NULL into *Int32Range") + } + l, u, err := parseIntRange(val.([]byte), 32) + if err != nil { + return err + } + r.Lower = int32(l) + r.Upper = int32(u) + return nil +} + +// Value implements the driver.Valuer interface +func (r Int32Range) Value() (driver.Value, error) { + if r.Lower > r.Upper { + return nil, errors.New("lower value is greater than the upper value") + } + return []byte(r.String()), nil +} + +// String returns a string representation of this range +func (r Int32Range) String() string { + return fmt.Sprintf("[%d,%d)", r.Lower, r.Upper) +} diff --git a/ranges/int32range_test.go b/ranges/int32range_test.go new file mode 100644 index 00000000..635a4946 --- /dev/null +++ b/ranges/int32range_test.go @@ -0,0 +1,32 @@ +package ranges + +import ( + "testing" +) + +func TestInt32RangeString(t *testing.T) { + test := func(lower, upper int32, expect string) { + s := Int32Range{lower, upper}.String() + if s != expect { + t.Errorf("expected '%s', got '%s'", expect, s) + } + } + + test(0, 2, "[0,2)") + test(0, 0, "[0,0)") + test(-2, 8, "[-2,8)") + test(8, -2, "[8,-2)") +} + +func TestInt32RangeValue(t *testing.T) { + expectError := func(lower, upper int32) { + r := Int32Range{lower, upper} + if _, err := r.Value(); err == nil { + t.Errorf("expected an error for '%s' but did not get one", r.String()) + } + } + + expectError(2, 0) + expectError(8, -4) + expectError(-8, -9) +} diff --git a/ranges/int64range.go b/ranges/int64range.go new file mode 100644 index 00000000..996e3c77 --- /dev/null +++ b/ranges/int64range.go @@ -0,0 +1,41 @@ +package ranges + +import ( + "database/sql/driver" + "errors" + "fmt" +) + +// Int64Range represents a range between two int64 values. The lower value is +// inclusive and the upper is exclusive. +type Int64Range struct { + Lower int64 + Upper int64 +} + +// Scan implements the sql.Scanner interface +func (r *Int64Range) Scan(val interface{}) error { + if val == nil { + return errors.New("cannot scan NULL into *Int64Range") + } + l, u, err := parseIntRange(val.([]byte), 64) + if err != nil { + return err + } + r.Lower = l + r.Upper = u + return nil +} + +// Value implements the driver.Valuer interface +func (r Int64Range) Value() (driver.Value, error) { + if r.Lower > r.Upper { + return nil, errors.New("lower value is greater than the upper value") + } + return []byte(r.String()), nil +} + +// String returns a string representation of this range +func (r Int64Range) String() string { + return fmt.Sprintf("[%d,%d)", r.Lower, r.Upper) +} diff --git a/ranges/int64range_test.go b/ranges/int64range_test.go new file mode 100644 index 00000000..389ff07f --- /dev/null +++ b/ranges/int64range_test.go @@ -0,0 +1,32 @@ +package ranges + +import ( + "testing" +) + +func TestInt64RangeString(t *testing.T) { + test := func(lower, upper int64, expect string) { + s := Int64Range{lower, upper}.String() + if s != expect { + t.Errorf("expected '%s', got '%s'", expect, s) + } + } + + test(0, 2, "[0,2)") + test(0, 0, "[0,0)") + test(-2, 8, "[-2,8)") + test(8, -2, "[8,-2)") +} + +func TestInt64RangeValue(t *testing.T) { + expectError := func(lower, upper int64) { + r := Int64Range{lower, upper} + if _, err := r.Value(); err == nil { + t.Errorf("expected an error for '%s' but did not get one", r.String()) + } + } + + expectError(2, 0) + expectError(8, -4) + expectError(-8, -9) +} diff --git a/ranges/intrange.go b/ranges/intrange.go new file mode 100644 index 00000000..5a5f93f0 --- /dev/null +++ b/ranges/intrange.go @@ -0,0 +1,25 @@ +package ranges + +import ( + "errors" + "strconv" +) + +func parseIntRange(buf []byte, bitSize int) (int64, int64, error) { + lowerb, upperb, err := readDiscreteRange(buf) + if err != nil { + return 0, 0, err + } + lower, err := strconv.ParseInt(string(lowerb), 10, bitSize) + if err != nil { + return 0, 0, err + } + upper, err := strconv.ParseInt(string(upperb), 10, bitSize) + if err != nil { + return 0, 0, err + } + if lower > upper { + return 0, 0, errors.New("lower value is greater than the upper value") + } + return lower, upper, nil +} diff --git a/ranges/parsing.go b/ranges/parsing.go new file mode 100644 index 00000000..b42f86b7 --- /dev/null +++ b/ranges/parsing.go @@ -0,0 +1,147 @@ +package ranges + +import ( + "errors" + "fmt" +) + +func readNumber(buf []byte, pos int) ([]byte, int, error) { + var ( + s []byte + b byte + canEnd = false + inMantissa bool + ) + for pos < len(buf) { + b = buf[pos] + if b == '-' && len(s) == 0 { + s = append(s, b) + canEnd = false + } else if b >= 48 && b <= 57 { + s = append(s, b) + canEnd = true + } else if b == '.' && !inMantissa { + s = append(s, b) + canEnd = false + } else { + break + } + pos++ + } + if !canEnd { + return s, pos, fmt.Errorf("unexpected character '%c' at position %d", b, pos) + } + return s, pos, nil +} + +func readByte(buf []byte, pos int, expect byte) (int, error) { + if pos >= len(buf) { + return pos, fmt.Errorf("unexpected end of input at position %d", pos) + } + if buf[pos] != expect { + return pos, fmt.Errorf("unexpected character '%c' at position %d", buf[pos], pos) + } + return pos + 1, nil +} + +func readRangeBound(buf []byte, pos int, incl, excl byte) (bool, int, error) { + if pos >= len(buf) { + return false, 0, fmt.Errorf("unexpected end of input at position %d", pos) + } + switch buf[pos] { + case incl: + return true, pos + 1, nil + case excl: + return false, pos + 1, nil + default: + return false, pos, fmt.Errorf("unexpected character '%c' at position %d", buf[pos], pos) + } +} + +func readRange(buf []byte) (minIncl bool, maxIncl bool, min []byte, max []byte, err error) { + var pos int + minIncl, pos, err = readRangeBound(buf, pos, '[', '(') + if err != nil { + return + } + min, pos, err = readNumber(buf, pos) + if err != nil { + return + } + pos, err = readByte(buf, pos, ',') + if err != nil { + return + } + max, pos, err = readNumber(buf, pos) + if err != nil { + return + } + maxIncl, pos, err = readRangeBound(buf, pos, ']', ')') + if err != nil { + return + } + return +} + +func readUntilTerminator(buf []byte, pos int, term byte) ([]byte, int, error) { + var s []byte + for pos < len(buf) && buf[pos] != term { + s = append(s, buf[pos]) + pos++ + } + return s, pos, nil +} + +func readDiscreteRange(buf []byte) (min []byte, max []byte, err error) { + var pos int + pos, err = readByte(buf, pos, '[') + if err != nil { + return + } + min, pos, err = readNumber(buf, pos) + if err != nil { + return + } + pos, err = readByte(buf, pos, ',') + if err != nil { + return + } + max, pos, err = readNumber(buf, pos) + if err != nil { + return + } + pos, err = readByte(buf, pos, ')') + if err != nil { + return + } + return +} + +func readDiscreteTimeRange(buf []byte) (min []byte, max []byte, err error) { + var pos int + minIn, pos, err := readRangeBound(buf, pos, '[', '(') + if err != nil { + return + } + min, pos, err = readUntilTerminator(buf, pos, ',') + if err != nil { + return + } + if !minIn && len(min) != 0 { + err = errors.New("lower value is marked as exclusive but does not have an empty value") + return + } + pos, err = readByte(buf, pos, ',') + if err != nil { + return + } + max, pos, err = readUntilTerminator(buf, pos, ')') + if err != nil { + return + } + pos, err = readByte(buf, pos, ')') + if err != nil { + return + } + return +} diff --git a/ranges/parsing_test.go b/ranges/parsing_test.go new file mode 100644 index 00000000..de7f76f3 --- /dev/null +++ b/ranges/parsing_test.go @@ -0,0 +1,39 @@ +package ranges + +import ( + "testing" +) + +func TestReadRange(t *testing.T) { + cases := []struct { + Input string + MinIn bool + MaxIn bool + Min string + Max string + }{ + {"[-1.23,98.0]", true, true, "-1.23", "98.0"}, + {"(1,2]", false, true, "1", "2"}, + {"[0,0.0]", true, true, "0", "0.0"}, + {"(1.29,-0.5)", false, false, "1.29", "-0.5"}, + } + + for _, tc := range cases { + minIn, maxIn, min, max, err := readRange([]byte(tc.Input)) + if err != nil { + t.Fatalf("unexpected error: " + err.Error()) + } + if minIn != tc.MinIn { + t.Fatalf("expected min to be inclusive=%t, got %t", tc.MinIn, minIn) + } + if maxIn != tc.MaxIn { + t.Fatalf("expected max to be inclusive=%t, got %t", tc.MaxIn, maxIn) + } + if string(min) != tc.Min { + t.Fatalf("expected min to be '%s', got '%s'", tc.Min, min) + } + if string(max) != tc.Max { + t.Fatalf("expected max to be '%s', got '%s'", tc.Max, max) + } + } +}