diff --git a/pkg/sql/codegen/intermediate_representation.go b/pkg/sql/codegen/intermediate_representation.go index 9a481d07bb..a33c493637 100644 --- a/pkg/sql/codegen/intermediate_representation.go +++ b/pkg/sql/codegen/intermediate_representation.go @@ -117,6 +117,8 @@ type AnalyzeIR struct { // "select ... analyze ... with analyze.plot_type = "bar"", // the Attributes will be {"analyze.plot_type": "bar"} Attributes map[string]interface{} + // Explainer types. For example TreeExplainer. + Explainer string // TrainIR is the TrainIR used for generating the training job of the corresponding model - TrainIR TrainIR + TrainIR *TrainIR } diff --git a/pkg/sql/codegen_alps.go b/pkg/sql/codegen_alps.go index 237c4b067a..4979dae8a9 100644 --- a/pkg/sql/codegen_alps.go +++ b/pkg/sql/codegen_alps.go @@ -188,7 +188,9 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra // TODO(joyyoj) read feature mapping table's name from table attributes. // TODO(joyyoj) pr may contains partition. - fmap := columns.FeatureMap{pr.tables[0] + "_feature_map", ""} + fmap := columns.FeatureMap{ + Table: pr.tables[0] + "_feature_map", + Partition: ""} var meta metadata fields := make([]string, 0) if db != nil { @@ -670,22 +672,22 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c shape[0] = len(fields) if userSpec, ok := refColumns[ct.Name()]; ok { output[ct.Name()] = &columns.ColumnSpec{ - ct.Name(), - false, - shape, - userSpec.DType, - userSpec.Delimiter, - nil, - *meta.featureMap} + ColumnName: ct.Name(), + IsSparse: false, + Shape: shape, + DType: userSpec.DType, + Delimiter: userSpec.Delimiter, + Vocabulary: nil, + FeatureMap: *meta.featureMap} } else { output[ct.Name()] = &columns.ColumnSpec{ - ct.Name(), - false, - shape, - "float", - ",", - nil, - *meta.featureMap} + ColumnName: ct.Name(), + IsSparse: false, + Shape: shape, + DType: "float", + Delimiter: ",", + Vocabulary: nil, + FeatureMap: *meta.featureMap} } } } @@ -732,7 +734,14 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, err column, present := output[*name] if !present { shape := make([]int, 0, 1000) - column := &columns.ColumnSpec{*name, true, shape, "int64", "", nil, *meta.featureMap} + column := &columns.ColumnSpec{ + ColumnName: *name, + IsSparse: true, + Shape: shape, + DType: "int64", + Delimiter: "", + Vocabulary: nil, + FeatureMap: *meta.featureMap} column.DType = "int64" output[*name] = column } diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index 5bfa22db01..25e1ff2394 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -72,20 +72,27 @@ func generateTrainIR(slct *extendedSelect, connStr string) (*codegen.TrainIR, er }, nil } -func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) { - attrMap, err := generateAttributeIR(&slct.predAttrs) +func generateTrainIRByModel(slct *extendedSelect, connStr, cwd, modelDir, model string) (*codegen.TrainIR, error) { + db, err := open(connStr) if err != nil { return nil, err } - db, err := open(connStr) + defer db.Close() + + slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, model) if err != nil { return nil, err } - slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, slct.model) + return generateTrainIR(slctWithTrain, connStr) +} + +func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) { + attrMap, err := generateAttributeIR(&slct.predAttrs) if err != nil { return nil, err } - trainir, err := generateTrainIR(slctWithTrain, connStr) + + trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir, slct.model) if err != nil { return nil, err } @@ -94,7 +101,26 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi Select: slct.standardSelect.String(), ResultTable: slct.into, Attributes: attrMap, - TrainIR: trainir, + TrainIR: trainIR, + }, nil +} + +func generateAnalyzeIR(slct *extendedSelect, connStr, cwd, modelDir string) (*codegen.AnalyzeIR, error) { + attrs, err := generateAttributeIR(&slct.analyzeAttrs) + if err != nil { + return nil, err + } + + trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir, slct.trainedModel) + if err != nil { + return nil, err + } + return &codegen.AnalyzeIR{ + DataSource: connStr, + Select: slct.standardSelect.String(), + Attributes: attrs, + Explainer: slct.explainer, + TrainIR: trainIR, }, nil } diff --git a/pkg/sql/ir_generator_test.go b/pkg/sql/ir_generator_test.go index bb6b5b5ee1..9e2517ba39 100644 --- a/pkg/sql/ir_generator_test.go +++ b/pkg/sql/ir_generator_test.go @@ -194,6 +194,62 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil) a.Equal("sepal_length", nc.FieldMeta.Name) } +func TestGenerateAnalyzeIR(t *testing.T) { + if getEnv("SQLFLOW_TEST_DB", "mysql") != "mysql" { + t.Skip(fmt.Sprintf("%s: skip test", getEnv("SQLFLOW_TEST_DB", "mysql"))) + } + a := assert.New(t) + + modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models") + a.Nil(e) + defer os.RemoveAll(modelDir) + stream := runExtendedSQL(` + SELECT * + FROM iris.train + TRAIN xgboost.gbtree + WITH + objective="multi:softprob", + train.num_boost_round = 30, + eta = 3.1, + num_class = 3 + COLUMN sepal_length, sepal_width, petal_length, petal_width + LABEL class + INTO sqlflow_models.my_xgboost_model; + `, testDB, modelDir, nil) + a.True(goodStream(stream.ReadAll())) + + // Test generate AnalyzeIR + cwd, e := ioutil.TempDir("/tmp", "sqlflow") + a.Nil(e) + defer os.RemoveAll(cwd) + + pr, e := newParser().Parse(` + SELECT * + FROM iris.train + ANALYZE sqlflow_models.my_xgboost_model + WITH + shap_summary.plot_type="bar", + shap_summary.alpha=1, + shap_summary.sort=True + USING TreeExplainer; + `) + a.NoError(e) + + connStr := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0" + ir, e := generateAnalyzeIR(pr, connStr, cwd, modelDir) + a.NoError(e) + a.Equal(ir.DataSource, connStr) + a.Equal(ir.Explainer, "TreeExplainer") + a.Equal(len(ir.Attributes), 3) + a.Equal(ir.Attributes["shap_summary.sort"], true) + a.Equal(ir.Attributes["shap_summary.plot_type"], "bar") + a.Equal(ir.Attributes["shap_summary.alpha"], 1) + + nc, ok := ir.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn) + a.True(ok) + a.Equal("sepal_length", nc.FieldMeta.Name) +} + func TestInferStringValue(t *testing.T) { a := assert.New(t) for _, s := range []string{"true", "TRUE", "True"} {