Skip to content

Commit af50695

Browse files
authored
[Intermediate Representation] Generate AnalyzeIR (#997)
* Remove complaints: "GOLANG: composite literal uses unkeyed fields" * add explainer for AnalyzeIR * add test for analyzeIR * add ut for generateAnalyzeIR * remove TO before TRAIN/ANALYZE in ut
1 parent e60790f commit af50695

File tree

4 files changed

+116
-23
lines changed

4 files changed

+116
-23
lines changed

pkg/sql/codegen/intermediate_representation.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ type AnalyzeIR struct {
117117
// "select ... analyze ... with analyze.plot_type = "bar"",
118118
// the Attributes will be {"analyze.plot_type": "bar"}
119119
Attributes map[string]interface{}
120+
// Explainer types. For example TreeExplainer.
121+
Explainer string
120122
// TrainIR is the TrainIR used for generating the training job of the corresponding model
121-
TrainIR TrainIR
123+
TrainIR *TrainIR
122124
}

pkg/sql/codegen_alps.go

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra
188188

189189
// TODO(joyyoj) read feature mapping table's name from table attributes.
190190
// TODO(joyyoj) pr may contains partition.
191-
fmap := columns.FeatureMap{pr.tables[0] + "_feature_map", ""}
191+
fmap := columns.FeatureMap{
192+
Table: pr.tables[0] + "_feature_map",
193+
Partition: ""}
192194
var meta metadata
193195
fields := make([]string, 0)
194196
if db != nil {
@@ -670,22 +672,22 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
670672
shape[0] = len(fields)
671673
if userSpec, ok := refColumns[ct.Name()]; ok {
672674
output[ct.Name()] = &columns.ColumnSpec{
673-
ct.Name(),
674-
false,
675-
shape,
676-
userSpec.DType,
677-
userSpec.Delimiter,
678-
nil,
679-
*meta.featureMap}
675+
ColumnName: ct.Name(),
676+
IsSparse: false,
677+
Shape: shape,
678+
DType: userSpec.DType,
679+
Delimiter: userSpec.Delimiter,
680+
Vocabulary: nil,
681+
FeatureMap: *meta.featureMap}
680682
} else {
681683
output[ct.Name()] = &columns.ColumnSpec{
682-
ct.Name(),
683-
false,
684-
shape,
685-
"float",
686-
",",
687-
nil,
688-
*meta.featureMap}
684+
ColumnName: ct.Name(),
685+
IsSparse: false,
686+
Shape: shape,
687+
DType: "float",
688+
Delimiter: ",",
689+
Vocabulary: nil,
690+
FeatureMap: *meta.featureMap}
689691
}
690692
}
691693
}
@@ -732,7 +734,14 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, err
732734
column, present := output[*name]
733735
if !present {
734736
shape := make([]int, 0, 1000)
735-
column := &columns.ColumnSpec{*name, true, shape, "int64", "", nil, *meta.featureMap}
737+
column := &columns.ColumnSpec{
738+
ColumnName: *name,
739+
IsSparse: true,
740+
Shape: shape,
741+
DType: "int64",
742+
Delimiter: "",
743+
Vocabulary: nil,
744+
FeatureMap: *meta.featureMap}
736745
column.DType = "int64"
737746
output[*name] = column
738747
}

pkg/sql/ir_generator.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,27 @@ func generateTrainIR(slct *extendedSelect, connStr string) (*codegen.TrainIR, er
7272
}, nil
7373
}
7474

