Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,31 @@ def _fit(self, dataset):

class CrossValidatorTests(PySparkTestCase):

def test_copy(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
(10, 10.0),
(50, 50.0),
(100, 100.0),
(500, 500.0)] * 10,
["feature", "label"])

iee = InducedErrorEstimator()
evaluator = RegressionEvaluator(metricName="rmse")

grid = (ParamGridBuilder()
.addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
.build())
cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
cvCopied = cv.copy()
self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid)

cvModel = cv.fit(dataset)
cvModelCopied = cvModel.copy()
for index in range(len(cvModel.avgMetrics)):
self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try assertEqual and find it did not work? Why do we need approximate equality here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried assertEqual before. This test case causes loss of precision under python2 if we use assertEqual. But under python3, it passes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. OK thanks for checking!

< 0.0001)

def test_fit_minimize_metric(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
Expand Down Expand Up @@ -534,6 +559,8 @@ def test_save_load(self):
cvModel.save(cvModelPath)
loadedModel = CrossValidatorModel.load(cvModelPath)
self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
for index in range(len(loadedModel.avgMetrics)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor possible suggestion, there are some other places in the doctests where we use numpys assert_almost_equal, it seems like that might simplify things here a bit if you wanted to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk Thanks for suggestion. I used assert_almost_equal here.

self.assertTrue(abs(loadedModel.avgMetrics[index] - cvModel.avgMetrics[index]) < 0.0001)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here; why approximate equality?



class TrainValidationSplitTests(PySparkTestCase):
Expand Down
18 changes: 10 additions & 8 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _fit(self, dataset):
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(seed).alias(randCol))
metrics = np.zeros(numModels)
metrics = [0.0] * numModels
for i in range(nFolds):
validateLB = i * h
validateUB = (i + 1) * h
Expand All @@ -266,7 +266,7 @@ def _fit(self, dataset):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
return self._copyValues(CrossValidatorModel(bestModel))
return self._copyValues(CrossValidatorModel(bestModel, metrics))

@since("1.4.0")
def copy(self, extra=None):
Expand Down Expand Up @@ -346,10 +346,11 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
.. versionadded:: 1.4.0
"""

def __init__(self, bestModel):
def __init__(self, bestModel, avgMetrics=[]):
super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
self.avgMetrics = avgMetrics

def _transform(self, dataset):
return self.bestModel.transform(dataset)
Expand All @@ -367,7 +368,9 @@ def copy(self, extra=None):
"""
if extra is None:
extra = dict()
return CrossValidatorModel(self.bestModel.copy(extra))
bestModel = self.bestModel.copy(extra)
avgMetrics = self.avgMetrics
return CrossValidatorModel(bestModel, avgMetrics)

@since("2.0.0")
def write(self):
Expand All @@ -394,9 +397,10 @@ def _from_java(cls, java_stage):

# Load information from java_stage to the instance.
bestModel = JavaParams._from_java(java_stage.bestModel())
avgMetrics = list(java_stage.avgMetrics())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel)\
py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)\
.setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
py_stage._resetUid(java_stage.uid())
return py_stage
Expand All @@ -408,12 +412,10 @@ def _to_java(self):
:return: Java object equivalent to this instance.
"""

sc = SparkContext._active_spark_context

_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
self.avgMetrics)
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down