From 0e16b2df3f7b18d9a584e73d74d6c23c1e9c0588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Mon, 9 Sep 2019 11:10:23 +0800 Subject: [PATCH] move xgboost objective to attr --- doc/xgboost_on_sqlflow_design.md | 10 +++++----- sql/codegen_xgboost.go | 20 +++++++++++++------- sql/codegen_xgboost_test.go | 3 ++- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/doc/xgboost_on_sqlflow_design.md b/doc/xgboost_on_sqlflow_design.md index d9f5841170..9f519746b5 100644 --- a/doc/xgboost_on_sqlflow_design.md +++ b/doc/xgboost_on_sqlflow_design.md @@ -10,8 +10,9 @@ To explain the benefit of integrating XGBoost with SQLFlow, let us start with an ``` sql SELECT * FROM train_table -TRAIN xgboost.multi.softmax +TRAIN xgboost.gbtree WITH + objective=multi:softmax, train.num_round=2, max_depth=2, eta=1 @@ -29,10 +30,9 @@ USING my_xgb_model; The the above examples, - `my_xgb_model` names the trained model. -- `xgboost.multi.softmax` is the model spec, where - - the prefix `xgboost.` tells the model is a XGBoost one, but not a Tensorflow model, and - - `multi.softmax` names an [XGBoost learning task](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters). -- In the `WITH` clause, +- `xgboost.gbtree` is the model name, to use a different model provided by XGBoost, use `xgboost.gblinear` or `xgboost.dart`, see: [here](https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters) for details. +- In the `WITH` clause, + - objective names an [XGBoost learning task](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters) - keys with the prefix `train.` identifies parameters of XGBoost API [`xgboost.train`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.train), and - keys without any prefix identifies [XGBoost Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) except the `objective` parameter, which was specified by the identifier after the keyword `TRAIN`, as explained above. diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index ec3305d1a0..ad76e47889 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -76,12 +76,15 @@ func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, erro return params, nil } -func resolveObjective(pr *extendedSelect) (string, error) { +func resolveModelName(pr *extendedSelect) (string, error) { estimatorParts := strings.Split(pr.estimator, ".") - if len(estimatorParts) != 3 { - return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part, current: %s", pr.estimator) + if len(estimatorParts) != 2 { + return "", fmt.Errorf("XGBoost Estimator should be xgboost.modelname, current: %s", pr.estimator) } - return strings.Join(estimatorParts[1:], ":"), nil + if strings.ToUpper(estimatorParts[1]) != "GBTREE" { + return "", fmt.Errorf("model name %s is not supported yet", estimatorParts[1]) + } + return estimatorParts[1], nil } func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFiller, error) { @@ -109,18 +112,21 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille } if isTrain { + objective := getStringAttr(attrs, "objective", "gbtree") // resolve the attribute keys without any prefix as the XGBoost Paremeters params, err := resolveParamsCfg(attrs) if err != nil { return nil, err } + params["objective"] = objective - // fill learning target - objective, err := resolveObjective(pr) + // get model name, could be gbtree, gblinear or dart. + // TODO(typhoonzero): only gbtree is supported here, use model name to generate + // differnet training code. + _, err = resolveModelName(pr) if err != nil { return nil, err } - params["objective"] = objective paramsJSON, err := json.Marshal(params) if err != nil { diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go index 2036bdebc0..4ae958ff38 100644 --- a/sql/codegen_xgboost_test.go +++ b/sql/codegen_xgboost_test.go @@ -23,8 +23,9 @@ import ( const testXGBoostTrainSelectIris = ` SELECT * FROM iris.train -TRAIN xgb.multi.softprob +TRAIN xgb.gbtree WITH + objective="multi:softprob", train.num_boost_round = 30, eta = 3.1, num_class = 3