Skip to content

Commit d8adefe

Browse files
Davies Liumarmbrus
authored andcommitted
[SPARK-5859] [PySpark] [SQL] fix DataFrame Python API
1. added explain() 2. add isLocal() 3. do not call show() in __repl__ 4. add foreach() and foreachPartition() 5. add distinct() 6. fix functions.col()/column()/lit() 7. fix unit tests in sql/functions.py 8. fix unicode in showString() Author: Davies Liu <[email protected]> Closes #4645 from davies/df6 and squashes the following commits: 6b46a2c [Davies Liu] fix DataFrame Python API
1 parent c74b07f commit d8adefe

File tree

2 files changed

+59
-18
lines changed

2 files changed

+59
-18
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,22 @@ def printSchema(self):
238238
"""
239239
print (self._jdf.schema().treeString())
240240

241+
def explain(self, extended=False):
242+
"""
243+
Prints the plans (logical and physical) to the console for
244+
debugging purpose.
245+
246+
If extended is False, only prints the physical plan.
247+
"""
248+
self._jdf.explain(extended)
249+
250+
def isLocal(self):
251+
"""
252+
Returns True if the `collect` and `take` methods can be run locally
253+
(without any Spark executors).
254+
"""
255+
return self._jdf.isLocal()
256+
241257
def show(self):
242258
"""
243259
Print the first 20 rows.
@@ -247,14 +263,12 @@ def show(self):
247263
2 Alice
248264
5 Bob
249265
>>> df
250-
age name
251-
2 Alice
252-
5 Bob
266+
DataFrame[age: int, name: string]
253267
"""
254-
print (self)
268+
print self._jdf.showString().encode('utf8', 'ignore')
255269

256270
def __repr__(self):
257-
return self._jdf.showString()
271+
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
258272

259273
def count(self):
260274
"""Return the number of elements in this RDD.
@@ -336,13 +350,40 @@ def mapPartitions(self, f, preservesPartitioning=False):
336350
"""
337351
Return a new RDD by applying a function to each partition.
338352
353+
It's a shorthand for df.rdd.mapPartitions()
354+
339355
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
340356
>>> def f(iterator): yield 1
341357
>>> rdd.mapPartitions(f).sum()
342358
4
343359
"""
344360
return self.rdd.mapPartitions(f, preservesPartitioning)
345361

362+
def foreach(self, f):
363+
"""
364+
Applies a function to all rows of this DataFrame.
365+
366+
It's a shorthand for df.rdd.foreach()
367+
368+
>>> def f(person):
369+
... print person.name
370+
>>> df.foreach(f)
371+
"""
372+
return self.rdd.foreach(f)
373+
374+
def foreachPartition(self, f):
375+
"""
376+
Applies a function to each partition of this DataFrame.
377+
378+
It's a shorthand for df.rdd.foreachPartition()
379+
380+
>>> def f(people):
381+
... for person in people:
382+
... print person.name
383+
>>> df.foreachPartition(f)
384+
"""
385+
return self.rdd.foreachPartition(f)
386+
346387
def cache(self):
347388
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
348389
"""
@@ -377,8 +418,13 @@ def repartition(self, numPartitions):
377418
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
378419
partitions.
379420
"""
380-
rdd = self._jdf.repartition(numPartitions, None)
381-
return DataFrame(rdd, self.sql_ctx)
421+
return DataFrame(self._jdf.repartition(numPartitions, None), self.sql_ctx)
422+
423+
def distinct(self):
424+
"""
425+
Return a new :class:`DataFrame` containing the distinct rows in this DataFrame.
426+
"""
427+
return DataFrame(self._jdf.distinct(), self.sql_ctx)
382428

383429
def sample(self, withReplacement, fraction, seed=None):
384430
"""
@@ -957,10 +1003,7 @@ def cast(self, dataType):
9571003
return Column(jc, self.sql_ctx)
9581004

9591005
def __repr__(self):
960-
if self._jdf.isComputable():
961-
return self._jdf.samples()
962-
else:
963-
return 'Column<%s>' % self._jdf.toString()
1006+
return 'Column<%s>' % self._jdf.toString().encode('utf8')
9641007

9651008
def toPandas(self):
9661009
"""

python/pyspark/sql/functions.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _create_function(name, doc=""):
3737
""" Create a function for aggregator by name"""
3838
def _(col):
3939
sc = SparkContext._active_spark_context
40-
jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
40+
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
4141
return Column(jc)
4242
_.__name__ = name
4343
_.__doc__ = doc
@@ -140,6 +140,7 @@ def __call__(self, *cols):
140140
def udf(f, returnType=StringType()):
141141
"""Create a user defined function (UDF)
142142
143+
>>> from pyspark.sql.types import IntegerType
143144
>>> slen = udf(lambda s: len(s), IntegerType())
144145
>>> df.select(slen(df.name).alias('slen')).collect()
145146
[Row(slen=5), Row(slen=3)]
@@ -151,17 +152,14 @@ def _test():
151152
import doctest
152153
from pyspark.context import SparkContext
153154
from pyspark.sql import Row, SQLContext
154-
import pyspark.sql.dataframe
155-
globs = pyspark.sql.dataframe.__dict__.copy()
155+
import pyspark.sql.functions
156+
globs = pyspark.sql.functions.__dict__.copy()
156157
sc = SparkContext('local[4]', 'PythonTest')
157158
globs['sc'] = sc
158159
globs['sqlCtx'] = SQLContext(sc)
159160
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
160-
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
161-
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
162-
Row(name='Bob', age=5, height=85)]).toDF()
163161
(failure_count, test_count) = doctest.testmod(
164-
pyspark.sql.dataframe, globs=globs,
162+
pyspark.sql.functions, globs=globs,
165163
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
166164
globs['sc'].stop()
167165
if failure_count:

0 commit comments

Comments
 (0)