diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 029260d020..5b1cee3210 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -306,6 +306,7 @@ func TestEnd2EndMySQLIR(t *testing.T) { t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin) t.Run("CaseTrainRegression", CaseTrainRegression) t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression) + t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression) } func TestEnd2EndHive(t *testing.T) { diff --git a/pkg/sql/codegen/tensorflow/codegen.go b/pkg/sql/codegen/tensorflow/codegen.go index e1d8cff56d..b1fa6abb6e 100644 --- a/pkg/sql/codegen/tensorflow/codegen.go +++ b/pkg/sql/codegen/tensorflow/codegen.go @@ -238,20 +238,10 @@ func Pred(ir *codegen.PredictIR) (string, error) { } isKeras, estimatorStr := isKerasModel(ir.TrainIR.Estimator) - resultTableParts := strings.Split(ir.ResultTable, ".") - resultTable := "" - if len(resultTableParts) == 3 { - resultTable = strings.Join(resultTableParts[0:2], ".") - } else if len(resultTableParts) == 2 || len(resultTableParts) == 1 { - resultTable = ir.ResultTable - } else { - return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table") - } - filler := predFiller{ DataSource: ir.DataSource, Select: ir.Select, - ResultTable: resultTable, + ResultTable: ir.ResultTable, Estimator: estimatorStr, IsKerasModel: isKeras, FieldMetas: fieldMetas, diff --git a/pkg/sql/codegen/xgboost/codegen.go b/pkg/sql/codegen/xgboost/codegen.go index d1f8b38f70..33fa9d3859 100644 --- a/pkg/sql/codegen/xgboost/codegen.go +++ b/pkg/sql/codegen/xgboost/codegen.go @@ -19,8 +19,9 @@ import ( "fmt" "strings" - "sqlflow.org/sqlflow/pkg/sql/codegen" "sqlflow.org/sqlflow/pkg/sql/codegen/attribute" + + "sqlflow.org/sqlflow/pkg/sql/codegen" ) func newFloat32(f float32) *float32 { @@ -148,3 +149,31 @@ func Train(ir *codegen.TrainIR) (string, error) { return program.String(), nil } + +// Pred generates a Python program for predict a xgboost model. +func Pred(ir *codegen.PredictIR) (string, error) { + featureFieldMeta, labelFieldMeta, err := getFieldMeta(ir.TrainIR.Features["feature_columns"], ir.TrainIR.Label) + f, err := json.Marshal(featureFieldMeta) + if err != nil { + return "", err + } + l, err := json.Marshal(labelFieldMeta) + if err != nil { + return "", err + } + + r := predFiller{ + DataSource: ir.DataSource, + PredSelect: ir.Select, + FeatureMetaJSON: string(f), + LabelMetaJSON: string(l), + ResultTable: ir.ResultTable, + } + + var program bytes.Buffer + + if err := predTemplate.Execute(&program, r); err != nil { + return "", err + } + return program.String(), nil +} diff --git a/pkg/sql/codegen/xgboost/codegen_test.go b/pkg/sql/codegen/xgboost/codegen_test.go index c5adaa1b04..33b00d28c6 100644 --- a/pkg/sql/codegen/xgboost/codegen_test.go +++ b/pkg/sql/codegen/xgboost/codegen_test.go @@ -22,13 +22,25 @@ import ( "sqlflow.org/sqlflow/pkg/sql/codegen" ) -func TestTrain(t *testing.T) { +func TestTrainAndPredict(t *testing.T) { a := assert.New(t) tir := mockTrainIR() _, err := Train(tir) a.NoError(err) + + pir := mockPrdcIR(tir) + _, err = Pred(pir) + a.NoError(err) } +func mockPrdcIR(trainIR *codegen.TrainIR) *codegen.PredictIR { + return &codegen.PredictIR{ + DataSource: trainIR.DataSource, + Select: "select * from iris.test;", + ResultTable: "iris.predict", + TrainIR: trainIR, + } +} func mockTrainIR() *codegen.TrainIR { cfg := &mysql.Config{ User: "root", diff --git a/pkg/sql/codegen/xgboost/template_pred.go b/pkg/sql/codegen/xgboost/template_pred.go new file mode 100644 index 0000000000..ce28fc2efd --- /dev/null +++ b/pkg/sql/codegen/xgboost/template_pred.go @@ -0,0 +1,82 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xgboost + +import ( + "text/template" +) + +type predFiller struct { + DataSource string + PredSelect string + FeatureMetaJSON string + LabelMetaJSON string + ResultTable string +} + +const predTemplateText = ` +import json +import xgboost as xgb +import numpy as np +from sqlflow_submitter.db import connect_with_data_source, db_generator, buffered_db_writer + +feature_field_meta = json.loads('''{{.FeatureMetaJSON}}''') +label_field_meta = json.loads('''{{.LabelMetaJSON}}''') + +feature_column_names = [k["name"] for k in feature_field_meta] +label_name = label_field_meta["name"] + +feature_specs = {k['name']: k for k in feature_field_meta} + +conn = connect_with_data_source('''{{.DataSource}}''') + +def xgb_dataset(fn, dataset_sql): + gen = db_generator(conn.driver, conn, dataset_sql, feature_column_names, "", feature_specs) + with open(fn, 'w') as f: + for item in gen(): + features, label = item + row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)] + f.write("\t".join(row_data) + "\n") + # TODO(yancey1989): genearte group and weight text file if necessary + return xgb.DMatrix(fn) + +dpred = xgb_dataset('predict.txt', """{{.PredSelect}}""") + +bst = xgb.Booster({'nthread': 4}) # init model +bst.load_model("my_model") # load data +preds = bst.predict(dpred) + +# TODO(Yancey1989): using the train parameters to decide regressoin model or classifier model +if len(preds.shape) == 2: + # classifier result + preds = np.argmax(np.array(preds), axis=1) + +feature_file_read = open("predict.txt", "r") + +result_column_names = feature_column_names +result_column_names.append(label_name) +line_no = 0 +with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w: + while True: + line = feature_file_read.readline() + if not line: + break + row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]] + row.append(str(preds[line_no])) + w.write(row) + line_no += 1 +print("Done predicting. Predict table : {{.ResultTable}}") +` + +var predTemplate = template.Must(template.New("Pred").Parse(predTemplateText)) diff --git a/pkg/sql/codegen/xgboost/template.go b/pkg/sql/codegen/xgboost/template_train.go similarity index 100% rename from pkg/sql/codegen/xgboost/template.go rename to pkg/sql/codegen/xgboost/template_train.go diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index 5a72be8e45..7de5033c16 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -28,7 +28,7 @@ import ( pb "sqlflow.org/sqlflow/pkg/server/proto" "sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow" - xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost" + "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost" ) // Run executes a SQL query and returns a stream of rows or messages @@ -413,12 +413,12 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri var program bytes.Buffer if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) { // FIXME(weiguoz): Remove the condition after the codegen refactor - if os.Getenv("SQLFLOW_codegen") == "ir" { + if enableIR() { ir, err := generateTrainIR(tr, db.String()) if err != nil { return err } - code, err := xgb.Train(ir) + code, err := xgboost.Train(ir) if err != nil { return err } @@ -430,7 +430,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri } } else { // FIXME(typhoonzero): Remove the condition after the codegen refactor - if os.Getenv("SQLFLOW_codegen") == "ir" { + if enableIR() { ir, err := generateTrainIR(tr, db.String()) if err != nil { return err @@ -497,6 +497,10 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string) return pr, fts, nil } +func enableIR() bool { + return os.Getenv("SQLFLOW_codegen") == "ir" +} + func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error { pr, fts, e := loadModelMeta(pr, db, cwd, modelDir, pr.model) if e != nil { @@ -505,11 +509,27 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin var buf bytes.Buffer if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) { - if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil { - return fmt.Errorf("genXGBoost %v", e) + if enableIR() { + ir, err := generatePredictIR(pr, db.String(), cwd, modelDir) + if err != nil { + return err + } + code, err := xgboost.Pred(ir) + if err != nil { + return err + } + err = createPredictionTable(pr, db, session) + if err != nil { + return err + } + buf.WriteString(code) + } else { + if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil { + return fmt.Errorf("genXGBoost %v", e) + } } } else { - if os.Getenv("SQLFLOW_codegen") == "ir" { + if enableIR() { ir, err := generatePredictIR(pr, db.String(), cwd, modelDir) if err != nil { return err diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index a01f4b4173..787e638b85 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -96,12 +96,16 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi if err != nil { return nil, err } - fmt.Printf("select %s, result table: %s\n", slct.standardSelect.String(), slct.into) + + resultTable, err := parseResultTable(slct.into) + if err != nil { + return nil, err + } return &codegen.PredictIR{ DataSource: connStr, Select: slct.standardSelect.String(), - ResultTable: slct.into, + ResultTable: resultTable, Attributes: attrMap, TrainIR: trainIR, }, nil @@ -558,3 +562,19 @@ func parseShape(e *expr) ([]int, error) { } return shape, nil } + +// parseResultTable parse out the table name from the INTO statment +// as the following 3 cases: +// db.table.class_col -> db.table # cut the column name +// db.table -> db.table # using the specified db +// table -> table # using the default db +func parseResultTable(intoStatement string) (string, error) { + resultTableParts := strings.Split(intoStatement, ".") + if len(resultTableParts) == 3 { + return strings.Join(resultTableParts[0:2], "."), nil + } else if len(resultTableParts) == 2 || len(resultTableParts) == 1 { + return intoStatement, nil + } else { + return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table") + } +} diff --git a/pkg/sql/ir_generator_test.go b/pkg/sql/ir_generator_test.go index 9e2517ba39..1021ff03c7 100644 --- a/pkg/sql/ir_generator_test.go +++ b/pkg/sql/ir_generator_test.go @@ -186,7 +186,7 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil) a.NoError(err) a.Equal(connStr, predIR.DataSource) - a.Equal("iris.predict.class", predIR.ResultTable) + a.Equal("iris.predict", predIR.ResultTable) a.Equal("class", predIR.TrainIR.Label.GetFieldMeta()[0].Name) a.Equal("DNNClassifier", predIR.TrainIR.Estimator) nc, ok := predIR.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn) diff --git a/pkg/sql/template_xgboost.go b/pkg/sql/template_xgboost.go index 425e9a4f11..5fca19eda1 100644 --- a/pkg/sql/template_xgboost.go +++ b/pkg/sql/template_xgboost.go @@ -144,7 +144,6 @@ feature_file_read = open("predict.txt", "r") result_column_names = feature_column_names result_column_names.append("{{.Y.FeatureName}}") - line_no = 0 with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100, hdfs_namenode_addr="{{.HDFSNameNodeAddr}}", hive_location="{{.HiveLocation}}", hdfs_user="{{.HDFSUser}}", hdfs_pass="{{.HDFSPass}}") as w: while True: