@@ -28,7 +28,7 @@ import (
2828
2929 pb "sqlflow.org/sqlflow/pkg/server/proto"
3030 "sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow"
31- xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
31+ "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
3232)
3333
3434// Run executes a SQL query and returns a stream of rows or messages
@@ -413,12 +413,12 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
413413 var program bytes.Buffer
414414 if strings .HasPrefix (strings .ToUpper (tr .estimator ), `XGBOOST.` ) {
415415 // FIXME(weiguoz): Remove the condition after the codegen refactor
416- if os . Getenv ( "SQLFLOW_codegen" ) == "ir" {
416+ if enableIR () {
417417 ir , err := generateTrainIR (tr , db .String ())
418418 if err != nil {
419419 return err
420420 }
421- code , err := xgb .Train (ir )
421+ code , err := xgboost .Train (ir )
422422 if err != nil {
423423 return err
424424 }
@@ -430,7 +430,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
430430 }
431431 } else {
432432 // FIXME(typhoonzero): Remove the condition after the codegen refactor
433- if os . Getenv ( "SQLFLOW_codegen" ) == "ir" {
433+ if enableIR () {
434434 ir , err := generateTrainIR (tr , db .String ())
435435 if err != nil {
436436 return err
@@ -497,6 +497,10 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
497497 return pr , fts , nil
498498}
499499
500+ func enableIR () bool {
501+ return os .Getenv ("SQLFLOW_codegen" ) == "ir"
502+ }
503+
500504func pred (wr * PipeWriter , pr * extendedSelect , db * DB , cwd string , modelDir string , session * pb.Session ) error {
501505 pr , fts , e := loadModelMeta (pr , db , cwd , modelDir , pr .model )
502506 if e != nil {
@@ -505,11 +509,27 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
505509
506510 var buf bytes.Buffer
507511 if strings .HasPrefix (strings .ToUpper (pr .estimator ), `XGBOOST.` ) {
508- if e := genXGBoost (& buf , pr , nil , fts , db , session ); e != nil {
509- return fmt .Errorf ("genXGBoost %v" , e )
512+ if enableIR () {
513+ ir , err := generatePredictIR (pr , db .String (), cwd , modelDir )
514+ if err != nil {
515+ return err
516+ }
517+ code , err := xgboost .Pred (ir )
518+ if err != nil {
519+ return err
520+ }
521+ err = createPredictionTable (pr , db , session )
522+ if err != nil {
523+ return err
524+ }
525+ buf .WriteString (code )
526+ } else {
527+ if e := genXGBoost (& buf , pr , nil , fts , db , session ); e != nil {
528+ return fmt .Errorf ("genXGBoost %v" , e )
529+ }
510530 }
511531 } else {
512- if os . Getenv ( "SQLFLOW_codegen" ) == "ir" {
532+ if enableIR () {
513533 ir , err := generatePredictIR (pr , db .String (), cwd , modelDir )
514534 if err != nil {
515535 return err
0 commit comments