Skip to content

Commit 969e8b3

Browse files
MechCodermengxr
authored andcommitted
[SPARK-9828] [PYSPARK] Mutable values should not be default arguments
Author: MechCoder <[email protected]> Closes #8110 from MechCoder/spark-9828. (cherry picked from commit ffa05c8) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent db71ea4 commit 969e8b3

File tree

8 files changed

+50
-21
lines changed

8 files changed

+50
-21
lines changed

python/pyspark/ml/evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _evaluate(self, dataset):
4545
"""
4646
raise NotImplementedError()
4747

48-
def evaluate(self, dataset, params={}):
48+
def evaluate(self, dataset, params=None):
4949
"""
5050
Evaluates the output with optional parameters.
5151
@@ -55,6 +55,8 @@ def evaluate(self, dataset, params={}):
5555
params
5656
:return: metric
5757
"""
58+
if params is None:
59+
params = dict()
5860
if isinstance(params, dict):
5961
if params:
6062
return self.copy(params)._evaluate(dataset)

python/pyspark/ml/param/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,16 @@ class Params(Identifiable):
6060

6161
__metaclass__ = ABCMeta
6262

63-
#: internal param map for user-supplied values param map
64-
_paramMap = {}
63+
def __init__(self):
64+
super(Params, self).__init__()
65+
#: internal param map for user-supplied values param map
66+
self._paramMap = {}
6567

66-
#: internal param map for default values
67-
_defaultParamMap = {}
68+
#: internal param map for default values
69+
self._defaultParamMap = {}
6870

69-
#: value returned by :py:func:`params`
70-
_params = None
71+
#: value returned by :py:func:`params`
72+
self._params = None
7173

7274
@property
7375
def params(self):
@@ -155,7 +157,7 @@ def getOrDefault(self, param):
155157
else:
156158
return self._defaultParamMap[param]
157159

158-
def extractParamMap(self, extra={}):
160+
def extractParamMap(self, extra=None):
159161
"""
160162
Extracts the embedded default param values and user-supplied
161163
values, and then merges them with extra values from input into
@@ -165,12 +167,14 @@ def extractParamMap(self, extra={}):
165167
:param extra: extra param values
166168
:return: merged param map
167169
"""
170+
if extra is None:
171+
extra = dict()
168172
paramMap = self._defaultParamMap.copy()
169173
paramMap.update(self._paramMap)
170174
paramMap.update(extra)
171175
return paramMap
172176

173-
def copy(self, extra={}):
177+
def copy(self, extra=None):
174178
"""
175179
Creates a copy of this instance with the same uid and some
176180
extra params. The default implementation creates a
@@ -181,6 +185,8 @@ def copy(self, extra={}):
181185
:param extra: Extra parameters to copy to the new instance
182186
:return: Copy of this instance
183187
"""
188+
if extra is None:
189+
extra = dict()
184190
that = copy.copy(self)
185191
that._paramMap = self.extractParamMap(extra)
186192
return that
@@ -233,14 +239,16 @@ def _setDefault(self, **kwargs):
233239
self._defaultParamMap[getattr(self, param)] = value
234240
return self
235241

236-
def _copyValues(self, to, extra={}):
242+
def _copyValues(self, to, extra=None):
237243
"""
238244
Copies param values from this instance to another instance for
239245
params shared by them.
240246
:param to: the target instance
241247
:param extra: extra params to be copied
242248
:return: the target instance with param values copied
243249
"""
250+
if extra is None:
251+
extra = dict()
244252
paramMap = self.extractParamMap(extra)
245253
for p in self.params:
246254
if p in paramMap and to.hasParam(p.name):

python/pyspark/ml/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class Pipeline(Estimator):
141141
@keyword_only
142142
def __init__(self, stages=None):
143143
"""
144-
__init__(self, stages=[])
144+
__init__(self, stages=None)
145145
"""
146146
if stages is None:
147147
stages = []
@@ -170,7 +170,7 @@ def getStages(self):
170170
@keyword_only
171171
def setParams(self, stages=None):
172172
"""
173-
setParams(self, stages=[])
173+
setParams(self, stages=None)
174174
Sets params for Pipeline.
175175
"""
176176
if stages is None:

python/pyspark/ml/tuning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def _fit(self, dataset):
227227
bestModel = est.fit(dataset, epm[bestIndex])
228228
return CrossValidatorModel(bestModel)
229229

230-
def copy(self, extra={}):
230+
def copy(self, extra=None):
231+
if extra is None:
232+
extra = dict()
231233
newCV = Params.copy(self, extra)
232234
if self.isSet(self.estimator):
233235
newCV.setEstimator(self.getEstimator().copy(extra))
@@ -250,7 +252,7 @@ def __init__(self, bestModel):
250252
def _transform(self, dataset):
251253
return self.bestModel.transform(dataset)
252254

253-
def copy(self, extra={}):
255+
def copy(self, extra=None):
254256
"""
255257
Creates a copy of this instance with a randomly generated uid
256258
and some extra params. This copies the underlying bestModel,
@@ -259,6 +261,8 @@ def copy(self, extra={}):
259261
:param extra: Extra parameters to copy to the new instance
260262
:return: Copy of this instance
261263
"""
264+
if extra is None:
265+
extra = dict()
262266
return CrossValidatorModel(self.bestModel.copy(extra))
263267

264268

python/pyspark/rdd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,13 +699,16 @@ def groupBy(self, f, numPartitions=None):
699699
return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
700700

701701
@ignore_unicode_prefix
702-
def pipe(self, command, env={}):
702+
def pipe(self, command, env=None):
703703
"""
704704
Return an RDD created by piping elements to a forked external process.
705705
706706
>>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
707707
[u'1', u'2', u'', u'3']
708708
"""
709+
if env is None:
710+
env = dict()
711+
709712
def func(iterator):
710713
pipe = Popen(
711714
shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)

python/pyspark/sql/readwriter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def parquet(self, *path):
168168

169169
@since(1.4)
170170
def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
171-
predicates=None, properties={}):
171+
predicates=None, properties=None):
172172
"""
173173
Construct a :class:`DataFrame` representing the database table accessible
174174
via JDBC URL `url` named `table` and connection `properties`.
@@ -194,6 +194,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
194194
should be included.
195195
:return: a DataFrame
196196
"""
197+
if properties is None:
198+
properties = dict()
197199
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
198200
for k in properties:
199201
jprop.setProperty(k, properties[k])
@@ -385,7 +387,7 @@ def parquet(self, path, mode=None, partitionBy=()):
385387
self._jwrite.parquet(path)
386388

387389
@since(1.4)
388-
def jdbc(self, url, table, mode=None, properties={}):
390+
def jdbc(self, url, table, mode=None, properties=None):
389391
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
390392
391393
.. note:: Don't create too many partitions in parallel on a large cluster;\
@@ -403,6 +405,8 @@ def jdbc(self, url, table, mode=None, properties={}):
403405
arbitrary string tag/value. Normally at least a
404406
"user" and "password" property should be included.
405407
"""
408+
if properties is None:
409+
properties = dict()
406410
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
407411
for k in properties:
408412
jprop.setProperty(k, properties[k])

