Skip to content

Commit ce320cb

Browse files
committed
[SPARK-8060] Improve DataFrame Python test coverage and documentation.
Author: Reynold Xin <[email protected]> Closes apache#6601 from rxin/python-read-write-test-and-doc and squashes the following commits: baa8ad5 [Reynold Xin] Code review feedback. f081d47 [Reynold Xin] More documentation updates. c9902fa [Reynold Xin] [SPARK-8060] Improve DataFrame Python reader/writer interface doc and testing.
1 parent 452eb82 commit ce320cb

20 files changed

+180
-227
lines changed

.rat-excludes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@ local-1426633911242/*
8282
local-1430917381534/*
8383
DESCRIPTION
8484
NAMESPACE
85+
test_support/*

python/pyspark/sql/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,20 @@
4545

4646

4747
def since(version):
48+
"""
49+
A decorator that annotates a function to append the version of Spark the function was added.
50+
"""
51+
import re
52+
indent_p = re.compile(r'\n( +)')
53+
4854
def deco(f):
49-
f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version
55+
indents = indent_p.findall(f.__doc__)
56+
indent = ' ' * (min(len(m) for m in indents) if indents else 0)
57+
f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
5058
return f
5159
return deco
5260

61+
5362
from pyspark.sql.types import Row
5463
from pyspark.sql.context import SQLContext, HiveContext
5564
from pyspark.sql.column import Column
@@ -58,7 +67,9 @@ def deco(f):
5867
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
5968
from pyspark.sql.window import Window, WindowSpec
6069

70+
6171
__all__ = [
6272
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
6373
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
74+
'DataFrameReader', 'DataFrameWriter'
6475
]

python/pyspark/sql/context.py

Lines changed: 37 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def getConf(self, key, defaultValue):
124124
@property
125125
@since("1.3.1")
126126
def udf(self):
127-
"""Returns a :class:`UDFRegistration` for UDF registration."""
127+
"""Returns a :class:`UDFRegistration` for UDF registration.
128+
129+
:return: :class:`UDFRegistration`
130+
"""
128131
return UDFRegistration(self)
129132

130133
@since(1.4)
@@ -138,7 +141,7 @@ def range(self, start, end, step=1, numPartitions=None):
138141
:param end: the end value (exclusive)
139142
:param step: the incremental step (default: 1)
140143
:param numPartitions: the number of partitions of the DataFrame
141-
:return: A new DataFrame
144+
:return: :class:`DataFrame`
142145
143146
>>> sqlContext.range(1, 7, 2).collect()
144147
[Row(id=1), Row(id=3), Row(id=5)]
@@ -195,8 +198,8 @@ def _inferSchema(self, rdd, samplingRatio=None):
195198
raise ValueError("The first row in RDD is empty, "
196199
"can not infer schema")
197200
if type(first) is dict:
198-
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
199-
"please use pyspark.sql.Row instead")
201+
warnings.warn("Using RDD of dict to inferSchema is deprecated. "
202+
"Use pyspark.sql.Row instead")
200203

201204
if samplingRatio is None:
202205
schema = _infer_schema(first)
@@ -219,7 +222,7 @@ def inferSchema(self, rdd, samplingRatio=None):
219222
"""
220223
.. note:: Deprecated in 1.3, use :func:`createDataFrame` instead.
221224
"""
222-
warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
225+
warnings.warn("inferSchema is deprecated, please use createDataFrame instead.")
223226

224227
if isinstance(rdd, DataFrame):
225228
raise TypeError("Cannot apply schema to DataFrame")
@@ -262,6 +265,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
262265
:class:`list`, or :class:`pandas.DataFrame`.
263266
:param schema: a :class:`StructType` or list of column names. default None.
264267
:param samplingRatio: the sample ratio of rows used for inferring
268+
:return: :class:`DataFrame`
265269
266270
>>> l = [('Alice', 1)]
267271
>>> sqlContext.createDataFrame(l).collect()
@@ -359,58 +363,31 @@ def registerDataFrameAsTable(self, df, tableName):
359363
else:
360364
raise ValueError("Can only register DataFrame as table")
361365

362-
@since(1.0)
363366
def parquetFile(self, *paths):
364367
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
365368
366-
>>> import tempfile, shutil
367-
>>> parquetFile = tempfile.mkdtemp()
368-
>>> shutil.rmtree(parquetFile)
369-
>>> df.saveAsParquetFile(parquetFile)
370-
>>> df2 = sqlContext.parquetFile(parquetFile)
371-
>>> sorted(df.collect()) == sorted(df2.collect())
372-
True
369+
.. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead.
370+
371+
>>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes
372+
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
373373
"""
374+
warnings.warn("parquetFile is deprecated. Use read.parquet() instead.")
374375
gateway = self._sc._gateway
375376
jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
376377
for i in range(0, len(paths)):
377378
jpaths[i] = paths[i]
378379
jdf = self._ssql_ctx.parquetFile(jpaths)
379380
return DataFrame(jdf, self)
380381

381-
@since(1.0)
382382
def jsonFile(self, path, schema=None, samplingRatio=1.0):
383383
"""Loads a text file storing one JSON object per line as a :class:`DataFrame`.
384384
385-
If the schema is provided, applies the given schema to this JSON dataset.
386-
Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.
387-
388-
>>> import tempfile, shutil
389-
>>> jsonFile = tempfile.mkdtemp()
390-
>>> shutil.rmtree(jsonFile)
391-
>>> with open(jsonFile, 'w') as f:
392-
... f.writelines(jsonStrings)
393-
>>> df1 = sqlContext.jsonFile(jsonFile)
394-
>>> df1.printSchema()
395-
root
396-
|-- field1: long (nullable = true)
397-
|-- field2: string (nullable = true)
398-
|-- field3: struct (nullable = true)
399-
| |-- field4: long (nullable = true)
385+
.. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead.
400386
401-
>>> from pyspark.sql.types import *
402-
>>> schema = StructType([
403-
... StructField("field2", StringType()),
404-
... StructField("field3",
405-
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
406-
>>> df2 = sqlContext.jsonFile(jsonFile, schema)
407-
>>> df2.printSchema()
408-
root
409-
|-- field2: string (nullable = true)
410-
|-- field3: struct (nullable = true)
411-
| |-- field5: array (nullable = true)
412-
| | |-- element: integer (containsNull = true)
387+
>>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes
388+
[('age', 'bigint'), ('name', 'string')]
413389
"""
390+
warnings.warn("jsonFile is deprecated. Use read.json() instead.")
414391
if schema is None:
415392
df = self._ssql_ctx.jsonFile(path, samplingRatio)
416393
else:
@@ -462,21 +439,16 @@ def func(iterator):
462439
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
463440
return DataFrame(df, self)
464441

465-
@since(1.3)
466442
def load(self, path=None, source=None, schema=None, **options):
467443
"""Returns the dataset in a data source as a :class:`DataFrame`.
468444
469-
The data source is specified by the ``source`` and a set of ``options``.
470-
If ``source`` is not specified, the default data source configured by
471-
``spark.sql.sources.default`` will be used.
472-
473-
Optionally, a schema can be provided as the schema of the returned DataFrame.
445+
.. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead.
474446
"""
447+
warnings.warn("load is deprecated. Use read.load() instead.")
475448
return self.read.load(path, source, schema, **options)
476449

477450
@since(1.3)
478-
def createExternalTable(self, tableName, path=None, source=None,
479-
schema=None, **options):
451+
def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
480452
"""Creates an external table based on the dataset in a data source.
481453
482454
It returns the DataFrame associated with the external table.
@@ -487,6 +459,8 @@ def createExternalTable(self, tableName, path=None, source=None,
487459
488460
Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
489461
created external table.
462+
463+
:return: :class:`DataFrame`
490464
"""
491465
if path is not None:
492466
options["path"] = path
@@ -508,6 +482,8 @@ def createExternalTable(self, tableName, path=None, source=None,
508482
def sql(self, sqlQuery):
509483
"""Returns a :class:`DataFrame` representing the result of the given query.
510484
485+
:return: :class:`DataFrame`
486+
511487
>>> sqlContext.registerDataFrameAsTable(df, "table1")
512488
>>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
513489
>>> df2.collect()
@@ -519,6 +495,8 @@ def sql(self, sqlQuery):
519495
def table(self, tableName):
520496
"""Returns the specified table as a :class:`DataFrame`.
521497
498+
:return: :class:`DataFrame`
499+
522500
>>> sqlContext.registerDataFrameAsTable(df, "table1")
523501
>>> df2 = sqlContext.table("table1")
524502
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -536,6 +514,9 @@ def tables(self, dbName=None):
536514
The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
537515
(a column with :class:`BooleanType` indicating if a table is a temporary one or not).
538516
517+
:param dbName: string, name of the database to use.
518+
:return: :class:`DataFrame`
519+
539520
>>> sqlContext.registerDataFrameAsTable(df, "table1")
540521
>>> df2 = sqlContext.tables()
541522
>>> df2.filter("tableName = 'table1'").first()
@@ -550,7 +531,8 @@ def tables(self, dbName=None):
550531
def tableNames(self, dbName=None):
551532
"""Returns a list of names of tables in the database ``dbName``.
552533
553-
If ``dbName`` is not specified, the current database will be used.
534+
:param dbName: string, name of the database to use. Default to the current database.
535+
:return: list of table names, in string
554536
555537
>>> sqlContext.registerDataFrameAsTable(df, "table1")
556538
>>> "table1" in sqlContext.tableNames()
@@ -585,8 +567,7 @@ def read(self):
585567
Returns a :class:`DataFrameReader` that can be used to read data
586568
in as a :class:`DataFrame`.
587569
588-
>>> sqlContext.read
589-
<pyspark.sql.readwriter.DataFrameReader object at ...>
570+
:return: :class:`DataFrameReader`
590571
"""
591572
return DataFrameReader(self)
592573

@@ -644,10 +625,14 @@ def register(self, name, f, returnType=StringType()):
644625

645626

646627
def _test():
628+
import os
647629
import doctest
648630
from pyspark.context import SparkContext
649631
from pyspark.sql import Row, SQLContext
650632
import pyspark.sql.context
633+
634+
os.chdir(os.environ["SPARK_HOME"])
635+
651636
globs = pyspark.sql.context.__dict__.copy()
652637
sc = SparkContext('local[4]', 'PythonTest')
653638
globs['sc'] = sc

0 commit comments

Comments
 (0)