Skip to content

Commit b8d751a

Browse files
authored
[Intermediate Representation] Add e2e test for IR analyze (#1061)
* add ut for ir.analyze * fix type * fix test
1 parent 31d6cb1 commit b8d751a

File tree

5 files changed

+82
-24
lines changed

5 files changed

+82
-24
lines changed

cmd/sqlflowserver/main_test.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
326326
t.Run("CaseTrainRegression", CaseTrainRegression)
327327
t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression)
328328
t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression)
329+
t.Run("CaseAnalyzeXGBoostModel", CaseTrainAndAnalyzeXGBoostModel)
329330
}
330331

331332
func CaseTrainTextClassificationIR(t *testing.T) {
@@ -852,7 +853,7 @@ func CaseTrainALPSRemoteModel(t *testing.T) {
852853
FROM %s.sparse_column_test
853854
LIMIT 100
854855
TRAIN models.estimator.dnn_classifier.DNNClassifier
855-
WITH
856+
WITH
856857
model.n_classes = 2, model.hidden_units = [10, 20], train.batch_size = 10, engine.ps_num=0, engine.worker_num=0, engine.type=local,
857858
gitlab.project = "Alps/sqlflow-models",
858859
gitlab.source_root = python,
@@ -979,6 +980,51 @@ INTO sqlflow_models.my_xgb_regression_model;
979980
}
980981
}
981982

