From 22ee3b715712be2bb924a59127a5a688374c4590 Mon Sep 17 00:00:00 2001 From: w7u Date: Mon, 9 Sep 2019 20:52:44 +0800 Subject: [PATCH] support summary_plot(args) --- sql/codegen_analyze.go | 49 +++++++++++++++++++++++++++-------------- sql/template_analyze.go | 10 +++++---- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/sql/codegen_analyze.go b/sql/codegen_analyze.go index 9e0ed0b009..de8b5a2d37 100644 --- a/sql/codegen_analyze.go +++ b/sql/codegen_analyze.go @@ -19,16 +19,21 @@ import ( "strings" ) +const ( + shapSummaryAttributePrefix = "shap_summary" +) + type analyzeFiller struct { *connectionConfig - X []*FeatureMeta - Label string - AnalyzeDatasetSQL string - PlotType string - ModelFile string // path/to/model_file + X []*FeatureMeta + Label string + AnalyzeDatasetSQL string + PlotType string + ShapSummaryParames map[string]interface{} + ModelFile string // path/to/model_file } -func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, modelPath, plotType string) (*analyzeFiller, error) { +func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, modelPath string, summaryAttrs map[string]interface{}) (*analyzeFiller, error) { conn, err := newConnectionConfig(db) if err != nil { return nil, err @@ -39,9 +44,9 @@ func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, mod Label: label, // TODO(weiguo): test if it needs TrimSuffix(SQL, ";") on hive, // or we trim it in pr(*extendedSelect) - AnalyzeDatasetSQL: pr.standardSelect.String(), - ModelFile: modelPath, - PlotType: plotType, + AnalyzeDatasetSQL: pr.standardSelect.String(), + ModelFile: modelPath, + ShapSummaryParames: summaryAttrs, }, nil } @@ -71,13 +76,19 @@ func readXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error) return xs, fr.Y.FeatureName, nil } -func readPlotType(pr *extendedSelect) string { - v, ok := pr.analyzeAttrs["shap.plot_type"] - if !ok { - // using shap default value - return `""` +func resolveAnalyzeSummaryParames(atts *attrs) (map[string]interface{}, error) { + parames, err := resolveAttribute(atts) + if err != nil { + return nil, err } - return v.val + + summaryAttrs := make(map[string]interface{}) + for _, v := range parames { + if v.Prefix == shapSummaryAttributePrefix { + summaryAttrs[v.Name] = v.Value + } + } + return summaryAttrs, nil } func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffer, error) { @@ -89,13 +100,17 @@ func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffe return nil, fmt.Errorf("analyzer: model[%s] not supported", pr.estimator) } // We untar the XGBoost.{pr.trainedModel}.tar.gz and get three files. - plotType := readPlotType(pr) + summaryAttrs, err := resolveAnalyzeSummaryParames(&pr.analyzeAttrs) + if err != nil { + return nil, err + } + xs, label, err := readXGBFeatures(pr, db) if err != nil { return nil, err } - fr, err := newAnalyzeFiller(pr, db, xs, label, pr.trainedModel, plotType) + fr, err := newAnalyzeFiller(pr, db, xs, label, pr.trainedModel, summaryAttrs) if err != nil { return nil, fmt.Errorf("create analyze filler failed: %v", err) } diff --git a/sql/template_analyze.go b/sql/template_analyze.go index 5abfa78245..9edbd50b92 100644 --- a/sql/template_analyze.go +++ b/sql/template_analyze.go @@ -66,9 +66,11 @@ def analyzer_dataset(): # 2. load the model model_path = "{{.ModelFile}}" -ptype = {{.PlotType}} -if len(ptype) == 0: - ptype = None + +summaryAttrs = {} +{{ range $k, $v := .ShapSummaryParames }} +summaryAttrs["{{$k}}"] = {{$v}} +{{end}} X,y = analyzer_dataset() @@ -77,7 +79,7 @@ bst.load_model(fname=model_path) explainer = shap.TreeExplainer(bst) shap_values = explainer.shap_values(X) -shap.summary_plot(shap_values, X, plot_type=ptype) +shap.summary_plot(shap_values, X, **summaryAttrs) plt.savefig('summary') `