Skip to content

Commit f91555e

Browse files
committed
add test for analyzeIR
1 parent f6f37f1 commit f91555e

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

pkg/sql/ir_generator.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ func generateTrainIR(slct *extendedSelect, connStr string) (*codegen.TrainIR, er
7272
}, nil
7373
}
7474

75+
func generateAnalyzeIR(slct *extendedSelect, connStr string) (*codegen.AnalyzeIR, error) {
76+
attrs, err := generateAttributeIR(&slct.analyzeAttrs)
77+
if err != nil {
78+
return nil, err
79+
}
80+
return &codegen.AnalyzeIR{
81+
DataSource: connStr,
82+
Select: slct.standardSelect.String(),
83+
Attributes: attrs,
84+
Explainer: slct.explainer,
85+
// TrainIR is the TrainIR used for generating the training job of the corresponding model
86+
// TrainIR TrainIR
87+
}, nil
88+
}
89+
7590
func generateAttributeIR(attrs *attrs) (map[string]interface{}, error) {
7691
ret := make(map[string]interface{})
7792
for k, v := range *attrs {

pkg/sql/ir_generator_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,25 @@ INTO mymodel;`
148148
a.Equal(10000, catCol.FieldMeta.Shape[0])
149149
a.Equal(",", catCol.FieldMeta.Delimiter)
150150
}
151+
152+
func TestGenerateAnalyzeIR(t *testing.T) {
153+
a := assert.New(t)
154+
stmt := `
155+
SELECT *
156+
FROM iris.train
157+
ANALYZE sqlflow_models.my_xgboost_model
158+
WITH
159+
shap_summary.plot_type="bar",
160+
shap_summary.alpha=1,
161+
shap_summary.sort=True
162+
USING TreeExplainer;
163+
`
164+
pr, e := newParser().Parse(stmt)
165+
a.NoError(e)
166+
167+
connStr := "mysql://root:root@tcp(localhost)"
168+
ir, e := generateAnalyzeIR(pr, connStr)
169+
a.NoError(e)
170+
a.Equal(ir.Explainer, "TreeExplainer")
171+
a.Equal(ir.DataSource, connStr)
172+
}

0 commit comments

Comments
 (0)