From 844fe5c87fff19042e3abf84be590ce9deb0eac7 Mon Sep 17 00:00:00 2001 From: Louiszr Date: Sun, 23 Aug 2020 23:01:40 +0100 Subject: [PATCH] Fixed copy() to copy models instead of list --- python/pyspark/ml/tests/test_tuning.py | 8 ++++---- python/pyspark/ml/tuning.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index 0937e4707eab1..c9163627fdd54 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -127,9 +127,9 @@ def test_copy(self): 'foo', "Changing the original avgMetrics should not affect the copied model" ) - cvModel.subModels[0] = 'foo' + cvModel.subModels[0][0].getInducedError = lambda: 'foo' self.assertNotEqual( - cvModelCopied.subModels[0], + cvModelCopied.subModels[0][0].getInducedError(), 'foo', "Changing the original subModels should not affect the copied model" ) @@ -852,9 +852,9 @@ def test_copy(self): 'foo', "Changing the original validationMetrics should not affect the copied model" ) - tvsModel.subModels[0] = 'foo' + tvsModel.subModels[0].getInducedError = lambda: 'foo' self.assertNotEqual( - tvsModelCopied.subModels[0], + tvsModelCopied.subModels[0].getInducedError(), 'foo', "Changing the original subModels should not affect the copied model" ) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 466491d046692..6a0c85089e114 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -535,7 +535,10 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = list(self.avgMetrics) - subModels = [model.copy() for model in self.subModels] + subModels = [ + [sub_model.copy() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels + ] return self._copyValues(CrossValidatorModel(bestModel, avgMetrics, subModels), extra=extra) @since("2.3.0")