Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sql/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ func genTF(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTyp
if pr.train {
return tfTrainTemplate.Execute(w, r)
}
if e := createPredictionTable(pr, db); e != nil {
return fmt.Errorf("failed to create prediction table: %v", e)
}
return tfPredTemplate.Execute(w, r)
}

Expand Down
15 changes: 15 additions & 0 deletions sql/codegen_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -765,6 +766,20 @@ func xgCreatePredictionTable(pr *extendedSelect, r *xgboostFiller, db *DB) error
return nil
}

func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
r, e := newXGBoostFiller(pr, ds, fts, db)
if e != nil {
return e
}
if pr.train {
return xgTemplate.Execute(w, r)
}
if e := xgCreatePredictionTable(pr, r, db); e != nil {
return fmt.Errorf("failed to create prediction table: %v", e)
}
return xgTemplate.Execute(w, r)
}

var xgTemplate = template.Must(template.New("codegenXG").Parse(xgTemplateText))

const xgTemplateText = `
Expand Down
85 changes: 57 additions & 28 deletions sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,26 @@ func (cw *logChanWriter) Close() {
}
}

func buildFiller(es *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (filler interface{}, e error) {
// trainAndValDataset only work in train mode
var dataset *trainAndValDataset
if es.train {
dataset = ds
}
if strings.HasPrefix(strings.ToUpper(es.estimator), `XGBOOST.`) {
filler, e = newXGBoostFiller(es, dataset, fts, db)
if e != nil {
e = fmt.Errorf("failed to build XGBoostFiller: %v", e)
}
} else {
filler, e = newFiller(es, dataset, fts, db)
if e != nil {
e = fmt.Errorf("failed to build TensorFlowFiller: %v", e)
}
}
return filler, e
}

func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir string, slct string, ds *trainAndValDataset) error {
fts, e := verify(tr, db)
if e != nil {
Expand All @@ -374,12 +394,8 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri

var program bytes.Buffer
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
// FIXME(sperlingxx): write a separate train pipeline for xgboost to support remote mode
filler, e := newXGBoostFiller(tr, ds, fts, db)
if e != nil {
return fmt.Errorf("genXG %v", e)
}
if e := xgTemplate.Execute(&program, filler); e != nil {
// TODO(sperlingxx): write a separate train pipeline for xgboost to support remote mode
if e := genXG(&program, tr, ds, fts, db); e != nil {
return fmt.Errorf("genXG %v", e)
}
} else {
Expand All @@ -404,7 +420,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
return m.save(db, tr.save)
}

func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string) error {
func loadModelMeta(pr *extendedSelect, db *DB, cwd string, modelDir string) (*extendedSelect, fieldTypes, error) {
var m *model
var e error
if modelDir != "" {
Expand All @@ -413,43 +429,42 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
m, e = load(db, pr.model, cwd)
}
if e != nil {
return fmt.Errorf("load %v", e)
return nil, nil, fmt.Errorf("load %v", e)
}

// Parse the training SELECT statement used to train
// the model for the prediction.
tr, e := newParser().Parse(m.TrainSelect)
if e != nil {
return fmt.Errorf("parse: TrainSelect %v raise %v", m.TrainSelect, e)
return nil, nil, fmt.Errorf("parse: TrainSelect %v raise %v", m.TrainSelect, e)
}

if e := verifyColumnNameAndType(tr, pr, db); e != nil {
return fmt.Errorf("verifyColumnNameAndType: %v", e)
return nil, nil, fmt.Errorf("verifyColumnNameAndType: %v", e)
}

pr.trainClause = tr.trainClause
fts, e := verify(pr, db)
if e != nil {
return fmt.Errorf("verify: %v", e)
return nil, nil, fmt.Errorf("verify: %v", e)
}

return pr, fts, nil
}

func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string) error {
pr, fts, e := loadModelMeta(pr, db, cwd, modelDir)
if e != nil {
return fmt.Errorf("loadModelMeta %v", e)
}

var buf bytes.Buffer
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
// FIXME(sperlingxx): write a separate pred pipeline for xgboost to support remote mode
filler, e := newXGBoostFiller(pr, nil, fts, db)
if e != nil {
return fmt.Errorf("genXG %v", e)
}
if e := xgCreatePredictionTable(pr, filler, db); e != nil {
return fmt.Errorf("genXG %v", e)
}
if e := xgTemplate.Execute(&buf, filler); e != nil {
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
// TODO(sperlingxx): write a separate pred pipeline for xgboost to support remote mode
if e := genXG(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genXG %v", e)
}
} else {
if e := createPredictionTable(tr, pr, db); e != nil {
return fmt.Errorf("createPredictionTable: %v", e)
}
if e := genTF(&buf, pr, nil, fts, db); e != nil {
return fmt.Errorf("genTF %v", e)
}
Expand All @@ -466,6 +481,20 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
}

func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir string) error {
//pr, fts, e := loadModelMeta(es, db, cwd, modelDir)
//if e != nil {
// return fmt.Errorf("loadModelMeta %v", e)
//}
//filler, e := buildFiller(pr, nil, fts, db)
//if e != nil {
// return e
//}
//switch filler.(type) {
//case *xgboostFiller:
//
//default:
//
//}
cmd := exec.Command("python", "-u")
cmd.Dir = cwd
cmd.Stdin = strings.NewReader(analyzeTemplateText)
Expand Down Expand Up @@ -493,7 +522,7 @@ func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir st

// Create prediction table with appropriate column type.
// If prediction table already exists, it will be overwritten.
func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) error {
func createPredictionTable(predParsed *extendedSelect, db *DB) error {
tableName, columnName, e := parseTableColumn(predParsed.into)
if e != nil {
return fmt.Errorf("invalid predParsed.into, %v", e)
Expand All @@ -504,14 +533,14 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro
return fmt.Errorf("failed executing %s: %q", dropStmt, e)
}

fts, e := verify(trainParsed, db)
fts, e := verify(predParsed, db)
if e != nil {
return e
}

var b bytes.Buffer
fmt.Fprintf(&b, "create table %s (", tableName)
for _, c := range trainParsed.columns["feature_columns"] {
for _, c := range predParsed.columns["feature_columns"] {
name, err := getExpressionFieldName(c)
if err != nil {
return err
Expand All @@ -526,7 +555,7 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro
}
fmt.Fprintf(&b, "%s %s, ", name, stype)
}
typ, _ := fts.get(trainParsed.label)
typ, _ := fts.get(predParsed.label)
stype, e := universalizeColumnType(db.driverName, typ)
if e != nil {
return e
Expand Down
3 changes: 2 additions & 1 deletion sql/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ func TestCreatePredictionTable(t *testing.T) {
a.NoError(e)
predParsed, e := newParser().Parse(testPredictSelectIris)
a.NoError(e)
a.NoError(createPredictionTable(trainParsed, predParsed, testDB))
predParsed.trainClause = trainParsed.trainClause
a.NoError(createPredictionTable(predParsed, testDB))
}

func TestIsQuery(t *testing.T) {
Expand Down