@@ -15,16 +15,43 @@ package tensorflow
1515
1616import (
1717 "fmt"
18+ "regexp"
1819 "testing"
1920
2021 "github.com/go-sql-driver/mysql"
2122 "github.com/stretchr/testify/assert"
23+ pb "sqlflow.org/sqlflow/pkg/server/proto"
2224 "sqlflow.org/sqlflow/pkg/sql/codegen"
2325)
2426
2527func TestTrainCodegen (t * testing.T ) {
2628 a := assert .New (t )
29+ tir := mockTrainIR ()
30+ _ , err := Train (tir )
31+ a .NoError (err )
32+
33+ pir := mockPredIR (tir )
34+
35+ sess := & pb.Session {
36+ Token : "" ,
37+ DbConnStr : "" ,
38+ ExitOnSubmit : false ,
39+ UserId : "" ,
40+ HiveLocation : "/sqlflowtmp" ,
41+ HdfsNamenodeAddr : "192.168.1.1:8020" ,
42+ HdfsUser : "sqlflow_admin" ,
43+ HdfsPass : "sqlflow_pass" ,
44+ }
45+ code , err := Pred (pir , sess )
46+ a .NoError (err )
2747
48+ r , _ := regexp .Compile (`hdfs_user="(.*)"` )
49+ a .Equal (r .FindStringSubmatch (code )[1 ], "sqlflow_admin" )
50+ r , _ = regexp .Compile (`hdfs_pass="(.*)"` )
51+ a .Equal (r .FindStringSubmatch (code )[1 ], "sqlflow_pass" )
52+ }
53+
54+ func mockTrainIR () * codegen.TrainIR {
2855 cfg := & mysql.Config {
2956 User : "root" ,
3057 Passwd : "root" ,
@@ -42,7 +69,7 @@ func TestTrainCodegen(t *testing.T) {
4269 COLUMN sepal_length, sepal_width, petal_length, petal_width
4370 LABEL class
4471 INTO sqlflow_models.my_xgboost_model;`
45- ir := codegen.TrainIR {
72+ return & codegen.TrainIR {
4673 DataSource : fmt .Sprintf ("mysql://%s" , cfg .FormatDSN ()),
4774 Select : "select * from iris.train;" ,
4875 ValidationSelect : "select * from iris.test;" ,
@@ -59,16 +86,14 @@ func TestTrainCodegen(t *testing.T) {
5986 & codegen.NumericColumn {& codegen.FieldMeta {"petal_length" , codegen .Float , "" , []int {1 }, false , nil }},
6087 & codegen.NumericColumn {& codegen.FieldMeta {"petal_width" , codegen .Float , "" , []int {1 }, false , nil }}}},
6188 Label : & codegen.NumericColumn {& codegen.FieldMeta {"class" , codegen .Int , "" , []int {1 }, false , nil }}}
62- _ , err := Train (& ir )
63- a .NoError (err )
89+ }
6490
65- predIR := codegen.PredictIR {
66- DataSource : fmt .Sprintf ("mysql://%s" , cfg .FormatDSN ()),
91+ func mockPredIR (trainIR * codegen.TrainIR ) * codegen.PredictIR {
92+ return & codegen.PredictIR {
93+ DataSource : trainIR .DataSource ,
6794 Select : "select * from iris.test;" ,
6895 ResultTable : "iris.predict" ,
6996 Attributes : make (map [string ]interface {}),
70- TrainIR : & ir ,
97+ TrainIR : trainIR ,
7198 }
72- _ , err = Pred (& predIR )
73- a .NoError (err )
7499}
0 commit comments