983+
// CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model,
984+
// then analyze it
985+
func CaseTrainAndAnalyzeXGBoostModel(t *testing.T) {
986+
a := assert.New(t)
987+
trainStmt := `
988+
SELECT *
989+
FROM housing.train
990+
TRAIN xgboost.gbtree
991+
WITH
992+
objective="reg:squarederror",
993+
train.num_boost_round = 30
994+
COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
995+
LABEL target
996+
INTO sqlflow_models.my_xgb_regression_model;
997+
`
998+
analyzeStmt := `
999+
SELECT *
1000+
FROM housing.train
1001+
ANALYZE sqlflow_models.my_xgb_regression_model
1002+
WITH
1003+
shap_summary.plot_type="bar",
1004+
shap_summary.alpha=1,
1005+
shap_summary.sort=True
1006+
USING TreeExplainer;
1007+
`
1008+
conn, err := createRPCConn()
1009+
a.NoError(err)
1010+
defer conn.Close()
1011+
cli := pb.NewSQLFlowClient(conn)
1012+
1013+
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
1014+
defer cancel()
1015+
1016+
stream, err := cli.Run(ctx, sqlRequest(trainStmt))
1017+
if err != nil {
1018+
a.Fail("Check if the server started successfully. %v", err)
1019+
}
1020+
ParseRow(stream)
1021+
stream, err = cli.Run(ctx, sqlRequest(analyzeStmt))
1022+
if err != nil {
1023+
a.Fail("Check if the server started successfully. %v", err)
1024+
}
1025+
ParseRow(stream)
1026+
}
1027+
9821028
func CasePredictXGBoostRegression(t *testing.T) {
9831029
a := assert.New(t)
9841030
predSQL := fmt.Sprintf(`SELECT *

pkg/sql/codegen/xgboost/codegen_analyze.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@ import (
2323
)
2424

2525
const (
26-
shapSummaryAttributes = "shap_summary"
26+
shapSummaryAttrPrefix = "shap_summary."
2727
)
2828

2929
// Analyze generates a Python program to analyze a trained model.
30-
func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) {
30+
func Analyze(ir *codegen.AnalyzeIR) (string, error) {
3131
if ir.Explainer != "TreeExplainer" {
3232
return "", fmt.Errorf("unsupported explainer %s", ir.Explainer)
3333
}
34-
summaryAttrs, err := resolveParams(ir.Attributes, shapSummaryAttributes)
34+
summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttrPrefix)
35+
jsonSummary, err := json.Marshal(summaryAttrs)
3536
if err != nil {
3637
return "", err
3738
}
@@ -47,10 +48,9 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) {
4748
fr := &analyzeFiller{
4849
DataSource: ir.DataSource,
4950
DatasetSQL: ir.Select,
50-
ShapSummaryParames: summaryAttrs,
51+
ShapSummaryParames: string(jsonSummary),
5152
FieldMetaJSON: string(fm),
5253
Label: y.Name,
53-
ModelFile: modelPath,
5454
}
5555
var analysis bytes.Buffer
5656
if err := analyzeTemplate.Execute(&analysis, fr); err != nil {
@@ -59,12 +59,12 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) {
5959
return analysis.String(), nil
6060
}
6161

62-
func resolveParams(attrs map[string]interface{}, group string) (map[string]interface{}, error) {
62+
func resolveParams(attrs map[string]interface{}, group string) map[string]interface{} {
6363
sp := make(map[string]interface{})
6464
for k, v := range attrs {
6565
if strings.HasPrefix(k, group) {
6666
sp[k[len(group):]] = v
6767
}
6868
}
69-
return sp, nil
69+
return sp
7070
}

pkg/sql/codegen/xgboost/codegen_analyze_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ func TestAnalyze(t *testing.T) {
3333
},
3434
TrainIR: tir,
3535
}
36-
_, err := Analyze(air, "")
36+
_, err := Analyze(air)
3737
a.NoError(err)
3838
}

pkg/sql/codegen/xgboost/template_analyze.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ import (
2020
type analyzeFiller struct {
2121
DataSource string
2222
DatasetSQL string
23-
ShapSummaryParames map[string]interface{}
23+
ShapSummaryParames string
2424
FieldMetaJSON string
2525
Label string
26-
ModelFile string
2726
}
2827

2928
const analyzeTemplateText = `
3029
import xgboost
3130
import shap
32-
import json
31+
import json
3332
import matplotlib.pyplot as plt
3433
import pandas as pd
3534
@@ -42,12 +41,8 @@ feature_column_name = sorted([k["name"] for k in feature_field_meta])
4241
feature_spec = {k['name']: k for k in feature_field_meta}
4342
conn = connect_with_data_source('''{{.DataSource}}''')
4443
label_name = "{{.Label}}"
45-
model_path = "{{.ModelFile}}"
4644
47-
summaryAttrs = {}
48-
{{ range $k, $v := .ShapSummaryParames }}
49-
summaryAttrs["{{$k}}"] = {{$v}}
50-
{{end}}
45+
summaryAttrs = json.loads('''{{.ShapSummaryParames}}''')
5146
5247
def analyzer_dataset():
5348
stream = db_generator(conn.driver, conn, """{{.DatasetSQL}}""", feature_column_name, label_name, feature_spec)
@@ -62,7 +57,7 @@ def analyzer_dataset():
6257
6358
X,y = analyzer_dataset()
6459
bst = xgboost.Booster()
65-
bst.load_model(fname=model_path)
60+
bst.load_model("my_model")
6661
explainer = shap.TreeExplainer(bst)
6762
shap_values = explainer.shap_values(X)
6863
shap.summary_plot(shap_values, X, show=False, **summaryAttrs)

pkg/sql/executor.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -557,14 +557,31 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
557557
}
558558

559559
func analyze(wr *PipeWriter, pr *extendedSelect, db *DB, cwd, modelDir string) error {
560-
program, err := genAnalyzer(pr, db, cwd, modelDir)
561-
if err != nil {
562-
return err
563-
}
564560
cmd := exec.Command("python", "-u")
565561
cmd.Dir = cwd
566-
cmd.Stdin = program
567-
if _, err = cmd.CombinedOutput(); err != nil {
562+
if enableIR() {
563+
ir, err := generateAnalyzeIR(pr, db.String(), cwd, modelDir)
564+
if err != nil {
565+
return err
566+
}
567+
if !strings.HasPrefix(strings.ToUpper(ir.TrainIR.Estimator), `XGBOOST.`) {
568+
return fmt.Errorf("unsupported model %s", ir.TrainIR.Estimator)
569+
}
570+
code, err := xgboost.Analyze(ir)
571+
if err != nil {
572+
return err
573+
}
574+
var program bytes.Buffer
575+
program.WriteString(code)
576+
cmd.Stdin = &program
577+
} else {
578+
prog, err := genAnalyzer(pr, db, cwd, modelDir)
579+
if err != nil {
580+
return err
581+
}
582+
cmd.Stdin = prog
583+
}
584+
if _, err := cmd.CombinedOutput(); err != nil {
568585
return err
569586
}
570587

0 commit comments

Comments
 (0)