@@ -422,6 +422,67 @@ def repartition(self, numPartitions):
422422 """
423423 return DataFrame (self ._jdf .repartition (numPartitions ), self .sql_ctx )
424424
425+ @since (1.3 )
426+ def repartition (self , numPartitions , * cols ):
427+ """
428+ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
429+ resulting DataFrame is hash partitioned.
430+
431+ ``numPartitions`` can be an int to specify the target number of partitions or a Column.
432+ If it is a Column, it will be used as the first partitioning column. If not specified,
433+ the default number of partitions is used.
434+
435+ .. versionchanged:: 1.6
436+ Added optional arguments to specify the partitioning columns. Also made numPartitions
437+ optional if partitioning columns are specified.
438+
439+ >>> df.repartition(10).rdd.getNumPartitions()
440+ 10
441+ >>> data = df.unionAll(df).repartition("age")
442+ >>> data.show()
443+ +---+-----+
444+ |age| name|
445+ +---+-----+
446+ | 2|Alice|
447+ | 2|Alice|
448+ | 5| Bob|
449+ | 5| Bob|
450+ +---+-----+
451+ >>> data = data.repartition(7, "age")
452+ >>> data.show()
453+ +---+-----+
454+ |age| name|
455+ +---+-----+
456+ | 5| Bob|
457+ | 5| Bob|
458+ | 2|Alice|
459+ | 2|Alice|
460+ +---+-----+
461+ >>> data.rdd.getNumPartitions()
462+ 7
463+ >>> data = data.repartition("name", "age")
464+ >>> data.show()
465+ +---+-----+
466+ |age| name|
467+ +---+-----+
468+ | 5| Bob|
469+ | 5| Bob|
470+ | 2|Alice|
471+ | 2|Alice|
472+ +---+-----+
473+ """
474+ if isinstance (numPartitions , int ):
475+ if len (cols ) == 0 :
476+ return DataFrame (self ._jdf .repartition (numPartitions ), self .sql_ctx )
477+ else :
478+ return DataFrame (
479+ self ._jdf .repartition (numPartitions , self ._jcols (* cols )), self .sql_ctx )
480+ elif isinstance (numPartitions , (basestring , Column )):
481+ cols = (numPartitions , ) + cols
482+ return DataFrame (self ._jdf .repartition (self ._jcols (* cols )), self .sql_ctx )
483+ else :
484+ raise TypeError ("numPartitions should be an int or Column" )
485+
425486 @since (1.3 )
426487 def distinct (self ):
427488 """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
@@ -589,6 +650,26 @@ def join(self, other, on=None, how=None):
589650 jdf = self ._jdf .join (other ._jdf , on ._jc , how )
590651 return DataFrame (jdf , self .sql_ctx )
591652
653+ @since (1.6 )
654+ def sortWithinPartitions (self , * cols , ** kwargs ):
655+ """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).
656+
657+ :param cols: list of :class:`Column` or column names to sort by.
658+ :param ascending: boolean or list of boolean (default True).
659+ Sort ascending vs. descending. Specify list for multiple sort orders.
660+ If a list is specified, length of the list must equal length of the `cols`.
661+
662+ >>> df.sortWithinPartitions("age", ascending=False).show()
663+ +---+-----+
664+ |age| name|
665+ +---+-----+
666+ | 2|Alice|
667+ | 5| Bob|
668+ +---+-----+
669+ """
670+ jdf = self ._jdf .sortWithinPartitions (self ._sort_cols (cols , kwargs ))
671+ return DataFrame (jdf , self .sql_ctx )
672+
592673 @ignore_unicode_prefix
593674 @since (1.3 )
594675 def sort (self , * cols , ** kwargs ):
@@ -613,22 +694,7 @@ def sort(self, *cols, **kwargs):
613694 >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
614695 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
615696 """
616- if not cols :
617- raise ValueError ("should sort by at least one column" )
618- if len (cols ) == 1 and isinstance (cols [0 ], list ):
619- cols = cols [0 ]
620- jcols = [_to_java_column (c ) for c in cols ]
621- ascending = kwargs .get ('ascending' , True )
622- if isinstance (ascending , (bool , int )):
623- if not ascending :
624- jcols = [jc .desc () for jc in jcols ]
625- elif isinstance (ascending , list ):
626- jcols = [jc if asc else jc .desc ()
627- for asc , jc in zip (ascending , jcols )]
628- else :
629- raise TypeError ("ascending can only be boolean or list, but got %s" % type (ascending ))
630-
631- jdf = self ._jdf .sort (self ._jseq (jcols ))
697+ jdf = self ._jdf .sort (self ._sort_cols (cols , kwargs ))
632698 return DataFrame (jdf , self .sql_ctx )
633699
634700 orderBy = sort
@@ -650,6 +716,25 @@ def _jcols(self, *cols):
650716 cols = cols [0 ]
651717 return self ._jseq (cols , _to_java_column )
652718
719+ def _sort_cols (self , cols , kwargs ):
720+ """ Return a JVM Seq of Columns that describes the sort order
721+ """
722+ if not cols :
723+ raise ValueError ("should sort by at least one column" )
724+ if len (cols ) == 1 and isinstance (cols [0 ], list ):
725+ cols = cols [0 ]
726+ jcols = [_to_java_column (c ) for c in cols ]
727+ ascending = kwargs .get ('ascending' , True )
728+ if isinstance (ascending , (bool , int )):
729+ if not ascending :
730+ jcols = [jc .desc () for jc in jcols ]
731+ elif isinstance (ascending , list ):
732+ jcols = [jc if asc else jc .desc ()
733+ for asc , jc in zip (ascending , jcols )]
734+ else :
735+ raise TypeError ("ascending can only be boolean or list, but got %s" % type (ascending ))
736+ return self ._jseq (jcols )
737+
653738 @since ("1.3.1" )
654739 def describe (self , * cols ):
655740 """Computes statistics for numeric columns.
0 commit comments