diff --git a/sql/codegen_ant_xgboost.go b/sql/codegen_ant_xgboost.go index 4eee99804c..dcc2302da9 100644 --- a/sql/codegen_ant_xgboost.go +++ b/sql/codegen_ant_xgboost.go @@ -795,7 +795,7 @@ func xgCreatePredictionTable(pr *extendedSelect, r *antXGBoostFiller, db *DB) er return nil } -func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { +func genAntXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { r, e := newAntXGBoostFiller(pr, ds, db) if e != nil { return e diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go new file mode 100644 index 0000000000..3822ae7e67 --- /dev/null +++ b/sql/codegen_xgboost.go @@ -0,0 +1,98 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "io" + "text/template" +) + +type xgbTrainConfig struct { + NumBoostRound int `json:"num_boost_round,omitempty"` + Maximize bool `json:"maximize,omitempty"` +} + +type xgbFiller struct { + IsTrain bool + TrainingDatasetSQL string + ValidationDatasetSQL string + TrainCfg *xgbTrainConfig + Features []*featureMeta + Label *featureMeta + Save string + ParamsCfgJSON string + TrainCfgJSON string + *connectionConfig +} + +func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*xgbFiller, error) { + var err error + training, validation := trainingAndValidationDataset(pr, ds) + r := &xgbFiller{ + IsTrain: pr.train, + TrainingDatasetSQL: training, + ValidationDatasetSQL: validation, + Save: pr.save, + } + // TODO(Yancey1989): fill the train_args and parameters by WITH statment + r.TrainCfgJSON = "" + r.ParamsCfgJSON = "" + + if r.connectionConfig, err = newConnectionConfig(db); err != nil { + return nil, err + } + + for _, columns := range pr.columns { + feaCols, colSpecs, err := resolveTrainColumns(&columns) + if err != nil { + return nil, err + } + if len(colSpecs) != 0 { + return nil, fmt.Errorf("newXGBoostFiller doesn't support DENSE/SPARSE") + } + for _, col := range feaCols { + fm := &featureMeta{ + FeatureName: col.GetKey(), + Dtype: col.GetDtype(), + Delimiter: col.GetDelimiter(), + InputShape: col.GetInputShape(), + IsSparse: false, + } + r.Features = append(r.Features, fm) + } + } + r.Label = &featureMeta{ + FeatureName: pr.label, + Dtype: "int32", + Delimiter: ",", + InputShape: "[1]", + IsSparse: false, + } + + return r, nil +} + +func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { + r, e := newXGBFiller(pr, ds, fts, db) + if e != nil { + return e + } + if pr.train { + return xgbTrainTemplate.Execute(w, r) + } + return fmt.Errorf("xgboost prediction codegen has not been implemented") +} + +var xgbTrainTemplate = template.Must(template.New("codegenXGBTrain").Parse(xgbTrainTemplateText)) diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go new file mode 100644 index 0000000000..b1aa40351c --- /dev/null +++ b/sql/codegen_xgboost_test.go @@ -0,0 +1,25 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +const testXGBoostTrainSelectIris = ` +SELECT * +FROM iris.train +TRAIN xgb.multi.softprob +WITH + train.num_boost_round = 30 +COLUMN sepal_length, sepal_width, petal_length, petal_width +LABEL class +INTO sqlflow_models.my_xgboost_model; +` diff --git a/sql/executor.go b/sql/executor.go index 17187c56dd..c7c7a60011 100644 --- a/sql/executor.go +++ b/sql/executor.go @@ -387,8 +387,15 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri var program bytes.Buffer if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { // TODO(sperlingxx): write a separate train pipeline for ant-xgboost to support remote mode - if e := genXG(&program, tr, ds, fts, db); e != nil { - return fmt.Errorf("genXG %v", e) + if e := genAntXGBoost(&program, tr, ds, fts, db); e != nil { + return fmt.Errorf("genAntXGBoost %v", e) + } + } else if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGB.`) { + // FIXME(Yancey1989): it's a temporary solution, just for the unit test, we perfer to distinguish + // xgboost and ant-xgboost with env SQLFLOW_WITH_ANTXGBOOST, + // issue: https://github.com/sql-machine-learning/sqlflow/issues/758 + if e := genXGBoost(&program, tr, ds, fts, db); e != nil { + return fmt.Errorf("GenXGBoost %v", e) } } else { if e := genTF(&program, tr, ds, fts, db); e != nil { @@ -453,8 +460,8 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin var buf bytes.Buffer if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { // TODO(sperlingxx): write a separate pred pipeline for ant-xgboost to support remote mode - if e := genXG(&buf, pr, nil, fts, db); e != nil { - return fmt.Errorf("genXG %v", e) + if e := genAntXGBoost(&buf, pr, nil, fts, db); e != nil { + return fmt.Errorf("genAntXGBoost %v", e) } } else { if e := genTF(&buf, pr, nil, fts, db); e != nil { diff --git a/sql/executor_test.go b/sql/executor_test.go index bf65b9d548..a360be58e8 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -88,6 +88,16 @@ func TestExecutorTrainAnalyzePredictAntXGBoost(t *testing.T) { }) } +func TestExecutorTrainXGBoost(t *testing.T) { + a := assert.New(t) + modelDir := "" + a.NotPanics(func() { + stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil) + a.True(goodStream(stream.ReadAll())) + + }) +} + func TestExecutorTrainAndPredictDNN(t *testing.T) { a := assert.New(t) modelDir := "" diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go new file mode 100644 index 0000000000..7b3776f900 --- /dev/null +++ b/sql/template_xgboost.go @@ -0,0 +1,82 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +const xgbTrainTemplateText = ` +import xgboost as xgb +from sqlflow_submitter.db import connect, db_generator + +driver="{{.Driver}}" + +{{if ne .Database ""}} +database="{{.Database}}" +{{else}} +database="" +{{end}} + +session_cfg = {} +{{ range $k, $v := .Session }} +session_cfg["{{$k}}"] = "{{$v}}" +{{end}} + +{{if ne .TrainCfgJSON ""}} +train_args = {{.TrainCfgJSON}} +{{else}} +train_args = {} +{{end}} + +{{if ne .ParamsCfgJSON ""}} +params = {{.ParamsCfgJSON}} +{{else}} +params = {} +{{end}} + +feature_column_names = [{{range .Features}} +"{{.FeatureName}}", +{{end}}] + +{{/* Convert go side featureSpec to python dict for input_fn */}} +feature_specs = dict() +{{ range $value := .Features }} +feature_specs["{{$value.FeatureName}}"] = { + "feature_name": "{{$value.FeatureName}}", + "dtype": "{{$value.Dtype}}", + "delimiter": "{{$value.Delimiter}}", + "shape": {{$value.InputShape}}, + "is_sparse": "{{$value.IsSparse}}" == "true" +} +{{end}} + + + +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") + +def xgb_dataset(fn, dataset_sql): + gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Label.FeatureName}}", feature_specs) + with open(fn, 'w') as f: + for item in gen(): + features, label = item + row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] + f.write("\t".join(row_data) + "\n") + # TODO(yancey1989): genearte group and weight text file if necessary + return xgb.DMatrix(fn) + +dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}") +dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}") + +#TODO(Yancey1989): specify the eval metrics by WITH statement in SQL +train_args["evals"] = [(dtest, "auc")] +bst = xgb.train(params, dtrain, **train_args) +bst.save_model("{{.Save}}") +`