Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,30 +485,60 @@ def join(self, other, joinExprs=None, joinType=None):
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
def sort(self, *cols):
def sort(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).

:param cols: list of :class:`Column` to sort by.
:param cols: list of :class:`Column` or column names to sort by.
:param ascending: sort by ascending order or not, could be bool, int
or list of bool, int (default: True).

>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.sort("age", ascending=False).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
>>> df.sort(asc("age")).collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
if not cols:
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
jcols = [_to_java_column(c) for c in cols]
ascending = kwargs.get('ascending', True)
if isinstance(ascending, (bool, int)):
if not ascending:
jcols = [jc.desc() for jc in jcols]
elif isinstance(ascending, list):
jcols = [jc if asc else jc.desc()
for asc, jc in zip(ascending, jcols)]
else:
raise TypeError("ascending can only be bool or list, but got %s" % type(ascending))

jdf = self._jdf.sort(self._jseq(jcols))
return DataFrame(jdf, self.sql_ctx)

orderBy = sort

def _jseq(self, cols, converter=None):
"""Return a JVM Seq of Columns from a list of Column or names"""
return _to_seq(self.sql_ctx._sc, cols, converter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some docstring here to explain what this does?


def _jcols(self, *cols):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function too docstring

"""Return a JVM Seq of Columns from a list of Column or column names

If `cols` has only one list in it, cols[0] will be used as the list.
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
return self._jseq(cols, _to_java_column)

def describe(self, *cols):
"""Computes statistics for numeric columns.

Expand All @@ -523,9 +553,7 @@ def describe(self, *cols):
min 2
max 5
"""
cols = ListConverter().convert(cols,
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
Expand Down Expand Up @@ -600,9 +628,7 @@ def select(self, *cols):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sql_ctx)

def selectExpr(self, *expr):
Expand All @@ -613,8 +639,9 @@ def selectExpr(self, *expr):
>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
"""
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
Expand Down Expand Up @@ -652,6 +679,8 @@ def groupBy(self, *cols):
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.

:func:`groupby` is an alias for :func:`groupBy`.

:param cols: list of columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`).

Expand All @@ -661,12 +690,14 @@ def groupBy(self, *cols):
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(['name', df.age]).count().collect()
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.groupBy(self._jcols(*cols))
return GroupedData(jdf, self.sql_ctx)

groupby = groupBy

def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy.agg()``).
Expand Down Expand Up @@ -737,9 +768,7 @@ def dropna(self, how='any', thresh=None, subset=None):
if thresh is None:
thresh = len(subset) if how == 'any' else 1

cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)

def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
Expand Down Expand Up @@ -792,9 +821,7 @@ def fillna(self, value, subset=None):
elif not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")

cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)

@ignore_unicode_prefix
def withColumn(self, colName, col):
Expand Down Expand Up @@ -855,10 +882,8 @@ def _api(self):

def df_varargs_api(f):
def _api(self, *args):
jargs = ListConverter().convert(args,
self.sql_ctx._sc._gateway._gateway_client)
name = f.__name__
jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
Expand Down Expand Up @@ -905,9 +930,8 @@ def agg(self, *exprs):
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)

@dfapi
Expand Down Expand Up @@ -999,6 +1023,19 @@ def _to_java_column(col):
return jcol


def _to_seq(sc, cols, converter=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring for this as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Convert a list of Column (or names) into a JVM Seq of Column.

An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
return sc._jvm.PythonUtils.toSeq(jcols)


def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
Expand Down Expand Up @@ -1138,8 +1175,7 @@ def inSet(self, *cols):
cols = cols[0]
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
sc = SparkContext._active_spark_context
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols))
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
return Column(jc)

# order
Expand Down
11 changes: 3 additions & 8 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
if sys.version < "3":
from itertools import imap as map

from py4j.java_collections import ListConverter

from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.dataframe import Column, _to_java_column
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq


__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
Expand Down Expand Up @@ -87,8 +85,7 @@ def countDistinct(col, *cols):
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
return Column(jc)


Expand Down Expand Up @@ -138,9 +135,7 @@ def __del__(self):

def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
return Column(jc)


Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_apply_schema(self):
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
StructField("null1", DoubleType(), True)])
df = self.sqlCtx.applySchema(rdd, schema)
df = self.sqlCtx.createDataFrame(rdd, schema)
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
Expand Down