diff --git a/doc/design/design_xgboost_on_sqlflow.md b/doc/design/design_xgboost_on_sqlflow.md index 94ca1766a7..27fc570572 100644 --- a/doc/design/design_xgboost_on_sqlflow.md +++ b/doc/design/design_xgboost_on_sqlflow.md @@ -13,7 +13,7 @@ SELECT * FROM train_table TRAIN xgboost.gbtree WITH objective=multi:softmax, - train.num_round=2, + train.num_boost_round=2, max_depth=2, eta=1 LABEL class diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index cf2cd0ae46..b3e2a121d7 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -17,42 +17,33 @@ import ( "bytes" "encoding/json" "fmt" + "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" "strings" "sqlflow.org/sqlflow/pkg/sql/codegen" ) -var attributeChecker = map[string]func(interface{}) error{ - "eta": func(x interface{}) error { - switch x.(type) { - case float32, float64: - return nil - default: - return fmt.Errorf("eta should be of type float, received %T", x) - } - }, - "num_class": func(x interface{}) error { - switch x.(type) { - case int, int32, int64: - return nil - default: - return fmt.Errorf("num_class should be of type int, received %T", x) - } - }, - "train.num_boost_round": func(x interface{}) error { - switch x.(type) { - case int, int32, int64: - return nil - default: - return fmt.Errorf("train.num_boost_round should be of type int, received %T", x) - } - }, - "objective": func(x interface{}) error { - if _, ok := x.(string); !ok { - return fmt.Errorf("objective should be of type string, received %T", x) - } - return nil - }, +func newFloat32(f float32) *float32 { + return &f +} + +func newInt(i int) *int { + return &i +} + +// TODO(tony): complete model parameter and training parameter list +// model parameter list: https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters +// training parameter list: https://github.com/dmlc/xgboost/blob/b61d53447203ca7a321d72f6bdd3f553a3aa06c4/python-package/xgboost/training.py#L115-L117 +var attributeDictionary = attribute.Dictionary{ + "eta": {attribute.Float, `[default=0.3, alias: learning_rate] +Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative. +range: [0,1]`, attribute.Float32RangeChecker(newFloat32(0), newFloat32(1), true, true)}, + "num_class": {attribute.Int, `Number of classes. +range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)}, + "objective": {attribute.String, `Learning objective`, nil}, + "train.num_boost_round": {attribute.Int, `[default=10] +The number of rounds for boosting. +range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)}, } func resolveModelType(estimator string) (string, error) { @@ -69,22 +60,13 @@ func resolveModelType(estimator string) (string, error) { } func parseAttribute(attrs map[string]interface{}) (map[string]map[string]interface{}, error) { - attrNames := map[string]bool{} + if err := attributeDictionary.Validate(attrs); err != nil { + return nil, err + } params := map[string]map[string]interface{}{"": {}, "train.": {}} - paramPrefix := []string{"train.", ""} // use slice to assure traverse order + paramPrefix := []string{"train.", ""} // use slice to assure traverse order, this is necessary because all string starts with "" for key, attr := range attrs { - if _, ok := attrNames[key]; ok { - return nil, fmt.Errorf("duplicated attribute %s", key) - } - attrNames[key] = true - checker, ok := attributeChecker[key] - if !ok { - return nil, fmt.Errorf("unrecognized attribute %v", key) - } - if err := checker(attr); err != nil { - return nil, err - } for _, pp := range paramPrefix { if strings.HasPrefix(key, pp) { params[pp][key[len(pp):]] = attr diff --git a/pkg/sql/codegen/xgboost/codegen_test.go b/pkg/sql/codegen/xgboost/codegen_test.go index fedb3a8dff..99184781ce 100644 --- a/pkg/sql/codegen/xgboost/codegen_test.go +++ b/pkg/sql/codegen/xgboost/codegen_test.go @@ -49,9 +49,9 @@ func TestTrain(t *testing.T) { ValidationSelect: "select * from iris.test;", Estimator: "xgboost.gbtree", Attributes: map[string]interface{}{ - "train.num_boost_round": 30, + "train.num_boost_round": 10, "objective": "multi:softprob", - "eta": 3.1, + "eta": float32(0.1), "num_class": 3}, Features: map[string][]codegen.FeatureColumn{ "feature_columns": {