Skip to content

Commit 1866be3

Browse files
Yancey0623weiguoz
authored andcommitted
[Intermediate Representation] xgboost predict using IR (#1049)
* refact xgboost predict using ir * fix ut
1 parent 6b4e926 commit 1866be3

File tree

10 files changed

+177
-24
lines changed

10 files changed

+177
-24
lines changed

cmd/sqlflowserver/main_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
306306
t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin)
307307
t.Run("CaseTrainRegression", CaseTrainRegression)
308308
t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression)
309+
t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression)
309310
}
310311

311312
func TestEnd2EndHive(t *testing.T) {

pkg/sql/codegen/tensorflow/codegen.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,20 +238,10 @@ func Pred(ir *codegen.PredictIR) (string, error) {
238238
}
239239
isKeras, estimatorStr := isKerasModel(ir.TrainIR.Estimator)
240240

241-
resultTableParts := strings.Split(ir.ResultTable, ".")
242-
resultTable := ""
243-
if len(resultTableParts) == 3 {
244-
resultTable = strings.Join(resultTableParts[0:2], ".")
245-
} else if len(resultTableParts) == 2 || len(resultTableParts) == 1 {
246-
resultTable = ir.ResultTable
247-
} else {
248-
return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table")
249-
}
250-
251241
filler := predFiller{
252242
DataSource: ir.DataSource,
253243
Select: ir.Select,
254-
ResultTable: resultTable,
244+
ResultTable: ir.ResultTable,
255245
Estimator: estimatorStr,
256246
IsKerasModel: isKeras,
257247
FieldMetas: fieldMetas,

pkg/sql/codegen/xgboost/codegen.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ import (
1919
"fmt"
2020
"strings"
2121

22-
"sqlflow.org/sqlflow/pkg/sql/codegen"
2322
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
23+
24+
"sqlflow.org/sqlflow/pkg/sql/codegen"
2425
)
2526

2627
func newFloat32(f float32) *float32 {
@@ -148,3 +149,31 @@ func Train(ir *codegen.TrainIR) (string, error) {
148149

149150
return program.String(), nil
150151
}
152+
153+
// Pred generates a Python program for predict a xgboost model.
154+
func Pred(ir *codegen.PredictIR) (string, error) {
155+
featureFieldMeta, labelFieldMeta, err := getFieldMeta(ir.TrainIR.Features["feature_columns"], ir.TrainIR.Label)
156+
f, err := json.Marshal(featureFieldMeta)
157+
if err != nil {
158+
return "", err
159+
}
160+
l, err := json.Marshal(labelFieldMeta)
161+
if err != nil {
162+
return "", err
163+
}
164+
165+
r := predFiller{
166+
DataSource: ir.DataSource,
167+
PredSelect: ir.Select,
168+
FeatureMetaJSON: string(f),
169+
LabelMetaJSON: string(l),
170+
ResultTable: ir.ResultTable,
171+
}
172+
173+
var program bytes.Buffer
174+
175+
if err := predTemplate.Execute(&program, r); err != nil {
176+
return "", err
177+
}
178+
return program.String(), nil
179+
}

pkg/sql/codegen/xgboost/codegen_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,25 @@ import (
2222
"sqlflow.org/sqlflow/pkg/sql/codegen"
2323
)
2424

25-
func TestTrain(t *testing.T) {
25+
func TestTrainAndPredict(t *testing.T) {
2626
a := assert.New(t)
2727
tir := mockTrainIR()
2828
_, err := Train(tir)
2929
a.NoError(err)
30+
31+
pir := mockPrdcIR(tir)
32+
_, err = Pred(pir)
33+
a.NoError(err)
3034
}
3135

36+
func mockPrdcIR(trainIR *codegen.TrainIR) *codegen.PredictIR {
37+
return &codegen.PredictIR{
38+
DataSource: trainIR.DataSource,
39+
Select: "select * from iris.test;",
40+
ResultTable: "iris.predict",
41+
TrainIR: trainIR,
42+
}
43+
}
3244
func mockTrainIR() *codegen.TrainIR {
3345
cfg := &mysql.Config{
3446
User: "root",
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2019 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package xgboost
15+
16+
import (
17+
"text/template"
18+
)
19+
20+
type predFiller struct {
21+
DataSource string
22+
PredSelect string
23+
FeatureMetaJSON string
24+
LabelMetaJSON string
25+
ResultTable string
26+
}
27+
28+
const predTemplateText = `
29+
import json
30+
import xgboost as xgb
31+
import numpy as np
32+
from sqlflow_submitter.db import connect_with_data_source, db_generator, buffered_db_writer
33+
34+
feature_field_meta = json.loads('''{{.FeatureMetaJSON}}''')
35+
label_field_meta = json.loads('''{{.LabelMetaJSON}}''')
36+
37+
feature_column_names = [k["name"] for k in feature_field_meta]
38+
label_name = label_field_meta["name"]
39+
40+
feature_specs = {k['name']: k for k in feature_field_meta}
41+
42+
conn = connect_with_data_source('''{{.DataSource}}''')
43+
44+
def xgb_dataset(fn, dataset_sql):
45+
gen = db_generator(conn.driver, conn, dataset_sql, feature_column_names, "", feature_specs)
46+
with open(fn, 'w') as f:
47+
for item in gen():
48+
features, label = item
49+
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
50+
f.write("\t".join(row_data) + "\n")
51+
# TODO(yancey1989): genearte group and weight text file if necessary
52+
return xgb.DMatrix(fn)
53+
54+
dpred = xgb_dataset('predict.txt', """{{.PredSelect}}""")
55+
56+
bst = xgb.Booster({'nthread': 4}) # init model
57+
bst.load_model("my_model") # load data
58+
preds = bst.predict(dpred)
59+
60+
# TODO(Yancey1989): using the train parameters to decide regressoin model or classifier model
61+
if len(preds.shape) == 2:
62+
# classifier result
63+
preds = np.argmax(np.array(preds), axis=1)
64+
65+
feature_file_read = open("predict.txt", "r")
66+
67+
result_column_names = feature_column_names
68+
result_column_names.append(label_name)
69+
line_no = 0
70+
with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w:
71+
while True:
72+
line = feature_file_read.readline()
73+
if not line:
74+
break
75+
row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]]
76+
row.append(str(preds[line_no]))
77+
w.write(row)
78+
line_no += 1
79+
print("Done predicting. Predict table : {{.ResultTable}}")
80+
`
81+
82+
var predTemplate = template.Must(template.New("Pred").Parse(predTemplateText))
File renamed without changes.

pkg/sql/executor.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
pb "sqlflow.org/sqlflow/pkg/server/proto"
3030
"sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow"
31-
xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
31+
"sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
3232
)
3333

3434
// Run executes a SQL query and returns a stream of rows or messages
@@ -413,12 +413,12 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
413413
var program bytes.Buffer
414414
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
415415
// FIXME(weiguoz): Remove the condition after the codegen refactor
416-
if os.Getenv("SQLFLOW_codegen") == "ir" {
416+
if enableIR() {
417417
ir, err := generateTrainIR(tr, db.String())
418418
if err != nil {
419419
return err
420420
}
421-
code, err := xgb.Train(ir)
421+
code, err := xgboost.Train(ir)
422422
if err != nil {
423423
return err
424424
}
@@ -430,7 +430,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
430430
}
431431
} else {
432432
// FIXME(typhoonzero): Remove the condition after the codegen refactor
433-
if os.Getenv("SQLFLOW_codegen") == "ir" {
433+
if enableIR() {
434434
ir, err := generateTrainIR(tr, db.String())
435435
if err != nil {
436436
return err
@@ -497,6 +497,10 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
497497
return pr, fts, nil
498498
}
499499

500+
func enableIR() bool {
501+
return os.Getenv("SQLFLOW_codegen") == "ir"
502+
}
503+
500504
func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error {
501505
pr, fts, e := loadModelMeta(pr, db, cwd, modelDir, pr.model)
502506
if e != nil {
@@ -505,11 +509,27 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
505509

506510
var buf bytes.Buffer
507511
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
508-
if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil {
509-
return fmt.Errorf("genXGBoost %v", e)
512+
if enableIR() {
513+
ir, err := generatePredictIR(pr, db.String(), cwd, modelDir)
514+
if err != nil {
515+
return err
516+
}
517+
code, err := xgboost.Pred(ir)
518+
if err != nil {
519+
return err
520+
}
521+
err = createPredictionTable(pr, db, session)
522+
if err != nil {
523+
return err
524+
}
525+
buf.WriteString(code)
526+
} else {
527+
if e := genXGBoost(&buf, pr, nil, fts, db, session); e != nil {
528+
return fmt.Errorf("genXGBoost %v", e)
529+
}
510530
}
511531
} else {
512-
if os.Getenv("SQLFLOW_codegen") == "ir" {
532+
if enableIR() {
513533
ir, err := generatePredictIR(pr, db.String(), cwd, modelDir)
514534
if err != nil {
515535
return err

pkg/sql/ir_generator.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,16 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi
9696
if err != nil {
9797
return nil, err
9898
}
99-
fmt.Printf("select %s, result table: %s\n", slct.standardSelect.String(), slct.into)
99+
100+
resultTable, err := parseResultTable(slct.into)
101+
if err != nil {
102+
return nil, err
103+
}
100104

101105
return &codegen.PredictIR{
102106
DataSource: connStr,
103107
Select: slct.standardSelect.String(),
104-
ResultTable: slct.into,
108+
ResultTable: resultTable,
105109
Attributes: attrMap,
106110
TrainIR: trainIR,
107111
}, nil
@@ -558,3 +562,19 @@ func parseShape(e *expr) ([]int, error) {
558562
}
559563
return shape, nil
560564
}
565+
566+
// parseResultTable parse out the table name from the INTO statment
567+
// as the following 3 cases:
568+
// db.table.class_col -> db.table # cut the column name
569+
// db.table -> db.table # using the specified db
570+
// table -> table # using the default db
571+
func parseResultTable(intoStatement string) (string, error) {
572+
resultTableParts := strings.Split(intoStatement, ".")
573+
if len(resultTableParts) == 3 {
574+
return strings.Join(resultTableParts[0:2], "."), nil
575+
} else if len(resultTableParts) == 2 || len(resultTableParts) == 1 {
576+
return intoStatement, nil
577+
} else {
578+
return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table")
579+
}
580+
}

pkg/sql/ir_generator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil)
186186
a.NoError(err)
187187

188188
a.Equal(connStr, predIR.DataSource)
189-
a.Equal("iris.predict.class", predIR.ResultTable)
189+
a.Equal("iris.predict", predIR.ResultTable)
190190
a.Equal("class", predIR.TrainIR.Label.GetFieldMeta()[0].Name)
191191
a.Equal("DNNClassifier", predIR.TrainIR.Estimator)
192192
nc, ok := predIR.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn)

pkg/sql/template_xgboost.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ feature_file_read = open("predict.txt", "r")
144144
145145
result_column_names = feature_column_names
146146
result_column_names.append("{{.Y.FeatureName}}")
147-
148147
line_no = 0
149148
with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100, hdfs_namenode_addr="{{.HDFSNameNodeAddr}}", hive_location="{{.HiveLocation}}", hdfs_user="{{.HDFSUser}}", hdfs_pass="{{.HDFSPass}}") as w:
150149
while True:

0 commit comments

Comments
 (0)