Skip to content

Commit f51a036

Browse files
sperlingxxweiguoz
authored andcommitted
refine runExtendedSQL (#750)
* refine codegen and codegen-prepare for analyze
1 parent fd806cb commit f51a036

File tree

4 files changed

+77
-29
lines changed

4 files changed

+77
-29
lines changed

sql/codegen.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ func genTF(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTyp
271271
if pr.train {
272272
return tfTrainTemplate.Execute(w, r)
273273
}
274+
if e := createPredictionTable(pr, db); e != nil {
275+
return fmt.Errorf("failed to create prediction table: %v", e)
276+
}
274277
return tfPredTemplate.Execute(w, r)
275278
}
276279

sql/codegen_xgboost.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"bytes"
1818
"encoding/json"
1919
"fmt"
20+
"io"
2021
"os"
2122
"strconv"
2223
"strings"
@@ -765,6 +766,20 @@ func xgCreatePredictionTable(pr *extendedSelect, r *xgboostFiller, db *DB) error
765766
return nil
766767
}
767768

769+
func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
770+
r, e := newXGBoostFiller(pr, ds, fts, db)
771+
if e != nil {
772+
return e
773+
}
774+
if pr.train {
775+
return xgTemplate.Execute(w, r)
776+
}
777+
if e := xgCreatePredictionTable(pr, r, db); e != nil {
778+
return fmt.Errorf("failed to create prediction table: %v", e)
779+
}
780+
return xgTemplate.Execute(w, r)
781+
}
782+
768783
var xgTemplate = template.Must(template.New("codegenXG").Parse(xgTemplateText))
769784

