From e359497c94725cda8cd82c47b3103fa061adbdda Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Sep 2019 13:49:47 +0800 Subject: [PATCH 1/6] initialize xgboost codegen --- sql/codegen_ant_xgboost.go | 2 +- sql/codegen_xgboost.go | 116 +++++++++++++++++++++++++++++++++ sql/codegen_xgboost_test.go | 25 +++++++ sql/executor.go | 15 +++-- sql/executor_test.go | 10 +++ sql/expression_resolver_xgb.go | 64 ++++++++++++++++++ sql/template_xgboost.go | 82 +++++++++++++++++++++++ 7 files changed, 309 insertions(+), 5 deletions(-) create mode 100644 sql/codegen_xgboost.go create mode 100644 sql/codegen_xgboost_test.go create mode 100644 sql/expression_resolver_xgb.go create mode 100644 sql/template_xgboost.go diff --git a/sql/codegen_ant_xgboost.go b/sql/codegen_ant_xgboost.go index 52c57a516e..cb10777ab6 100644 --- a/sql/codegen_ant_xgboost.go +++ b/sql/codegen_ant_xgboost.go @@ -790,7 +790,7 @@ func xgCreatePredictionTable(pr *extendedSelect, r *antXGBoostFiller, db *DB) er return nil } -func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { +func genAntXGboost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { r, e := newAntXGBoostFiller(pr, ds, db) if e != nil { return e diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go new file mode 100644 index 0000000000..879a7ea68c --- /dev/null +++ b/sql/codegen_xgboost.go @@ -0,0 +1,116 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "io" + "text/template" +) + +type xgbTrainConfig struct { + NumBoostRound int `json:"num_boost_round,omitempty"` + Maximize bool `json:"maximize,omitempty"` +} + +type xgbFiller struct { + IsTrain bool + TrainingDatasetSQL string + ValidationDatasetSQL string + TrainCfg *xgbTrainConfig + Features []*featureMeta + Label *featureMeta + ParamsCfgJSON string + TrainCfgJSON string + *connectionConfig +} + +func fillXGBTrainCfg(rt *resolvedXGBTrainClause) (*xgbTrainConfig, error) { + // TODO(Yancey1989): fill all the training control parameters + c := &xgbTrainConfig{ + NumBoostRound: rt.NumBoostRound, + Maximize: rt.Maximize, + } + return c, nil +} + +func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*xgbFiller, error) { + rt, err := resolveXGBTrainClause(&pr.trainClause) + training, validation := trainingAndValidationDataset(pr, ds) + if err != nil { + return nil, err + } + + trainCfg, err := fillXGBTrainCfg(rt) + if err != nil { + return nil, err + } + + r := &xgbFiller{ + IsTrain: pr.train, + TrainCfg: trainCfg, + TrainingDatasetSQL: training, + ValidationDatasetSQL: validation, + } + // TODO(Yancey1989): fill the train_args and parameters by WITH statment + r.TrainCfgJSON = "" + r.ParamsCfgJSON = "" + + if r.connectionConfig, err = newConnectionConfig(db); err != nil { + return nil, err + } + + for _, columns := range pr.columns { + feaCols, colSpecs, err := resolveTrainColumns(&columns) + if err != nil { + return nil, err + } + if len(colSpecs) != 0 { + return nil, fmt.Errorf("newFiller doesn't support DENSE/SPARSE") + } + for _, col := range feaCols { + fm := &featureMeta{ + FeatureName: col.GetKey(), + Dtype: col.GetDtype(), + Delimiter: col.GetDelimiter(), + InputShape: col.GetInputShape(), + IsSparse: false, + } + r.Features = append(r.Features, fm) + } + } + r.Label = &featureMeta{ + FeatureName: pr.label, + Dtype: "int32", + Delimiter: ",", + InputShape: "[1]", + IsSparse: false, + } + + return r, nil +} + +func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { + r, e := newXGBFiller(pr, ds, fts, db) + if e != nil { + return e + } + if pr.train { + fmt.Println(r.TrainCfgJSON) + return xgbTrainTemplate.Execute(w, r) + } + return fmt.Errorf("xgboost prediction codegen has not been implemented") +} + +var xgbTrainTemplate = template.Must(template.New("codegenXGBTrain").Parse(xgbTrainTemplateText)) diff --git a/sql/codegen_xgboost_test.go b/sql/codegen_xgboost_test.go new file mode 100644 index 0000000000..b1aa40351c --- /dev/null +++ b/sql/codegen_xgboost_test.go @@ -0,0 +1,25 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +const testXGBoostTrainSelectIris = ` +SELECT * +FROM iris.train +TRAIN xgb.multi.softprob +WITH + train.num_boost_round = 30 +COLUMN sepal_length, sepal_width, petal_length, petal_width +LABEL class +INTO sqlflow_models.my_xgboost_model; +` diff --git a/sql/executor.go b/sql/executor.go index 17187c56dd..ea9435eeb8 100644 --- a/sql/executor.go +++ b/sql/executor.go @@ -387,8 +387,15 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri var program bytes.Buffer if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { // TODO(sperlingxx): write a separate train pipeline for ant-xgboost to support remote mode - if e := genXG(&program, tr, ds, fts, db); e != nil { - return fmt.Errorf("genXG %v", e) + if e := genAntXGboost(&program, tr, ds, fts, db); e != nil { + return fmt.Errorf("genAntXGBoost %v", e) + } + } else if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGB.`) { + // FIXME(Yancey1989): it's a temporary solution, just for the unit test, we perfer to distinguish + // xgboost and ant-xgboost with env SQLFLOW_WITH_ANTXGBOOST, + // issue: https://github.com/sql-machine-learning/sqlflow/issues/758 + if e := genXGBoost(&program, tr, ds, fts, db); e != nil { + return fmt.Errorf("GenXGBoost %v", e) } } else { if e := genTF(&program, tr, ds, fts, db); e != nil { @@ -453,8 +460,8 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin var buf bytes.Buffer if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { // TODO(sperlingxx): write a separate pred pipeline for ant-xgboost to support remote mode - if e := genXG(&buf, pr, nil, fts, db); e != nil { - return fmt.Errorf("genXG %v", e) + if e := genAntXGboost(&buf, pr, nil, fts, db); e != nil { + return fmt.Errorf("genAntXGBoost %v", e) } } else { if e := genTF(&buf, pr, nil, fts, db); e != nil { diff --git a/sql/executor_test.go b/sql/executor_test.go index bf65b9d548..a360be58e8 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -88,6 +88,16 @@ func TestExecutorTrainAnalyzePredictAntXGBoost(t *testing.T) { }) } +func TestExecutorTrainXGBoost(t *testing.T) { + a := assert.New(t) + modelDir := "" + a.NotPanics(func() { + stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil) + a.True(goodStream(stream.ReadAll())) + + }) +} + func TestExecutorTrainAndPredictDNN(t *testing.T) { a := assert.New(t) modelDir := "" diff --git a/sql/expression_resolver_xgb.go b/sql/expression_resolver_xgb.go new file mode 100644 index 0000000000..d44cbcd41e --- /dev/null +++ b/sql/expression_resolver_xgb.go @@ -0,0 +1,64 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" +) + +type resolvedXGBTrainClause struct { + NumBoostRound int + Maximize bool + ParamsAttr map[string]*attribute +} + +func resolveXGBTrainClause(tc *trainClause) (*resolvedXGBTrainClause, error) { + attrs, err := resolveAttribute(&tc.trainAttrs) + if err != nil { + return nil, err + } + getIntAttr := func(key string, defaultValue int) int { + if p, ok := attrs[key]; ok { + strVal, _ := p.Value.(string) + intVal, err := strconv.Atoi(trimQuotes(strVal)) + defer delete(attrs, p.FullName) + if err == nil { + return intVal + } + fmt.Printf("ignore invalid %s=%s, default is %d", key, p.Value, defaultValue) + } + return defaultValue + } + getBoolAttr := func(key string, defaultValue bool, optional bool) bool { + if p, ok := attrs[key]; ok { + strVal, _ := p.Value.(string) + boolVal, err := strconv.ParseBool(trimQuotes(strVal)) + if !optional { + defer delete(attrs, p.FullName) + } + if err == nil { + return boolVal + } else if !optional { + fmt.Printf("ignore invalid %s=%s, default is %v", key, p.Value, defaultValue) + } + } + return defaultValue + } + return &resolvedXGBTrainClause{ + NumBoostRound: getIntAttr("train.num_boost_round", 10), + Maximize: getBoolAttr("train.maximize", false, true), + ParamsAttr: filter(attrs, "params", true), + }, nil +} diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go new file mode 100644 index 0000000000..8cdf256ccf --- /dev/null +++ b/sql/template_xgboost.go @@ -0,0 +1,82 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +const xgbTrainTemplateText = ` +import xgboost as xgb +from sqlflow_submitter.db import connect, db_generator + +driver="{{.Driver}}" + +{{if ne .Database ""}} +database="{{.Database}}" +{{else}} +database="" +{{end}} + +session_cfg = {} +{{ range $k, $v := .Session }} +session_cfg["{{$k}}"] = "{{$v}}" +{{end}} + +{{if ne .TrainCfgJSON ""}} +train_args = {{.TrainCfgJSON}} +{{else}} +train_args = {} +{{end}} + +{{if ne .ParamsCfgJSON ""}} +params = {{.ParamsCfgJSON}} +{{else}} +params = {} +{{end}} + +feature_column_names = [{{range .Features}} +"{{.FeatureName}}", +{{end}}] + +{{/* Convert go side featureSpec to python dict for input_fn */}} +feature_specs = dict() +{{ range $value := .Features }} +feature_specs["{{$value.FeatureName}}"] = { + "feature_name": "{{$value.FeatureName}}", + "dtype": "{{$value.Dtype}}", + "delimiter": "{{$value.Delimiter}}", + "shape": {{$value.InputShape}}, + "is_sparse": "{{$value.IsSparse}}" == "true" +} +{{end}} + + + +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") + +def xgb_dataset(fn, dataset_sql): + gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Label.FeatureName}}", feature_specs) + with open(fn, 'w') as f: + for item in gen(): + features, label = item + row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] + f.write("\t".join(row_data) + "\n") + # TODO(yancey1989): genearte group and weight text file if necessary + return xgb.DMatrix(fn) + +dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}") +dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}") + +//TODO(Yancey1989): specify the eval metrics by WITH statement in SQL +train_args["evals"] = [(dtest, "auc")] +bst = xgb.train(params, dtrain, **train_args) +bst.save_model() +` From 50e703161daf2f0808aacba5c1fab0bb1fad9522 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Sep 2019 13:50:38 +0800 Subject: [PATCH 2/6] initialize xgboost codegen --- sql/template_xgboost.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go index 8cdf256ccf..bf1d4fc7c5 100644 --- a/sql/template_xgboost.go +++ b/sql/template_xgboost.go @@ -75,7 +75,7 @@ def xgb_dataset(fn, dataset_sql): dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}") dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}") -//TODO(Yancey1989): specify the eval metrics by WITH statement in SQL +#TODO(Yancey1989): specify the eval metrics by WITH statement in SQL train_args["evals"] = [(dtest, "auc")] bst = xgb.train(params, dtrain, **train_args) bst.save_model() From e83530b431cc7ce64b5cc038632f250399c3c95b Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Sep 2019 14:24:34 +0800 Subject: [PATCH 3/6] init xgboost codegen --- sql/codegen_xgboost.go | 4 +++- sql/expression_resolver_xgb.go | 2 +- sql/template_xgboost.go | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 879a7ea68c..5fec5673f2 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -31,6 +31,7 @@ type xgbFiller struct { TrainCfg *xgbTrainConfig Features []*featureMeta Label *featureMeta + Save string ParamsCfgJSON string TrainCfgJSON string *connectionConfig @@ -62,6 +63,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db TrainCfg: trainCfg, TrainingDatasetSQL: training, ValidationDatasetSQL: validation, + Save: pr.save, } // TODO(Yancey1989): fill the train_args and parameters by WITH statment r.TrainCfgJSON = "" @@ -77,7 +79,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db return nil, err } if len(colSpecs) != 0 { - return nil, fmt.Errorf("newFiller doesn't support DENSE/SPARSE") + return nil, fmt.Errorf("newXGBoostFiller doesn't support DENSE/SPARSE") } for _, col := range feaCols { fm := &featureMeta{ diff --git a/sql/expression_resolver_xgb.go b/sql/expression_resolver_xgb.go index d44cbcd41e..28b102eb9f 100644 --- a/sql/expression_resolver_xgb.go +++ b/sql/expression_resolver_xgb.go @@ -56,9 +56,9 @@ func resolveXGBTrainClause(tc *trainClause) (*resolvedXGBTrainClause, error) { } return defaultValue } + return &resolvedXGBTrainClause{ NumBoostRound: getIntAttr("train.num_boost_round", 10), Maximize: getBoolAttr("train.maximize", false, true), - ParamsAttr: filter(attrs, "params", true), }, nil } diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go index bf1d4fc7c5..7b3776f900 100644 --- a/sql/template_xgboost.go +++ b/sql/template_xgboost.go @@ -78,5 +78,5 @@ dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}") #TODO(Yancey1989): specify the eval metrics by WITH statement in SQL train_args["evals"] = [(dtest, "auc")] bst = xgb.train(params, dtrain, **train_args) -bst.save_model() +bst.save_model("{{.Save}}") ` From 545645ebcb3feb4111de43285a82f0d309600427 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Sep 2019 14:50:56 +0800 Subject: [PATCH 4/6] fix typo --- sql/codegen_ant_xgboost.go | 2 +- sql/executor.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/codegen_ant_xgboost.go b/sql/codegen_ant_xgboost.go index da202f0377..dcc2302da9 100644 --- a/sql/codegen_ant_xgboost.go +++ b/sql/codegen_ant_xgboost.go @@ -795,7 +795,7 @@ func xgCreatePredictionTable(pr *extendedSelect, r *antXGBoostFiller, db *DB) er return nil } -func genAntXGboost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { +func genAntXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { r, e := newAntXGBoostFiller(pr, ds, db) if e != nil { return e diff --git a/sql/executor.go b/sql/executor.go index ea9435eeb8..c7c7a60011 100644 --- a/sql/executor.go +++ b/sql/executor.go @@ -387,7 +387,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri var program bytes.Buffer if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { // TODO(sperlingxx): write a separate train pipeline for ant-xgboost to support remote mode - if e := genAntXGboost(&program, tr, ds, fts, db); e != nil { + if e := genAntXGBoost(&program, tr, ds, fts, db); e != nil { return fmt.Errorf("genAntXGBoost %v", e) } } else if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGB.`) { @@ -460,7 +460,7 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin var buf bytes.Buffer if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { // TODO(sperlingxx): write a separate pred pipeline for ant-xgboost to support remote mode - if e := genAntXGboost(&buf, pr, nil, fts, db); e != nil { + if e := genAntXGBoost(&buf, pr, nil, fts, db); e != nil { return fmt.Errorf("genAntXGBoost %v", e) } } else { From 0b1d9a3a94ffbf02bb5df7a8874eecb72873a498 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Sep 2019 19:53:03 +0800 Subject: [PATCH 5/6] remove unused code --- sql/codegen_xgboost.go | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 5fec5673f2..acf03e11bd 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -109,7 +109,6 @@ func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fie return e } if pr.train { - fmt.Println(r.TrainCfgJSON) return xgbTrainTemplate.Execute(w, r) } return fmt.Errorf("xgboost prediction codegen has not been implemented") From 60ca03022ba26383461a42ea700acd734af100bc Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 4 Sep 2019 13:48:04 +0800 Subject: [PATCH 6/6] remove xgb resolver --- sql/codegen_xgboost.go | 21 +---------- sql/expression_resolver_xgb.go | 64 ---------------------------------- 2 files changed, 1 insertion(+), 84 deletions(-) delete mode 100644 sql/expression_resolver_xgb.go diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index acf03e11bd..3822ae7e67 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -37,30 +37,11 @@ type xgbFiller struct { *connectionConfig } -func fillXGBTrainCfg(rt *resolvedXGBTrainClause) (*xgbTrainConfig, error) { - // TODO(Yancey1989): fill all the training control parameters - c := &xgbTrainConfig{ - NumBoostRound: rt.NumBoostRound, - Maximize: rt.Maximize, - } - return c, nil -} - func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (*xgbFiller, error) { - rt, err := resolveXGBTrainClause(&pr.trainClause) + var err error training, validation := trainingAndValidationDataset(pr, ds) - if err != nil { - return nil, err - } - - trainCfg, err := fillXGBTrainCfg(rt) - if err != nil { - return nil, err - } - r := &xgbFiller{ IsTrain: pr.train, - TrainCfg: trainCfg, TrainingDatasetSQL: training, ValidationDatasetSQL: validation, Save: pr.save, diff --git a/sql/expression_resolver_xgb.go b/sql/expression_resolver_xgb.go deleted file mode 100644 index 28b102eb9f..0000000000 --- a/sql/expression_resolver_xgb.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strconv" -) - -type resolvedXGBTrainClause struct { - NumBoostRound int - Maximize bool - ParamsAttr map[string]*attribute -} - -func resolveXGBTrainClause(tc *trainClause) (*resolvedXGBTrainClause, error) { - attrs, err := resolveAttribute(&tc.trainAttrs) - if err != nil { - return nil, err - } - getIntAttr := func(key string, defaultValue int) int { - if p, ok := attrs[key]; ok { - strVal, _ := p.Value.(string) - intVal, err := strconv.Atoi(trimQuotes(strVal)) - defer delete(attrs, p.FullName) - if err == nil { - return intVal - } - fmt.Printf("ignore invalid %s=%s, default is %d", key, p.Value, defaultValue) - } - return defaultValue - } - getBoolAttr := func(key string, defaultValue bool, optional bool) bool { - if p, ok := attrs[key]; ok { - strVal, _ := p.Value.(string) - boolVal, err := strconv.ParseBool(trimQuotes(strVal)) - if !optional { - defer delete(attrs, p.FullName) - } - if err == nil { - return boolVal - } else if !optional { - fmt.Printf("ignore invalid %s=%s, default is %v", key, p.Value, defaultValue) - } - } - return defaultValue - } - - return &resolvedXGBTrainClause{ - NumBoostRound: getIntAttr("train.num_boost_round", 10), - Maximize: getBoolAttr("train.maximize", false, true), - }, nil -}