From 106ba3089c9b449449bdc24401bab97759ca7ad9 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 Oct 2019 16:15:09 +0800 Subject: [PATCH 1/6] refact xgboost predict using ir --- cmd/sqlflowserver/main_test.go | 1 + pkg/sql/codegen/xgboost/codegen.go | 31 ++++++++- pkg/sql/codegen/xgboost/codegen_test.go | 11 +++- pkg/sql/codegen/xgboost/template_pred.go | 83 ++++++++++++++++++++++++ pkg/sql/executor.go | 34 ++++++++-- 5 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 pkg/sql/codegen/xgboost/template_pred.go diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 029260d020..54581d2b98 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -306,6 +306,7 @@ func TestEnd2EndMySQLIR(t *testing.T) { t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin) t.Run("CaseTrainRegression", CaseTrainRegression) t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression) + t.Run("CasePredXGBoostRegressionIR", CasePredictXGBoostRegression) } func TestEnd2EndHive(t *testing.T) { diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index b3e2a121d7..5d6dcf1d86 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -17,9 +17,10 @@ import ( "bytes" "encoding/json" "fmt" - "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" "strings" + "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" + "sqlflow.org/sqlflow/pkg/sql/codegen" ) @@ -148,3 +149,31 @@ func Train(ir *codegen.TrainIR) (string, error) { return program.String(), nil } + +// Pred generates a Python program for predict a xgboost model. +func Pred(ir *codegen.PredictIR) (string, error) { + featureFieldMeta, labelFieldMeta, err := getFieldMeta(ir.TrainIR.Features["feature_columns"], ir.TrainIR.Label) + f, err := json.Marshal(featureFieldMeta) + if err != nil { + return "", err + } + l, err := json.Marshal(labelFieldMeta) + if err != nil { + return "", err + } + + r := predFiller{ + DataSource: ir.DataSource, + PredSelect: ir.Select, + FeatureMetaJSON: string(f), + LabelMetaJSON: string(l), + } + + var program bytes.Buffer + + if err := predTemplate.Execute(&program, r); err != nil { + return "", nil + } + fmt.Println(program.String()) + return program.String(), nil +} diff --git a/pkg/sql/codegen/xgboost/codegen_test.go b/pkg/sql/codegen/xgboost/codegen_test.go index 99184781ce..4a5bcb9754 100644 --- a/pkg/sql/codegen/xgboost/codegen_test.go +++ b/pkg/sql/codegen/xgboost/codegen_test.go @@ -22,7 +22,7 @@ import ( "sqlflow.org/sqlflow/pkg/sql/codegen" ) -func TestTrain(t *testing.T) { +func TestTrainAndPredict(t *testing.T) { a := assert.New(t) cfg := &mysql.Config{ @@ -62,4 +62,13 @@ func TestTrain(t *testing.T) { Label: &codegen.NumericColumn{&codegen.FieldMeta{"class", codegen.Int, "", []int{1}, false, nil}}} _, err := Train(ir) a.NoError(err) + + predIR := codegen.PredictIR{ + DataSource: fmt.Sprintf("mysql://%s", cfg.FormatDSN()), + Select: "select * from iris.test;", + ResultTable: "iris.predict", + TrainIR: ir, + } + _, err = Pred(&predIR) + a.NoError(err) } diff --git a/pkg/sql/codegen/xgboost/template_pred.go b/pkg/sql/codegen/xgboost/template_pred.go new file mode 100644 index 0000000000..6496b18241 --- /dev/null +++ b/pkg/sql/codegen/xgboost/template_pred.go @@ -0,0 +1,83 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xgboost + +import ( + "text/template" +) + +type predFiller struct { + DataSource string + PredSelect string + FeatureMetaJSON string + LabelMetaJSON string + HDFSNameNodeAddr string + HiveLocation string +} + +const predTemplateText = ` +import xgboost as xgb +import numpy as np +from sqlflow_submitter.db import connect, db_generator, buffered_db_writer + +feature_field_meta = json.loads('''{{.FieldMetaJSON}}''') +label_field_meta = json.loads('''{{.LabelJSON}}''') + +feature_column_name = sorted([k["name"] for k in feature_field_meta]) +label_name = label_field_meta["name"] + +feature_spec = {k['name']: k for k in feature_field_meta} + +conn = connect_with_data_source('''{{.DataSource}}''') + +def xgb_dataset(fn, dataset_sql): + gen = db_generator(driver, conn, dataset_sql, feature_column_names, "", feature_specs) + with open(fn, 'w') as f: + for item in gen(): + features, label = item + row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] + f.write("\t".join(row_data) + "\n") + # TODO(yancey1989): genearte group and weight text file if necessary + return xgb.DMatrix(fn) + +dpred = xgb_dataset('predict.txt', """{{.PredSelect}}""") + +bst = xgb.Booster({'nthread': 4}) # init model +bst.load_model("{{.Save}}") # load data +preds = bst.predict(dpred) + +# TODO(Yancey1989): using the train parameters to decide regressoin model or classifier model +if len(preds.shape) == 2: + # classifier result + preds = np.argmax(np.array(preds), axis=1) + +feature_file_read = open("predict.txt", "r") + +result_column_names = feature_column_names +result_column_names.append(label_name) + +line_no = 0 +with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w: + while True: + line = feature_file_read.readline() + if not line: + break + row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]] + row.append(preds[line_no]) + w.write(row) + line_no += 1 +print("Done predicting. Predict table : {{.ResultName}}") +` + +var predTemplate = template.Must(template.New("Pred").Parse(predTemplateText)) diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index 5a72be8e45..d9e05d6a8e 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -28,6 +28,7 @@ import ( pb "sqlflow.org/sqlflow/pkg/server/proto" "sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow" + "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost" xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost" ) @@ -413,7 +414,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri var program bytes.Buffer if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { // FIXME(weiguoz): Remove the condition after the codegen refactor - if os.Getenv("SQLFLOW_codegen") == "ir" { + if enableIR() { ir, err := generateTrainIR(tr, db.String()) if err != nil { return err @@ -430,7 +431,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri } } else { // FIXME(typhoonzero): Remove the condition after the codegen refactor - if os.Getenv("SQLFLOW_codegen") == "ir" { + if enableIR() { ir, err := generateTrainIR(tr, db.String()) if err != nil { return err @@ -497,6 +498,13 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string) return pr, fts, nil } +func enableIR() bool { + if os.Getenv("SQLFLOW_codegen") == "ir" { + return true + } + return false +} + func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error { pr, fts, e := loadModelMeta(pr, db, cwd, modelDir, pr.model) if e != nil { @@ -505,11 +513,27 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin var buf bytes.Buffer if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { - if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil { - return fmt.Errorf("genXGBoost %v", e) + if enableIR() { + ir, err := generatePredictIR(pr, db.String(), cwd, modelDir) + if err != nil { + return err + } + code, err := xgboost.Pred(ir) + if err != nil { + return err + } + err = createPredictionTable(pr, db, session) + if err != nil { + return err + } + buf.WriteString(code) + } else { + if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil { + return fmt.Errorf("genXGBoost %v", e) + } } } else { - if os.Getenv("SQLFLOW_codegen") == "ir" { + if enableIR() { ir, err := generatePredictIR(pr, db.String(), cwd, modelDir) if err != nil { return err From e72157214c98d87100112aee456e889ec11787b5 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 Oct 2019 16:20:39 +0800 Subject: [PATCH 2/6] update --- pkg/sql/codegen/xgboost/codegen.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index 5d6dcf1d86..f47e458105 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -174,6 +174,5 @@ func Pred(ir *codegen.PredictIR) (string, error) { if err := predTemplate.Execute(&program, r); err != nil { return "", nil } - fmt.Println(program.String()) return program.String(), nil } From 93cc26ff178f2d3861f4926f4988474db364312b Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 Oct 2019 20:41:30 +0800 Subject: [PATCH 3/6] fix ut --- cmd/sqlflowserver/main_test.go | 2 +- pkg/sql/codegen/xgboost/codegen.go | 3 +- pkg/sql/codegen/xgboost/template_pred.go | 31 +++++++++---------- .../{template.go => template_train.go} | 0 pkg/sql/executor.go | 5 +-- pkg/sql/ir_generator.go | 20 +++++++++++- pkg/sql/template_xgboost.go | 3 +- 7 files changed, 40 insertions(+), 24 deletions(-) rename pkg/sql/codegen/xgboost/{template.go => template_train.go} (100%) diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 54581d2b98..5b1cee3210 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -306,7 +306,7 @@ func TestEnd2EndMySQLIR(t *testing.T) { t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin) t.Run("CaseTrainRegression", CaseTrainRegression) t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression) - t.Run("CasePredXGBoostRegressionIR", CasePredictXGBoostRegression) + t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression) } func TestEnd2EndHive(t *testing.T) { diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index f47e458105..33fa9d3859 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -167,12 +167,13 @@ func Pred(ir *codegen.PredictIR) (string, error) { PredSelect: ir.Select, FeatureMetaJSON: string(f), LabelMetaJSON: string(l), + ResultTable: ir.ResultTable, } var program bytes.Buffer if err := predTemplate.Execute(&program, r); err != nil { - return "", nil + return "", err } return program.String(), nil } diff --git a/pkg/sql/codegen/xgboost/template_pred.go b/pkg/sql/codegen/xgboost/template_pred.go index 6496b18241..ce28fc2efd 100644 --- a/pkg/sql/codegen/xgboost/template_pred.go +++ b/pkg/sql/codegen/xgboost/template_pred.go @@ -18,31 +18,31 @@ import ( ) type predFiller struct { - DataSource string - PredSelect string - FeatureMetaJSON string - LabelMetaJSON string - HDFSNameNodeAddr string - HiveLocation string + DataSource string + PredSelect string + FeatureMetaJSON string + LabelMetaJSON string + ResultTable string } const predTemplateText = ` +import json import xgboost as xgb import numpy as np -from sqlflow_submitter.db import connect, db_generator, buffered_db_writer +from sqlflow_submitter.db import connect_with_data_source, db_generator, buffered_db_writer -feature_field_meta = json.loads('''{{.FieldMetaJSON}}''') -label_field_meta = json.loads('''{{.LabelJSON}}''') +feature_field_meta = json.loads('''{{.FeatureMetaJSON}}''') +label_field_meta = json.loads('''{{.LabelMetaJSON}}''') -feature_column_name = sorted([k["name"] for k in feature_field_meta]) +feature_column_names = [k["name"] for k in feature_field_meta] label_name = label_field_meta["name"] -feature_spec = {k['name']: k for k in feature_field_meta} +feature_specs = {k['name']: k for k in feature_field_meta} conn = connect_with_data_source('''{{.DataSource}}''') def xgb_dataset(fn, dataset_sql): - gen = db_generator(driver, conn, dataset_sql, feature_column_names, "", feature_specs) + gen = db_generator(conn.driver, conn, dataset_sql, feature_column_names, "", feature_specs) with open(fn, 'w') as f: for item in gen(): features, label = item @@ -54,7 +54,7 @@ def xgb_dataset(fn, dataset_sql): dpred = xgb_dataset('predict.txt', """{{.PredSelect}}""") bst = xgb.Booster({'nthread': 4}) # init model -bst.load_model("{{.Save}}") # load data +bst.load_model("my_model") # load data preds = bst.predict(dpred) # TODO(Yancey1989): using the train parameters to decide regressoin model or classifier model @@ -66,7 +66,6 @@ feature_file_read = open("predict.txt", "r") result_column_names = feature_column_names result_column_names.append(label_name) - line_no = 0 with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w: while True: @@ -74,10 +73,10 @@ with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_nam if not line: break row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]] - row.append(preds[line_no]) + row.append(str(preds[line_no])) w.write(row) line_no += 1 -print("Done predicting. Predict table : {{.ResultName}}") +print("Done predicting. Predict table : {{.ResultTable}}") ` var predTemplate = template.Must(template.New("Pred").Parse(predTemplateText)) diff --git a/pkg/sql/codegen/xgboost/template.go b/pkg/sql/codegen/xgboost/template_train.go similarity index 100% rename from pkg/sql/codegen/xgboost/template.go rename to pkg/sql/codegen/xgboost/template_train.go diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index d9e05d6a8e..f4086510eb 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -499,10 +499,7 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string) } func enableIR() bool { - if os.Getenv("SQLFLOW_codegen") == "ir" { - return true - } - return false + return os.Getenv("SQLFLOW_codegen") == "ir" } func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error { diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index a01f4b4173..bc68ee71ce 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -98,10 +98,15 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi } fmt.Printf("select %s, result table: %s\n", slct.standardSelect.String(), slct.into) + resultTable, err := parseResultTable(slct.into) + if err != nil { + return nil, err + } + return &codegen.PredictIR{ DataSource: connStr, Select: slct.standardSelect.String(), - ResultTable: slct.into, + ResultTable: resultTable, Attributes: attrMap, TrainIR: trainIR, }, nil @@ -558,3 +563,16 @@ func parseShape(e *expr) ([]int, error) { } return shape, nil } + +func parseResultTable(intoStatement string) (string, error) { + resultTableParts := strings.Split(intoStatement, ".") + resultTable := "" + if len(resultTableParts) == 3 { + resultTable = strings.Join(resultTableParts[0:2], ".") + } else if len(resultTableParts) == 2 || len(resultTableParts) == 1 { + resultTable = intoStatement + } else { + return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table") + } + return resultTable, nil +} diff --git a/pkg/sql/template_xgboost.go b/pkg/sql/template_xgboost.go index 1550439847..83a78e2bcb 100644 --- a/pkg/sql/template_xgboost.go +++ b/pkg/sql/template_xgboost.go @@ -144,7 +144,7 @@ feature_file_read = open("predict.txt", "r") result_column_names = feature_column_names result_column_names.append("{{.Y.FeatureName}}") - +print(result_column_names) line_no = 0 with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100, hdfs_namenode_addr="{{.HDFSNameNodeAddr}}", hive_location="{{.HiveLocation}}") as w: while True: @@ -153,6 +153,7 @@ with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100 break row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]] row.append(preds[line_no]) + print(row) w.write(row) line_no += 1 print("Done predicting. Predict table : {{.TableName}}") From 6e7ab885c00ed52ee86e846d78775af6cd300084 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 Oct 2019 20:44:23 +0800 Subject: [PATCH 4/6] update --- pkg/sql/template_xgboost.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/sql/template_xgboost.go b/pkg/sql/template_xgboost.go index 83a78e2bcb..b0d34acf9d 100644 --- a/pkg/sql/template_xgboost.go +++ b/pkg/sql/template_xgboost.go @@ -144,7 +144,6 @@ feature_file_read = open("predict.txt", "r") result_column_names = feature_column_names result_column_names.append("{{.Y.FeatureName}}") -print(result_column_names) line_no = 0 with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100, hdfs_namenode_addr="{{.HDFSNameNodeAddr}}", hive_location="{{.HiveLocation}}") as w: while True: @@ -153,7 +152,6 @@ with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100 break row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]] row.append(preds[line_no]) - print(row) w.write(row) line_no += 1 print("Done predicting. Predict table : {{.TableName}}") From b90839c3d17ea8b57a6303ad8c969d8cdb02f880 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 22 Oct 2019 22:32:52 +0800 Subject: [PATCH 5/6] fix ut --- pkg/sql/ir_generator_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sql/ir_generator_test.go b/pkg/sql/ir_generator_test.go index 9e2517ba39..1021ff03c7 100644 --- a/pkg/sql/ir_generator_test.go +++ b/pkg/sql/ir_generator_test.go @@ -186,7 +186,7 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil) a.NoError(err) a.Equal(connStr, predIR.DataSource) - a.Equal("iris.predict.class", predIR.ResultTable) + a.Equal("iris.predict", predIR.ResultTable) a.Equal("class", predIR.TrainIR.Label.GetFieldMeta()[0].Name) a.Equal("DNNClassifier", predIR.TrainIR.Estimator) nc, ok := predIR.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn) From 0977ec4ec9e776e0e80423ddcd33b636b223eda0 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 23 Oct 2019 10:47:58 +0800 Subject: [PATCH 6/6] update --- pkg/sql/codegen/tensorflow/codegen.go | 12 +----------- pkg/sql/codegen/xgboost/codegen.go | 1 - pkg/sql/executor.go | 3 +-- pkg/sql/ir_generator.go | 12 +++++++----- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/pkg/sql/codegen/tensorflow/codegen.go b/pkg/sql/codegen/tensorflow/codegen.go index e1d8cff56d..b1fa6abb6e 100644 --- a/pkg/sql/codegen/tensorflow/codegen.go +++ b/pkg/sql/codegen/tensorflow/codegen.go @@ -238,20 +238,10 @@ func Pred(ir *codegen.PredictIR) (string, error) { } isKeras, estimatorStr := isKerasModel(ir.TrainIR.Estimator) - resultTableParts := strings.Split(ir.ResultTable, ".") - resultTable := "" - if len(resultTableParts) == 3 { - resultTable = strings.Join(resultTableParts[0:2], ".") - } else if len(resultTableParts) == 2 || len(resultTableParts) == 1 { - resultTable = ir.ResultTable - } else { - return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table") - } - filler := predFiller{ DataSource: ir.DataSource, Select: ir.Select, - ResultTable: resultTable, + ResultTable: ir.ResultTable, Estimator: estimatorStr, IsKerasModel: isKeras, FieldMetas: fieldMetas, diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index 3fd10fbc74..33fa9d3859 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -22,7 +22,6 @@ import ( "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" "sqlflow.org/sqlflow/pkg/sql/codegen" - "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" ) func newFloat32(f float32) *float32 { diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index f4086510eb..7de5033c16 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -29,7 +29,6 @@ import ( pb "sqlflow.org/sqlflow/pkg/server/proto" "sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow" "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost" - xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost" ) // Run executes a SQL query and returns a stream of rows or messages @@ -419,7 +418,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri if err != nil { return err } - code, err := xgb.Train(ir) + code, err := xgboost.Train(ir) if err != nil { return err } diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index bc68ee71ce..787e638b85 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -96,7 +96,6 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi if err != nil { return nil, err } - fmt.Printf("select %s, result table: %s\n", slct.standardSelect.String(), slct.into) resultTable, err := parseResultTable(slct.into) if err != nil { @@ -564,15 +563,18 @@ func parseShape(e *expr) ([]int, error) { return shape, nil } +// parseResultTable parse out the table name from the INTO statment +// as the following 3 cases: +// db.table.class_col -> db.table # cut the column name +// db.table -> db.table # using the specified db +// table -> table # using the default db func parseResultTable(intoStatement string) (string, error) { resultTableParts := strings.Split(intoStatement, ".") - resultTable := "" if len(resultTableParts) == 3 { - resultTable = strings.Join(resultTableParts[0:2], ".") + return strings.Join(resultTableParts[0:2], "."), nil } else if len(resultTableParts) == 2 || len(resultTableParts) == 1 { - resultTable = intoStatement + return intoStatement, nil } else { return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table") } - return resultTable, nil }