Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Linear regression.

The learning objective is to minimize the squared error, with regularization.
Expand Down Expand Up @@ -123,6 +125,8 @@ def _create_model(self, java_model):

class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Model fitted by LinearRegression.

.. versionadded:: 1.4.0
Expand Down Expand Up @@ -631,6 +635,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol):
"""
.. note:: Experimental

`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
learning algorithm for regression.
It supports both continuous and categorical features.
Expand Down Expand Up @@ -713,7 +719,10 @@ def _create_model(self, java_model):

@inherit_doc
class DecisionTreeModel(JavaModel):
"""Abstraction for Decision Tree models.
"""
.. note:: Experimental

Abstraction for Decision Tree models.

.. versionadded:: 1.5.0
"""
Expand All @@ -736,7 +745,10 @@ def __repr__(self):

@inherit_doc
class TreeEnsembleModels(JavaModel):
"""Represents a tree ensemble model.
"""
.. note:: Experimental

Represents a tree ensemble model.

.. versionadded:: 1.5.0
"""
Expand All @@ -754,6 +766,8 @@ def __repr__(self):
@inherit_doc
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Model fitted by DecisionTreeRegressor.

.. versionadded:: 1.4.0
Expand Down Expand Up @@ -786,6 +800,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
learning algorithm for regression.
It supports both continuous and categorical features.
Expand Down Expand Up @@ -868,6 +884,8 @@ def _create_model(self, java_model):

class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Model fitted by RandomForestRegressor.

.. versionadded:: 1.4.0
Expand All @@ -892,8 +910,10 @@ def featureImportances(self):
@inherit_doc
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
JavaMLReadable):
JavaMLReadable, TreeRegressorParams):
"""
.. note:: Experimental

`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for regression.
It supports both continuous and categorical features.
Expand All @@ -904,6 +924,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
>>> print(gbt.getImpurity())
variance
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
Expand Down Expand Up @@ -940,19 +962,21 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None):
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
impurity="variance"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
impurity="variance")
"""
super(GBTRegressor, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
seed=None)
seed=None, impurity="variance")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

Expand All @@ -961,12 +985,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None):
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
impuriy="variance"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
impurity="variance")
Sets params for Gradient Boosted Tree Regression.
"""
kwargs = self.setParams._input_kwargs
Expand All @@ -992,6 +1018,8 @@ def getLossType(self):

class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Model fitted by GBTRegressor.

.. versionadded:: 1.4.0
Expand All @@ -1017,6 +1045,8 @@ def featureImportances(self):
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Accelerated Failure Time (AFT) Model Survival Regression

Fit a parametric AFT survival regression model based on the Weibull distribution
Expand Down Expand Up @@ -1157,6 +1187,8 @@ def getQuantilesCol(self):

class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Model fitted by AFTSurvivalRegression.

.. versionadded:: 1.6.0
Expand Down Expand Up @@ -1204,6 +1236,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
HasSolver, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Generalized Linear Regression.

Fit a Generalized Linear Model specified by giving a symbolic description of the linear
Expand Down Expand Up @@ -1320,6 +1354,8 @@ def getLink(self):

class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Model fitted by GeneralizedLinearRegression.

.. versionadded:: 2.0.0
Expand Down