Skip to content

Commit 129f2f4

Browse files
sethahjkbradley
authored andcommitted
[SPARK-14104][PYSPARK][ML] All Python param setters should use the _set method
## What changes were proposed in this pull request? Param setters in python previously accessed the _paramMap directly to update values. The `_set` method now implements type checking, so it should be used to update all parameters. This PR eliminates all direct accesses to `_paramMap` besides the one in the `_set` method to ensure type checking happens. Additional changes: * [SPARK-13068](#11663) missed adding type converters in evaluation.py so those are done here * An incorrect `toBoolean` type converter was used for StringIndexer `handleInvalid` param in previous PR. This is fixed here. ## How was this patch tested? Existing unit tests verify that parameters are still set properly. No new functionality is actually added in this PR. Author: sethah <[email protected]> Closes #11939 from sethah/SPARK-14104.
1 parent d6ae7d4 commit 129f2f4

File tree

12 files changed

+110
-91
lines changed

12 files changed

+110
-91
lines changed

python/pyspark/ml/classification.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,8 @@ def setThreshold(self, value):
142142
Sets the value of :py:attr:`threshold`.
143143
Clears value of :py:attr:`thresholds` if it has been set.
144144
"""
145-
self._paramMap[self.threshold] = value
146-
if self.isSet(self.thresholds):
147-
del self._paramMap[self.thresholds]
145+
self._set(threshold=value)
146+
self._clear(self.thresholds)
148147
return self
149148

150149
@since("1.4.0")
@@ -169,9 +168,8 @@ def setThresholds(self, value):
169168
Sets the value of :py:attr:`thresholds`.
170169
Clears value of :py:attr:`threshold` if it has been set.
171170
"""
172-
self._paramMap[self.thresholds] = value
173-
if self.isSet(self.threshold):
174-
del self._paramMap[self.threshold]
171+
self._set(thresholds=value)
172+
self._clear(self.threshold)
175173
return self
176174

177175
@since("1.5.0")
@@ -471,7 +469,7 @@ def setImpurity(self, value):
471469
"""
472470
Sets the value of :py:attr:`impurity`.
473471
"""
474-
self._paramMap[self.impurity] = value
472+
self._set(impurity=value)
475473
return self
476474

477475
@since("1.6.0")
@@ -833,7 +831,7 @@ def setLossType(self, value):
833831
"""
834832
Sets the value of :py:attr:`lossType`.
835833
"""
836-
self._paramMap[self.lossType] = value
834+
self._set(lossType=value)
837835
return self
838836

839837
@since("1.4.0")
@@ -963,7 +961,7 @@ def setSmoothing(self, value):
963961
"""
964962
Sets the value of :py:attr:`smoothing`.
965963
"""
966-
self._paramMap[self.smoothing] = value
964+
self._set(smoothing=value)
967965
return self
968966

969967
@since("1.5.0")
@@ -978,7 +976,7 @@ def setModelType(self, value):
978976
"""
979977
Sets the value of :py:attr:`modelType`.
980978
"""
981-
self._paramMap[self.modelType] = value
979+
self._set(modelType=value)
982980
return self
983981

984982
@since("1.5.0")
@@ -1108,7 +1106,7 @@ def setLayers(self, value):
11081106
"""
11091107
Sets the value of :py:attr:`layers`.
11101108
"""
1111-
self._paramMap[self.layers] = value
1109+
self._set(layers=value)
11121110
return self
11131111

11141112
@since("1.6.0")
@@ -1123,7 +1121,7 @@ def setBlockSize(self, value):
11231121
"""
11241122
Sets the value of :py:attr:`blockSize`.
11251123
"""
1126-
self._paramMap[self.blockSize] = value
1124+
self._set(blockSize=value)
11271125
return self
11281126

11291127
@since("1.6.0")

python/pyspark/ml/clustering.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def setK(self, value):
130130
"""
131131
Sets the value of :py:attr:`k`.
132132
"""
133-
self._paramMap[self.k] = value
133+
self._set(k=value)
134134
return self
135135

