Skip to content

Commit 672d4ef

Browse files
committed
ut
1 parent 11ed163 commit 672d4ef

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

cmd/sqlflowserver/main_test.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
307307
t.Run("CaseTrainRegression", CaseTrainRegression)
308308
t.Run("CaseTrainXGBoostRegressionIR", CaseTrainXGBoostRegression)
309309
t.Run("CasePredictXGBoostRegressionIR", CasePredictXGBoostRegression)
310-
t.Run("CaseAnalyzeXGBoostModel", CaseAnalyzeXGBoostModel)
310+
t.Run("CaseAnalyzeXGBoostModel", CaseTrainAndAnalyzeXGBoostModel)
311311
}
312312

313313
func TestEnd2EndHive(t *testing.T) {
@@ -1089,10 +1089,22 @@ INTO sqlflow_models.my_xgb_regression_model;
10891089
ParseRow(stream)
10901090
}
10911091

1092-
// CaseAnalyzeXGBoostModel is used to test analyze a xgboost model
1093-
func CaseAnalyzeXGBoostModel(t *testing.T) {
1092+
// CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model,
1093+
// then analyze it
1094+
func CaseTrainAndAnalyzeXGBoostModel(t *testing.T) {
10941095
a := assert.New(t)
1095-
stmt := `
1096+
trainStmt := `
1097+
SELECT *
1098+
FROM housing.train
1099+
TRAIN xgboost.gbtree
1100+
WITH
1101+
objective="reg:squarederror",
1102+
train.num_boost_round = 30
1103+
COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
1104+
LABEL target
1105+
INTO sqlflow_models.my_xgb_regression_model;
1106+
`
1107+
analyzeStmt := `
10961108
SELECT *
10971109
FROM housing.train
10981110
ANALYZE sqlflow_models.my_xgb_regression_model
@@ -1110,7 +1122,12 @@ USING TreeExplainer;
11101122
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
11111123
defer cancel()
11121124

1113-
stream, err := cli.Run(ctx, sqlRequest(stmt))
1125+
stream, err := cli.Run(ctx, sqlRequest(trainStmt))
1126+
if err != nil {
1127+
a.Fail("Check if the server started successfully. %v", err)
1128+
}
1129+
ParseRow(stream)
1130+
stream, err = cli.Run(ctx, sqlRequest(analyzeStmt))
11141131
if err != nil {
11151132
a.Fail("Check if the server started successfully. %v", err)
11161133
}

pkg/sql/codegen/xgboost/codegen_analyze.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ import (
2323
)
2424

2525
const (
26-
shapSummaryAttributes = "shap_summary."
26+
shapSummaryAttrPrefix = "shap_summary."
2727
)
2828

2929
// Analyze generates a Python program to analyze a trained model.
3030
func Analyze(ir *codegen.AnalyzeIR) (string, error) {
3131
if ir.Explainer != "TreeExplainer" {
3232
return "", fmt.Errorf("unsupported explainer %s", ir.Explainer)
3333
}
34-
summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttributes)
34+
summaryAttrs := resolveParams(ir.Attributes, shapSummaryAttrPrefix)
3535
jsonSummary, err := json.Marshal(summaryAttrs)
3636
if err != nil {
3737
return "", err

0 commit comments

Comments
 (0)