diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 5b1cee3210..22b92b7352 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -307,6 +307,7 @@ func TestEnd2EndMySQLIR(t *testing.T) { t.Run("CaseTrainRegression", CaseTrainRegression) t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression) t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression) + t.Run("CaseAnalyzeXGBoostModel", CaseTrainAndAnalyzeXGBoostModel) } func TestEnd2EndHive(t *testing.T) { @@ -912,7 +913,7 @@ func CaseTrainALPSRemoteModel(t *testing.T) { FROM %s.sparse_column_test LIMIT 100 TRAIN models.estimator.dnn_classifier.DNNClassifier -WITH +WITH model.n_classes = 2, model.hidden_units = [10, 20], train.batch_size = 10, engine.ps_num=0, engine.worker_num=0, engine.type=local, gitlab.project = "Alps/sqlflow-models", gitlab.source_root = python, @@ -1088,6 +1089,51 @@ INTO sqlflow_models.my_xgb_regression_model; ParseRow(stream) } +// CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model, +// then analyze it +func CaseTrainAndAnalyzeXGBoostModel(t *testing.T) { + a := assert.New(t) + trainStmt := ` +SELECT * +FROM housing.train +TRAIN xgboost.gbtree +WITH + objective="reg:squarederror", + train.num_boost_round = 30 + COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13 +LABEL target +INTO sqlflow_models.my_xgb_regression_model; + ` + analyzeStmt := ` +SELECT * +FROM housing.train +ANALYZE sqlflow_models.my_xgb_regression_model +WITH + shap_summary.plot_type="bar", + shap_summary.alpha=1, + shap_summary.sort=True +USING TreeExplainer; + ` + conn, err := createRPCConn() + a.NoError(err) + defer conn.Close() + cli := pb.NewSQLFlowClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + defer cancel() + + stream, err := cli.Run(ctx, sqlRequest(trainStmt)) + if err != nil { + a.Fail("Check if the server started successfully. %v", err) + } + ParseRow(stream) + stream, err = cli.Run(ctx, sqlRequest(analyzeStmt)) + if err != nil { + a.Fail("Check if the server started successfully. %v", err) + } + ParseRow(stream) +} + func CasePredictXGBoostRegression(t *testing.T) { a := assert.New(t) conn, err := createRPCConn() diff --git a/pkg/sql/codegen/xgboost/codegen_analyze.go b/pkg/sql/codegen/xgboost/codegen_analyze.go index b1b14ae296..7fd14247c1 100644 --- a/pkg/sql/codegen/xgboost/codegen_analyze.go +++ b/pkg/sql/codegen/xgboost/codegen_analyze.go @@ -23,15 +23,16 @@ import ( ) const ( - shapSummaryAttributes = "shap_summary" + shapSummaryAttrPrefix = "shap_summary." ) // Analyze generates a Python program to analyze a trained model. -func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) { +func Analyze(ir *codegen.AnalyzeIR) (string, error) { if ir.Explainer != "TreeExplainer" { return "", fmt.Errorf("unsupported explainer %s", ir.Explainer) } - summaryAttrs, err := resolveParams(ir.Attributes, shapSummaryAttributes) + summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttrPrefix) + jsonSummary, err := json.Marshal(summaryAttrs) if err != nil { return "", err } @@ -47,10 +48,9 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) { fr := &analyzeFiller{ DataSource: ir.DataSource, DatasetSQL: ir.Select, - ShapSummaryParames: summaryAttrs, + ShapSummaryParames: string(jsonSummary), FieldMetaJSON: string(fm), Label: y.Name, - ModelFile: modelPath, } var analysis bytes.Buffer if err := analyzeTemplate.Execute(&analysis, fr); err != nil { @@ -59,12 +59,12 @@ func Analyze(ir *codegen.AnalyzeIR, modelPath string) (string, error) { return analysis.String(), nil } -func resolveParams(attrs map[string]interface{}, group string) (map[string]interface{}, error) { +func resolveParams(attrs map[string]interface{}, group string) map[string]interface{} { sp := make(map[string]interface{}) for k, v := range attrs { if strings.HasPrefix(k, group) { sp[k[len(group):]] = v } } - return sp, nil + return sp } diff --git a/pkg/sql/codegen/xgboost/codegen_analyze_test.go b/pkg/sql/codegen/xgboost/codegen_analyze_test.go index 809c40e120..a33c06123f 100644 --- a/pkg/sql/codegen/xgboost/codegen_analyze_test.go +++ b/pkg/sql/codegen/xgboost/codegen_analyze_test.go @@ -33,6 +33,6 @@ func TestAnalyze(t *testing.T) { }, TrainIR: tir, } - _, err := Analyze(air, "") + _, err := Analyze(air) a.NoError(err) } diff --git a/pkg/sql/codegen/xgboost/template_analyze.go b/pkg/sql/codegen/xgboost/template_analyze.go index 8b1a261af2..0b5da388e9 100644 --- a/pkg/sql/codegen/xgboost/template_analyze.go +++ b/pkg/sql/codegen/xgboost/template_analyze.go @@ -20,16 +20,15 @@ import ( type analyzeFiller struct { DataSource string DatasetSQL string - ShapSummaryParames map[string]interface{} + ShapSummaryParames string FieldMetaJSON string Label string - ModelFile string } const analyzeTemplateText = ` import xgboost import shap -import json +import json import matplotlib.pyplot as plt import pandas as pd @@ -42,12 +41,8 @@ feature_column_name = sorted([k["name"] for k in feature_field_meta]) feature_spec = {k['name']: k for k in feature_field_meta} conn = connect_with_data_source('''{{.DataSource}}''') label_name = "{{.Label}}" -model_path = "{{.ModelFile}}" -summaryAttrs = {} -{{ range $k, $v := .ShapSummaryParames }} -summaryAttrs["{{$k}}"] = {{$v}} -{{end}} +summaryAttrs = json.loads('''{{.ShapSummaryParames}}''') def analyzer_dataset(): stream = db_generator(conn.driver, conn, """{{.DatasetSQL}}""", feature_column_name, label_name, feature_spec) @@ -62,7 +57,7 @@ def analyzer_dataset(): X,y = analyzer_dataset() bst = xgboost.Booster() -bst.load_model(fname=model_path) +bst.load_model("my_model") explainer = shap.TreeExplainer(bst) shap_values = explainer.shap_values(X) shap.summary_plot(shap_values, X, show=False, **summaryAttrs) diff --git a/pkg/sql/executor.go b/pkg/sql/executor.go index 7de5033c16..c3303c636d 100644 --- a/pkg/sql/executor.go +++ b/pkg/sql/executor.go @@ -561,14 +561,31 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin } func analyze(wr *PipeWriter, pr *extendedSelect, db *DB, cwd, modelDir string) error { - program, err := genAnalyzer(pr, db, cwd, modelDir) - if err != nil { - return err - } cmd := exec.Command("python", "-u") cmd.Dir = cwd - cmd.Stdin = program - if _, err = cmd.CombinedOutput(); err != nil { + if enableIR() { + ir, err := generateAnalyzeIR(pr, db.String(), cwd, modelDir) + if err != nil { + return err + } + if !strings.HasPrefix(strings.ToUpper(ir.TrainIR.Estimator), `XGBOOST.`) { + return fmt.Errorf("unsupported model %s", ir.TrainIR.Estimator) + } + code, err := xgboost.Analyze(ir) + if err != nil { + return err + } + var program bytes.Buffer + program.WriteString(code) + cmd.Stdin = &program + } else { + prog, err := genAnalyzer(pr, db, cwd, modelDir) + if err != nil { + return err + } + cmd.Stdin = prog + } + if _, err := cmd.CombinedOutput(); err != nil { return err }