Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion cmd/sqlflowserver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
t.Run("CaseTrainRegression", CaseTrainRegression)
t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression)
t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression)
t.Run("CaseAnalyzeXGBoostModel", CaseTrainAndAnalyzeXGBoostModel)
}

func TestEnd2EndHive(t *testing.T) {
Expand Down Expand Up @@ -912,7 +913,7 @@ func CaseTrainALPSRemoteModel(t *testing.T) {
FROM %s.sparse_column_test
LIMIT 100
TRAIN models.estimator.dnn_classifier.DNNClassifier
WITH
WITH
model.n_classes = 2, model.hidden_units = [10, 20], train.batch_size = 10, engine.ps_num=0, engine.worker_num=0, engine.type=local,
gitlab.project = "Alps/sqlflow-models",
gitlab.source_root = python,
Expand Down Expand Up @@ -1088,6 +1089,51 @@ INTO sqlflow_models.my_xgb_regression_model;
ParseRow(stream)
}

// CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model,
// then analyze it
func CaseTrainAndAnalyzeXGBoostModel(t *testing.T) {
a := assert.New(t)
trainStmt := `
SELECT *
FROM housing.train
TRAIN xgboost.gbtree
WITH
objective="reg:squarederror",
train.num_boost_round = 30
COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
LABEL target
INTO sqlflow_models.my_xgb_regression_model;
`
analyzeStmt := `
SELECT *
FROM housing.train
ANALYZE sqlflow_models.my_xgb_regression_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that go test may run test cases unordered, the model may be not generated when running this test, it's better to generate this model first before running ANALYZE.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ref introduces a go test run parallel with other tests if and only if the t.Parallel() was called.

Copy link
Collaborator

@typhoonzero typhoonzero Oct 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not about parallelism, it's about the order, see https://stackoverflow.com/questions/31201858/how-to-run-golang-tests-sequentially, it says:

You can't / shouldn't rely on test execution order. The order in which tests are executed is not defined, and with the use of testing flags it is possible to exclude tests from running, so you have no guarantee that they will run at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our e2e test, the subtests are ordered by group

t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression)
t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression)
t.Run("CaseAnalyzeXGBoostModel", CaseTrainAndAnalyzeXGBoostModel)

But we would run go test -run CaseTrainAndAnalyzeXGBoostModel independently some times, it needs an existed model, so I add the several codes for that.

WITH
shap_summary.plot_type="bar",
shap_summary.alpha=1,
shap_summary.sort=True
USING TreeExplainer;
`
conn, err := createRPCConn()
a.NoError(err)
defer conn.Close()
cli := pb.NewSQLFlowClient(conn)

ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
defer cancel()

stream, err := cli.Run(ctx, sqlRequest(trainStmt))
if err != nil {
a.Fail("Check if the server started successfully. %v", err)
}
ParseRow(stream)
stream, err = cli.Run(ctx, sqlRequest(analyzeStmt))
if err != nil {
a.Fail("Check if the server started successfully. %v", err)
}
ParseRow(stream)
}

func CasePredictXGBoostRegression(t *testing.T) {
a := assert.New(t)
conn, err := createRPCConn()
Expand Down
14 changes: 7 additions & 7 deletions pkg/sql/codegen/xgboost/codegen_analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ import (
)

const (
shapSummaryAttributes = "shap_summary"
shapSummaryAttrPrefix = "shap_summary."
)

// Analyze generates a Python program to analyze a trained model.
func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) {
func Analyze(ir *codegen.AnalyzeIR) (string, error) {
if ir.Explainer != "TreeExplainer" {
return "", fmt.Errorf("unsupported explainer %s", ir.Explainer)
}
summaryAttrs, err := resolveParams(ir.Attributes, shapSummaryAttributes)
summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttrPrefix)
jsonSummary, err := json.Marshal(summaryAttrs)
if err != nil {
return "", err
}
Expand All @@ -47,10 +48,9 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) {
fr := &analyzeFiller{
DataSource: ir.DataSource,
DatasetSQL: ir.Select,
ShapSummaryParames: summaryAttrs,
ShapSummaryParames: string(jsonSummary),
FieldMetaJSON: string(fm),
Label: y.Name,
ModelFile: modelPath,
}
var analysis bytes.Buffer
if err := analyzeTemplate.Execute(&analysis, fr); err != nil {
Expand All @@ -59,12 +59,12 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) {
return analysis.String(), nil
}

func resolveParams(attrs map[string]interface{}, group string) (map[string]interface{}, error) {
func resolveParams(attrs map[string]interface{}, group string) map[string]interface{} {
sp := make(map[string]interface{})
for k, v := range attrs {
if strings.HasPrefix(k, group) {
sp[k[len(group):]] = v
}
}
return sp, nil
return sp
}
2 changes: 1 addition & 1 deletion pkg/sql/codegen/xgboost/codegen_analyze_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ func TestAnalyze(t *testing.T) {
},
TrainIR: tir,
}
_, err := Analyze(air, "")
_, err := Analyze(air)
a.NoError(err)
}
13 changes: 4 additions & 9 deletions pkg/sql/codegen/xgboost/template_analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ import (
type analyzeFiller struct {
DataSource string
DatasetSQL string
ShapSummaryParames map[string]interface{}
ShapSummaryParames string
FieldMetaJSON string
Label string
ModelFile string
}

const analyzeTemplateText = `
import xgboost
import shap
import json
import json
import matplotlib.pyplot as plt
import pandas as pd

Expand All @@ -42,12 +41,8 @@ feature_column_name = sorted([k["name"] for k in feature_field_meta])
feature_spec = {k['name']: k for k in feature_field_meta}
conn = connect_with_data_source('''{{.DataSource}}''')
label_name = "{{.Label}}"
model_path = "{{.ModelFile}}"

summaryAttrs = {}
{{ range $k, $v := .ShapSummaryParames }}
summaryAttrs["{{$k}}"] = {{$v}}
{{end}}
summaryAttrs = json.loads('''{{.ShapSummaryParames}}''')

def analyzer_dataset():
stream = db_generator(conn.driver, conn, """{{.DatasetSQL}}""", feature_column_name, label_name, feature_spec)
Expand All @@ -62,7 +57,7 @@ def analyzer_dataset():

X,y = analyzer_dataset()
bst = xgboost.Booster()
bst.load_model(fname=model_path)
bst.load_model("my_model")
explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(X)
shap.summary_plot(shap_values, X, show=False, **summaryAttrs)
Expand Down
29 changes: 23 additions & 6 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,31 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
}

func analyze(wr *PipeWriter, pr *extendedSelect, db *DB, cwd, modelDir string) error {
program, err := genAnalyzer(pr, db, cwd, modelDir)
if err != nil {
return err
}
cmd := exec.Command("python", "-u")
cmd.Dir = cwd
cmd.Stdin = program
if _, err = cmd.CombinedOutput(); err != nil {
if enableIR() {
ir, err := generateAnalyzeIR(pr, db.String(), cwd, modelDir)
if err != nil {
return err
}
if !strings.HasPrefix(strings.ToUpper(ir.TrainIR.Estimator), `XGBOOST.`) {
return fmt.Errorf("unsupported model %s", ir.TrainIR.Estimator)
}
code, err := xgboost.Analyze(ir)
if err != nil {
return err
}
var program bytes.Buffer
program.WriteString(code)
cmd.Stdin = &program
} else {
prog, err := genAnalyzer(pr, db, cwd, modelDir)
if err != nil {
return err
}
cmd.Stdin = prog
}
if _, err := cmd.CombinedOutput(); err != nil {
return err
}

Expand Down