python/pyspark/statcounter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
class StatCounter(object):
3232

33-
def __init__(self, values=[]):
33+
def __init__(self, values=None):
34+
if values is None:
35+
values = list()
3436
self.n = 0 # Running count of our values
3537
self.mu = 0.0 # Running mean of our values
3638
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)

python/pyspark/streaming/kafka.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def utf8_decoder(s):
3333
class KafkaUtils(object):
3434

3535
@staticmethod
36-
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
36+
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None,
3737
storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
3838
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
3939
"""
@@ -50,6 +50,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
5050
:param valueDecoder: A function used to decode value (default is utf8_decoder)
5151
:return: A DStream object
5252
"""
53+
if kafkaParams is None:
54+
kafkaParams = dict()
5355
kafkaParams.update({
5456
"zookeeper.connect": zkQuorum,
5557
"group.id": groupId,
@@ -75,7 +77,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
7577
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
7678

7779
@staticmethod
78-
def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
80+
def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
7981
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
8082
"""
8183
.. note:: Experimental
@@ -103,6 +105,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
103105
:param valueDecoder: A function used to decode value (default is utf8_decoder).
104106
:return: A DStream object
105107
"""
108+
if fromOffsets is None:
109+
fromOffsets = dict()
106110
if not isinstance(topics, list):
107111
raise TypeError("topics should be list")
108112
if not isinstance(kafkaParams, dict):
@@ -126,7 +130,7 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
126130
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
127131

128132
@staticmethod
129-
def createRDD(sc, kafkaParams, offsetRanges, leaders={},
133+
def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
130134
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
131135
"""
132136
.. note:: Experimental
@@ -142,6 +146,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
142146
:param valueDecoder: A function used to decode value (default is utf8_decoder)
143147
:return: A RDD object
144148
"""
149+
if leaders is None:
150+
leaders = dict()
145151
if not isinstance(kafkaParams, dict):
146152
raise TypeError("kafkaParams should be dict")
147153
if not isinstance(offsetRanges, list):

0 commit comments

Comments
 (0)