From d0ab13b4dd29cbd6c1f0989979fcf88d7744321c Mon Sep 17 00:00:00 2001 From: w7u Date: Wed, 23 Oct 2019 16:26:06 +0800 Subject: [PATCH 1/4] add ut for ir.analyze --- cmd/sqlflowserver/main_test.go | 31 ++++++++++++++++++++- pkg/sql/codegen/xgboost/codegen_analyze.go | 5 ++-- pkg/sql/codegen/xgboost/template_analyze.go | 4 +-- pkg/sql/executor.go | 30 ++++++++++++++++---- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 5b1cee3210..c254e74203 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -307,6 +307,7 @@ func TestEnd2EndMySQLIR(t *testing.T) { t.Run("CaseTrainRegression", CaseTrainRegression) t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression) t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression) + t.Run("CaseAnalyzeXGBoostModel", CaseAnalyzeXGBoostModel) } func TestEnd2EndHive(t *testing.T) { @@ -912,7 +913,7 @@ func CaseTrainALPSRemoteModel(t *testing.T) { FROM %s.sparse_column_test LIMIT 100 TRAIN models.estimator.dnn_classifier.DNNClassifier -WITH +WITH model.n_classes = 2, model.hidden_units = [10, 20], train.batch_size = 10, engine.ps_num=0, engine.worker_num=0, engine.type=local, gitlab.project = "Alps/sqlflow-models", gitlab.source_root = python, @@ -1088,6 +1089,34 @@ INTO sqlflow_models.my_xgb_regression_model; ParseRow(stream) } +// CaseAnalyzeXGBoostModel is used to test analyze a xgboost model +func CaseAnalyzeXGBoostModel(t *testing.T) { + a := assert.New(t) + stmt := ` +SELECT * +FROM housing.train +ANALYZE sqlflow_models.my_xgb_regression_model +WITH + shap_summary.plot_type="bar", + shap_summary.alpha=1, + shap_summary.sort=True +USING TreeExplainer; + ` + conn, err := createRPCConn() + a.NoError(err) + defer conn.Close() + cli := pb.NewSQLFlowClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + defer cancel() + + stream, err := cli.Run(ctx, sqlRequest(stmt)) + if err != nil { + a.Fail("Check if the server started successfully. %v", err) + } + ParseRow(stream) +} + func CasePredictXGBoostRegression(t *testing.T) { a := assert.New(t) conn, err := createRPCConn() diff --git a/pkg/sql/codegen/xgboost/codegen_analyze.go b/pkg/sql/codegen/xgboost/codegen_analyze.go index b1b14ae296..a0664b225f 100644 --- a/pkg/sql/codegen/xgboost/codegen_analyze.go +++ b/pkg/sql/codegen/xgboost/codegen_analyze.go @@ -23,11 +23,11 @@ import ( ) const ( - shapSummaryAttributes = "shap_summary" + shapSummaryAttributes = "shap_summary." ) // Analyze generates a Python program to analyze a trained model. -func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) { +func Analyze(ir *codegen.AnalyzeIR) (string, error) { if ir.Explainer != "TreeExplainer" { return "", fmt.Errorf("unsupported explainer %s", ir.Explainer) } @@ -50,7 +50,6 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) { ShapSummaryParames: summaryAttrs, FieldMetaJSON: string(fm), Label: y.Name, - ModelFile: modelPath, } var analysis bytes.Buffer if err := analyzeTemplate.Execute(&analysis, fr); err != nil { diff --git a/pkg/sql/codegen/xgboost/template_analyze.go b/pkg/sql/codegen/xgboost/template_analyze.go index 8b1a261af2..51ed414978 100644 --- a/pkg/sql/codegen/xgboost/template_analyze.go +++ b/pkg/sql/codegen/xgboost/template_analyze.go @@ -23,7 +23,6 @@ type analyzeFiller struct { ShapSummaryParames map[string]interface{} FieldMetaJSON string Label string - ModelFile string } const analyzeTemplateText = ` @@ -42,7 +41,6 @@ feature_column_name = sorted([k["name"] for k in feature_field_meta]) feature_spec = {k['name']: k for k in feature_field_meta} conn = connect_with_data_source('''{{.DataSource}}''') label_name = "{{.Label}}" -model_path = "{{.ModelFile}}" summaryAttrs = {} {{ range $k, $v := .ShapSummaryParames }} @@ -62,7 +60,7 @@ def analyzer_dataset(): X,y = analyzer_dataset() bst = xgboost.Booster() -bst.load_model(fname=model_path) +bst.load_model("my_model") explainer = shap.TreeExplainer(bst) shap_values = explainer.shap_values(X) shap.summary_plot(shap_values, X, show=False, **summaryAttrs) diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index 7de5033c16..7ec14e62e6 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -561,14 +561,32 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin } func analyze(wr *PipeWriter, pr *extendedSelect, db *DB, cwd, modelDir string) error { - program, err := genAnalyzer(pr, db, cwd, modelDir) - if err != nil { - return err - } cmd := exec.Command("python", "-u") cmd.Dir = cwd - cmd.Stdin = program - if _, err = cmd.CombinedOutput(); err != nil { + if enableIR() { + ir, err := generateAnalyzeIR(pr, db.String(), cwd, modelDir) + if err != nil { + return err + } + if !strings.HasPrefix(strings.ToUpper(ir.TrainIR.Estimator), `XGBOOST.`) { + return fmt.Errorf("unsupported model%s", ir.TrainIR.Estimator) + } + code, err := xgboost.Analyze(ir) + if err != nil { + return err + } + fmt.Fprintln(os.Stdout, code) + var program bytes.Buffer + program.WriteString(code) + cmd.Stdin = &program + } else { + prog, err := genAnalyzer(pr, db, cwd, modelDir) + if err != nil { + return err + } + cmd.Stdin = prog + } + if _, err := cmd.CombinedOutput(); err != nil { return err } From 406c768a194344d973b08d5f081b746912944339 Mon Sep 17 00:00:00 2001 From: w7u Date: Wed, 23 Oct 2019 17:22:26 +0800 Subject: [PATCH 2/4] fix type --- pkg/sql/codegen/xgboost/codegen_analyze.go | 9 +++++---- pkg/sql/codegen/xgboost/template_analyze.go | 9 +++------ pkg/sql/executor.go | 3 +-- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pkg/sql/codegen/xgboost/codegen_analyze.go b/pkg/sql/codegen/xgboost/codegen_analyze.go index a0664b225f..c627afc2c9 100644 --- a/pkg/sql/codegen/xgboost/codegen_analyze.go +++ b/pkg/sql/codegen/xgboost/codegen_analyze.go @@ -31,7 +31,8 @@ func Analyze(ir *codegen.AnalyzeIR) (string, error) { if ir.Explainer != "TreeExplainer" { return "", fmt.Errorf("unsupported explainer %s", ir.Explainer) } - summaryAttrs, err := resolveParams(ir.Attributes, shapSummaryAttributes) + summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttributes) + jsonSummary, err := json.Marshal(summaryAttrs) if err != nil { return "", err } @@ -47,7 +48,7 @@ func Analyze(ir *codegen.AnalyzeIR) (string, error) { fr := &analyzeFiller{ DataSource: ir.DataSource, DatasetSQL: ir.Select, - ShapSummaryParames: summaryAttrs, + ShapSummaryParames: string(jsonSummary), FieldMetaJSON: string(fm), Label: y.Name, } @@ -58,12 +59,12 @@ func Analyze(ir *codegen.AnalyzeIR) (string, error) { return analysis.String(), nil } -func resolveParams(attrs map[string]interface{}, group string) (map[string]interface{}, error) { +func resolveParams(attrs map[string]interface{}, group string) map[string]interface{} { sp := make(map[string]interface{}) for k, v := range attrs { if strings.HasPrefix(k, group) { sp[k[len(group):]] = v } } - return sp, nil + return sp } diff --git a/pkg/sql/codegen/xgboost/template_analyze.go b/pkg/sql/codegen/xgboost/template_analyze.go index 51ed414978..0b5da388e9 100644 --- a/pkg/sql/codegen/xgboost/template_analyze.go +++ b/pkg/sql/codegen/xgboost/template_analyze.go @@ -20,7 +20,7 @@ import ( type analyzeFiller struct { DataSource string DatasetSQL string - ShapSummaryParames map[string]interface{} + ShapSummaryParames string FieldMetaJSON string Label string } @@ -28,7 +28,7 @@ type analyzeFiller struct { const analyzeTemplateText = ` import xgboost import shap -import json +import json import matplotlib.pyplot as plt import pandas as pd @@ -42,10 +42,7 @@ feature_spec = {k['name']: k for k in feature_field_meta} conn = connect_with_data_source('''{{.DataSource}}''') label_name = "{{.Label}}" -summaryAttrs = {} -{{ range $k, $v := .ShapSummaryParames }} -summaryAttrs["{{$k}}"] = {{$v}} -{{end}} +summaryAttrs = json.loads('''{{.ShapSummaryParames}}''') def analyzer_dataset(): stream = db_generator(conn.driver, conn, """{{.DatasetSQL}}""", feature_column_name, label_name, feature_spec) diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index 7ec14e62e6..c3303c636d 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -569,13 +569,12 @@ func analyze(wr *PipeWriter, pr *extendedSelect, db *DB, cwd, modelDir string) e return err } if !strings.HasPrefix(strings.ToUpper(ir.TrainIR.Estimator), `XGBOOST.`) { - return fmt.Errorf("unsupported model%s", ir.TrainIR.Estimator) + return fmt.Errorf("unsupported model %s", ir.TrainIR.Estimator) } code, err := xgboost.Analyze(ir) if err != nil { return err } - fmt.Fprintln(os.Stdout, code) var program bytes.Buffer program.WriteString(code) cmd.Stdin = &program From 11ed1635366ef52f329872fe19d1e7703ab06b5c Mon Sep 17 00:00:00 2001 From: w7u Date: Wed, 23 Oct 2019 19:03:49 +0800 Subject: [PATCH 3/4] fix test --- pkg/sql/codegen/xgboost/codegen_analyze_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sql/codegen/xgboost/codegen_analyze_test.go b/pkg/sql/codegen/xgboost/codegen_analyze_test.go index 809c40e120..a33c06123f 100644 --- a/pkg/sql/codegen/xgboost/codegen_analyze_test.go +++ b/pkg/sql/codegen/xgboost/codegen_analyze_test.go @@ -33,6 +33,6 @@ func TestAnalyze(t *testing.T) { }, TrainIR: tir, } - _, err := Analyze(air, "") + _, err := Analyze(air) a.NoError(err) } From 672d4ef75fcaa955ddab1fc5f03278d51c77a4b4 Mon Sep 17 00:00:00 2001 From: w7u Date: Thu, 24 Oct 2019 16:29:28 +0800 Subject: [PATCH 4/4] ut --- cmd/sqlflowserver/main_test.go | 27 ++++++++++++++++++---- pkg/sql/codegen/xgboost/codegen_analyze.go | 4 ++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index c254e74203..22b92b7352 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -307,7 +307,7 @@ func TestEnd2EndMySQLIR(t *testing.T) { t.Run("CaseTrainRegression", CaseTrainRegression) t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression) t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression) - t.Run("CaseAnalyzeXGBoostModel", CaseAnalyzeXGBoostModel) + t.Run("CaseAnalyzeXGBoostModel", CaseTrainAndAnalyzeXGBoostModel) } func TestEnd2EndHive(t *testing.T) { @@ -1089,10 +1089,22 @@ INTO sqlflow_models.my_xgb_regression_model; ParseRow(stream) } -// CaseAnalyzeXGBoostModel is used to test analyze a xgboost model -func CaseAnalyzeXGBoostModel(t *testing.T) { +// CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model, +// then analyze it +func CaseTrainAndAnalyzeXGBoostModel(t *testing.T) { a := assert.New(t) - stmt := ` + trainStmt := ` +SELECT * +FROM housing.train +TRAIN xgboost.gbtree +WITH + objective="reg:squarederror", + train.num_boost_round = 30 + COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13 +LABEL target +INTO sqlflow_models.my_xgb_regression_model; + ` + analyzeStmt := ` SELECT * FROM housing.train ANALYZE sqlflow_models.my_xgb_regression_model @@ -1110,7 +1122,12 @@ USING TreeExplainer; ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) defer cancel() - stream, err := cli.Run(ctx, sqlRequest(stmt)) + stream, err := cli.Run(ctx, sqlRequest(trainStmt)) + if err != nil { + a.Fail("Check if the server started successfully. %v", err) + } + ParseRow(stream) + stream, err = cli.Run(ctx, sqlRequest(analyzeStmt)) if err != nil { a.Fail("Check if the server started successfully. %v", err) } diff --git a/pkg/sql/codegen/xgboost/codegen_analyze.go b/pkg/sql/codegen/xgboost/codegen_analyze.go index c627afc2c9..7fd14247c1 100644 --- a/pkg/sql/codegen/xgboost/codegen_analyze.go +++ b/pkg/sql/codegen/xgboost/codegen_analyze.go @@ -23,7 +23,7 @@ import ( ) const ( - shapSummaryAttributes = "shap_summary." + shapSummaryAttrPrefix = "shap_summary." ) // Analyze generates a Python program to analyze a trained model. @@ -31,7 +31,7 @@ func Analyze(ir *codegen.AnalyzeIR) (string, error) { if ir.Explainer != "TreeExplainer" { return "", fmt.Errorf("unsupported explainer %s", ir.Explainer) } - summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttributes) + summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttrPrefix) jsonSummary, err := json.Marshal(summaryAttrs) if err != nil { return "", err