@@ -307,7 +307,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
307307 t .Run ("CaseTrainRegression" , CaseTrainRegression )
308308 t .Run ("CaseTrainXGBoostRegressionIR" , CaseTrainXGBoostRegression )
309309 t .Run ("CasePredictXGBoostRegressionIR" , CasePredictXGBoostRegression )
310- t .Run ("CaseAnalyzeXGBoostModel" , CaseAnalyzeXGBoostModel )
310+ t .Run ("CaseAnalyzeXGBoostModel" , CaseTrainAndAnalyzeXGBoostModel )
311311}
312312
313313func TestEnd2EndHive (t * testing.T ) {
@@ -1089,10 +1089,22 @@ INTO sqlflow_models.my_xgb_regression_model;
10891089 ParseRow (stream )
10901090}
10911091
1092- // CaseAnalyzeXGBoostModel is used to test analyze a xgboost model
1093- func CaseAnalyzeXGBoostModel (t * testing.T ) {
1092+ // CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model,
1093+ // then analyze it
1094+ func CaseTrainAndAnalyzeXGBoostModel (t * testing.T ) {
10941095 a := assert .New (t )
1095- stmt := `
1096+ trainStmt := `
1097+ SELECT *
1098+ FROM housing.train
1099+ TRAIN xgboost.gbtree
1100+ WITH
1101+ objective="reg:squarederror",
1102+ train.num_boost_round = 30
1103+ COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
1104+ LABEL target
1105+ INTO sqlflow_models.my_xgb_regression_model;
1106+ `
1107+ analyzeStmt := `
10961108SELECT *
10971109FROM housing.train
10981110ANALYZE sqlflow_models.my_xgb_regression_model
@@ -1110,7 +1122,12 @@ USING TreeExplainer;
11101122 ctx , cancel := context .WithTimeout (context .Background (), 300 * time .Second )
11111123 defer cancel ()
11121124
1113- stream , err := cli .Run (ctx , sqlRequest (stmt ))
1125+ stream , err := cli .Run (ctx , sqlRequest (trainStmt ))
1126+ if err != nil {
1127+ a .Fail ("Check if the server started successfully. %v" , err )
1128+ }
1129+ ParseRow (stream )
1130+ stream , err = cli .Run (ctx , sqlRequest (analyzeStmt ))
11141131 if err != nil {
11151132 a .Fail ("Check if the server started successfully. %v" , err )
11161133 }
0 commit comments