75-
func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) {
76-
attrMap, err := generateAttributeIR(&slct.predAttrs)
75+
func generateTrainIRByModel(slct *extendedSelect, connStr, cwd, modelDir, model string) (*codegen.TrainIR, error) {
76+
db, err := open(connStr)
7777
if err != nil {
7878
return nil, err
7979
}
80-
db, err := open(connStr)
80+
defer db.Close()
81+
82+
slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, model)
8183
if err != nil {
8284
return nil, err
8385
}
84-
slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, slct.model)
86+
return generateTrainIR(slctWithTrain, connStr)
87+
}
88+
89+
func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) {
90+
attrMap, err := generateAttributeIR(&slct.predAttrs)
8591
if err != nil {
8692
return nil, err
8793
}
88-
trainir, err := generateTrainIR(slctWithTrain, connStr)
94+
95+
trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir, slct.model)
8996
if err != nil {
9097
return nil, err
9198
}
@@ -94,7 +101,26 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi
94101
Select: slct.standardSelect.String(),
95102
ResultTable: slct.into,
96103
Attributes: attrMap,
97-
TrainIR: trainir,
104+
TrainIR: trainIR,
105+
}, nil
106+
}
107+
108+
func generateAnalyzeIR(slct *extendedSelect, connStr, cwd, modelDir string) (*codegen.AnalyzeIR, error) {
109+
attrs, err := generateAttributeIR(&slct.analyzeAttrs)
110+
if err != nil {
111+
return nil, err
112+
}
113+
114+
trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir, slct.trainedModel)
115+
if err != nil {
116+
return nil, err
117+
}
118+
return &codegen.AnalyzeIR{
119+
DataSource: connStr,
120+
Select: slct.standardSelect.String(),
121+
Attributes: attrs,
122+
Explainer: slct.explainer,
123+
TrainIR: trainIR,
98124
}, nil
99125
}
100126

pkg/sql/ir_generator_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,62 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil)
194194
a.Equal("sepal_length", nc.FieldMeta.Name)
195195
}
196196

197+
func TestGenerateAnalyzeIR(t *testing.T) {
198+
if getEnv("SQLFLOW_TEST_DB", "mysql") != "mysql" {
199+
t.Skip(fmt.Sprintf("%s: skip test", getEnv("SQLFLOW_TEST_DB", "mysql")))
200+
}
201+
a := assert.New(t)
202+
203+
modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models")
204+
a.Nil(e)
205+
defer os.RemoveAll(modelDir)
206+
stream := runExtendedSQL(`
207+
SELECT *
208+
FROM iris.train
209+
TRAIN xgboost.gbtree
210+
WITH
211+
objective="multi:softprob",
212+
train.num_boost_round = 30,
213+
eta = 3.1,
214+
num_class = 3
215+
COLUMN sepal_length, sepal_width, petal_length, petal_width
216+
LABEL class
217+
INTO sqlflow_models.my_xgboost_model;
218+
`, testDB, modelDir, nil)
219+
a.True(goodStream(stream.ReadAll()))
220+
221+
// Test generate AnalyzeIR
222+
cwd, e := ioutil.TempDir("/tmp", "sqlflow")
223+
a.Nil(e)
224+
defer os.RemoveAll(cwd)
225+
226+
pr, e := newParser().Parse(`
227+
SELECT *
228+
FROM iris.train
229+
ANALYZE sqlflow_models.my_xgboost_model
230+
WITH
231+
shap_summary.plot_type="bar",
232+
shap_summary.alpha=1,
233+
shap_summary.sort=True
234+
USING TreeExplainer;
235+
`)
236+
a.NoError(e)
237+
238+
connStr := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
239+
ir, e := generateAnalyzeIR(pr, connStr, cwd, modelDir)
240+
a.NoError(e)
241+
a.Equal(ir.DataSource, connStr)
242+
a.Equal(ir.Explainer, "TreeExplainer")
243+
a.Equal(len(ir.Attributes), 3)
244+
a.Equal(ir.Attributes["shap_summary.sort"], true)
245+
a.Equal(ir.Attributes["shap_summary.plot_type"], "bar")
246+
a.Equal(ir.Attributes["shap_summary.alpha"], 1)
247+
248+
nc, ok := ir.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn)
249+
a.True(ok)
250+
a.Equal("sepal_length", nc.FieldMeta.Name)
251+
}
252+
197253
func TestInferStringValue(t *testing.T) {
198254
a := assert.New(t)
199255
for _, s := range []string{"true", "TRUE", "True"} {

0 commit comments

Comments
 (0)