Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions doc/xgboost_on_sqlflow_design.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ To explain the benefit of integrating XGBoost with SQLFlow, let us start with an

``` sql
SELECT * FROM train_table
TRAIN xgboost.multi.softmax
TRAIN xgboost.gbtree
WITH
objective=multi:softmax,
train.num_round=2,
max_depth=2,
eta=1
Expand All @@ -29,10 +30,9 @@ USING my_xgb_model;

The the above examples,
- `my_xgb_model` names the trained model.
- `xgboost.multi.softmax` is the model spec, where
- the prefix `xgboost.` tells the model is a XGBoost one, but not a Tensorflow model, and
- `multi.softmax` names an [XGBoost learning task](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters).
- In the `WITH` clause,
- `xgboost.gbtree` is the model name, to use a different model provided by XGBoost, use `xgboost.gblinear` or `xgboost.dart`, see: [here](https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters) for details.
- In the `WITH` clause,
- objective names an [XGBoost learning task](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters)
- keys with the prefix `train.` identifies parameters of XGBoost API [`xgboost.train`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.train), and
- keys without any prefix identifies [XGBoost Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) except the `objective` parameter, which was specified by the identifier after the keyword `TRAIN`, as explained above.

Expand Down
20 changes: 13 additions & 7 deletions sql/codegen_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, erro
return params, nil
}

func resolveObjective(pr *extendedSelect) (string, error) {
func resolveModelName(pr *extendedSelect) (string, error) {
estimatorParts := strings.Split(pr.estimator, ".")
if len(estimatorParts) != 3 {
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part, current: %s", pr.estimator)
if len(estimatorParts) != 2 {
return "", fmt.Errorf("XGBoost Estimator should be xgboost.modelname, current: %s", pr.estimator)
}
return strings.Join(estimatorParts[1:], ":"), nil
if strings.ToUpper(estimatorParts[1]) != "GBTREE" {
return "", fmt.Errorf("model name %s is not supported yet", estimatorParts[1])
}
return estimatorParts[1], nil
}

func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFiller, error) {
Expand Down Expand Up @@ -109,18 +112,21 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille
}

if isTrain {
objective := getStringAttr(attrs, "objective", "gbtree")
// resolve the attribute keys without any prefix as the XGBoost Paremeters
params, err := resolveParamsCfg(attrs)
if err != nil {
return nil, err
}
params["objective"] = objective

// fill learning target
objective, err := resolveObjective(pr)
// get model name, could be gbtree, gblinear or dart.
// TODO(typhoonzero): only gbtree is supported here, use model name to generate
// differnet training code.
_, err = resolveModelName(pr)
if err != nil {
return nil, err
}
params["objective"] = objective

paramsJSON, err := json.Marshal(params)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion sql/codegen_xgboost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
const testXGBoostTrainSelectIris = `
SELECT *
FROM iris.train
TRAIN xgboost.multi.softprob
TRAIN xgboost.gbtree
WITH
objective="multi:softprob",
train.num_boost_round = 30,
eta = 3.1,
num_class = 3
Expand Down