Skip to content

Commit 40a5db5

Browse files
nonglidavies
authored andcommitted
[SPARK-11410] [PYSPARK] Add python bindings for repartition and sortW…
…ithinPartitions. Author: Nong Li <[email protected]> Closes #9504 from nongli/spark-11410. (cherry picked from commit 1ab72b0) Signed-off-by: Davies Liu <[email protected]>
1 parent 02748c9 commit 40a5db5

File tree

1 file changed

+101
-16
lines changed

1 file changed

+101
-16
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 101 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,67 @@ def repartition(self, numPartitions):
422422
"""
423423
return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
424424

425+
@since(1.3)
426+
def repartition(self, numPartitions, *cols):
427+
"""
428+
Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
429+
resulting DataFrame is hash partitioned.
430+
431+
``numPartitions`` can be an int to specify the target number of partitions or a Column.
432+
If it is a Column, it will be used as the first partitioning column. If not specified,
433+
the default number of partitions is used.
434+
435+
.. versionchanged:: 1.6
436+
Added optional arguments to specify the partitioning columns. Also made numPartitions
437+
optional if partitioning columns are specified.
438+
439+
>>> df.repartition(10).rdd.getNumPartitions()
440+
10
441+
>>> data = df.unionAll(df).repartition("age")
442+
>>> data.show()
443+
+---+-----+
444+
|age| name|
445+
+---+-----+
446+
| 2|Alice|
447+
| 2|Alice|
448+
| 5| Bob|
449+
| 5| Bob|
450+
+---+-----+
451+
>>> data = data.repartition(7, "age")
452+
>>> data.show()
453+
+---+-----+
454+
|age| name|
455+
+---+-----+
456+
| 5| Bob|
457+
| 5| Bob|
458+
| 2|Alice|
459+
| 2|Alice|
460+
+---+-----+
461+
>>> data.rdd.getNumPartitions()
462+
7
463+
>>> data = data.repartition("name", "age")
464+
>>> data.show()
465+
+---+-----+
466+
|age| name|
467+
+---+-----+
468+
| 5| Bob|
469+
| 5| Bob|
470+
| 2|Alice|
471+
| 2|Alice|
472+
+---+-----+
473+
"""
474+
if isinstance(numPartitions, int):
475+
if len(cols) == 0:
476+
return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
477+
else:
478+
return DataFrame(
479+
self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx)
480+
elif isinstance(numPartitions, (basestring, Column)):
481+
cols = (numPartitions, ) + cols
482+
return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx)
483+
else:
484+
raise TypeError("numPartitions should be an int or Column")
485+
425486
@since(1.3)
426487
def distinct(self):
427488
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
@@ -589,6 +650,26 @@ def join(self, other, on=None, how=None):
589650
jdf = self._jdf.join(other._jdf, on._jc, how)
590651
return DataFrame(jdf, self.sql_ctx)
591652

653+
@since(1.6)
654+
def sortWithinPartitions(self, *cols, **kwargs):
655+
"""Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).
656+
657+
:param cols: list of :class:`Column` or column names to sort by.
658+
:param ascending: boolean or list of boolean (default True).
659+
Sort ascending vs. descending. Specify list for multiple sort orders.
660+
If a list is specified, length of the list must equal length of the `cols`.
661+
662+
>>> df.sortWithinPartitions("age", ascending=False).show()
663+
+---+-----+
664+
|age| name|
665+
+---+-----+
666+
| 2|Alice|
667+
| 5| Bob|
668+
+---+-----+
669+
"""
670+
jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
671+
return DataFrame(jdf, self.sql_ctx)
672+
592673
@ignore_unicode_prefix
593674
@since(1.3)
594675
def sort(self, *cols, **kwargs):
@@ -613,22 +694,7 @@ def sort(self, *cols, **kwargs):
613694
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
614695
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
615696
"""
616-
if not cols:
617-
raise ValueError("should sort by at least one column")
618-
if len(cols) == 1 and isinstance(cols[0], list):
619-
cols = cols[0]
620-
jcols = [_to_java_column(c) for c in cols]
621-
ascending = kwargs.get('ascending', True)
622-
if isinstance(ascending, (bool, int)):
623-
if not ascending:
624-
jcols = [jc.desc() for jc in jcols]
625-
elif isinstance(ascending, list):
626-
jcols = [jc if asc else jc.desc()
627-
for asc, jc in zip(ascending, jcols)]
628-
else:
629-
raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending))
630-
631-
jdf = self._jdf.sort(self._jseq(jcols))
697+
jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
632698
return DataFrame(jdf, self.sql_ctx)
633699

634700
orderBy = sort
@@ -650,6 +716,25 @@ def _jcols(self, *cols):
650716
cols = cols[0]
651717
return self._jseq(cols, _to_java_column)
652718

719+
def _sort_cols(self, cols, kwargs):
720+
""" Return a JVM Seq of Columns that describes the sort order
721+
"""
722+
if not cols:
723+
raise ValueError("should sort by at least one column")
724+
if len(cols) == 1 and isinstance(cols[0], list):
725+
cols = cols[0]
726+
jcols = [_to_java_column(c) for c in cols]
727+
ascending = kwargs.get('ascending', True)
728+
if isinstance(ascending, (bool, int)):
729+
if not ascending:
730+
jcols = [jc.desc() for jc in jcols]
731+
elif isinstance(ascending, list):
732+
jcols = [jc if asc else jc.desc()
733+
for asc, jc in zip(ascending, jcols)]
734+
else:
735+
raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending))
736+
return self._jseq(jcols)
737+
653738
@since("1.3.1")
654739
def describe(self, *cols):
655740
"""Computes statistics for numeric columns.

0 commit comments

Comments
 (0)