@@ -142,7 +142,7 @@ def test_clustering(self):
142142
143143 def test_classification (self ):
144144 from pyspark .mllib .classification import LogisticRegressionWithSGD , SVMWithSGD , NaiveBayes
145- from pyspark .mllib .tree import DecisionTree
145+ from pyspark .mllib .tree import DecisionTree , RandomForest , GradientBoostedTrees
146146 data = [
147147 LabeledPoint (0.0 , [1 , 0 , 0 ]),
148148 LabeledPoint (1.0 , [0 , 1 , 1 ]),
@@ -171,18 +171,31 @@ def test_classification(self):
171171 self .assertTrue (nb_model .predict (features [3 ]) > 0 )
172172
173173 categoricalFeaturesInfo = {0 : 3 } # feature 0 has 3 categories
174- dt_model = \
175- DecisionTree .trainClassifier (rdd , numClasses = 2 ,
176- categoricalFeaturesInfo = categoricalFeaturesInfo )
174+ dt_model = DecisionTree .trainClassifier (
175+ rdd , numClasses = 2 , categoricalFeaturesInfo = categoricalFeaturesInfo )
177176 self .assertTrue (dt_model .predict (features [0 ]) <= 0 )
178177 self .assertTrue (dt_model .predict (features [1 ]) > 0 )
179178 self .assertTrue (dt_model .predict (features [2 ]) <= 0 )
180179 self .assertTrue (dt_model .predict (features [3 ]) > 0 )
181180
181+ rf_model = RandomForest .trainClassifier (
182+ rdd , numClasses = 2 , categoricalFeaturesInfo = categoricalFeaturesInfo , numTrees = 100 )
183+ self .assertTrue (rf_model .predict (features [0 ]) <= 0 )
184+ self .assertTrue (rf_model .predict (features [1 ]) > 0 )
185+ self .assertTrue (rf_model .predict (features [2 ]) <= 0 )
186+ self .assertTrue (rf_model .predict (features [3 ]) > 0 )
187+
188+ gbt_model = GradientBoostedTrees .trainClassifier (
189+ rdd , categoricalFeaturesInfo = categoricalFeaturesInfo )
190+ self .assertTrue (gbt_model .predict (features [0 ]) <= 0 )
191+ self .assertTrue (gbt_model .predict (features [1 ]) > 0 )
192+ self .assertTrue (gbt_model .predict (features [2 ]) <= 0 )
193+ self .assertTrue (gbt_model .predict (features [3 ]) > 0 )
194+
182195 def test_regression (self ):
183196 from pyspark .mllib .regression import LinearRegressionWithSGD , LassoWithSGD , \
184197 RidgeRegressionWithSGD
185- from pyspark .mllib .tree import DecisionTree
198+ from pyspark .mllib .tree import DecisionTree , RandomForest , GradientBoostedTrees
186199 data = [
187200 LabeledPoint (- 1.0 , [0 , - 1 ]),
188201 LabeledPoint (1.0 , [0 , 1 ]),
@@ -211,13 +224,27 @@ def test_regression(self):
211224 self .assertTrue (rr_model .predict (features [3 ]) > 0 )
212225
213226 categoricalFeaturesInfo = {0 : 2 } # feature 0 has 2 categories
214- dt_model = \
215- DecisionTree . trainRegressor ( rdd , categoricalFeaturesInfo = categoricalFeaturesInfo )
227+ dt_model = DecisionTree . trainRegressor (
228+ rdd , categoricalFeaturesInfo = categoricalFeaturesInfo )
216229 self .assertTrue (dt_model .predict (features [0 ]) <= 0 )
217230 self .assertTrue (dt_model .predict (features [1 ]) > 0 )
218231 self .assertTrue (dt_model .predict (features [2 ]) <= 0 )
219232 self .assertTrue (dt_model .predict (features [3 ]) > 0 )
220233
234+ rf_model = RandomForest .trainRegressor (
235+ rdd , categoricalFeaturesInfo = categoricalFeaturesInfo , numTrees = 100 )
236+ self .assertTrue (rf_model .predict (features [0 ]) <= 0 )
237+ self .assertTrue (rf_model .predict (features [1 ]) > 0 )
238+ self .assertTrue (rf_model .predict (features [2 ]) <= 0 )
239+ self .assertTrue (rf_model .predict (features [3 ]) > 0 )
240+
241+ gbt_model = GradientBoostedTrees .trainRegressor (
242+ rdd , categoricalFeaturesInfo = categoricalFeaturesInfo )
243+ self .assertTrue (gbt_model .predict (features [0 ]) <= 0 )
244+ self .assertTrue (gbt_model .predict (features [1 ]) > 0 )
245+ self .assertTrue (gbt_model .predict (features [2 ]) <= 0 )
246+ self .assertTrue (gbt_model .predict (features [3 ]) > 0 )
247+
221248
222249class StatTests (PySparkTestCase ):
223250 # SPARK-4023
0 commit comments