Skip to content

Commit 5207a00

Browse files
holdenkNick Pentreath
authored andcommitted
[SPARK-15281][PYSPARK][ML][TRIVIAL] Add impurity param to GBTRegressor & add experimental inside of regression.py
## What changes were proposed in this pull request? Add impurity param to GBTRegressor and mark the of the models & regressors in regression.py as experimental to match Scaladoc. ## How was this patch tested? Added default value to init, tested with unit/doc tests. Author: Holden Karau <[email protected]> Closes #13071 from holdenk/SPARK-15281-GBTRegressor-impurity.
1 parent 4699144 commit 5207a00

File tree

1 file changed

+44
-8
lines changed

1 file changed

+44
-8
lines changed

python/pyspark/ml/regression.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
4040
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
4141
HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
4242
"""
43+
.. note:: Experimental
44+
4345
Linear regression.
4446
4547
The learning objective is to minimize the squared error, with regularization.
@@ -123,6 +125,8 @@ def _create_model(self, java_model):
123125

124126
class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
125127
"""
128+
.. note:: Experimental
129+
126130
Model fitted by LinearRegression.
127131
128132
.. versionadded:: 1.4.0
@@ -631,6 +635,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
631635
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
632636
HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol):
633637
"""
638+
.. note:: Experimental
639+
634640
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
635641
learning algorithm for regression.
636642
It supports both continuous and categorical features.
@@ -713,7 +719,10 @@ def _create_model(self, java_model):
713719

714720
@inherit_doc
715721
class DecisionTreeModel(JavaModel):
716-
"""Abstraction for Decision Tree models.
722+
"""
723+
.. note:: Experimental
724+
725+
Abstraction for Decision Tree models.
717726
718727
.. versionadded:: 1.5.0
719728
"""
@@ -736,7 +745,10 @@ def __repr__(self):
736745

737746
@inherit_doc
738747
class TreeEnsembleModels(JavaModel):
739-
"""Represents a tree ensemble model.
748+
"""
749+
.. note:: Experimental
750+
751+
Represents a tree ensemble model.
740752
741753
.. versionadded:: 1.5.0
742754
"""
@@ -754,6 +766,8 @@ def __repr__(self):
754766
@inherit_doc
755767
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
756768
"""
769+
.. note:: Experimental
770+
757771
Model fitted by DecisionTreeRegressor.
758772
759773
.. versionadded:: 1.4.0
@@ -786,6 +800,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
786800
RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
787801
JavaMLWritable, JavaMLReadable):
788802
"""
803+
.. note:: Experimental
804+
789805
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
790806
learning algorithm for regression.
791807
It supports both continuous and categorical features.
@@ -868,6 +884,8 @@ def _create_model(self, java_model):
868884

869885
class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
870886
"""
887+
.. note:: Experimental
888+
871889
Model fitted by RandomForestRegressor.
872890
873891
.. versionadded:: 1.4.0
@@ -892,8 +910,10 @@ def featureImportances(self):
892910
@inherit_doc
893911
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
894912
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
895-
JavaMLReadable):
913+
JavaMLReadable, TreeRegressorParams):
896914
"""
915+
.. note:: Experimental
916+
897917
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
898918
learning algorithm for regression.
899919
It supports both continuous and categorical features.
@@ -904,6 +924,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
904924
... (1.0, Vectors.dense(1.0)),
905925
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
906926
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
927+
>>> print(gbt.getImpurity())
928+
variance
907929
>>> model = gbt.fit(df)
908930
>>> model.featureImportances
909931
SparseVector(1, {0: 1.0})
@@ -940,19 +962,21 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
940962
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
941963
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
942964
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
943-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None):
965+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
966+
impurity="variance"):
944967
"""
945968
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
946969
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
947970
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
948-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
971+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
972+
impurity="variance")
949973
"""
950974
super(GBTRegressor, self).__init__()
951975
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
952976
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
953977
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
954978
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
955-
seed=None)
979+
seed=None, impurity="variance")
956980
kwargs = self.__init__._input_kwargs
957981
self.setParams(**kwargs)
958982

@@ -961,12 +985,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
961985
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
962986
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
963987
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
964-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None):
988+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
989+
impuriy="variance"):
965990
"""
966991
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
967992
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
968993
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
969-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
994+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
995+
impurity="variance")
970996
Sets params for Gradient Boosted Tree Regression.
971997
"""
972998
kwargs = self.setParams._input_kwargs
@@ -992,6 +1018,8 @@ def getLossType(self):
9921018

9931019
class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
9941020
"""
1021+
.. note:: Experimental
1022+
9951023
Model fitted by GBTRegressor.
9961024
9971025
.. versionadded:: 1.4.0
@@ -1017,6 +1045,8 @@ def featureImportances(self):
10171045
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
10181046
HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
10191047
"""
1048+
.. note:: Experimental
1049+
10201050
Accelerated Failure Time (AFT) Model Survival Regression
10211051
10221052
Fit a parametric AFT survival regression model based on the Weibull distribution
@@ -1157,6 +1187,8 @@ def getQuantilesCol(self):
11571187

11581188
class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
11591189
"""
1190+
.. note:: Experimental
1191+
11601192
Model fitted by AFTSurvivalRegression.
11611193
11621194
.. versionadded:: 1.6.0
@@ -1204,6 +1236,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12041236
HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
12051237
HasSolver, JavaMLWritable, JavaMLReadable):
12061238
"""
1239+
.. note:: Experimental
1240+
12071241
Generalized Linear Regression.
12081242
12091243
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
@@ -1320,6 +1354,8 @@ def getLink(self):
13201354

13211355
class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
13221356
"""
1357+
.. note:: Experimental
1358+
13231359
Model fitted by GeneralizedLinearRegression.
13241360
13251361
.. versionadded:: 2.0.0

0 commit comments

Comments
 (0)