Skip to content

Commit 80f07fb

Browse files
committed
address comments
1 parent 81473b0 commit 80f07fb

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def get$Name(self):
157157
"TypeConverters.toInt"),
158158
("parallelism", "the number of threads to use when running parallel algorithms (>= 1).",
159159
"1", "TypeConverters.toInt"),
160-
("collectSubModels", "whether to collect a list of sub-models trained during tuning",
160+
("collectSubModels", "Param for whether to collect a list of sub-models trained during " +
161+
"tuning. If set to false, then only the single best sub-model will be available after " +
162+
"fitting. If set to true, then all sub-models will be available. Warning: For large " +
163+
"models, collecting all sub-models can cause OOMs on the Spark driver.",
161164
"False", "TypeConverters.toBoolean"),
162165
("loss", "the loss function to be optimized.", None, "TypeConverters.toString")]
163166

python/pyspark/ml/param/shared.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,10 +657,10 @@ def getParallelism(self):
657657

658658
class HasCollectSubModels(Params):
659659
"""
660-
Mixin for param collectSubModels: whether to collect a list of sub-models trained during tuning
660+
Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.
661661
"""
662662

663-
collectSubModels = Param(Params._dummy(), "collectSubModels", "whether to collect a list of sub-models trained during tuning", typeConverter=TypeConverters.toBoolean)
663+
collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean)
664664

665665
def __init__(self):
666666
super(HasCollectSubModels, self).__init__()

python/pyspark/ml/tests.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,9 +1037,9 @@ def test_expose_sub_models(self):
10371037
numFolds=numFolds, collectSubModels=True)
10381038

10391039
def checkSubModels(subModels):
1040-
assert len(subModels) == numFolds
1040+
self.assertEqual(len(subModels), numFolds)
10411041
for i in range(numFolds):
1042-
assert len(subModels[i]) == len(grid)
1042+
self.assertEqual(len(subModels[i]), len(grid))
10431043

10441044
cvModel = cv.fit(dataset)
10451045
checkSubModels(cvModel.subModels)
@@ -1050,11 +1050,13 @@ def checkSubModels(subModels):
10501050
cvModel.save(savingPathWithSubModels)
10511051
cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
10521052
checkSubModels(cvModel3.subModels)
1053+
cvModel4 = cvModel3.copy()
1054+
checkSubModels(cvModel4.subModels)
10531055

10541056
savingPathWithoutSubModels = testSubPath + "cvModel2"
10551057
cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
10561058
cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
1057-
assert cvModel2.subModels is None
1059+
self.assertEqual(cvModel2.subModels, None)
10581060

10591061
for i in range(numFolds):
10601062
for j in range(len(grid)):
@@ -1243,19 +1245,21 @@ def test_expose_sub_models(self):
12431245
tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
12441246
collectSubModels=True)
12451247
tvsModel = tvs.fit(dataset)
1246-
assert len(tvsModel.subModels) == len(grid)
1248+
self.assertEqual(len(tvsModel.subModels), len(grid))
12471249

12481250
# Test the default value for option "persistSubModel" to be "true"
12491251
testSubPath = temp_path + "/testTrainValidationSplitSubModels"
12501252
savingPathWithSubModels = testSubPath + "cvModel3"
12511253
tvsModel.save(savingPathWithSubModels)
12521254
tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
1253-
assert len(tvsModel3.subModels) == len(grid)
1255+
self.assertEqual(len(tvsModel3.subModels), len(grid))
1256+
tvsModel4 = tvsModel3.copy()
1257+
self.assertEqual(len(tvsModel4.subModels), len(grid))
12541258

12551259
savingPathWithoutSubModels = testSubPath + "cvModel2"
12561260
tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
12571261
tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
1258-
assert tvsModel2.subModels is None
1262+
self.assertEqual(tvsModel2.subModels, None)
12591263

12601264
for i in range(len(grid)):
12611265
self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)

python/pyspark/ml/tuning.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,11 @@ def _from_java(cls, java_stage):
354354
numFolds = java_stage.getNumFolds()
355355
seed = java_stage.getSeed()
356356
parallelism = java_stage.getParallelism()
357+
collectSubModels = java_stage.getCollectSubModels()
357358
# Create a new instance of this stage.
358359
py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
359-
numFolds=numFolds, seed=seed, parallelism=parallelism)
360+
numFolds=numFolds, seed=seed, parallelism=parallelism,
361+
collectSubModels=collectSubModels)
360362
py_stage._resetUid(java_stage.uid())
361363
return py_stage
362364

@@ -376,6 +378,7 @@ def _to_java(self):
376378
_java_obj.setSeed(self.getSeed())
377379
_java_obj.setNumFolds(self.getNumFolds())
378380
_java_obj.setParallelism(self.getParallelism())
381+
_java_obj.setCollectSubModels(self.getCollectSubModels())
379382

380383
return _java_obj
381384

@@ -410,6 +413,7 @@ def copy(self, extra=None):
410413
and some extra params. This copies the underlying bestModel,
411414
creates a deep copy of the embedded paramMap, and
412415
copies the embedded and extra parameters over.
416+
It does not copy the extra Params into the subModels.
413417
414418
:param extra: Extra parameters to copy to the new instance
415419
:return: Copy of this instance
@@ -628,9 +632,11 @@ def _from_java(cls, java_stage):
628632
trainRatio = java_stage.getTrainRatio()
629633
seed = java_stage.getSeed()
630634
parallelism = java_stage.getParallelism()
635+
collectSubModels = java_stage.getCollectSubModels()
631636
# Create a new instance of this stage.
632637
py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
633-
trainRatio=trainRatio, seed=seed, parallelism=parallelism)
638+
trainRatio=trainRatio, seed=seed, parallelism=parallelism,
639+
collectSubModels=collectSubModels)
634640
py_stage._resetUid(java_stage.uid())
635641
return py_stage
636642

@@ -650,6 +656,7 @@ def _to_java(self):
650656
_java_obj.setTrainRatio(self.getTrainRatio())
651657
_java_obj.setSeed(self.getSeed())
652658
_java_obj.setParallelism(self.getParallelism())
659+
_java_obj.setCollectSubModels(self.getCollectSubModels())
653660
return _java_obj
654661

655662

@@ -682,6 +689,7 @@ def copy(self, extra=None):
682689
creates a deep copy of the embedded paramMap, and
683690
copies the embedded and extra parameters over.
684691
And, this creates a shallow copy of the validationMetrics.
692+
It does not copy the extra Params into the subModels.
685693
686694
:param extra: Extra parameters to copy to the new instance
687695
:return: Copy of this instance

0 commit comments

Comments
 (0)