diff --git a/pkg/sql/codegen/attribute/checker.go b/pkg/sql/codegen/attribute/checker.go new file mode 100644 index 0000000000..fa73853b0c --- /dev/null +++ b/pkg/sql/codegen/attribute/checker.go @@ -0,0 +1,101 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attribute + +import ( + "fmt" +) + +func newFloat32(f float32) *float32 { + return &f +} + +// Float32RangeChecker is a helper function to generate range checkers on attribute. +// lower/upper indicates the lower bound and upper bound of the attribute value. +// If lower/upper is nil, it means no boundary. +// includeLower/includeUpper indicates the inclusion of the bound. +func Float32RangeChecker(lower, upper *float32, includeLower, includeUpper bool) func(interface{}) error { + + checker := func(e interface{}) error { + f, ok := e.(float32) + if !ok { + return fmt.Errorf("expected type float32, received %T", e) + } + + // NOTE(tony): nil means no boundary + if lower != nil { + if includeLower && !(*lower <= f) { + return fmt.Errorf("range check %v <= %v failed", *lower, f) + } + if !includeLower && !(*lower < f) { + return fmt.Errorf("range check %v < %v failed", *lower, f) + } + } + + // NOTE(tony): nil means no boundary + if upper != nil { + if includeUpper && !(f <= *upper) { + return fmt.Errorf("range check %v <= %v failed", f, *upper) + } + if !includeUpper && !(f < *upper) { + return fmt.Errorf("range check %v < %v failed", f, *upper) + } + } + + return nil + } + + return checker +} + +func newInt(i int) *int { + return &i +} + +// IntRangeChecker is a helper function to generate range checkers on attribute. +// lower/upper indicates the lower bound and upper bound of the attribute value. +// If lower/upper is nil, it means no boundary. +// includeLower/includeUpper indicates the inclusion of the bound. +func IntRangeChecker(lower, upper *int, includeLower, includeUpper bool) func(interface{}) error { + checker := func(e interface{}) error { + i, ok := e.(int) + if !ok { + return fmt.Errorf("expected type float32, received %T", e) + } + + // NOTE(tony): nil means no boundary + if lower != nil { + if includeLower && !(*lower <= i) { + return fmt.Errorf("range check %v <= %v failed", *lower, i) + } + if !includeLower && !(*lower < i) { + return fmt.Errorf("range check %v < %v failed", *lower, i) + } + } + + // NOTE(tony): nil means no boundary + if upper != nil { + if includeUpper && !(i <= *upper) { + return fmt.Errorf("range check %v <= %v failed", i, *upper) + } + if !includeUpper && !(i < *upper) { + return fmt.Errorf("range check %v < %v failed", i, *upper) + } + } + + return nil + } + + return checker +} diff --git a/pkg/sql/codegen/attribute/checker_test.go b/pkg/sql/codegen/attribute/checker_test.go new file mode 100644 index 0000000000..12f2e481c0 --- /dev/null +++ b/pkg/sql/codegen/attribute/checker_test.go @@ -0,0 +1,55 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attribute + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFloat32RangeChecker(t *testing.T) { + a := assert.New(t) + + checker := Float32RangeChecker(newFloat32(0.0), newFloat32(1.0), true, true) + a.Error(checker(float32(-1))) + a.NoError(checker(float32(0))) + a.NoError(checker(float32(0.5))) + a.NoError(checker(float32(1))) + a.Error(checker(float32(2))) + + checker2 := Float32RangeChecker(newFloat32(0.0), newFloat32(1.0), false, false) + a.Error(checker2(float32(-1))) + a.Error(checker2(float32(0))) + a.NoError(checker2(float32(0.5))) + a.Error(checker2(float32(1))) + a.Error(checker2(float32(2))) +} + +func TestIntRangeChecker(t *testing.T) { + a := assert.New(t) + + checker := IntRangeChecker(newInt(0), newInt(2), true, true) + a.Error(checker(int(-1))) + a.NoError(checker(int(0))) + a.NoError(checker(int(1))) + a.NoError(checker(int(2))) + a.Error(checker(int(3))) + + checker2 := IntRangeChecker(newInt(0), newInt(2), false, false) + a.Error(checker2(int(-1))) + a.Error(checker2(int(0))) + a.NoError(checker2(int(1))) + a.Error(checker2(int(2))) + a.Error(checker2(int(3))) +} diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index 210affeb2d..52b8c1b243 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -146,8 +146,7 @@ func inferStringValue(expr string) interface{} { return ret } if retFloat, err := strconv.ParseFloat(expr, 32); err == nil { - // always use float32 for attributes, we may never use a float64 - // value as some attribute. + // Note(typhoonzero): always use float32 for attributes, we may never use a float64. return float32(retFloat) } retString := strings.Trim(expr, "\"")