From e253be5e2154cf824f4d5429335f7f571a8dc621 Mon Sep 17 00:00:00 2001 From: =^_^= Date: Tue, 2 Aug 2016 00:47:22 -0700 Subject: [PATCH 1/3] fixed CrossValidator.avgMetric calculation --- python/pyspark/ml/tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 7f967e5463dcf..9436eae99011b 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -234,7 +234,7 @@ def _fit(self, dataset): model = est.fit(train, epm[j]) # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + metrics[j] += metric/nFolds if eva.isLargerBetter(): bestIndex = np.argmax(metrics) From 0a4b0a9406415df9af0896207289531d50c0add6 Mon Sep 17 00:00:00 2001 From: =^_^= Date: Tue, 2 Aug 2016 23:03:03 -0700 Subject: [PATCH 2/3] added regression test in doc string --- python/pyspark/ml/tuning.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 9436eae99011b..80968b0819f9e 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -166,8 +166,11 @@ class CrossValidator(Estimator, ValidatorParams): >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) >>> cvModel = cv.fit(dataset) + >>> cvModel.avgMetrics[0] + 0.5 >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + >>> .. versionadded:: 1.4.0 """ From c426a98daf5f1bad9e2093dfe0d25bd6ead63f61 Mon Sep 17 00:00:00 2001 From: =^_^= Date: Tue, 2 Aug 2016 23:38:35 -0700 Subject: [PATCH 3/3] Update tuning.py --- python/pyspark/ml/tuning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 80968b0819f9e..2dcc99cef8aa2 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -170,7 +170,6 @@ class CrossValidator(Estimator, ValidatorParams): 0.5 >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... - >>> .. versionadded:: 1.4.0 """