@@ -366,6 +366,26 @@ func (cw *logChanWriter) Close() {
366366 }
367367}
368368
369+ func buildFiller (es * extendedSelect , ds * trainAndValDataset , fts fieldTypes , db * DB ) (filler interface {}, e error ) {
370+ // trainAndValDataset only work in train mode
371+ var dataset * trainAndValDataset
372+ if es .train {
373+ dataset = ds
374+ }
375+ if strings .HasPrefix (strings .ToUpper (es .estimator ), `XGBOOST.` ) {
376+ filler , e = newXGBoostFiller (es , dataset , fts , db )
377+ if e != nil {
378+ e = fmt .Errorf ("failed to build XGBoostFiller: %v" , e )
379+ }
380+ } else {
381+ filler , e = newFiller (es , dataset , fts , db )
382+ if e != nil {
383+ e = fmt .Errorf ("failed to build TensorFlowFiller: %v" , e )
384+ }
385+ }
386+ return filler , e
387+ }
388+
369389func train (wr * PipeWriter , tr * extendedSelect , db * DB , cwd string , modelDir string , slct string , ds * trainAndValDataset ) error {
370390 fts , e := verify (tr , db )
371391 if e != nil {
@@ -374,12 +394,8 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
374394
375395 var program bytes.Buffer
376396 if strings .HasPrefix (strings .ToUpper (tr .estimator ), `XGBOOST.` ) {
377- // FIXME(sperlingxx): write a separate train pipeline for xgboost to support remote mode
378- filler , e := newXGBoostFiller (tr , ds , fts , db )
379- if e != nil {
380- return fmt .Errorf ("genXG %v" , e )
381- }
382- if e := xgTemplate .Execute (& program , filler ); e != nil {
397+ // TODO(sperlingxx): write a separate train pipeline for xgboost to support remote mode
398+ if e := genXG (& program , tr , ds , fts , db ); e != nil {
383399 return fmt .Errorf ("genXG %v" , e )
384400 }
385401 } else {
@@ -404,7 +420,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
404420 return m .save (db , tr .save )
405421}
406422
407- func pred ( wr * PipeWriter , pr * extendedSelect , db * DB , cwd string , modelDir string ) error {
423+ func loadModelMeta ( pr * extendedSelect , db * DB , cwd string , modelDir string ) ( * extendedSelect , fieldTypes , error ) {
408424 var m * model
409425 var e error
410426 if modelDir != "" {
@@ -413,43 +429,42 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
413429 m , e = load (db , pr .model , cwd )
414430 }
415431 if e != nil {
416- return fmt .Errorf ("load %v" , e )
432+ return nil , nil , fmt .Errorf ("load %v" , e )
417433 }
418434
419435 // Parse the training SELECT statement used to train
420436 // the model for the prediction.
421437 tr , e := newParser ().Parse (m .TrainSelect )
422438 if e != nil {
423- return fmt .Errorf ("parse: TrainSelect %v raise %v" , m .TrainSelect , e )
439+ return nil , nil , fmt .Errorf ("parse: TrainSelect %v raise %v" , m .TrainSelect , e )
424440 }
425441
426442 if e := verifyColumnNameAndType (tr , pr , db ); e != nil {
427- return fmt .Errorf ("verifyColumnNameAndType: %v" , e )
443+ return nil , nil , fmt .Errorf ("verifyColumnNameAndType: %v" , e )
428444 }
429445
430446 pr .trainClause = tr .trainClause
431447 fts , e := verify (pr , db )
432448 if e != nil {
433- return fmt .Errorf ("verify: %v" , e )
449+ return nil , nil , fmt .Errorf ("verify: %v" , e )
450+ }
451+
452+ return pr , fts , nil
453+ }
454+
455+ func pred (wr * PipeWriter , pr * extendedSelect , db * DB , cwd string , modelDir string ) error {
456+ pr , fts , e := loadModelMeta (pr , db , cwd , modelDir )
457+ if e != nil {
458+ return fmt .Errorf ("loadModelMeta %v" , e )
434459 }
435460
436461 var buf bytes.Buffer
437- if strings .HasPrefix (strings .ToUpper (tr .estimator ), `XGBOOST.` ) {
438- // FIXME(sperlingxx): write a separate pred pipeline for xgboost to support remote mode
439- filler , e := newXGBoostFiller (pr , nil , fts , db )
440- if e != nil {
441- return fmt .Errorf ("genXG %v" , e )
442- }
443- if e := xgCreatePredictionTable (pr , filler , db ); e != nil {
444- return fmt .Errorf ("genXG %v" , e )
445- }
446- if e := xgTemplate .Execute (& buf , filler ); e != nil {
462+ if strings .HasPrefix (strings .ToUpper (pr .estimator ), `XGBOOST.` ) {
463+ // TODO(sperlingxx): write a separate pred pipeline for xgboost to support remote mode
464+ if e := genXG (& buf , pr , nil , fts , db ); e != nil {
447465 return fmt .Errorf ("genXG %v" , e )
448466 }
449467 } else {
450- if e := createPredictionTable (tr , pr , db ); e != nil {
451- return fmt .Errorf ("createPredictionTable: %v" , e )
452- }
453468 if e := genTF (& buf , pr , nil , fts , db ); e != nil {
454469 return fmt .Errorf ("genTF %v" , e )
455470 }
@@ -466,6 +481,20 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
466481}
467482
468483func analyze (wr * PipeWriter , es * extendedSelect , db * DB , cwd string , modelDir string ) error {
484+ //pr, fts, e := loadModelMeta(es, db, cwd, modelDir)
485+ //if e != nil {
486+ // return fmt.Errorf("loadModelMeta %v", e)
487+ //}
488+ //filler, e := buildFiller(pr, nil, fts, db)
489+ //if e != nil {
490+ // return e
491+ //}
492+ //switch filler.(type) {
493+ //case *xgboostFiller:
494+ //
495+ //default:
496+ //
497+ //}
469498 cmd := exec .Command ("python" , "-u" )
470499 cmd .Dir = cwd
471500 cmd .Stdin = strings .NewReader (analyzeTemplateText )
@@ -493,7 +522,7 @@ func analyze(wr *PipeWriter, es *extendedSelect, db *DB, cwd string, modelDir st
493522
494523// Create prediction table with appropriate column type.
495524// If prediction table already exists, it will be overwritten.
496- func createPredictionTable (trainParsed , predParsed * extendedSelect , db * DB ) error {
525+ func createPredictionTable (predParsed * extendedSelect , db * DB ) error {
497526 tableName , columnName , e := parseTableColumn (predParsed .into )
498527 if e != nil {
499528 return fmt .Errorf ("invalid predParsed.into, %v" , e )
@@ -504,14 +533,14 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro
504533 return fmt .Errorf ("failed executing %s: %q" , dropStmt , e )
505534 }
506535
507- fts , e := verify (trainParsed , db )
536+ fts , e := verify (predParsed , db )
508537 if e != nil {
509538 return e
510539 }
511540
512541 var b bytes.Buffer
513542 fmt .Fprintf (& b , "create table %s (" , tableName )
514- for _ , c := range trainParsed .columns ["feature_columns" ] {
543+ for _ , c := range predParsed .columns ["feature_columns" ] {
515544 name , err := getExpressionFieldName (c )
516545 if err != nil {
517546 return err
@@ -526,7 +555,7 @@ func createPredictionTable(trainParsed, predParsed *extendedSelect, db *DB) erro
526555 }
527556 fmt .Fprintf (& b , "%s %s, " , name , stype )
528557 }
529- typ , _ := fts .get (trainParsed .label )
558+ typ , _ := fts .get (predParsed .label )
530559 stype , e := universalizeColumnType (db .driverName , typ )
531560 if e != nil {
532561 return e
0 commit comments