@@ -485,30 +485,60 @@ def join(self, other, joinExprs=None, joinType=None):
485485 return DataFrame (jdf , self .sql_ctx )
486486
487487 @ignore_unicode_prefix
488- def sort (self , * cols ):
488+ def sort (self , * cols , ** kwargs ):
489489 """Returns a new :class:`DataFrame` sorted by the specified column(s).
490490
491- :param cols: list of :class:`Column` to sort by.
491+ :param cols: list of :class:`Column` or column names to sort by.
492+ :param ascending: sort by ascending order or not, could be bool, int
493+ or list of bool, int (default: True).
492494
493495 >>> df.sort(df.age.desc()).collect()
494496 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
497+ >>> df.sort("age", ascending=False).collect()
498+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
495499 >>> df.orderBy(df.age.desc()).collect()
496500 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
497501 >>> from pyspark.sql.functions import *
498502 >>> df.sort(asc("age")).collect()
499503 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
500504 >>> df.orderBy(desc("age"), "name").collect()
501505 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
506+ >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
507+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
502508 """
503509 if not cols :
504510 raise ValueError ("should sort by at least one column" )
505- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
506- self ._sc ._gateway ._gateway_client )
507- jdf = self ._jdf .sort (self ._sc ._jvm .PythonUtils .toSeq (jcols ))
511+ if len (cols ) == 1 and isinstance (cols [0 ], list ):
512+ cols = cols [0 ]
513+ jcols = [_to_java_column (c ) for c in cols ]
514+ ascending = kwargs .get ('ascending' , True )
515+ if isinstance (ascending , (bool , int )):
516+ if not ascending :
517+ jcols = [jc .desc () for jc in jcols ]
518+ elif isinstance (ascending , list ):
519+ jcols = [jc if asc else jc .desc ()
520+ for asc , jc in zip (ascending , jcols )]
521+ else :
522+ raise TypeError ("ascending can only be bool or list, but got %s" % type (ascending ))
523+
524+ jdf = self ._jdf .sort (self ._jseq (jcols ))
508525 return DataFrame (jdf , self .sql_ctx )
509526
510527 orderBy = sort
511528
529+ def _jseq (self , cols , converter = None ):
530+ """Return a JVM Seq of Columns from a list of Column or names"""
531+ return _to_seq (self .sql_ctx ._sc , cols , converter )
532+
533+ def _jcols (self , * cols ):
534+ """Return a JVM Seq of Columns from a list of Column or column names
535+
536+ If `cols` has only one list in it, cols[0] will be used as the list.
537+ """
538+ if len (cols ) == 1 and isinstance (cols [0 ], list ):
539+ cols = cols [0 ]
540+ return self ._jseq (cols , _to_java_column )
541+
512542 def describe (self , * cols ):
513543 """Computes statistics for numeric columns.
514544
@@ -523,9 +553,7 @@ def describe(self, *cols):
523553 min 2
524554 max 5
525555 """
526- cols = ListConverter ().convert (cols ,
527- self .sql_ctx ._sc ._gateway ._gateway_client )
528- jdf = self ._jdf .describe (self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (cols ))
556+ jdf = self ._jdf .describe (self ._jseq (cols ))
529557 return DataFrame (jdf , self .sql_ctx )
530558
531559 @ignore_unicode_prefix
@@ -607,9 +635,7 @@ def select(self, *cols):
607635 >>> df.select(df.name, (df.age + 10).alias('age')).collect()
608636 [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
609637 """
610- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
611- self ._sc ._gateway ._gateway_client )
612- jdf = self ._jdf .select (self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (jcols ))
638+ jdf = self ._jdf .select (self ._jcols (* cols ))
613639 return DataFrame (jdf , self .sql_ctx )
614640
615641 def selectExpr (self , * expr ):
@@ -620,8 +646,9 @@ def selectExpr(self, *expr):
620646 >>> df.selectExpr("age * 2", "abs(age)").collect()
621647 [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
622648 """
623- jexpr = ListConverter ().convert (expr , self ._sc ._gateway ._gateway_client )
624- jdf = self ._jdf .selectExpr (self ._sc ._jvm .PythonUtils .toSeq (jexpr ))
649+ if len (expr ) == 1 and isinstance (expr [0 ], list ):
650+ expr = expr [0 ]
651+ jdf = self ._jdf .selectExpr (self ._jseq (expr ))
625652 return DataFrame (jdf , self .sql_ctx )
626653
627654 @ignore_unicode_prefix
@@ -659,6 +686,8 @@ def groupBy(self, *cols):
659686 so we can run aggregation on them. See :class:`GroupedData`
660687 for all the available aggregate functions.
661688
689+ :func:`groupby` is an alias for :func:`groupBy`.
690+
662691 :param cols: list of columns to group by.
663692 Each element should be a column name (string) or an expression (:class:`Column`).
664693
@@ -668,12 +697,14 @@ def groupBy(self, *cols):
668697 [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
669698 >>> df.groupBy(df.name).avg().collect()
670699 [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
700+ >>> df.groupBy(['name', df.age]).count().collect()
701+ [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
671702 """
672- jcols = ListConverter ().convert ([_to_java_column (c ) for c in cols ],
673- self ._sc ._gateway ._gateway_client )
674- jdf = self ._jdf .groupBy (self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (jcols ))
703+ jdf = self ._jdf .groupBy (self ._jcols (* cols ))
675704 return GroupedData (jdf , self .sql_ctx )
676705
706+ groupby = groupBy
707+
677708 def agg (self , * exprs ):
678709 """ Aggregate on the entire :class:`DataFrame` without groups
679710 (shorthand for ``df.groupBy.agg()``).
@@ -744,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None):
744775 if thresh is None :
745776 thresh = len (subset ) if how == 'any' else 1
746777
747- cols = ListConverter ().convert (subset , self .sql_ctx ._sc ._gateway ._gateway_client )
748- cols = self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (cols )
749- return DataFrame (self ._jdf .na ().drop (thresh , cols ), self .sql_ctx )
778+ return DataFrame (self ._jdf .na ().drop (thresh , self ._jseq (subset )), self .sql_ctx )
750779
751780 def fillna (self , value , subset = None ):
752781 """Replace null values, alias for ``na.fill()``.
@@ -799,9 +828,7 @@ def fillna(self, value, subset=None):
799828 elif not isinstance (subset , (list , tuple )):
800829 raise ValueError ("subset should be a list or tuple of column names" )
801830
802- cols = ListConverter ().convert (subset , self .sql_ctx ._sc ._gateway ._gateway_client )
803- cols = self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (cols )
804- return DataFrame (self ._jdf .na ().fill (value , cols ), self .sql_ctx )
831+ return DataFrame (self ._jdf .na ().fill (value , self ._jseq (subset )), self .sql_ctx )
805832
806833 @ignore_unicode_prefix
807834 def withColumn (self , colName , col ):
@@ -862,10 +889,8 @@ def _api(self):
862889
863890def df_varargs_api (f ):
864891 def _api (self , * args ):
865- jargs = ListConverter ().convert (args ,
866- self .sql_ctx ._sc ._gateway ._gateway_client )
867892 name = f .__name__
868- jdf = getattr (self ._jdf , name )(self .sql_ctx ._sc . _jvm . PythonUtils . toSeq ( jargs ))
893+ jdf = getattr (self ._jdf , name )(_to_seq ( self .sql_ctx ._sc , args ))
869894 return DataFrame (jdf , self .sql_ctx )
870895 _api .__name__ = f .__name__
871896 _api .__doc__ = f .__doc__
@@ -912,9 +937,8 @@ def agg(self, *exprs):
912937 else :
913938 # Columns
914939 assert all (isinstance (c , Column ) for c in exprs ), "all exprs should be Column"
915- jcols = ListConverter ().convert ([c ._jc for c in exprs [1 :]],
916- self .sql_ctx ._sc ._gateway ._gateway_client )
917- jdf = self ._jdf .agg (exprs [0 ]._jc , self .sql_ctx ._sc ._jvm .PythonUtils .toSeq (jcols ))
940+ jdf = self ._jdf .agg (exprs [0 ]._jc ,
941+ _to_seq (self .sql_ctx ._sc , [c ._jc for c in exprs [1 :]]))
918942 return DataFrame (jdf , self .sql_ctx )
919943
920944 @dfapi
@@ -1006,6 +1030,19 @@ def _to_java_column(col):
10061030 return jcol
10071031
10081032
1033+ def _to_seq (sc , cols , converter = None ):
1034+ """
1035+ Convert a list of Column (or names) into a JVM Seq of Column.
1036+
1037+ An optional `converter` could be used to convert items in `cols`
1038+ into JVM Column objects.
1039+ """
1040+ if converter :
1041+ cols = [converter (c ) for c in cols ]
1042+ jcols = ListConverter ().convert (cols , sc ._gateway ._gateway_client )
1043+ return sc ._jvm .PythonUtils .toSeq (jcols )
1044+
1045+
10091046def _unary_op (name , doc = "unary operator" ):
10101047 """ Create a method for given unary operator """
10111048 def _ (self ):
@@ -1177,8 +1214,7 @@ def inSet(self, *cols):
11771214 cols = cols [0 ]
11781215 cols = [c ._jc if isinstance (c , Column ) else _create_column_from_literal (c ) for c in cols ]
11791216 sc = SparkContext ._active_spark_context
1180- jcols = ListConverter ().convert (cols , sc ._gateway ._gateway_client )
1181- jc = getattr (self ._jc , "in" )(sc ._jvm .PythonUtils .toSeq (jcols ))
1217+ jc = getattr (self ._jc , "in" )(_to_seq (sc , cols ))
11821218 return Column (jc )
11831219
11841220 # order
0 commit comments