@@ -1227,23 +1227,39 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
12271227
12281228 # Transferring O(n) objects to Java is too expensive. Instead, we'll
12291229 # form the hash buckets in Python, transferring O(numPartitions) objects
1230- # to Java. Each object is a (splitNumber, [objects]) pair.
1230+ # to Java. Each object is a (splitNumber, [objects]) pair.
1231+ # In order to void too huge objects, the objects are grouped into chunks.
12311232 outputSerializer = self .ctx ._unbatched_serializer
12321233
1233- limit = _parse_memory (self .ctx ._conf .get ("spark.python.worker.memory" )
1234- or "512m" )
1234+ limit = ( _parse_memory (self .ctx ._conf .get ("spark.python.worker.memory" )
1235+ or "512m" ) / 2 )
12351236 def add_shuffle_key (split , iterator ):
12361237
12371238 buckets = defaultdict (list )
1238- c , batch = 0 , 1000
1239+ c , batch = 0 , min (10 * numPartitions , 1000 )
1240+
12391241 for (k , v ) in iterator :
12401242 buckets [partitionFunc (k ) % numPartitions ].append ((k , v ))
12411243 c += 1
1242- if c % batch == 0 and get_used_memory () > limit :
1244+
1245+ # check used memory and avg size of chunk of objects
1246+ if (c % 1000 == 0 and get_used_memory () > limit
1247+ or c > batch ):
1248+ n , size = len (buckets ), 0
12431249 for split in buckets .keys ():
12441250 yield pack_long (split )
1245- yield outputSerializer .dumps (buckets [split ])
1251+ d = outputSerializer .dumps (buckets [split ])
12461252 del buckets [split ]
1253+ yield d
1254+ size += len (d )
1255+
1256+ avg = (size / n ) >> 20
1257+ # let 1M < avg < 10M
1258+ if avg < 1 :
1259+ batch *= 1.5
1260+ elif avg > 10 :
1261+ batch = max (batch / 1.5 , 1 )
1262+ c = 0
12471263
12481264 for (split , items ) in buckets .iteritems ():
12491265 yield pack_long (split )
0 commit comments