Skip to content

Commit 56ef24a

Browse files
committed
fix test and regenerate base
1 parent afdaa5c commit 56ef24a

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

python/pyspark/ml/param/shared.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919

2020
from pyspark.ml.param import Param, Params
2121

22-
import random
23-
24-
import sys
25-
2622

2723
class HasMaxIter(Params):
2824
"""
@@ -174,8 +170,7 @@ def __init__(self):
174170
super(HasProbabilityCol, self).__init__()
175171
#: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
176172
self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
177-
if 'probability' is not None:
178-
self._setDefault(probabilityCol='probability')
173+
self._setDefault(probabilityCol='probability')
179174

180175
def setProbabilityCol(self, value):
181176
"""
@@ -366,7 +361,7 @@ def __init__(self):
366361
super(HasSeed, self).__init__()
367362
#: param for random seed
368363
self.seed = Param(self, "seed", "random seed")
369-
self._setDefault(seed=random.randint(0, sys.maxsize))
364+
self._setDefault(seed=hash(type(self).name))
370365

371366
def setSeed(self, value):
372367
"""
@@ -452,11 +447,17 @@ class DecisionTreeParams(Params):
452447

453448
def __init__(self):
454449
super(DecisionTreeParams, self).__init__()
450+
#: param for Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
455451
self.maxDepth = Param(self, "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
452+
#: param for Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.
456453
self.maxBins = Param(self, "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.")
454+
#: param for Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.
457455
self.minInstancesPerNode = Param(self, "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.")
456+
#: param for Minimum information gain for a split to be considered at a tree node.
458457
self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
458+
#: param for Maximum memory in MB allocated to histogram aggregation.
459459
self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
460+
#: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
460461
self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
461462

462463
def setMaxDepth(self, value):

python/pyspark/ml/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_hasseed(self):
187187
# Check that a specified seed is honored
188188
self.assertEqual(withSeedSpecd.getSeed(), 42)
189189
# Check that a different class has a different seed
190-
self.assertNotEqual(other.getSeed(), oSeedSpeced.getSeed())
190+
self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
191191

192192
if __name__ == "__main__":
193193
unittest.main()

0 commit comments

Comments
 (0)