2525from pyspark .storagelevel import StorageLevel
2626from pyspark .resultiterable import ResultIterable
2727from pyspark .streaming .util import rddToFileName , RDDFunction
28+ from pyspark .rdd import portable_hash , _parse_memory
2829from pyspark .traceback_utils import SCCallSiteSync
2930
3031from py4j .java_collections import ListConverter , MapConverter
@@ -40,6 +41,7 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
4041 self ._jrdd_deserializer = jrdd_deserializer
4142 self .is_cached = False
4243 self .is_checkpointed = False
44+ self ._partitionFunc = None
4345
4446 def context (self ):
4547 """
@@ -161,32 +163,71 @@ def _mergeCombiners(iterator):
161163
162164 return shuffled .mapPartitions (_mergeCombiners )
163165
164- def partitionBy (self , numPartitions , partitionFunc = None ):
166+ def partitionBy (self , numPartitions , partitionFunc = portable_hash ):
165167 """
166168 Return a copy of the DStream partitioned using the specified partitioner.
167169 """
168170 if numPartitions is None :
169171 numPartitions = self .ctx ._defaultReducePartitions ()
170172
171- if partitionFunc is None :
172- partitionFunc = lambda x : 0 if x is None else hash (x )
173-
174173 # Transferring O(n) objects to Java is too expensive. Instead, we'll
175174 # form the hash buckets in Python, transferring O(numPartitions) objects
176175 # to Java. Each object is a (splitNumber, [objects]) pair.
176+
177177 outputSerializer = self .ctx ._unbatched_serializer
178+ #
179+ # def add_shuffle_key(split, iterator):
180+ # buckets = defaultdict(list)
181+ #
182+ # for (k, v) in iterator:
183+ # buckets[partitionFunc(k) % numPartitions].append((k, v))
184+ # for (split, items) in buckets.iteritems():
185+ # yield pack_long(split)
186+ # yield outputSerializer.dumps(items)
187+ # keyed = PipelinedDStream(self, add_shuffle_key)
188+
189+ limit = (_parse_memory (self .ctx ._conf .get (
190+ "spark.python.worker.memory" , "512m" )) / 2 )
178191
179192 def add_shuffle_key (split , iterator ):
193+
180194 buckets = defaultdict (list )
195+ c , batch = 0 , min (10 * numPartitions , 1000 )
181196
182- for ( k , v ) in iterator :
197+ for k , v in iterator :
183198 buckets [partitionFunc (k ) % numPartitions ].append ((k , v ))
184- for (split , items ) in buckets .iteritems ():
199+ c += 1
200+
201+ # check used memory and avg size of chunk of objects
202+ if (c % 1000 == 0 and get_used_memory () > limit
203+ or c > batch ):
204+ n , size = len (buckets ), 0
205+ for split in buckets .keys ():
206+ yield pack_long (split )
207+ d = outputSerializer .dumps (buckets [split ])
208+ del buckets [split ]
209+ yield d
210+ size += len (d )
211+
212+ avg = (size / n ) >> 20
213+ # let 1M < avg < 10M
214+ if avg < 1 :
215+ batch *= 1.5
216+ elif avg > 10 :
217+ batch = max (batch / 1.5 , 1 )
218+ c = 0
219+
220+ for split , items in buckets .iteritems ():
185221 yield pack_long (split )
186222 yield outputSerializer .dumps (items )
187- keyed = PipelinedDStream (self , add_shuffle_key )
223+
224+ keyed = self ._mapPartitionsWithIndex (add_shuffle_key )
225+
226+
227+
228+
188229 keyed ._bypass_serializer = True
189- with SCCallSiteSync (self .context ) as css :
230+ with SCCallSiteSync (self .ctx ) as css :
190231 partitioner = self .ctx ._jvm .PythonPartitioner (numPartitions ,
191232 id (partitionFunc ))
192233 jdstream = self .ctx ._jvm .PythonPairwiseDStream (keyed ._jdstream .dstream (),
@@ -428,6 +469,10 @@ def get_output(rdd, time):
428469
429470
430471class PipelinedDStream (DStream ):
472+ """
473+ Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed,
474+ this code should be changed in the same way.
475+ """
431476 def __init__ (self , prev , func , preservesPartitioning = False ):
432477 if not isinstance (prev , PipelinedDStream ) or not prev ._is_pipelinable ():
433478 # This transformation is the first in its stage:
@@ -453,19 +498,22 @@ def pipeline_func(split, iterator):
453498 self ._jdstream_val = None
454499 self ._jrdd_deserializer = self .ctx .serializer
455500 self ._bypass_serializer = False
501+ self ._partitionFunc = prev ._partitionFunc if self .preservesPartitioning else None
456502
457503 @property
458504 def _jdstream (self ):
459505 if self ._jdstream_val :
460506 return self ._jdstream_val
461507 if self ._bypass_serializer :
462- serializer = NoOpSerializer ()
463- else :
464- serializer = self .ctx .serializer
465-
466- command = (self .func , self ._prev_jrdd_deserializer , serializer )
467- ser = CompressedSerializer (CloudPickleSerializer ())
508+ self .jrdd_deserializer = NoOpSerializer ()
509+ command = (self .func , self ._prev_jrdd_deserializer ,
510+ self ._jrdd_deserializer )
511+ # the serialized command will be compressed by broadcast
512+ ser = CloudPickleSerializer ()
468513 pickled_command = ser .dumps (command )
514+ if pickled_command > (1 << 20 ): # 1M
515+ broadcast = self .ctx .broadcast (pickled_command )
516+ pickled_command = ser .dumps (broadcast )
469517 broadcast_vars = ListConverter ().convert (
470518 [x ._jbroadcast for x in self .ctx ._pickled_broadcast_vars ],
471519 self .ctx ._gateway ._gateway_client )
0 commit comments