Skip to content

Commit d92568a

Browse files
MechCodermengxr
authored andcommitted
[SPARK-9828] [PYSPARK] Mutable values should not be default arguments
Author: MechCoder <[email protected]> Closes #8110 from MechCoder/spark-9828. (cherry picked from commit ffa05c8) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent b284213 commit d92568a

File tree

8 files changed

+50
-21
lines changed

8 files changed

+50
-21
lines changed

python/pyspark/ml/evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _evaluate(self, dataset):
4646
"""
4747
raise NotImplementedError()
4848

49-
def evaluate(self, dataset, params={}):
49+
def evaluate(self, dataset, params=None):
5050
"""
5151
Evaluates the output with optional parameters.
5252
@@ -56,6 +56,8 @@ def evaluate(self, dataset, params={}):
5656
params
5757
:return: metric
5858
"""
59+
if params is None:
60+
params = dict()
5961
if isinstance(params, dict):
6062
if params:
6163
return self.copy(params)._evaluate(dataset)

python/pyspark/ml/param/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,16 @@ class Params(Identifiable):
6060

6161
__metaclass__ = ABCMeta
6262

63-
#: internal param map for user-supplied values param map
64-
_paramMap = {}
63+
def __init__(self):
64+
super(Params, self).__init__()
65+
#: internal param map for user-supplied values param map
66+
self._paramMap = {}
6567

66-
#: internal param map for default values
67-
_defaultParamMap = {}
68+
#: internal param map for default values
69+
self._defaultParamMap = {}
6870

69-
#: value returned by :py:func:`params`
70-
_params = None
71+
#: value returned by :py:func:`params`
72+
self._params = None
7173

7274
@property
7375
def params(self):
@@ -155,7 +157,7 @@ def getOrDefault(self, param):
155157
else:
156158
return self._defaultParamMap[param]
157159

158-
def extractParamMap(self, extra={}):
160+
def extractParamMap(self, extra=None):
159161
"""
160162
Extracts the embedded default param values and user-supplied
161163
values, and then merges them with extra values from input into
@@ -165,12 +167,14 @@ def extractParamMap(self, extra={}):
165167
:param extra: extra param values
166168
:return: merged param map
167169
"""
170+
if extra is None:
171+
extra = dict()
168172
paramMap = self._defaultParamMap.copy()
169173
paramMap.update(self._paramMap)
170174
paramMap.update(extra)
171175
return paramMap
172176

173-
def copy(self, extra={}):
177+
def copy(self, extra=None):
174178
"""
175179
Creates a copy of this instance with the same uid and some
176180
extra params. The default implementation creates a
@@ -181,6 +185,8 @@ def copy(self, extra={}):
181185
:param extra: Extra parameters to copy to the new instance
182186
:return: Copy of this instance
183187
"""
188+
if extra is None:
189+
extra = dict()
184190
that = copy.copy(self)
185191
that._paramMap = self.extractParamMap(extra)
186192
return that
@@ -233,14 +239,16 @@ def _setDefault(self, **kwargs):
233239
self._defaultParamMap[getattr(self, param)] = value
234240
return self
235241

236-
def _copyValues(self, to, extra={}):
242+
def _copyValues(self, to, extra=None):
237243
"""
238244
Copies param values from this instance to another instance for
239245
params shared by them.
240246
:param to: the target instance
241247
:param extra: extra params to be copied
242248
:return: the target instance with param values copied
243249
"""
250+
if extra is None:
251+
extra = dict()
244252
paramMap = self.extractParamMap(extra)
245253
for p in self.params:
246254
if p in paramMap and to.hasParam(p.name):

python/pyspark/ml/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class Pipeline(Estimator):
141141
@keyword_only
142142
def __init__(self, stages=None):
143143
"""
144-
__init__(self, stages=[])
144+
__init__(self, stages=None)
145145
"""
146146
if stages is None:
147147
stages = []
@@ -170,7 +170,7 @@ def getStages(self):
170170
@keyword_only
171171
def setParams(self, stages=None):
172172
"""
173-
setParams(self, stages=[])
173+
setParams(self, stages=None)
174174
Sets params for Pipeline.
175175
"""
176176
if stages is None:

python/pyspark/ml/tuning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def _fit(self, dataset):
227227
bestModel = est.fit(dataset, epm[bestIndex])
228228
return CrossValidatorModel(bestModel)
229229

230-
def copy(self, extra={}):
230+
def copy(self, extra=None):
231+
if extra is None:
232+
extra = dict()
231233
newCV = Params.copy(self, extra)
232234
if self.isSet(self.estimator):
233235
newCV.setEstimator(self.getEstimator().copy(extra))
@@ -250,7 +252,7 @@ def __init__(self, bestModel):
250252
def _transform(self, dataset):
251253
return self.bestModel.transform(dataset)
252254

253-
def copy(self, extra={}):
255+
def copy(self, extra=None):
254256
"""
255257
Creates a copy of this instance with a randomly generated uid
256258
and some extra params. This copies the underlying bestModel,
@@ -259,6 +261,8 @@ def copy(self, extra={}):
259261
:param extra: Extra parameters to copy to the new instance
260262
:return: Copy of this instance
261263
"""
264+
if extra is None:
265+
extra = dict()
262266
return CrossValidatorModel(self.bestModel.copy(extra))
263267

264268

python/pyspark/rdd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def groupBy(self, f, numPartitions=None):
700700
return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
701701

702702
@ignore_unicode_prefix
703-
def pipe(self, command, env={}, checkCode=False):
703+
def pipe(self, command, env=None, checkCode=False):
704704
"""
705705
Return an RDD created by piping elements to a forked external process.
706706
@@ -709,6 +709,9 @@ def pipe(self, command, env={}, checkCode=False):
709709
710710
:param checkCode: whether or not to check the return value of the shell command.
711711
"""
712+
if env is None:
713+
env = dict()
714+
712715
def func(iterator):
713716
pipe = Popen(
714717
shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)

python/pyspark/sql/readwriter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def orc(self, path):
182182

183183
@since(1.4)
184184
def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
185-
predicates=None, properties={}):
185+
predicates=None, properties=None):
186186
"""
187187
Construct a :class:`DataFrame` representing the database table accessible
188188
via JDBC URL `url` named `table` and connection `properties`.
@@ -208,6 +208,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
208208
should be included.
209209
:return: a DataFrame
210210
"""
211+
if properties is None:
212+
properties = dict()
211213
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
212214
for k in properties:
213215
jprop.setProperty(k, properties[k])
@@ -427,7 +429,7 @@ def orc(self, path, mode=None, partitionBy=None):
427429
self._jwrite.orc(path)
428430

429431
@since(1.4)
430-
def jdbc(self, url, table, mode=None, properties={}):
432+
def jdbc(self, url, table, mode=None, properties=None):
431433
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
432434
433435
.. note:: Don't create too many partitions in parallel on a large cluster;\
@@ -445,6 +447,8 @@ def jdbc(self, url, table, mode=None, properties={}):
445447
arbitrary string tag/value. Normally at least a
446448
"user" and "password" property should be included.
447449
"""
450+
if properties is None:
451+
properties = dict()
448452
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
449453
for k in properties:
450454
jprop.setProperty(k, properties[k])

