Skip to content
4 changes: 3 additions & 1 deletion pkg/sql/codegen/intermediate_representation.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ type AnalyzeIR struct {
// "select ... analyze ... with analyze.plot_type = "bar"",
// the Attributes will be {"analyze.plot_type": "bar"}
Attributes map[string]interface{}
// Explainer types. For example TreeExplainer.
Explainer string
// TrainIR is the TrainIR used for generating the training job of the corresponding model
TrainIR TrainIR
TrainIR *TrainIR
}
41 changes: 25 additions & 16 deletions pkg/sql/codegen_alps.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra

// TODO(joyyoj) read feature mapping table's name from table attributes.
// TODO(joyyoj) pr may contains partition.
fmap := columns.FeatureMap{pr.tables[0] + "_feature_map", ""}
fmap := columns.FeatureMap{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we need to change these files. Have you used a code formatter?

Table: pr.tables[0] + "_feature_map",
Partition: ""}
var meta metadata
fields := make([]string, 0)
if db != nil {
Expand Down Expand Up @@ -670,22 +672,22 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
shape[0] = len(fields)
if userSpec, ok := refColumns[ct.Name()]; ok {
output[ct.Name()] = &columns.ColumnSpec{
ct.Name(),
false,
shape,
userSpec.DType,
userSpec.Delimiter,
nil,
*meta.featureMap}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a linter complaining we need to explicitly name the field name? If so, please file the change in a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The go vet complains: [govet] xxx(struct) composite literal uses unkeyed fields [E] while I'm editing in another file.
I found that there was not much code to be changed, so I changed them conveniently.

ColumnName: ct.Name(),
IsSparse: false,
Shape: shape,
DType: userSpec.DType,
Delimiter: userSpec.Delimiter,
Vocabulary: nil,
FeatureMap: *meta.featureMap}
} else {
output[ct.Name()] = &columns.ColumnSpec{
ct.Name(),
false,
shape,
"float",
",",
nil,
*meta.featureMap}
ColumnName: ct.Name(),
IsSparse: false,
Shape: shape,
DType: "float",
Delimiter: ",",
Vocabulary: nil,
FeatureMap: *meta.featureMap}
}
}
}
Expand Down Expand Up @@ -732,7 +734,14 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, err
column, present := output[*name]
if !present {
shape := make([]int, 0, 1000)
column := &columns.ColumnSpec{*name, true, shape, "int64", "", nil, *meta.featureMap}
column := &columns.ColumnSpec{
ColumnName: *name,
IsSparse: true,
Shape: shape,
DType: "int64",
Delimiter: "",
Vocabulary: nil,
FeatureMap: *meta.featureMap}
column.DType = "int64"
output[*name] = column
}
Expand Down
38 changes: 32 additions & 6 deletions pkg/sql/ir_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,27 @@ func generateTrainIR(slct *extendedSelect, connStr string) (*codegen.TrainIR, er
}, nil
}

func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) {
attrMap, err := generateAttributeIR(&slct.predAttrs)
func generateTrainIRByModel(slct *extendedSelect, connStr, cwd, modelDir, model string) (*codegen.TrainIR, error) {
db, err := open(connStr)
if err != nil {
return nil, err
}
db, err := open(connStr)
defer db.Close()

slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, model)
if err != nil {
return nil, err
}
slctWithTrain, _, err := loadModelMeta(slct, db, cwd, modelDir, slct.model)
return generateTrainIR(slctWithTrain, connStr)
}

func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDir string) (*codegen.PredictIR, error) {
attrMap, err := generateAttributeIR(&slct.predAttrs)
if err != nil {
return nil, err
}
trainir, err := generateTrainIR(slctWithTrain, connStr)

trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir, slct.model)
if err != nil {
return nil, err
}
Expand All @@ -94,7 +101,26 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi
Select: slct.standardSelect.String(),
ResultTable: slct.into,
Attributes: attrMap,
TrainIR: trainir,
TrainIR: trainIR,
}, nil
}

func generateAnalyzeIR(slct *extendedSelect, connStr, cwd, modelDir string) (*codegen.AnalyzeIR, error) {
attrs, err := generateAttributeIR(&slct.analyzeAttrs)
if err != nil {
return nil, err
}

trainIR, err := generateTrainIRByModel(slct, connStr, cwd, modelDir, slct.trainedModel)
if err != nil {
return nil, err
}
return &codegen.AnalyzeIR{
DataSource: connStr,
Select: slct.standardSelect.String(),
Attributes: attrs,
Explainer: slct.explainer,
TrainIR: trainIR,
}, nil
}

Expand Down
56 changes: 56 additions & 0 deletions pkg/sql/ir_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,62 @@ INTO sqlflow_models.mymodel;`, testDB, modelDir, nil)
a.Equal("sepal_length", nc.FieldMeta.Name)
}

func TestGenerateAnalyzeIR(t *testing.T) {
if getEnv("SQLFLOW_TEST_DB", "mysql") != "mysql" {
t.Skip(fmt.Sprintf("%s: skip test", getEnv("SQLFLOW_TEST_DB", "mysql")))
}
a := assert.New(t)

modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models")
a.Nil(e)
defer os.RemoveAll(modelDir)
stream := runExtendedSQL(`
SELECT *
FROM iris.train
TRAIN xgboost.gbtree
WITH
objective="multi:softprob",
train.num_boost_round = 30,
eta = 3.1,
num_class = 3
COLUMN sepal_length, sepal_width, petal_length, petal_width
LABEL class
INTO sqlflow_models.my_xgboost_model;
`, testDB, modelDir, nil)
a.True(goodStream(stream.ReadAll()))

// Test generate AnalyzeIR
cwd, e := ioutil.TempDir("/tmp", "sqlflow")
a.Nil(e)
defer os.RemoveAll(cwd)

pr, e := newParser().Parse(`
SELECT *
FROM iris.train
ANALYZE sqlflow_models.my_xgboost_model
WITH
shap_summary.plot_type="bar",
shap_summary.alpha=1,
shap_summary.sort=True
USING TreeExplainer;
`)
a.NoError(e)

connStr := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
ir, e := generateAnalyzeIR(pr, connStr, cwd, modelDir)
a.NoError(e)
a.Equal(ir.DataSource, connStr)
a.Equal(ir.Explainer, "TreeExplainer")
a.Equal(len(ir.Attributes), 3)
a.Equal(ir.Attributes["shap_summary.sort"], true)
a.Equal(ir.Attributes["shap_summary.plot_type"], "bar")
a.Equal(ir.Attributes["shap_summary.alpha"], 1)

nc, ok := ir.TrainIR.Features["feature_columns"][0].(*codegen.NumericColumn)
a.True(ok)
a.Equal("sepal_length", nc.FieldMeta.Name)
}

func TestInferStringValue(t *testing.T) {
a := assert.New(t)
for _, s := range []string{"true", "TRUE", "True"} {
Expand Down