From 27b088817b89efa4e52325e346546a26144f49fc Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 14 Oct 2019 14:44:54 -0700 Subject: [PATCH 1/3] Add attribute value range checker helper function --- pkg/sql/codegen/attribute/checker.go | 101 ++++++++++++++++++++++ pkg/sql/codegen/attribute/checker_test.go | 55 ++++++++++++ pkg/sql/ir_generator.go | 3 +- 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 pkg/sql/codegen/attribute/checker.go create mode 100644 pkg/sql/codegen/attribute/checker_test.go diff --git a/pkg/sql/codegen/attribute/checker.go b/pkg/sql/codegen/attribute/checker.go new file mode 100644 index 0000000000..fa73853b0c --- /dev/null +++ b/pkg/sql/codegen/attribute/checker.go @@ -0,0 +1,101 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attribute + +import ( + "fmt" +) + +func newFloat32(f float32) *float32 { + return &f +} + +// Float32RangeChecker is a helper function to generate range checkers on attribute. +// lower/upper indicates the lower bound and upper bound of the attribute value. +// If lower/upper is nil, it means no boundary. +// includeLower/includeUpper indicates the inclusion of the bound. +func Float32RangeChecker(lower, upper *float32, includeLower, includeUpper bool) func(interface{}) error { + + checker := func(e interface{}) error { + f, ok := e.(float32) + if !ok { + return fmt.Errorf("expected type float32, received %T", e) + } + + // NOTE(tony): nil means no boundary + if lower != nil { + if includeLower && !(*lower <= f) { + return fmt.Errorf("range check %v <= %v failed", *lower, f) + } + if !includeLower && !(*lower < f) { + return fmt.Errorf("range check %v < %v failed", *lower, f) + } + } + + // NOTE(tony): nil means no boundary + if upper != nil { + if includeUpper && !(f <= *upper) { + return fmt.Errorf("range check %v <= %v failed", f, *upper) + } + if !includeUpper && !(f < *upper) { + return fmt.Errorf("range check %v < %v failed", f, *upper) + } + } + + return nil + } + + return checker +} + +func newInt(i int) *int { + return &i +} + +// IntRangeChecker is a helper function to generate range checkers on attribute. +// lower/upper indicates the lower bound and upper bound of the attribute value. +// If lower/upper is nil, it means no boundary. +// includeLower/includeUpper indicates the inclusion of the bound. +func IntRangeChecker(lower, upper *int, includeLower, includeUpper bool) func(interface{}) error { + checker := func(e interface{}) error { + i, ok := e.(int) + if !ok { + return fmt.Errorf("expected type float32, received %T", e) + } + + // NOTE(tony): nil means no boundary + if lower != nil { + if includeLower && !(*lower <= i) { + return fmt.Errorf("range check %v <= %v failed", *lower, i) + } + if !includeLower && !(*lower < i) { + return fmt.Errorf("range check %v < %v failed", *lower, i) + } + } + + // NOTE(tony): nil means no boundary + if upper != nil { + if includeUpper && !(i <= *upper) { + return fmt.Errorf("range check %v <= %v failed", i, *upper) + } + if !includeUpper && !(i < *upper) { + return fmt.Errorf("range check %v < %v failed", i, *upper) + } + } + + return nil + } + + return checker +} diff --git a/pkg/sql/codegen/attribute/checker_test.go b/pkg/sql/codegen/attribute/checker_test.go new file mode 100644 index 0000000000..12f2e481c0 --- /dev/null +++ b/pkg/sql/codegen/attribute/checker_test.go @@ -0,0 +1,55 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attribute + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFloat32RangeChecker(t *testing.T) { + a := assert.New(t) + + checker := Float32RangeChecker(newFloat32(0.0), newFloat32(1.0), true, true) + a.Error(checker(float32(-1))) + a.NoError(checker(float32(0))) + a.NoError(checker(float32(0.5))) + a.NoError(checker(float32(1))) + a.Error(checker(float32(2))) + + checker2 := Float32RangeChecker(newFloat32(0.0), newFloat32(1.0), false, false) + a.Error(checker2(float32(-1))) + a.Error(checker2(float32(0))) + a.NoError(checker2(float32(0.5))) + a.Error(checker2(float32(1))) + a.Error(checker2(float32(2))) +} + +func TestIntRangeChecker(t *testing.T) { + a := assert.New(t) + + checker := IntRangeChecker(newInt(0), newInt(2), true, true) + a.Error(checker(int(-1))) + a.NoError(checker(int(0))) + a.NoError(checker(int(1))) + a.NoError(checker(int(2))) + a.Error(checker(int(3))) + + checker2 := IntRangeChecker(newInt(0), newInt(2), false, false) + a.Error(checker2(int(-1))) + a.Error(checker2(int(0))) + a.NoError(checker2(int(1))) + a.Error(checker2(int(2))) + a.Error(checker2(int(3))) +} diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index 210affeb2d..52b8c1b243 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -146,8 +146,7 @@ func inferStringValue(expr string) interface{} { return ret } if retFloat, err := strconv.ParseFloat(expr, 32); err == nil { - // always use float32 for attributes, we may never use a float64 - // value as some attribute. + // Note(typhoonzero): always use float32 for attributes, we may never use a float64. return float32(retFloat) } retString := strings.Trim(expr, "\"") From 6b66e7a7f0e0178e98c84a818207165fb1f0a3fc Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 14 Oct 2019 16:00:07 -0700 Subject: [PATCH 2/3] Use pkg/sql/codegen/attribute in XGBoost TrainIR --- doc/design/design_xgboost_on_sqlflow.md | 2 +- pkg/sql/codegen/xgboost/codegen.go | 70 +++++++++---------------- pkg/sql/codegen/xgboost/codegen_test.go | 7 +-- 3 files changed, 31 insertions(+), 48 deletions(-) diff --git a/doc/design/design_xgboost_on_sqlflow.md b/doc/design/design_xgboost_on_sqlflow.md index 94ca1766a7..27fc570572 100644 --- a/doc/design/design_xgboost_on_sqlflow.md +++ b/doc/design/design_xgboost_on_sqlflow.md @@ -13,7 +13,7 @@ SELECT * FROM train_table TRAIN xgboost.gbtree WITH objective=multi:softmax, - train.num_round=2, + train.num_boost_round=2, max_depth=2, eta=1 LABEL class diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index cf2cd0ae46..b3e2a121d7 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -17,42 +17,33 @@ import ( "bytes" "encoding/json" "fmt" + "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" "strings" "sqlflow.org/sqlflow/pkg/sql/codegen" ) -var attributeChecker = map[string]func(interface{}) error{ - "eta": func(x interface{}) error { - switch x.(type) { - case float32, float64: - return nil - default: - return fmt.Errorf("eta should be of type float, received %T", x) - } - }, - "num_class": func(x interface{}) error { - switch x.(type) { - case int, int32, int64: - return nil - default: - return fmt.Errorf("num_class should be of type int, received %T", x) - } - }, - "train.num_boost_round": func(x interface{}) error { - switch x.(type) { - case int, int32, int64: - return nil - default: - return fmt.Errorf("train.num_boost_round should be of type int, received %T", x) - } - }, - "objective": func(x interface{}) error { - if _, ok := x.(string); !ok { - return fmt.Errorf("objective should be of type string, received %T", x) - } - return nil - }, +func newFloat32(f float32) *float32 { + return &f +} + +func newInt(i int) *int { + return &i +} + +// TODO(tony): complete model parameter and training parameter list +// model parameter list: https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters +// training parameter list: https://github.com/dmlc/xgboost/blob/b61d53447203ca7a321d72f6bdd3f553a3aa06c4/python-package/xgboost/training.py#L115-L117 +var attributeDictionary = attribute.Dictionary{ + "eta": {attribute.Float, `[default=0.3, alias: learning_rate] +Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative. +range: [0,1]`, attribute.Float32RangeChecker(newFloat32(0), newFloat32(1), true, true)}, + "num_class": {attribute.Int, `Number of classes. +range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)}, + "objective": {attribute.String, `Learning objective`, nil}, + "train.num_boost_round": {attribute.Int, `[default=10] +The number of rounds for boosting. +range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)}, } func resolveModelType(estimator string) (string, error) { @@ -69,22 +60,13 @@ func resolveModelType(estimator string) (string, error) { } func parseAttribute(attrs map[string]interface{}) (map[string]map[string]interface{}, error) { - attrNames := map[string]bool{} + if err := attributeDictionary.Validate(attrs); err != nil { + return nil, err + } params := map[string]map[string]interface{}{"": {}, "train.": {}} - paramPrefix := []string{"train.", ""} // use slice to assure traverse order + paramPrefix := []string{"train.", ""} // use slice to assure traverse order, this is necessary because all string starts with "" for key, attr := range attrs { - if _, ok := attrNames[key]; ok { - return nil, fmt.Errorf("duplicated attribute %s", key) - } - attrNames[key] = true - checker, ok := attributeChecker[key] - if !ok { - return nil, fmt.Errorf("unrecognized attribute %v", key) - } - if err := checker(attr); err != nil { - return nil, err - } for _, pp := range paramPrefix { if strings.HasPrefix(key, pp) { params[pp][key[len(pp):]] = attr diff --git a/pkg/sql/codegen/xgboost/codegen_test.go b/pkg/sql/codegen/xgboost/codegen_test.go index fedb3a8dff..8ab0d37b6a 100644 --- a/pkg/sql/codegen/xgboost/codegen_test.go +++ b/pkg/sql/codegen/xgboost/codegen_test.go @@ -49,9 +49,9 @@ func TestTrain(t *testing.T) { ValidationSelect: "select * from iris.test;", Estimator: "xgboost.gbtree", Attributes: map[string]interface{}{ - "train.num_boost_round": 30, + "train.num_boost_round": 10, "objective": "multi:softprob", - "eta": 3.1, + "eta": float32(0.1), "num_class": 3}, Features: map[string][]codegen.FeatureColumn{ "feature_columns": { @@ -60,6 +60,7 @@ func TestTrain(t *testing.T) { &codegen.NumericColumn{&codegen.FieldMeta{"petal_length", codegen.Float, "", []int{1}, false, nil}}, &codegen.NumericColumn{&codegen.FieldMeta{"petal_width", codegen.Float, "", []int{1}, false, nil}}}}, Label: &codegen.NumericColumn{&codegen.FieldMeta{"class", codegen.Int, "", []int{1}, false, nil}}} - _, err := Train(ir) + program, err := Train(ir) + fmt.Println(program) a.NoError(err) } From 37939cb543c39b32e10268c300c54dacdd2ee203 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 15 Oct 2019 11:01:47 -0700 Subject: [PATCH 3/3] clean up --- pkg/sql/codegen/xgboost/codegen_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/sql/codegen/xgboost/codegen_test.go b/pkg/sql/codegen/xgboost/codegen_test.go index 8ab0d37b6a..99184781ce 100644 --- a/pkg/sql/codegen/xgboost/codegen_test.go +++ b/pkg/sql/codegen/xgboost/codegen_test.go @@ -60,7 +60,6 @@ func TestTrain(t *testing.T) { &codegen.NumericColumn{&codegen.FieldMeta{"petal_length", codegen.Float, "", []int{1}, false, nil}}, &codegen.NumericColumn{&codegen.FieldMeta{"petal_width", codegen.Float, "", []int{1}, false, nil}}}}, Label: &codegen.NumericColumn{&codegen.FieldMeta{"class", codegen.Int, "", []int{1}, false, nil}}} - program, err := Train(ir) - fmt.Println(program) + _, err := Train(ir) a.NoError(err) }