|
42 | 42 | from pyspark.rddsampler import RDDSampler |
43 | 43 | from pyspark.storagelevel import StorageLevel |
44 | 44 | from pyspark.resultiterable import ResultIterable |
45 | | -from pyspark.shuffle import MapMerger, ExternalHashMapMerger |
| 45 | +from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ |
| 46 | + get_used_memory |
46 | 47 |
|
47 | 48 | from py4j.java_collections import ListConverter, MapConverter |
48 | 49 |
|
@@ -171,18 +172,20 @@ def _replaceRoot(self, value): |
171 | 172 |
|
172 | 173 | def _parse_memory(s): |
173 | 174 | """ |
174 | | - It returns a number in MB |
| 175 | + Parse a memory string in the format supported by Java (e.g. 1g, 200m) and |
| 176 | + return the value in MB |
175 | 177 |
|
176 | 178 | >>> _parse_memory("256m") |
177 | 179 | 256 |
178 | 180 | >>> _parse_memory("2g") |
179 | 181 | 2048 |
180 | 182 | """ |
181 | | - units = {'g': 1024, 'm': 1, 't': 1<<20, 'k':1.0/1024} |
| 183 | + units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} |
182 | 184 | if s[-1] not in units: |
183 | 185 | raise ValueError("invalid format: " + s) |
184 | 186 | return int(float(s[:-1]) * units[s[-1].lower()]) |
185 | 187 |
|
| 188 | + |
186 | 189 | class RDD(object): |
187 | 190 |
|
188 | 191 | """ |
@@ -1198,15 +1201,25 @@ def partitionBy(self, numPartitions, partitionFunc=None): |
1198 | 1201 | # to Java. Each object is a (splitNumber, [objects]) pair. |
1199 | 1202 | outputSerializer = self.ctx._unbatched_serializer |
1200 | 1203 |
|
| 1204 | + limit = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") |
| 1205 | + or "512m") |
1201 | 1206 | def add_shuffle_key(split, iterator): |
1202 | 1207 |
|
1203 | 1208 | buckets = defaultdict(list) |
1204 | | - |
| 1209 | + c, batch = 0, 1000 |
1205 | 1210 | for (k, v) in iterator: |
1206 | 1211 | buckets[partitionFunc(k) % numPartitions].append((k, v)) |
| 1212 | + c += 1 |
| 1213 | + if c % batch == 0 and get_used_memory() > limit: |
| 1214 | + for split in buckets.keys(): |
| 1215 | + yield pack_long(split) |
| 1216 | + yield outputSerializer.dumps(buckets[split]) |
| 1217 | + del buckets[split] |
| 1218 | + |
1207 | 1219 | for (split, items) in buckets.iteritems(): |
1208 | 1220 | yield pack_long(split) |
1209 | 1221 | yield outputSerializer.dumps(items) |
| 1222 | + |
1210 | 1223 | keyed = PipelinedRDD(self, add_shuffle_key) |
1211 | 1224 | keyed._bypass_serializer = True |
1212 | 1225 | with _JavaStackTrace(self.context) as st: |
@@ -1251,27 +1264,26 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, |
1251 | 1264 | if numPartitions is None: |
1252 | 1265 | numPartitions = self._defaultReducePartitions() |
1253 | 1266 |
|
| 1267 | + serializer = self.ctx.serializer |
| 1268 | + spill = (self.ctx._conf.get("spark.shuffle.spill") or 'True').lower() == 'true' |
| 1269 | + memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m") |
| 1270 | + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) |
| 1271 | + |
1254 | 1272 | def combineLocally(iterator): |
1255 | | - combiners = {} |
1256 | | - for x in iterator: |
1257 | | - (k, v) = x |
1258 | | - if k not in combiners: |
1259 | | - combiners[k] = createCombiner(v) |
1260 | | - else: |
1261 | | - combiners[k] = mergeValue(combiners[k], v) |
1262 | | - return combiners.iteritems() |
| 1273 | + merger = ExternalMerger(agg, memory, serializer) \ |
| 1274 | + if spill else InMemoryMerger(agg) |
| 1275 | + merger.combine(iterator) |
| 1276 | + return merger.iteritems() |
| 1277 | + |
1263 | 1278 | locally_combined = self.mapPartitions(combineLocally) |
1264 | 1279 | shuffled = locally_combined.partitionBy(numPartitions) |
1265 | 1280 |
|
1266 | | - serializer = self.ctx.serializer |
1267 | | - spill = ((self.ctx._conf.get("spark.shuffle.spill") or 'True').lower() |
1268 | | - in ('true', '1', 'yes')) |
1269 | | - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m") |
1270 | 1281 | def _mergeCombiners(iterator): |
1271 | | - merger = ExternalHashMapMerger(mergeCombiners, memory, serializer)\ |
1272 | | - if spill else MapMerger(mergeCombiners) |
| 1282 | + merger = ExternalMerger(agg, memory, serializer) \ |
| 1283 | + if spill else InMemoryMerger(agg) |
1273 | 1284 | merger.merge(iterator) |
1274 | 1285 | return merger.iteritems() |
| 1286 | + |
1275 | 1287 | return shuffled.mapPartitions(_mergeCombiners) |
1276 | 1288 |
|
1277 | 1289 | def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): |
|
0 commit comments