@@ -316,6 +316,34 @@ def sample(self, withReplacement, fraction, seed=None):
316316 assert fraction >= 0.0 , "Negative fraction value: %s" % fraction
317317 return self .mapPartitionsWithIndex (RDDSampler (withReplacement , fraction , seed ).func , True )
318318
319+ def randomSplit (self , weights , seed = None ):
320+ """
321+ Randomly splits this RDD with the provided weights.
322+
323+ :param weights: weights for splits, will be normalized if they don't sum to 1
324+ :param seed: random seed
325+ :return: split RDDs in an list
326+
327+ >>> rdd = sc.parallelize(range(10), 1)
328+ >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
329+ >>> rdd1.collect()
330+ [3, 6]
331+ >>> rdd2.collect()
332+ [0, 5, 7]
333+ >>> rdd3.collect()
334+ [1, 2, 4, 8, 9]
335+ """
336+ ser = BatchedSerializer (PickleSerializer (), 1 )
337+ rdd = self ._reserialize (ser )
338+ jweights = ListConverter ().convert ([float (w ) for w in weights ],
339+ self .ctx ._gateway ._gateway_client )
340+ jweights = self .ctx ._jvm .PythonRDD .listToArrayDouble (jweights )
341+ if seed is None :
342+ jrdds = rdd ._jrdd .randomSplit (jweights )
343+ else :
344+ jrdds = rdd ._jrdd .randomSplit (jweights , seed )
345+ return [RDD (jrdd , self .ctx , ser ) for jrdd in jrdds ]
346+
319347 # this is ported from scala/spark/RDD.scala
320348 def takeSample (self , withReplacement , num , seed = None ):
321349 """
0 commit comments