770785
const xgTemplateText = `

sql/executor.go

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,26 @@ func (cw *logChanWriter) Close() {
366366
}
367367
}
368368

369+
func buildFiller(es *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) (filler interface{}, e error) {
370+
// trainAndValDataset only work in train mode
371+
var dataset *trainAndValDataset
372+
if es.train {
373+
dataset = ds
374+
}
375+
if strings.HasPrefix(strings.ToUpper(es.estimator), `XGBOOST.`) {
376+
filler, e = newXGBoostFiller(es, dataset, fts, db)
377+
if e != nil {
378+
e = fmt.Errorf("failed to build XGBoostFiller: %v", e)
379+
}
380+
} else {
381+
filler, e = newFiller(es, dataset, fts, db)
382+
if e != nil {
383+
e = fmt.Errorf("failed to build TensorFlowFiller: %v", e)
384+
}
385+
}
386+
return filler, e
387+
}
388+
369389
func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir string, slct string, ds *trainAndValDataset) error {
370390
fts, e := verify(tr, db)
371391
if e != nil {
@@ -374,12 +394,8 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
374394

375395
var program bytes.Buffer
376396
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
377-
// FIXME(sperlingxx): write a separate train pipeline for xgboost to support remote mode
378-
filler, e := newXGBoostFiller(tr, ds, fts, db)
379-
if e != nil {
380-
return fmt.Errorf("genXG %v", e)
381-
}
382-
if e := xgTemplate.Execute(&program, filler); e != nil {
397+
// TODO(sperlingxx): write a separate train pipeline for xgboost to support remote mode
398+
if e := genXG(&program, tr, ds, fts, db); e != nil {
383399
return fmt.Errorf("genXG %v", e)
384400
}
385401
} else {
@@ -404,7 +420,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
404420
return m.save(db, tr.save)
405421
}
406422

407-
func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string) error {
423+
func loadModelMeta(pr *extendedSelect, db *DB, cwd string, modelDir string) (*extendedSelect, fieldTypes, error) {
408424
var m *model
409425
var e error
410426
if modelDir != "" {
@@ -413,43 +429,42 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
413429
m, e = load(db, pr.model, cwd)
414430
}
415431
if e != nil {
416-
return fmt.Errorf("load %v", e)
432+
return nil, nil, fmt.Errorf("load %v", e)
417433
}
418434

419435
// Parse the training SELECT statement used to train
420436
// the model for the prediction.
421437
tr, e := newParser().Parse(m.TrainSelect)
422438
if e != nil {
423-
return fmt.Errorf("parse: TrainSelect %v raise %v", m.TrainSelect, e)
439+
return nil, nil, fmt.Errorf("parse: TrainSelect %v raise %v", m.TrainSelect, e)
424440
}
425441

426442
if e := verifyColumnNameAndType(tr, pr, db); e != nil {
427-
return fmt.Errorf("verifyColumnNameAndType: %v", e)
443+
return nil, nil, fmt.Errorf("verifyColumnNameAndType: %v", e)
428444
}
429445

430446
pr.trainClause = tr.trainClause
431447
fts, e := verify(pr, db)
432448
if e != nil {
433-
return fmt.Errorf("verify: %v", e)
449+
return nil, nil, fmt.Errorf("verify: %v", e)
450+
}
451+
452+
return pr, fts, nil
453+
}
454+
455+
func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string) error {
456+
pr, fts, e := loadModelMeta(pr, db, cwd, modelDir)
457+
if e != nil {
458+
return fmt.Errorf("loadModelMeta %v", e)
434459
}
435460

436461
var buf bytes.Buffer
437-
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
438-
// FIXME(sperlingxx): write a separate pred pipeline for xgboost to support remote mode
439-
filler, e := newXGBoostFiller(pr, nil, fts, db)
440-
if e != nil {
441-
return fmt.Errorf("genXG %v", e)
442-
}
443-
if e := xgCreatePredictionTable(pr, filler, db); e != nil {
444-
return fmt.Errorf("genXG %v", e)
445-
}
446-
if e := xgTemplate.Execute(&buf, filler); e != nil {
462+
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
463+
// TODO(sperlingxx): write a separate pred pipeline for xgboost to support remote mode
464+
if e := genXG(&buf, pr, nil, fts, db); e != nil {
447465
return fmt.Errorf("genXG %v", e)
448466
}
449467
} else {
450-
if e := createPredictionTable(tr, pr, db); e != nil {
451-
return fmt.Errorf("createPredictionTable: %v", e)
452-
}
453468
if e := genTF(&buf, pr, nil, fts, db); e != nil {
454469
return fmt.Errorf("genTF %v", e)
455470
}
@@ -466,6 +481,20 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
466481
}
467482

468483
func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir string) error {
484+
//pr, fts, e := loadModelMeta(es, db, cwd, modelDir)
485+
//if e != nil {
486+
// return fmt.Errorf("loadModelMeta %v", e)
487+
//}
488+
//filler, e := buildFiller(pr, nil, fts, db)
489+
//if e != nil {
490+
// return e
491+
//}
492+
//switch filler.(type) {
493+
//case *xgboostFiller:
494+
//
495+
//default:
496+
//
497+
//}
469498
cmd := exec.Command("python", "-u")
470499
cmd.Dir = cwd
471500
cmd.Stdin = strings.NewReader(analyzeTemplateText)
@@ -493,7 +522,7 @@ func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir st
493522

494523
// Create prediction table with appropriate column type.
495524
// If prediction table already exists, it will be overwritten.
496-
func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) error {
525+
func createPredictionTable(predParsed *extendedSelect, db *DB) error {
497526
tableName, columnName, e := parseTableColumn(predParsed.into)
498527
if e != nil {
499528
return fmt.Errorf("invalid predParsed.into, %v", e)
@@ -504,14 +533,14 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro
504533
return fmt.Errorf("failed executing %s: %q", dropStmt, e)
505534
}
506535

507-
fts, e := verify(trainParsed, db)
536+
fts, e := verify(predParsed, db)
508537
if e != nil {
509538
return e
510539
}
511540

512541
var b bytes.Buffer
513542
fmt.Fprintf(&b, "create table %s (", tableName)
514-
for _, c := range trainParsed.columns["feature_columns"] {
543+
for _, c := range predParsed.columns["feature_columns"] {
515544
name, err := getExpressionFieldName(c)
516545
if err != nil {
517546
return err
@@ -526,7 +555,7 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro
526555
}
527556
fmt.Fprintf(&b, "%s %s, ", name, stype)
528557
}
529-
typ, _ := fts.get(trainParsed.label)
558+
typ, _ := fts.get(predParsed.label)
530559
stype, e := universalizeColumnType(db.driverName, typ)
531560
if e != nil {
532561
return e

sql/executor_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ func TestCreatePredictionTable(t *testing.T) {
181181
a.NoError(e)
182182
predParsed, e := newParser().Parse(testPredictSelectIris)
183183
a.NoError(e)
184-
a.NoError(createPredictionTable(trainParsed, predParsed, testDB))
184+
predParsed.trainClause = trainParsed.trainClause
185+
a.NoError(createPredictionTable(predParsed, testDB))
185186
}
186187

187188
func TestIsQuery(t *testing.T) {

0 commit comments

Comments
 (0)