@@ -326,6 +326,7 @@ func TestEnd2EndMySQLIR(t *testing.T) {
326326 t .Run ("CaseTrainRegression" , CaseTrainRegression )
327327 t .Run ("CaseTrainXGBoostRegressionIR" , CaseTrainXGBoostRegression )
328328 t .Run ("CasePredictXGBoostRegressionIR" , CasePredictXGBoostRegression )
329+ t .Run ("CaseAnalyzeXGBoostModel" , CaseTrainAndAnalyzeXGBoostModel )
329330}
330331
331332func CaseTrainTextClassificationIR (t * testing.T ) {
@@ -852,7 +853,7 @@ func CaseTrainALPSRemoteModel(t *testing.T) {
852853FROM %s.sparse_column_test
853854LIMIT 100
854855TRAIN models.estimator.dnn_classifier.DNNClassifier
855- WITH
856+ WITH
856857 model.n_classes = 2, model.hidden_units = [10, 20], train.batch_size = 10, engine.ps_num=0, engine.worker_num=0, engine.type=local,
857858 gitlab.project = "Alps/sqlflow-models",
858859 gitlab.source_root = python,
@@ -979,6 +980,51 @@ INTO sqlflow_models.my_xgb_regression_model;
979980 }
980981}
981982
983+ // CaseTrainAndAnalyzeXGBoostModel is used to test training a xgboost model,
984+ // then analyze it
985+ func CaseTrainAndAnalyzeXGBoostModel (t * testing.T ) {
986+ a := assert .New (t )
987+ trainStmt := `
988+ SELECT *
989+ FROM housing.train
990+ TRAIN xgboost.gbtree
991+ WITH
992+ objective="reg:squarederror",
993+ train.num_boost_round = 30
994+ COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
995+ LABEL target
996+ INTO sqlflow_models.my_xgb_regression_model;
997+ `
998+ analyzeStmt := `
999+ SELECT *
1000+ FROM housing.train
1001+ ANALYZE sqlflow_models.my_xgb_regression_model
1002+ WITH
1003+ shap_summary.plot_type="bar",
1004+ shap_summary.alpha=1,
1005+ shap_summary.sort=True
1006+ USING TreeExplainer;
1007+ `
1008+ conn , err := createRPCConn ()
1009+ a .NoError (err )
1010+ defer conn .Close ()
1011+ cli := pb .NewSQLFlowClient (conn )
1012+
1013+ ctx , cancel := context .WithTimeout (context .Background (), 300 * time .Second )
1014+ defer cancel ()
1015+
1016+ stream , err := cli .Run (ctx , sqlRequest (trainStmt ))
1017+ if err != nil {
1018+ a .Fail ("Check if the server started successfully. %v" , err )
1019+ }
1020+ ParseRow (stream )
1021+ stream , err = cli .Run (ctx , sqlRequest (analyzeStmt ))
1022+ if err != nil {
1023+ a .Fail ("Check if the server started successfully. %v" , err )
1024+ }
1025+ ParseRow (stream )
1026+ }
1027+
9821028func CasePredictXGBoostRegression (t * testing.T ) {
9831029 a := assert .New (t )
9841030 predSQL := fmt .Sprintf (`SELECT *
0 commit comments