Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 46 additions & 35 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
RandomForestParams, TreeEnsembleModel, TreeEnsembleParams
GBTParams, HasVarianceImpurity, RandomForestParams, TreeEnsembleModel, TreeEnsembleParams
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
Expand Down Expand Up @@ -895,15 +895,6 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)


class GBTParams(TreeEnsembleParams):
"""
Private class to track supported GBT params.

.. versionadded:: 1.4.0
"""
supportedLossTypes = ["logistic"]


@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
Expand Down Expand Up @@ -1174,9 +1165,31 @@ def trees(self):
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]


class GBTClassifierParams(GBTParams, HasVarianceImpurity):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should extend TreeClassifierParams

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler Thanks for your review.
Seems recently #22986 added trait HasVarianceImpurity and made
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see. let me take another look..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're correct, this is fine

"""
Private class to track supported GBTClassifier params.

.. versionadded:: 3.0.0
"""

supportedLossTypes = ["logistic"]

lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(supportedLossTypes),
typeConverter=TypeConverters.toString)

@since("1.4.0")
def getLossType(self):
"""
Gets the value of lossType or its default value.
"""
return self.getOrDefault(self.lossType)


@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
Expand Down Expand Up @@ -1242,40 +1255,36 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
>>> model.numClasses
2
>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
>>> gbt.getValidationIndicatorCol()
'validationIndicator'
>>> gbt.getValidationTol()
0.01

.. versionadded:: 1.4.0
"""

lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(GBTParams.supportedLossTypes),
typeConverter=TypeConverters.toString)

stepSize = Param(Params._dummy(), "stepSize",
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
"the contribution of each estimator.",
typeConverter=TypeConverters.toFloat)

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
featureSubsetStrategy="all"):
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
Copy link
Member

@BryanCutler BryanCutler Dec 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the correct default impurity default value has been changed in Scala, this is correct

featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
featureSubsetStrategy="all")
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None)
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
featureSubsetStrategy="all")
impurity="variance", featureSubsetStrategy="all", validationTol=0.01)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand All @@ -1285,13 +1294,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
featureSubsetStrategy="all"):
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
validationIndicatorCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
featureSubsetStrategy="all")
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None)
Sets params for Gradient Boosted Tree Classification.
"""
kwargs = self._input_kwargs
Expand All @@ -1307,20 +1318,20 @@ def setLossType(self, value):
"""
return self._set(lossType=value)

@since("1.4.0")
def getLossType(self):
"""
Gets the value of lossType or its default value.
"""
return self.getOrDefault(self.lossType)

@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
Sets the value of :py:attr:`featureSubsetStrategy`.
"""
return self._set(featureSubsetStrategy=value)

@since("3.0.0")
def setValidationIndicatorCol(self, value):
"""
Sets the value of :py:attr:`validationIndicatorCol`.
"""
return self._set(validationIndicatorCol=value)


class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/ml/param/_shared_params_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def get$Name(self):
"False", "TypeConverters.toBoolean"),
("loss", "the loss function to be optimized.", None, "TypeConverters.toString"),
("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.",
"'euclidean'", "TypeConverters.toString")]
"'euclidean'", "TypeConverters.toString"),
("validationIndicatorCol", "name of the column that indicates whether each row is for " +
"training or for validation. False indicates training; true indicates validation.",
None, "TypeConverters.toString")]

code = []
for name, doc, defaultValueStr, typeConverter in shared:
Expand Down
71 changes: 47 additions & 24 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,53 @@ def getLoss(self):
return self.getOrDefault(self.loss)


class HasDistanceMeasure(Params):
"""
Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
"""

distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)

def __init__(self):
super(HasDistanceMeasure, self).__init__()
self._setDefault(distanceMeasure='euclidean')

def setDistanceMeasure(self, value):
"""
Sets the value of :py:attr:`distanceMeasure`.
"""
return self._set(distanceMeasure=value)

def getDistanceMeasure(self):
"""
Gets the value of distanceMeasure or its default value.
"""
return self.getOrDefault(self.distanceMeasure)


class HasValidationIndicatorCol(Params):
"""
Mixin for param validationIndicatorCol: name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.
"""

validationIndicatorCol = Param(Params._dummy(), "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.", typeConverter=TypeConverters.toString)

def __init__(self):
super(HasValidationIndicatorCol, self).__init__()

def setValidationIndicatorCol(self, value):
"""
Sets the value of :py:attr:`validationIndicatorCol`.
"""
return self._set(validationIndicatorCol=value)

def getValidationIndicatorCol(self):
"""
Gets the value of validationIndicatorCol or its default value.
"""
return self.getOrDefault(self.validationIndicatorCol)


class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
Expand Down Expand Up @@ -790,27 +837,3 @@ def getCacheNodeIds(self):
"""
return self.getOrDefault(self.cacheNodeIds)


class HasDistanceMeasure(Params):
"""
Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
"""

distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)

def __init__(self):
super(HasDistanceMeasure, self).__init__()
self._setDefault(distanceMeasure='euclidean')

def setDistanceMeasure(self, value):
"""
Sets the value of :py:attr:`distanceMeasure`.
"""
return self._set(distanceMeasure=value)

def getDistanceMeasure(self):
"""
Gets the value of distanceMeasure or its default value.
"""
return self.getOrDefault(self.distanceMeasure)

Loading