136136
@since("1.5.0")
@@ -145,7 +145,7 @@ def setInitMode(self, value):
145145
"""
146146
Sets the value of :py:attr:`initMode`.
147147
"""
148-
self._paramMap[self.initMode] = value
148+
self._set(initMode=value)
149149
return self
150150

151151
@since("1.5.0")
@@ -160,7 +160,7 @@ def setInitSteps(self, value):
160160
"""
161161
Sets the value of :py:attr:`initSteps`.
162162
"""
163-
self._paramMap[self.initSteps] = value
163+
self._set(initSteps=value)
164164
return self
165165

166166
@since("1.5.0")
@@ -280,7 +280,7 @@ def setK(self, value):
280280
"""
281281
Sets the value of :py:attr:`k`.
282282
"""
283-
self._paramMap[self.k] = value
283+
self._set(k=value)
284284
return self
285285

286286
@since("2.0.0")
@@ -295,7 +295,7 @@ def setMinDivisibleClusterSize(self, value):
295295
"""
296296
Sets the value of :py:attr:`minDivisibleClusterSize`.
297297
"""
298-
self._paramMap[self.minDivisibleClusterSize] = value
298+
self._set(minDivisibleClusterSize=value)
299299
return self
300300

301301
@since("2.0.0")

python/pyspark/ml/evaluation.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pyspark import since
2121
from pyspark.ml.wrapper import JavaParams
22-
from pyspark.ml.param import Param, Params
22+
from pyspark.ml.param import Param, Params, TypeConverters
2323
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
2424
from pyspark.ml.util import keyword_only
2525
from pyspark.mllib.common import inherit_doc
@@ -125,7 +125,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
125125
"""
126126

127127
metricName = Param(Params._dummy(), "metricName",
128-
"metric name in evaluation (areaUnderROC|areaUnderPR)")
128+
"metric name in evaluation (areaUnderROC|areaUnderPR)",
129+
typeConverter=TypeConverters.toString)
129130

130131
@keyword_only
131132
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
@@ -147,7 +148,7 @@ def setMetricName(self, value):
147148
"""
148149
Sets the value of :py:attr:`metricName`.
149150
"""
150-
self._paramMap[self.metricName] = value
151+
self._set(metricName=value)
151152
return self
152153

153154
@since("1.4.0")
@@ -194,7 +195,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
194195
# when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
195196
# we take and output the negative of this metric.
196197
metricName = Param(Params._dummy(), "metricName",
197-
"metric name in evaluation (mse|rmse|r2|mae)")
198+
"metric name in evaluation (mse|rmse|r2|mae)",
199+
typeConverter=TypeConverters.toString)
198200

199201
@keyword_only
200202
def __init__(self, predictionCol="prediction", labelCol="label",
@@ -216,7 +218,7 @@ def setMetricName(self, value):
216218
"""
217219
Sets the value of :py:attr:`metricName`.
218220
"""
219-
self._paramMap[self.metricName] = value
221+
self._set(metricName=value)
220222
return self
221223

222224
@since("1.4.0")
@@ -260,7 +262,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
260262
"""
261263
metricName = Param(Params._dummy(), "metricName",
262264
"metric name in evaluation "
263-
"(f1|precision|recall|weightedPrecision|weightedRecall)")
265+
"(f1|precision|recall|weightedPrecision|weightedRecall)",
266+
typeConverter=TypeConverters.toString)
264267

265268
@keyword_only
266269
def __init__(self, predictionCol="prediction", labelCol="label",
@@ -282,7 +285,7 @@ def setMetricName(self, value):
282285
"""
283286
Sets the value of :py:attr:`metricName`.
284287
"""
285-
self._paramMap[self.metricName] = value
288+
self._set(metricName=value)
286289
return self
287290

288291
@since("1.5.0")

0 commit comments

Comments
 (0)