python/pyspark/statcounter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
class StatCounter(object):
3232

33-
def __init__(self, values=[]):
33+
def __init__(self, values=None):
34+
if values is None:
35+
values = list()
3436
self.n = 0 # Running count of our values
3537
self.mu = 0.0 # Running mean of our values
3638
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)

python/pyspark/streaming/kafka.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def utf8_decoder(s):
3535
class KafkaUtils(object):
3636

3737
@staticmethod
38-
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
38+
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None,
3939
storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
4040
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
4141
"""
@@ -52,6 +52,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
5252
:param valueDecoder: A function used to decode value (default is utf8_decoder)
5353
:return: A DStream object
5454
"""
55+
if kafkaParams is None:
56+
kafkaParams = dict()
5557
kafkaParams.update({
5658
"zookeeper.connect": zkQuorum,
5759
"group.id": groupId,
@@ -77,7 +79,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
7779
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
7880

7981
@staticmethod
80-
def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
82+
def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
8183
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
8284
"""
8385
.. note:: Experimental
@@ -105,6 +107,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
105107
:param valueDecoder: A function used to decode value (default is utf8_decoder).
106108
:return: A DStream object
107109
"""
110+
if fromOffsets is None:
111+
fromOffsets = dict()
108112
if not isinstance(topics, list):
109113
raise TypeError("topics should be list")
110114
if not isinstance(kafkaParams, dict):
@@ -129,7 +133,7 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
129133
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
130134

131135
@staticmethod
132-
def createRDD(sc, kafkaParams, offsetRanges, leaders={},
136+
def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
133137
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
134138
"""
135139
.. note:: Experimental
@@ -145,6 +149,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
145149
:param valueDecoder: A function used to decode value (default is utf8_decoder)
146150
:return: A RDD object
147151
"""
152+
if leaders is None:
153+
leaders = dict()
148154
if not isinstance(kafkaParams, dict):
149155
raise TypeError("kafkaParams should be dict")
150156
if not isinstance(offsetRanges, list):

0 commit comments

Comments
 (0)