Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/sqlflowserver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin)
t.Run("CaseTrainRegression", CaseTrainRegression)
t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression)
t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression)
}

func TestEnd2EndHive(t *testing.T) {
Expand Down
12 changes: 1 addition & 11 deletions pkg/sql/codegen/tensorflow/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,10 @@ func Pred(ir *codegen.PredictIR) (string, error) {
}
isKeras, estimatorStr := isKerasModel(ir.TrainIR.Estimator)

resultTableParts := strings.Split(ir.ResultTable, ".")
resultTable := ""
if len(resultTableParts) == 3 {
resultTable = strings.Join(resultTableParts[0:2], ".")
} else if len(resultTableParts) == 2 || len(resultTableParts) == 1 {
resultTable = ir.ResultTable
} else {
return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table")
}

filler := predFiller{
DataSource: ir.DataSource,
Select: ir.Select,
ResultTable: resultTable,
ResultTable: ir.ResultTable,
Estimator: estimatorStr,
IsKerasModel: isKeras,
FieldMetas: fieldMetas,
Expand Down
31 changes: 30 additions & 1 deletion pkg/sql/codegen/xgboost/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ import (
"fmt"
"strings"

"sqlflow.org/sqlflow/pkg/sql/codegen"
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"

"sqlflow.org/sqlflow/pkg/sql/codegen"
)

func newFloat32(f float32) *float32 {
Expand Down Expand Up @@ -148,3 +149,31 @@ func Train(ir *codegen.TrainIR) (string, error) {

return program.String(), nil
}

// Pred generates a Python program for predict a xgboost model.
func Pred(ir *codegen.PredictIR) (string, error) {
featureFieldMeta, labelFieldMeta, err := getFieldMeta(ir.TrainIR.Features["feature_columns"], ir.TrainIR.Label)
f, err := json.Marshal(featureFieldMeta)
if err != nil {
return "", err
}
l, err := json.Marshal(labelFieldMeta)
if err != nil {
return "", err
}

r := predFiller{
DataSource: ir.DataSource,
PredSelect: ir.Select,
FeatureMetaJSON: string(f),
LabelMetaJSON: string(l),
ResultTable: ir.ResultTable,
}

var program bytes.Buffer

if err := predTemplate.Execute(&program, r); err != nil {
return "", err
}
return program.String(), nil
}
14 changes: 13 additions & 1 deletion pkg/sql/codegen/xgboost/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ import (
"sqlflow.org/sqlflow/pkg/sql/codegen"
)

func TestTrain(t *testing.T) {
func TestTrainAndPredict(t *testing.T) {
a := assert.New(t)
tir := mockTrainIR()
_, err := Train(tir)
a.NoError(err)

pir := mockPrdcIR(tir)
_, err = Pred(pir)
a.NoError(err)
}

func mockPrdcIR(trainIR *codegen.TrainIR) *codegen.PredictIR {
return &codegen.PredictIR{
DataSource: trainIR.DataSource,
Select: "select * from iris.test;",
ResultTable: "iris.predict",
TrainIR: trainIR,
}
}
func mockTrainIR() *codegen.TrainIR {
cfg := &mysql.Config{
User: "root",
Expand Down
82 changes: 82 additions & 0 deletions pkg/sql/codegen/xgboost/template_pred.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2019 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package xgboost

import (
"text/template"
)

type predFiller struct {
DataSource string
PredSelect string
FeatureMetaJSON string
LabelMetaJSON string
ResultTable string
}

const predTemplateText = `
import json
import xgboost as xgb
import numpy as np
from sqlflow_submitter.db import connect_with_data_source, db_generator, buffered_db_writer

feature_field_meta = json.loads('''{{.FeatureMetaJSON}}''')
label_field_meta = json.loads('''{{.LabelMetaJSON}}''')

feature_column_names = [k["name"] for k in feature_field_meta]
label_name = label_field_meta["name"]

feature_specs = {k['name']: k for k in feature_field_meta}

conn = connect_with_data_source('''{{.DataSource}}''')

def xgb_dataset(fn, dataset_sql):
gen = db_generator(conn.driver, conn, dataset_sql, feature_column_names, "", feature_specs)
with open(fn, 'w') as f:
for item in gen():
features, label = item
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
f.write("\t".join(row_data) + "\n")
# TODO(yancey1989): genearte group and weight text file if necessary
return xgb.DMatrix(fn)

dpred = xgb_dataset('predict.txt', """{{.PredSelect}}""")

bst = xgb.Booster({'nthread': 4}) # init model
bst.load_model("my_model") # load data
preds = bst.predict(dpred)

# TODO(Yancey1989): using the train parameters to decide regressoin model or classifier model
if len(preds.shape) == 2:
# classifier result
preds = np.argmax(np.array(preds), axis=1)

feature_file_read = open("predict.txt", "r")

result_column_names = feature_column_names
result_column_names.append(label_name)
line_no = 0
with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w:
while True:
line = feature_file_read.readline()
if not line:
break
row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]]
row.append(str(preds[line_no]))
w.write(row)
line_no += 1
print("Done predicting. Predict table : {{.ResultTable}}")
`

var predTemplate = template.Must(template.New("Pred").Parse(predTemplateText))
34 changes: 27 additions & 7 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

pb "sqlflow.org/sqlflow/pkg/server/proto"
"sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow"
xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
"sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
)

// Run executes a SQL query and returns a stream of rows or messages
Expand Down Expand Up @@ -413,12 +413,12 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
var program bytes.Buffer
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
// FIXME(weiguoz): Remove the condition after the codegen refactor
if os.Getenv("SQLFLOW_codegen") == "ir" {
if enableIR() {
ir, err := generateTrainIR(tr, db.String())
if err != nil {
return err
}
code, err := xgb.Train(ir)
code, err := xgboost.Train(ir)
if err != nil {
return err
}
Expand All @@ -430,7 +430,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
}
} else {
// FIXME(typhoonzero): Remove the condition after the codegen refactor
if os.Getenv("SQLFLOW_codegen") == "ir" {
if enableIR() {
ir, err := generateTrainIR(tr, db.String())
if err != nil {
return err
Expand Down Expand Up @@ -497,6 +497,10 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
return pr, fts, nil
}

func enableIR() bool {
return os.Getenv("SQLFLOW_codegen") == "ir"
}

func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error {
pr, fts, e := loadModelMeta(pr, db, cwd, modelDir, pr.model)
if e != nil {
Expand All @@ -505,11 +509,27 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin

var buf bytes.Buffer
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil {
return fmt.Errorf("genXGBoost %v", e)
if enableIR() {
ir, err := generatePredictIR(pr, db.String(), cwd, modelDir)
if err != nil {
return err
}
code, err := xgboost.Pred(ir)
if err != nil {
return err
}
err = createPredictionTable(pr, db, session)
if err != nil {
return err
}
buf.WriteString(code)
} else {
if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil {
return fmt.Errorf("genXGBoost %v", e)
}
}
} else {
if os.Getenv("SQLFLOW_codegen") == "ir" {
if enableIR() {
ir, err := generatePredictIR(pr, db.String(), cwd, modelDir)
if err != nil {
return err
Expand Down
24 changes: 22 additions & 2 deletions pkg/sql/ir_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi
if err != nil {
return nil, err
}
fmt.Printf("select %s, result table: %s\n", slct.standardSelect.String(), slct.into)

resultTable, err := parseResultTable(slct.into)
if err != nil {
return nil, err
}

return &codegen.PredictIR{
DataSource: connStr,
Select: slct.standardSelect.String(),
ResultTable: slct.into,
ResultTable: resultTable,
Attributes: attrMap,
TrainIR: trainIR,
}, nil
Expand Down Expand Up @@ -558,3 +562,19 @@ func parseShape(e *expr) ([]int, error) {
}
return shape, nil
}

// parseResultTable parse out the table name from the INTO statment
// as the following 3 cases:
// db.table.class_col -> db.table # cut the column name
// db.table -> db.table # using the specified db
// table -> table # using the default db
func parseResultTable(intoStatement string) (string, error) {
resultTableParts := strings.Split(intoStatement, ".")
if len(resultTableParts) == 3 {
return strings.Join(resultTableParts[0:2], "."), nil
} else if len(resultTableParts) == 2 || len(resultTableParts) == 1 {
return intoStatement, nil
} else {
return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table")
}
}
2 changes: 1 addition & 1 deletion pkg/sql/ir_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil)
a.NoError(err)

a.Equal(connStr, predIR.DataSource)
a.Equal("iris.predict.class", predIR.ResultTable)
a.Equal("iris.predict", predIR.ResultTable)
a.Equal("class", predIR.TrainIR.Label.GetFieldMeta()[0].Name)
a.Equal("DNNClassifier", predIR.TrainIR.Estimator)
nc, ok := predIR.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn)
Expand Down
1 change: 0 additions & 1 deletion pkg/sql/template_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ feature_file_read = open("predict.txt", "r")

result_column_names = feature_column_names
result_column_names.append("{{.Y.FeatureName}}")

line_no = 0
with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100, hdfs_namenode_addr="{{.HDFSNameNodeAddr}}", hive_location="{{.HiveLocation}}", hdfs_user="{{.HDFSUser}}", hdfs_pass="{{.HDFSPass}}") as w:
while True:
Expand Down