Skip to content

Commit 5850977

Browse files
sethahmengxr
authored andcommitted
[SPARK-14107][PYSPARK][ML] Add seed as named argument to GBTs in pyspark
## What changes were proposed in this pull request? GBTs in pyspark previously had seed parameters, but they could not be passed as keyword arguments through the class constructor. This patch adds seed as a keyword argument and also sets default value. ## How was this patch tested? Doc tests were updated to pass a random seed through the GBTClassifier and GBTRegressor constructors. Author: sethah <[email protected]> Closes apache#11944 from sethah/SPARK-14107.
1 parent fdd460f commit 5850977

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

python/pyspark/ml/classification.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
520520
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
521521
>>> si_model = stringIndexer.fit(df)
522522
>>> td = si_model.transform(df)
523-
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
523+
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
524524
>>> model = gbt.fit(td)
525525
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
526526
True
@@ -543,19 +543,19 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
543543
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
544544
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
545545
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
546-
maxIter=20, stepSize=0.1):
546+
maxIter=20, stepSize=0.1, seed=None):
547547
"""
548548
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
549549
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
550550
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
551-
lossType="logistic", maxIter=20, stepSize=0.1)
551+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
552552
"""
553553
super(GBTClassifier, self).__init__()
554554
self._java_obj = self._new_java_obj(
555555
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
556556
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
557557
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
558-
lossType="logistic", maxIter=20, stepSize=0.1)
558+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
559559
kwargs = self.__init__._input_kwargs
560560
self.setParams(**kwargs)
561561

@@ -564,12 +564,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
564564
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
565565
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
566566
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
567-
lossType="logistic", maxIter=20, stepSize=0.1):
567+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None):
568568
"""
569569
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
570570
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
571571
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
572-
lossType="logistic", maxIter=20, stepSize=0.1)
572+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
573573
Sets params for Gradient Boosted Tree Classification.
574574
"""
575575
kwargs = self.setParams._input_kwargs

python/pyspark/ml/regression.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
641641
>>> df = sqlContext.createDataFrame([
642642
... (1.0, Vectors.dense(1.0)),
643643
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
644-
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
644+
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
645645
>>> model = gbt.fit(df)
646646
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
647647
True
@@ -664,18 +664,19 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
664664
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
665665
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
666666
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
667-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1):
667+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None):
668668
"""
669669
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
670670
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
671671
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
672-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
672+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
673673
"""
674674
super(GBTRegressor, self).__init__()
675675
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
676676
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
677677
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
678-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
678+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
679+
seed=None)
679680
kwargs = self.__init__._input_kwargs
680681
self.setParams(**kwargs)
681682

@@ -684,12 +685,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
684685
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
685686
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
686687
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
687-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1):
688+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None):
688689
"""
689690
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
690691
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
691692
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
692-
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
693+
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
693694
Sets params for Gradient Boosted Tree Regression.
694695
"""
695696
kwargs = self.setParams._input_kwargs

0 commit comments

Comments
 (0)