Skip to content

Commit f17364b

Browse files
Use pkg/sql/codegen/attribute in XGBoost TrainIR (#1005)
* Add attribute value range checker helper function * Use pkg/sql/codegen/attribute in XGBoost TrainIR * clean up
1 parent 32fb9df commit f17364b

File tree

3 files changed

+29
-47
lines changed

3 files changed

+29
-47
lines changed

doc/design/design_xgboost_on_sqlflow.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ SELECT * FROM train_table
1313
TO TRAIN xgboost.gbtree
1414
WITH
1515
objective=multi:softmax,
16-
train.num_round=2,
16+
train.num_boost_round=2,
1717
max_depth=2,
1818
eta=1
1919
LABEL class

pkg/sql/codegen/xgboost/codegen.go

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,33 @@ import (
1717
"bytes"
1818
"encoding/json"
1919
"fmt"
20+
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
2021
"strings"
2122

2223
"sqlflow.org/sqlflow/pkg/sql/codegen"
2324
)
2425

25-
var attributeChecker = map[string]func(interface{}) error{
26-
"eta": func(x interface{}) error {
27-
switch x.(type) {
28-
case float32, float64:
29-
return nil
30-
default:
31-
return fmt.Errorf("eta should be of type float, received %T", x)
32-
}
33-
},
34-
"num_class": func(x interface{}) error {
35-
switch x.(type) {
36-
case int, int32, int64:
37-
return nil
38-
default:
39-
return fmt.Errorf("num_class should be of type int, received %T", x)
40-
}
41-
},
42-
"train.num_boost_round": func(x interface{}) error {
43-
switch x.(type) {
44-
case int, int32, int64:
45-
return nil
46-
default:
47-
return fmt.Errorf("train.num_boost_round should be of type int, received %T", x)
48-
}
49-
},
50-
"objective": func(x interface{}) error {
51-
if _, ok := x.(string); !ok {
52-
return fmt.Errorf("objective should be of type string, received %T", x)
53-
}
54-
return nil
55-
},
26+
func newFloat32(f float32) *float32 {
27+
return &f
28+
}
29+
30+
func newInt(i int) *int {
31+
return &i
32+
}
33+
34+
// TODO(tony): complete model parameter and training parameter list
35+
// model parameter list: https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters
36+
// training parameter list: https://github.com/dmlc/xgboost/blob/b61d53447203ca7a321d72f6bdd3f553a3aa06c4/python-package/xgboost/training.py#L115-L117
37+
var attributeDictionary = attribute.Dictionary{
38+
"eta": {attribute.Float, `[default=0.3, alias: learning_rate]
39+
Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative.
40+
range: [0,1]`, attribute.Float32RangeChecker(newFloat32(0), newFloat32(1), true, true)},
41+
"num_class": {attribute.Int, `Number of classes.
42+
range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)},
43+
"objective": {attribute.String, `Learning objective`, nil},
44+
"train.num_boost_round": {attribute.Int, `[default=10]
45+
The number of rounds for boosting.
46+
range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)},
5647
}
5748

5849
func resolveModelType(estimator string) (string, error) {
@@ -69,22 +60,13 @@ func resolveModelType(estimator string) (string, error) {
6960
}
7061

7162
func parseAttribute(attrs map[string]interface{}) (map[string]map[string]interface{}, error) {
72-
attrNames := map[string]bool{}
63+
if err := attributeDictionary.Validate(attrs); err != nil {
64+
return nil, err
65+
}
7366

7467
params := map[string]map[string]interface{}{"": {}, "train.": {}}
75-
paramPrefix := []string{"train.", ""} // use slice to assure traverse order
68+
paramPrefix := []string{"train.", ""} // use slice to assure traverse order, this is necessary because all string starts with ""
7669
for key, attr := range attrs {
77-
if _, ok := attrNames[key]; ok {
78-
return nil, fmt.Errorf("duplicated attribute %s", key)
79-
}
80-
attrNames[key] = true
81-
checker, ok := attributeChecker[key]
82-
if !ok {
83-
return nil, fmt.Errorf("unrecognized attribute %v", key)
84-
}
85-
if err := checker(attr); err != nil {
86-
return nil, err
87-
}
8870
for _, pp := range paramPrefix {
8971
if strings.HasPrefix(key, pp) {
9072
params[pp][key[len(pp):]] = attr

pkg/sql/codegen/xgboost/codegen_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ func TestTrain(t *testing.T) {
4949
ValidationSelect: "select * from iris.test;",
5050
Estimator: "xgboost.gbtree",
5151
Attributes: map[string]interface{}{
52-
"train.num_boost_round": 30,
52+
"train.num_boost_round": 10,
5353
"objective": "multi:softprob",
54-
"eta": 3.1,
54+
"eta": float32(0.1),
5555
"num_class": 3},
5656
Features: map[string][]codegen.FeatureColumn{
5757
"feature_columns": {

0 commit comments

Comments
 (0)