From dbe6a5e38f6e52ac055610d8dc6dbff078f2e27f Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 30 Aug 2019 19:25:51 +0800 Subject: [PATCH 1/3] refine codegen and codegen-prepare for analyze --- sql/codegen.go | 3 ++ sql/codegen_xgboost.go | 15 ++++++++ sql/executor.go | 85 ++++++++++++++++++++++++++++-------------- 3 files changed, 75 insertions(+), 28 deletions(-) diff --git a/sql/codegen.go b/sql/codegen.go index 17e500dcd5..5a053b3803 100644 --- a/sql/codegen.go +++ b/sql/codegen.go @@ -271,6 +271,9 @@ func genTF(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTyp if pr.train { return tfTrainTemplate.Execute(w, r) } + if e := createPredictionTable(pr, db); e != nil { + return fmt.Errorf("failed to create prediction table: %v", e) + } return tfPredTemplate.Execute(w, r) } diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index a986ba3426..72a6e7a794 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -17,6 +17,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "os" "strconv" "strings" @@ -765,6 +766,20 @@ func xgCreatePredictionTable(pr *extendedSelect, r *xgboostFiller, db *DB) error return nil } +func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error { + r, e := newXGBoostFiller(pr, ds, fts, db) + if e != nil { + return e + } + if pr.train { + return xgTemplate.Execute(w, r) + } + if e := xgCreatePredictionTable(pr, r, db); e != nil { + return fmt.Errorf("failed to create prediction table: %v", e) + } + return xgTemplate.Execute(w, r) +} + var xgTemplate = template.Must(template.New("codegenXG").Parse(xgTemplateText)) const xgTemplateText = ` diff --git a/sql/executor.go b/sql/executor.go index f0894be187..fe01ded72b 100644 --- a/sql/executor.go +++ b/sql/executor.go @@ -366,6 +366,26 @@ func (cw *logChanWriter) Close() { } } +func buildFiller(es *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (filler interface{}, e error) { + // trainAndValDataset only work in train mode + var dataset *trainAndValDataset = nil + if es.train { + dataset = ds + } + if strings.HasPrefix(strings.ToUpper(es.estimator), `XGBOOST.`) { + filler, e = newXGBoostFiller(es, dataset, fts, db) + if e != nil { + e = fmt.Errorf("failed to build XGBoostFiller: %v", e) + } + } else { + filler, e = newFiller(es, dataset, fts, db) + if e != nil { + e = fmt.Errorf("failed to build TensorFlowFiller: %v", e) + } + } + return filler, e +} + func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir string, slct string, ds *trainAndValDataset) error { fts, e := verify(tr, db) if e != nil { @@ -374,12 +394,8 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri var program bytes.Buffer if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { - // FIXME(sperlingxx): write a separate train pipeline for xgboost to support remote mode - filler, e := newXGBoostFiller(tr, ds, fts, db) - if e != nil { - return fmt.Errorf("genXG %v", e) - } - if e := xgTemplate.Execute(&program, filler); e != nil { + // TODO(sperlingxx): write a separate train pipeline for xgboost to support remote mode + if e := genXG(&program, tr, ds, fts, db); e != nil { return fmt.Errorf("genXG %v", e) } } else { @@ -404,7 +420,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri return m.save(db, tr.save) } -func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string) error { +func loadModelMeta(pr *extendedSelect, db *DB, cwd string, modelDir string) (*extendedSelect, fieldTypes, error) { var m *model var e error if modelDir != "" { @@ -413,43 +429,42 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin m, e = load(db, pr.model, cwd) } if e != nil { - return fmt.Errorf("load %v", e) + return nil, nil, fmt.Errorf("load %v", e) } // Parse the training SELECT statement used to train // the model for the prediction. tr, e := newParser().Parse(m.TrainSelect) if e != nil { - return fmt.Errorf("parse: TrainSelect %v raise %v", m.TrainSelect, e) + return nil, nil, fmt.Errorf("parse: TrainSelect %v raise %v", m.TrainSelect, e) } if e := verifyColumnNameAndType(tr, pr, db); e != nil { - return fmt.Errorf("verifyColumnNameAndType: %v", e) + return nil, nil, fmt.Errorf("verifyColumnNameAndType: %v", e) } pr.trainClause = tr.trainClause fts, e := verify(pr, db) if e != nil { - return fmt.Errorf("verify: %v", e) + return nil, nil, fmt.Errorf("verify: %v", e) + } + + return pr, fts, nil +} + +func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string) error { + pr, fts, e := loadModelMeta(pr, db, cwd, modelDir) + if e != nil { + return fmt.Errorf("loadModelMeta %v", e) } var buf bytes.Buffer - if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { - // FIXME(sperlingxx): write a separate pred pipeline for xgboost to support remote mode - filler, e := newXGBoostFiller(pr, nil, fts, db) - if e != nil { - return fmt.Errorf("genXG %v", e) - } - if e := xgCreatePredictionTable(pr, filler, db); e != nil { - return fmt.Errorf("genXG %v", e) - } - if e := xgTemplate.Execute(&buf, filler); e != nil { + if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { + // TODO(sperlingxx): write a separate pred pipeline for xgboost to support remote mode + if e := genXG(&buf, pr, nil, fts, db); e != nil { return fmt.Errorf("genXG %v", e) } } else { - if e := createPredictionTable(tr, pr, db); e != nil { - return fmt.Errorf("createPredictionTable: %v", e) - } if e := genTF(&buf, pr, nil, fts, db); e != nil { return fmt.Errorf("genTF %v", e) } @@ -466,6 +481,20 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin } func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir string) error { + //pr, fts, e := loadModelMeta(es, db, cwd, modelDir) + //if e != nil { + // return fmt.Errorf("loadModelMeta %v", e) + //} + //filler, e := buildFiller(pr, nil, fts, db) + //if e != nil { + // return e + //} + //switch filler.(type) { + //case *xgboostFiller: + // + //default: + // + //} cmd := exec.Command("python", "-u") cmd.Dir = cwd cmd.Stdin = strings.NewReader(analyzeTemplateText) @@ -493,7 +522,7 @@ func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir st // Create prediction table with appropriate column type. // If prediction table already exists, it will be overwritten. -func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) error { +func createPredictionTable(predParsed *extendedSelect, db *DB) error { tableName, columnName, e := parseTableColumn(predParsed.into) if e != nil { return fmt.Errorf("invalid predParsed.into, %v", e) @@ -504,14 +533,14 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro return fmt.Errorf("failed executing %s: %q", dropStmt, e) } - fts, e := verify(trainParsed, db) + fts, e := verify(predParsed, db) if e != nil { return e } var b bytes.Buffer fmt.Fprintf(&b, "create table %s (", tableName) - for _, c := range trainParsed.columns["feature_columns"] { + for _, c := range predParsed.columns["feature_columns"] { name, err := getExpressionFieldName(c) if err != nil { return err @@ -526,7 +555,7 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro } fmt.Fprintf(&b, "%s %s, ", name, stype) } - typ, _ := fts.get(trainParsed.label) + typ, _ := fts.get(predParsed.label) stype, e := universalizeColumnType(db.driverName, typ) if e != nil { return e From 1d6be444f6de8d2330aa54036ab7f14557534e72 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 30 Aug 2019 19:26:44 +0800 Subject: [PATCH 2/3] fix --- sql/executor_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/executor_test.go b/sql/executor_test.go index 91598f73ae..1dc22a2db9 100644 --- a/sql/executor_test.go +++ b/sql/executor_test.go @@ -181,7 +181,8 @@ func TestCreatePredictionTable(t *testing.T) { a.NoError(e) predParsed, e := newParser().Parse(testPredictSelectIris) a.NoError(e) - a.NoError(createPredictionTable(trainParsed, predParsed, testDB)) + predParsed.trainClause = trainParsed.trainClause + a.NoError(createPredictionTable(predParsed, testDB)) } func TestIsQuery(t *testing.T) { From 1f15f36476632e5ae69e3800e98c0bb343111a35 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 30 Aug 2019 20:26:50 +0800 Subject: [PATCH 3/3] fix golint --- sql/executor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/executor.go b/sql/executor.go index fe01ded72b..fdaab0f2e3 100644 --- a/sql/executor.go +++ b/sql/executor.go @@ -368,7 +368,7 @@ func (cw *logChanWriter) Close() { func buildFiller(es *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (filler interface{}, e error) { // trainAndValDataset only work in train mode - var dataset *trainAndValDataset = nil + var dataset *trainAndValDataset if es.train { dataset = ds }