Skip to content

Commit 3e8324c

Browse files
committed
add ut for generateAnalyzeIR
1 parent 1f8dad7 commit 3e8324c

File tree

3 files changed

+61
-16
lines changed

3 files changed

+61
-16
lines changed

pkg/sql/codegen/intermediate_representation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,5 @@ type AnalyzeIR struct {
120120
// SQLFlow supports TreeExplainer so far.
121121
Explainer string
122122
// TrainIR is the TrainIR used for generating the training job of the corresponding model
123-
TrainIR TrainIR
123+
TrainIR *TrainIR
124124
}

pkg/sql/ir_generator.go

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,27 @@ func generateTrainIR(slct *extendedSelect, connStr string) (*codegen.TrainIR, er
7272
}, nil
7373
}
7474

75-
func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) {
76-
attrMap, err := generateAttributeIR(&slct.predAttrs)
75+
func generateTrainIRByModel(slct *extendedSelect, connStr, cwd, modelDir string) (*codegen.TrainIR, error) {
76+
db, err := open(connStr)
7777
if err != nil {
7878
return nil, err
7979
}
80-
db, err := open(connStr)
80+
defer db.Close()
81+
82+
slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, slct.trainedModel)
8183
if err != nil {
8284
return nil, err
8385
}
84-
slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, slct.model)
86+
return generateTrainIR(slctWithTrain, connStr)
87+
}
88+
89+
func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) {
90+
attrMap, err := generateAttributeIR(&slct.predAttrs)
8591
if err != nil {
8692
return nil, err
8793
}
88-
trainir, err := generateTrainIR(slctWithTrain, connStr)
94+
95+
trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir)
8996
if err != nil {
9097
return nil, err
9198
}
@@ -94,22 +101,26 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi
94101
Select: slct.standardSelect.String(),
95102
ResultTable: slct.into,
96103
Attributes: attrMap,
97-
TrainIR: trainir,
104+
TrainIR: trainIR,
98105
}, nil
99106
}
100107

101-
func generateAnalyzeIR(slct *extendedSelect, connStr string) (*codegen.AnalyzeIR, error) {
108+
func generateAnalyzeIR(slct *extendedSelect, connStr, cwd, modelDir string) (*codegen.AnalyzeIR, error) {
102109
attrs, err := generateAttributeIR(&slct.analyzeAttrs)
103110
if err != nil {
104111
return nil, err
105112
}
113+
114+
trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir)
115+
if err != nil {
116+
return nil, err
117+
}
106118
return &codegen.AnalyzeIR{
107119
DataSource: connStr,
108120
Select: slct.standardSelect.String(),
109121
Attributes: attrs,
110122
Explainer: slct.explainer,
111-
// TrainIR is the TrainIR used for generating the training job of the corresponding model
112-
// TrainIR TrainIR
123+
TrainIR: trainIR,
113124
}, nil
114125
}
115126

pkg/sql/ir_generator_test.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,35 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil)
195195
}
196196

197197
func TestGenerateAnalyzeIR(t *testing.T) {
198+
if getEnv("SQLFLOW_TEST_DB", "mysql") != "mysql" {
199+
t.Skip(fmt.Sprintf("%s: skip test", getEnv("SQLFLOW_TEST_DB", "mysql")))
200+
}
198201
a := assert.New(t)
199-
stmt := `
202+
203+
modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models")
204+
a.Nil(e)
205+
defer os.RemoveAll(modelDir)
206+
stream := runExtendedSQL(`
207+
SELECT *
208+
FROM iris.train
209+
TO TRAIN xgboost.gbtree
210+
WITH
211+
objective="multi:softprob",
212+
train.num_boost_round = 30,
213+
eta = 3.1,
214+
num_class = 3
215+
COLUMN sepal_length, sepal_width, petal_length, petal_width
216+
LABEL class
217+
INTO sqlflow_models.my_xgboost_model;
218+
`, testDB, modelDir, nil)
219+
a.True(goodStream(stream.ReadAll()))
220+
221+
// Test generate PredicrIR
222+
cwd, e := ioutil.TempDir("/tmp", "sqlflow")
223+
a.Nil(e)
224+
defer os.RemoveAll(cwd)
225+
226+
pr, e := newParser().Parse(`
200227
SELECT *
201228
FROM iris.train
202229
ANALYZE sqlflow_models.my_xgboost_model
@@ -205,13 +232,20 @@ func TestGenerateAnalyzeIR(t *testing.T) {
205232
shap_summary.alpha=1,
206233
shap_summary.sort=True
207234
USING TreeExplainer;
208-
`
209-
pr, e := newParser().Parse(stmt)
235+
`)
210236
a.NoError(e)
211237

212-
connStr := "mysql://root:root@tcp(localhost)"
213-
ir, e := generateAnalyzeIR(pr, connStr)
238+
connStr := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
239+
ir, e := generateAnalyzeIR(pr, connStr, cwd, modelDir)
214240
a.NoError(e)
215-
a.Equal(ir.Explainer, "TreeExplainer")
216241
a.Equal(ir.DataSource, connStr)
242+
a.Equal(ir.Explainer, "TreeExplainer")
243+
a.Equal(len(ir.Attributes), 3)
244+
a.Equal(ir.Attributes["shap_summary.sort"], "True")
245+
a.Equal(ir.Attributes["shap_summary.plot_type"], "bar")
246+
a.Equal(ir.Attributes["shap_summary.alpha"], 1)
247+
248+
nc, ok := ir.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn)
249+
a.True(ok)
250+
a.Equal("sepal_length", nc.FieldMeta.Name)
217251
}

0 commit comments

Comments
 (0)