-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas #5544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -485,30 +485,60 @@ def join(self, other, joinExprs=None, joinType=None): | |
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @ignore_unicode_prefix | ||
| def sort(self, *cols): | ||
| def sort(self, *cols, **kwargs): | ||
| """Returns a new :class:`DataFrame` sorted by the specified column(s). | ||
|
|
||
| :param cols: list of :class:`Column` to sort by. | ||
| :param cols: list of :class:`Column` or column names to sort by. | ||
| :param ascending: sort by ascending order or not, could be bool, int | ||
| or list of bool, int (default: True). | ||
|
|
||
| >>> df.sort(df.age.desc()).collect() | ||
| [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] | ||
| >>> df.sort("age", ascending=False).collect() | ||
| [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] | ||
| >>> df.orderBy(df.age.desc()).collect() | ||
| [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] | ||
| >>> from pyspark.sql.functions import * | ||
| >>> df.sort(asc("age")).collect() | ||
| [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] | ||
| >>> df.orderBy(desc("age"), "name").collect() | ||
| [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] | ||
| >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() | ||
| [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] | ||
| """ | ||
| if not cols: | ||
| raise ValueError("should sort by at least one column") | ||
| jcols = ListConverter().convert([_to_java_column(c) for c in cols], | ||
| self._sc._gateway._gateway_client) | ||
| jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) | ||
| if len(cols) == 1 and isinstance(cols[0], list): | ||
| cols = cols[0] | ||
| jcols = [_to_java_column(c) for c in cols] | ||
| ascending = kwargs.get('ascending', True) | ||
| if isinstance(ascending, (bool, int)): | ||
| if not ascending: | ||
| jcols = [jc.desc() for jc in jcols] | ||
| elif isinstance(ascending, list): | ||
| jcols = [jc if asc else jc.desc() | ||
| for asc, jc in zip(ascending, jcols)] | ||
| else: | ||
| raise TypeError("ascending can only be bool or list, but got %s" % type(ascending)) | ||
|
|
||
| jdf = self._jdf.sort(self._jseq(jcols)) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| orderBy = sort | ||
|
|
||
| def _jseq(self, cols, converter=None): | ||
| """Return a JVM Seq of Columns from a list of Column or names""" | ||
| return _to_seq(self.sql_ctx._sc, cols, converter) | ||
|
|
||
| def _jcols(self, *cols): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function too docstring |
||
| """Return a JVM Seq of Columns from a list of Column or column names | ||
|
|
||
| If `cols` has only one list in it, cols[0] will be used as the list. | ||
| """ | ||
| if len(cols) == 1 and isinstance(cols[0], list): | ||
| cols = cols[0] | ||
| return self._jseq(cols, _to_java_column) | ||
|
|
||
| def describe(self, *cols): | ||
| """Computes statistics for numeric columns. | ||
|
|
||
|
|
@@ -523,9 +553,7 @@ def describe(self, *cols): | |
| min 2 | ||
| max 5 | ||
| """ | ||
| cols = ListConverter().convert(cols, | ||
| self.sql_ctx._sc._gateway._gateway_client) | ||
| jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) | ||
| jdf = self._jdf.describe(self._jseq(cols)) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @ignore_unicode_prefix | ||
|
|
@@ -600,9 +628,7 @@ def select(self, *cols): | |
| >>> df.select(df.name, (df.age + 10).alias('age')).collect() | ||
| [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] | ||
| """ | ||
| jcols = ListConverter().convert([_to_java_column(c) for c in cols], | ||
| self._sc._gateway._gateway_client) | ||
| jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) | ||
| jdf = self._jdf.select(self._jcols(*cols)) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| def selectExpr(self, *expr): | ||
|
|
@@ -613,8 +639,9 @@ def selectExpr(self, *expr): | |
| >>> df.selectExpr("age * 2", "abs(age)").collect() | ||
| [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] | ||
| """ | ||
| jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) | ||
| jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) | ||
| if len(expr) == 1 and isinstance(expr[0], list): | ||
| expr = expr[0] | ||
| jdf = self._jdf.selectExpr(self._jseq(expr)) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @ignore_unicode_prefix | ||
|
|
@@ -652,6 +679,8 @@ def groupBy(self, *cols): | |
| so we can run aggregation on them. See :class:`GroupedData` | ||
| for all the available aggregate functions. | ||
|
|
||
| :func:`groupby` is an alias for :func:`groupBy`. | ||
|
|
||
| :param cols: list of columns to group by. | ||
| Each element should be a column name (string) or an expression (:class:`Column`). | ||
|
|
||
|
|
@@ -661,12 +690,14 @@ def groupBy(self, *cols): | |
| [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] | ||
| >>> df.groupBy(df.name).avg().collect() | ||
| [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] | ||
| >>> df.groupBy(['name', df.age]).count().collect() | ||
| [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] | ||
| """ | ||
| jcols = ListConverter().convert([_to_java_column(c) for c in cols], | ||
| self._sc._gateway._gateway_client) | ||
| jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) | ||
| jdf = self._jdf.groupBy(self._jcols(*cols)) | ||
| return GroupedData(jdf, self.sql_ctx) | ||
|
|
||
| groupby = groupBy | ||
|
|
||
| def agg(self, *exprs): | ||
| """ Aggregate on the entire :class:`DataFrame` without groups | ||
| (shorthand for ``df.groupBy.agg()``). | ||
|
|
@@ -737,9 +768,7 @@ def dropna(self, how='any', thresh=None, subset=None): | |
| if thresh is None: | ||
| thresh = len(subset) if how == 'any' else 1 | ||
|
|
||
| cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) | ||
| cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) | ||
| return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) | ||
| return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) | ||
|
|
||
| def fillna(self, value, subset=None): | ||
| """Replace null values, alias for ``na.fill()``. | ||
|
|
@@ -792,9 +821,7 @@ def fillna(self, value, subset=None): | |
| elif not isinstance(subset, (list, tuple)): | ||
| raise ValueError("subset should be a list or tuple of column names") | ||
|
|
||
| cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) | ||
| cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) | ||
| return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) | ||
| return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) | ||
|
|
||
| @ignore_unicode_prefix | ||
| def withColumn(self, colName, col): | ||
|
|
@@ -855,10 +882,8 @@ def _api(self): | |
|
|
||
| def df_varargs_api(f): | ||
| def _api(self, *args): | ||
| jargs = ListConverter().convert(args, | ||
| self.sql_ctx._sc._gateway._gateway_client) | ||
| name = f.__name__ | ||
| jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) | ||
| jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
| _api.__name__ = f.__name__ | ||
| _api.__doc__ = f.__doc__ | ||
|
|
@@ -905,9 +930,8 @@ def agg(self, *exprs): | |
| else: | ||
| # Columns | ||
| assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" | ||
| jcols = ListConverter().convert([c._jc for c in exprs[1:]], | ||
| self.sql_ctx._sc._gateway._gateway_client) | ||
| jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) | ||
| jdf = self._jdf.agg(exprs[0]._jc, | ||
| _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @dfapi | ||
|
|
@@ -999,6 +1023,19 @@ def _to_java_column(col): | |
| return jcol | ||
|
|
||
|
|
||
| def _to_seq(sc, cols, converter=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring for this as well
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| """ | ||
| Convert a list of Column (or names) into a JVM Seq of Column. | ||
|
|
||
| An optional `converter` could be used to convert items in `cols` | ||
| into JVM Column objects. | ||
| """ | ||
| if converter: | ||
| cols = [converter(c) for c in cols] | ||
| jcols = ListConverter().convert(cols, sc._gateway._gateway_client) | ||
| return sc._jvm.PythonUtils.toSeq(jcols) | ||
|
|
||
|
|
||
| def _unary_op(name, doc="unary operator"): | ||
| """ Create a method for given unary operator """ | ||
| def _(self): | ||
|
|
@@ -1138,8 +1175,7 @@ def inSet(self, *cols): | |
| cols = cols[0] | ||
| cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] | ||
| sc = SparkContext._active_spark_context | ||
| jcols = ListConverter().convert(cols, sc._gateway._gateway_client) | ||
| jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) | ||
| jc = getattr(self._jc, "in")(_to_seq(sc, cols)) | ||
| return Column(jc) | ||
|
|
||
| # order | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add some docstring here to explain what this does?