Skip to content

Commit 31d6cb1

Browse files
authored
[Intermediate Representation] Enable Hive e2e test (#1060)
* enable ir hive e2e test * fix ut * update * update * update tf codegen test
1 parent ba11601 commit 31d6cb1

File tree

11 files changed

+159
-41
lines changed

11 files changed

+159
-41
lines changed

cmd/sqlflowserver/main_test.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func AssertGreaterEqualAny(a *assert.Assertions, actual *any.Any, expected inter
118118
case "type.googleapis.com/google.protobuf.FloatValue":
119119
b := wrappers.FloatValue{}
120120
ptypes.UnmarshalAny(actual, &b)
121-
a.GreaterOrEqual(float32(expected.(float64)), b.Value)
121+
a.GreaterOrEqual(b.Value, float32(expected.(float64)))
122122
}
123123
}
124124

@@ -189,7 +189,10 @@ func prepareTestData(dbStr string) error {
189189
if err := testdata.Popularize(testDB.DB, testdata.IrisHiveSQL); err != nil {
190190
return err
191191
}
192-
return testdata.Popularize(testDB.DB, testdata.ChurnHiveSQL)
192+
if err = testdata.Popularize(testDB.DB, testdata.ChurnHiveSQL); err != nil {
193+
return err
194+
}
195+
return testdata.Popularize(testDB.DB, testdata.HousingSQL)
193196
case "maxcompute":
194197
submitter := os.Getenv("SQLFLOW_submitter")
195198
if submitter == "alps" {
@@ -381,6 +384,38 @@ func TestEnd2EndHive(t *testing.T) {
381384
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
382385
}
383386

387+
func TestEnd2EndHiveIR(t *testing.T) {
388+
if os.Getenv("SQLFLOW_codegen") != "ir" {
389+
t.Skip("Skipping ir test")
390+
}
391+
392+
if os.Getenv("SQLFLOW_TEST_DB") != "hive" {
393+
t.Skip("Skipping hive tests")
394+
}
395+
396+
modelDir := ""
397+
tmpDir, caCrt, caKey, err := generateTempCA()
398+
defer os.RemoveAll(tmpDir)
399+
if err != nil {
400+
t.Fatalf("failed to generate CA pair %v", err)
401+
}
402+
403+
dbConnStr = "hive://root:[email protected]:10000/iris?auth=NOSASL"
404+
go start("", modelDir, caCrt, caKey, true, unitestPort)
405+
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)
406+
err = prepareTestData(dbConnStr)
407+
if err != nil {
408+
t.Fatalf("prepare test dataset failed: %v", err)
409+
}
410+
t.Run("TestShowDatabases", CaseShowDatabases)
411+
t.Run("TestSelect", CaseSelect)
412+
t.Run("TestTrainSQL", CaseTrainSQL)
413+
t.Run("CaseTrainCustomModel", CaseTrainCustomModel)
414+
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
415+
t.Run("CaseTrainXGBoostRegression", CaseTrainXGBoostRegression)
416+
t.Run("CasePredictXGBoostRegression", CasePredictXGBoostRegression)
417+
}
418+
384419
func TestEnd2EndMaxCompute(t *testing.T) {
385420
testDBDriver := os.Getenv("SQLFLOW_TEST_DB")
386421
modelDir, _ := ioutil.TempDir("/tmp", "sqlflow_ssl_")

pkg/sql/codegen/tensorflow/codegen.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"strings"
2020
"text/template"
2121

22+
pb "sqlflow.org/sqlflow/pkg/server/proto"
2223
"sqlflow.org/sqlflow/pkg/sql/codegen"
2324
)
2425

@@ -195,6 +196,7 @@ func Train(ir *codegen.TrainIR) (string, error) {
195196
ModelParams: modelParams,
196197
TrainParams: trainParams,
197198
Save: "model_save", // TODO(typhoonzero): executor.go will save the working directory, should test later.
199+
198200
}
199201
var program bytes.Buffer
200202
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
@@ -210,7 +212,7 @@ func Train(ir *codegen.TrainIR) (string, error) {
210212
}
211213

212214
// Pred generates a Python program for predict using a TensorFlow model.
213-
func Pred(ir *codegen.PredictIR) (string, error) {
215+
func Pred(ir *codegen.PredictIR, session *pb.Session) (string, error) {
214216
modelParams := make(map[string]interface{})
215217
for attrKey, attr := range ir.TrainIR.Attributes {
216218
if strings.HasPrefix(attrKey, "model.") {
@@ -249,6 +251,10 @@ func Pred(ir *codegen.PredictIR) (string, error) {
249251
Y: ir.TrainIR.Label.GetFieldMeta()[0],
250252
ModelParams: modelParams,
251253
Save: "model_save",
254+
HDFSNameNodeAddr: session.HdfsNamenodeAddr,
255+
HiveLocation: session.HiveLocation,
256+
HDFSUser: session.HdfsUser,
257+
HDFSPass: session.HdfsPass,
252258
}
253259
var program bytes.Buffer
254260
var predTemplate = template.Must(template.New("Pred").Funcs(template.FuncMap{

pkg/sql/codegen/tensorflow/codegen_test.go

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,43 @@ package tensorflow
1515

1616
import (
1717
"fmt"
18+
"regexp"
1819
"testing"
1920

2021
"github.com/go-sql-driver/mysql"
2122
"github.com/stretchr/testify/assert"
23+
pb "sqlflow.org/sqlflow/pkg/server/proto"
2224
"sqlflow.org/sqlflow/pkg/sql/codegen"
2325
)
2426

2527
func TestTrainCodegen(t *testing.T) {
2628
a := assert.New(t)
29+
tir := mockTrainIR()
30+
_, err := Train(tir)
31+
a.NoError(err)
32+
33+
pir := mockPredIR(tir)
34+
35+
sess := &pb.Session{
36+
Token: "",
37+
DbConnStr: "",
38+
ExitOnSubmit: false,
39+
UserId: "",
40+
HiveLocation: "/sqlflowtmp",
41+
HdfsNamenodeAddr: "192.168.1.1:8020",
42+
HdfsUser: "sqlflow_admin",
43+
HdfsPass: "sqlflow_pass",
44+
}
45+
code, err := Pred(pir, sess)
46+
a.NoError(err)
2747

48+
r, _ := regexp.Compile(`hdfs_user="(.*)"`)
49+
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_admin")
50+
r, _ = regexp.Compile(`hdfs_pass="(.*)"`)
51+
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_pass")
52+
}
53+
54+
func mockTrainIR() *codegen.TrainIR {
2855
cfg := &mysql.Config{
2956
User: "root",
3057
Passwd: "root",
@@ -42,7 +69,7 @@ func TestTrainCodegen(t *testing.T) {
4269
COLUMN sepal_length, sepal_width, petal_length, petal_width
4370
LABEL class
4471
INTO sqlflow_models.my_xgboost_model;`
45-
ir := codegen.TrainIR{
72+
return &codegen.TrainIR{
4673
DataSource: fmt.Sprintf("mysql://%s", cfg.FormatDSN()),
4774
Select: "select * from iris.train;",
4875
ValidationSelect: "select * from iris.test;",
@@ -59,16 +86,14 @@ func TestTrainCodegen(t *testing.T) {
5986
&codegen.NumericColumn{&codegen.FieldMeta{"petal_length", codegen.Float, "", []int{1}, false, nil}},
6087
&codegen.NumericColumn{&codegen.FieldMeta{"petal_width", codegen.Float, "", []int{1}, false, nil}}}},
6188
Label: &codegen.NumericColumn{&codegen.FieldMeta{"class", codegen.Int, "", []int{1}, false, nil}}}
62-
_, err := Train(&ir)
63-
a.NoError(err)
89+
}
6490

65-
predIR := codegen.PredictIR{
66-
DataSource: fmt.Sprintf("mysql://%s", cfg.FormatDSN()),
91+
func mockPredIR(trainIR *codegen.TrainIR) *codegen.PredictIR {
92+
return &codegen.PredictIR{
93+
DataSource: trainIR.DataSource,
6794
Select: "select * from iris.test;",
6895
ResultTable: "iris.predict",
6996
Attributes: make(map[string]interface{}),
70-
TrainIR: &ir,
97+
TrainIR: trainIR,
7198
}
72-
_, err = Pred(&predIR)
73-
a.NoError(err)
7499
}

pkg/sql/codegen/tensorflow/template_pred.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ type predFiller struct {
2727
Y *codegen.FieldMeta
2828
ModelParams map[string]interface{}
2929
Save string
30+
HDFSNameNodeAddr string
31+
HiveLocation string
32+
HDFSUser string
33+
HDFSPass string
3034
}
3135

3236
const tfPredTemplateText = `
@@ -78,5 +82,9 @@ pred(is_keras_model="{{.IsKerasModel}}" == "true",
7882
label_meta=label_meta,
7983
model_params=model_params,
8084
save="{{.Save}}",
81-
batch_size=1)
85+
batch_size=1,
86+
hdfs_namenode_addr="{{.HDFSNameNodeAddr}}",
87+
hive_location="{{.HiveLocation}}",
88+
hdfs_user="{{.HDFSUser}}",
89+
hdfs_pass="{{.HDFSPass}}")
8290
`

pkg/sql/codegen/xgboost/codegen.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ import (
1919
"fmt"
2020
"strings"
2121

22-
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
23-
22+
pb "sqlflow.org/sqlflow/pkg/server/proto"
2423
"sqlflow.org/sqlflow/pkg/sql/codegen"
24+
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
2525
)
2626

2727
func newFloat32(f float32) *float32 {
@@ -151,7 +151,7 @@ func Train(ir *codegen.TrainIR) (string, error) {
151151
}
152152

153153
// Pred generates a Python program for predict a xgboost model.
154-
func Pred(ir *codegen.PredictIR) (string, error) {
154+
func Pred(ir *codegen.PredictIR, session *pb.Session) (string, error) {
155155
featureFieldMeta, labelFieldMeta, err := getFieldMeta(ir.TrainIR.Features["feature_columns"], ir.TrainIR.Label)
156156
f, err := json.Marshal(featureFieldMeta)
157157
if err != nil {
@@ -163,11 +163,15 @@ func Pred(ir *codegen.PredictIR) (string, error) {
163163
}
164164

165165
r := predFiller{
166-
DataSource: ir.DataSource,
167-
PredSelect: ir.Select,
168-
FeatureMetaJSON: string(f),
169-
LabelMetaJSON: string(l),
170-
ResultTable: ir.ResultTable,
166+
DataSource: ir.DataSource,
167+
PredSelect: ir.Select,
168+
FeatureMetaJSON: string(f),
169+
LabelMetaJSON: string(l),
170+
ResultTable: ir.ResultTable,
171+
HDFSNameNodeAddr: session.HdfsNamenodeAddr,
172+
HiveLocation: session.HiveLocation,
173+
HDFSUser: session.HdfsUser,
174+
HDFSPass: session.HdfsPass,
171175
}
172176

173177
var program bytes.Buffer

pkg/sql/codegen/xgboost/codegen_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ package xgboost
1515

1616
import (
1717
"fmt"
18+
"regexp"
1819
"testing"
1920

2021
"github.com/go-sql-driver/mysql"
2122
"github.com/stretchr/testify/assert"
23+
pb "sqlflow.org/sqlflow/pkg/server/proto"
2224
"sqlflow.org/sqlflow/pkg/sql/codegen"
2325
)
2426

@@ -29,7 +31,23 @@ func TestTrainAndPredict(t *testing.T) {
2931
a.NoError(err)
3032

3133
pir := mockPrdcIR(tir)
32-
_, err = Pred(pir)
34+
sess := &pb.Session{
35+
Token: "",
36+
DbConnStr: "",
37+
ExitOnSubmit: false,
38+
UserId: "",
39+
HiveLocation: "/sqlflowtmp",
40+
HdfsNamenodeAddr: "192.168.1.1:8020",
41+
HdfsUser: "sqlflow_admin",
42+
HdfsPass: "sqlflow_pass",
43+
}
44+
code, err := Pred(pir, sess)
45+
46+
r, _ := regexp.Compile(`hdfs_user="(.*)"`)
47+
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_admin")
48+
r, _ = regexp.Compile(`hdfs_pass="(.*)"`)
49+
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_pass")
50+
3351
a.NoError(err)
3452
}
3553

pkg/sql/codegen/xgboost/template_pred.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ import (
1818
)
1919

2020
type predFiller struct {
21-
DataSource string
22-
PredSelect string
23-
FeatureMetaJSON string
24-
LabelMetaJSON string
25-
ResultTable string
21+
DataSource string
22+
PredSelect string
23+
FeatureMetaJSON string
24+
LabelMetaJSON string
25+
ResultTable string
26+
HDFSNameNodeAddr string
27+
HiveLocation string
28+
HDFSUser string
29+
HDFSPass string
2630
}
2731

2832
const predTemplateText = `
@@ -61,13 +65,20 @@ preds = bst.predict(dpred)
6165
if len(preds.shape) == 2:
6266
# classifier result
6367
preds = np.argmax(np.array(preds), axis=1)
64-
6568
feature_file_read = open("predict.txt", "r")
6669
6770
result_column_names = feature_column_names
6871
result_column_names.append(label_name)
6972
line_no = 0
70-
with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w:
73+
with buffered_db_writer(conn.driver,
74+
conn,
75+
"{{.ResultTable}}",
76+
result_column_names,
77+
100,
78+
hdfs_namenode_addr="{{.HDFSNameNodeAddr}}",
79+
hive_location="{{.HiveLocation}}",
80+
hdfs_user="{{.HDFSUser}}",
81+
hdfs_pass="{{.HDFSPass}}") as w:
7182
while True:
7283
line = feature_file_read.readline()
7384
if not line:

pkg/sql/executor.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
411411
return e
412412
}
413413
var program bytes.Buffer
414-
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
414+
if isXGBoostModel(tr.estimator) {
415415
// FIXME(weiguoz): Remove the condition after the codegen refactor
416416
if enableIR() {
417417
ir, err := generateTrainIR(tr, db.String())
@@ -497,24 +497,20 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
497497
return pr, fts, nil
498498
}
499499

500-
func enableIR() bool {
501-
return os.Getenv("SQLFLOW_codegen") == "ir"
502-
}
503-
504500
func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error {
505501
pr, fts, e := loadModelMeta(pr, db, cwd, modelDir, pr.model)
506502
if e != nil {
507503
return fmt.Errorf("loadModelMeta %v", e)
508504
}
509505

510506
var buf bytes.Buffer
511-
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
507+
if isXGBoostModel(pr.estimator) {
512508
if enableIR() {
513509
ir, err := generatePredictIR(pr, db.String(), cwd, modelDir)
514510
if err != nil {
515511
return err
516512
}
517-
code, err := xgboost.Pred(ir)
513+
code, err := xgboost.Pred(ir, session)
518514
if err != nil {
519515
return err
520516
}
@@ -534,7 +530,7 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
534530
if err != nil {
535531
return err
536532
}
537-
code, err := tensorflow.Pred(ir)
533+
code, err := tensorflow.Pred(ir, session)
538534
if err != nil {
539535
return err
540536
}
@@ -646,6 +642,15 @@ func createPredictionTable(predParsed *extendedSelect, db *DB, session *pb.Sessi
646642
return nil
647643
}
648644

645+
// -------------------------- utilities --------------------------------------
646+
func isXGBoostModel(estimator string) bool {
647+
return strings.HasPrefix(strings.ToUpper(estimator), `XGBOOST.`)
648+
}
649+
650+
func enableIR() bool {
651+
return os.Getenv("SQLFLOW_codegen") == "ir"
652+
}
653+
649654
func parseTableColumn(s string) (string, string, error) {
650655
pos := strings.LastIndex(s, ".")
651656
if pos == -1 || pos == len(s)-1 {

python/sqlflow_submitter/db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ def connect_with_data_source(driver_dsn):
8383
port=int(port))
8484
elif driver == "hive":
8585
from impala.dbapi import connect
86-
user, passwd, host, port, database, auth, session = parseHiveDSN(dsn)
86+
user, passwd, host, port, database, auth, session_cfg = parseHiveDSN(dsn)
8787
conn = connect(user=user,
8888
password=passwd,
8989
database=database,
9090
host=host,
9191
port=int(port),
9292
auth_mechanism=auth)
93-
conn.session = session
93+
conn.session_cfg = session_cfg
9494
elif driver == "maxcompute":
9595
from sqlflow_submitter.maxcompute import MaxCompute
9696
user, passwd, address, database = parseMaxComputeDSN(dsn)

0 commit comments

Comments
 (0)