Skip to content

Commit 4464fa2

Browse files
committed
[SPARK-11967][SQL] Consistent use of varargs for multiple paths in DataFrameReader
This patch makes it consistent to use varargs in all DataFrameReader methods, including Parquet, JSON, text, and the generic load function. Also added a few more API tests for the Java API. Author: Reynold Xin <[email protected]> Closes #9945 from rxin/SPARK-11967. (cherry picked from commit 25bbd3c) Signed-off-by: Reynold Xin <[email protected]>
1 parent 36a99f9 commit 4464fa2

File tree

5 files changed

+66
-15
lines changed

5 files changed

+66
-15
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def options(self, **options):
109109
def load(self, path=None, format=None, schema=None, **options):
110110
"""Loads data from a data source and returns it as a :class`DataFrame`.
111111
112-
:param path: optional string for file-system backed data sources.
112+
:param path: optional string or a list of string for file-system backed data sources.
113113
:param format: optional string for format of the data source. Default to 'parquet'.
114114
:param schema: optional :class:`StructType` for the input schema.
115115
:param options: all other string options
@@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options):
118118
... opt2=1, opt3='str')
119119
>>> df.dtypes
120120
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
121+
121122
>>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
122123
... 'python/test_support/sql/people1.json'])
123124
>>> df.dtypes
@@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options):
130131
self.options(**options)
131132
if path is not None:
132133
if type(path) == list:
133-
paths = path
134-
gateway = self._sqlContext._sc._gateway
135-
jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths)
136-
return self._df(self._jreader.load(jpaths))
134+
return self._df(
135+
self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
137136
else:
138137
return self._df(self._jreader.load(path))
139138
else:
@@ -175,6 +174,8 @@ def json(self, path, schema=None):
175174
self.schema(schema)
176175
if isinstance(path, basestring):
177176
return self._df(self._jreader.json(path))
177+
elif type(path) == list:
178+
return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
178179
elif isinstance(path, RDD):
179180
return self._df(self._jreader.json(path._jrdd))
180181
else:
@@ -205,16 +206,20 @@ def parquet(self, *paths):
205206

206207
@ignore_unicode_prefix
207208
@since(1.6)
208-
def text(self, path):
209+
def text(self, paths):
209210
"""Loads a text file and returns a [[DataFrame]] with a single string column named "text".
210211
211212
Each line in the text file is a new row in the resulting DataFrame.
212213
214+
:param paths: string, or list of strings, for input path(s).
215+
213216
>>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
214217
>>> df.collect()
215218
[Row(value=u'hello'), Row(value=u'this')]
216219
"""
217-
return self._df(self._jreader.text(path))
220+
if isinstance(paths, basestring):
221+
paths = [paths]
222+
return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths)))
218223

219224
@since(1.5)
220225
def orc(self, path):

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ import scala.collection.JavaConverters._
2424
import org.apache.hadoop.fs.Path
2525
import org.apache.hadoop.util.StringUtils
2626

27+
import org.apache.spark.{Logging, Partition}
2728
import org.apache.spark.annotation.Experimental
2829
import org.apache.spark.api.java.JavaRDD
2930
import org.apache.spark.deploy.SparkHadoopUtil
3031
import org.apache.spark.rdd.RDD
32+
import org.apache.spark.sql.catalyst.SqlParser
3133
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
32-
import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation}
34+
import org.apache.spark.sql.execution.datasources.json.JSONRelation
3335
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
3436
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
3537
import org.apache.spark.sql.types.StructType
36-
import org.apache.spark.{Logging, Partition}
37-
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
3838

3939
/**
4040
* :: Experimental ::
@@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
104104
*
105105
* @since 1.4.0
106106
*/
107+
// TODO: Remove this one in Spark 2.0.
107108
def load(path: String): DataFrame = {
108109
option("path", path).load()
109110
}
@@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
130131
*
131132
* @since 1.6.0
132133
*/
133-
def load(paths: Array[String]): DataFrame = {
134+
@scala.annotation.varargs
135+
def load(paths: String*): DataFrame = {
134136
option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load()
135137
}
136138

@@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
236238
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
237239
* (e.g. 00012)</li>
238240
*
239-
* @param path input path
240241
* @since 1.4.0
241242
*/
243+
// TODO: Remove this one in Spark 2.0.
242244
def json(path: String): DataFrame = format("json").load(path)
243245

246+
/**
247+
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
248+
*
249+
* This function goes through the input once to determine the input schema. If you know the
250+
* schema in advance, use the version that specifies the schema to avoid the extra scan.
251+
*
252+
* You can set the following JSON-specific options to deal with non-standard JSON files:
253+
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
254+
* <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
255+
* <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
256+
* <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
257+
* </li>
258+
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
259+
* (e.g. 00012)</li>
260+
*
261+
* @since 1.6.0
262+
*/
263+
def json(paths: String*): DataFrame = format("json").load(paths : _*)
264+
244265
/**
245266
* Loads an `JavaRDD[String]` storing JSON objects (one object per record) and
246267
* returns the result as a [[DataFrame]].
@@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
328349
* sqlContext.read().text("/path/to/spark/README.md")
329350
* }}}
330351
*
331-
* @param path input path
352+
* @param paths input path
332353
* @since 1.6.0
333354
*/
334-
def text(path: String): DataFrame = format("text").load(path)
355+
@scala.annotation.varargs
356+
def text(paths: String*): DataFrame = format("text").load(paths : _*)
335357

336358
///////////////////////////////////////////////////////////////////////////////////////
337359
// Builder pattern config options

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,4 +298,27 @@ public void pivot() {
298298
Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
299299
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
300300
}
301+
302+
public void testGenericLoad() {
303+
DataFrame df1 = context.read().format("text").load(
304+
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
305+
Assert.assertEquals(4L, df1.count());
306+
307+
DataFrame df2 = context.read().format("text").load(
308+
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
309+
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
310+
Assert.assertEquals(5L, df2.count());
311+
}
312+
313+
@Test
314+
public void testTextLoad() {
315+
DataFrame df1 = context.read().text(
316+
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
317+
Assert.assertEquals(4L, df1.count());
318+
319+
DataFrame df2 = context.read().text(
320+
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
321+
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
322+
Assert.assertEquals(5L, df2.count());
323+
}
301324
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is another file for testing multi path loading.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
897897
val dir2 = new File(dir, "dir2").getCanonicalPath
898898
df2.write.format("json").save(dir2)
899899

900-
checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)),
900+
checkAnswer(sqlContext.read.format("json").load(dir1, dir2),
901901
Row(1, 22) :: Row(2, 23) :: Nil)
902902

903903
checkAnswer(sqlContext.read.format("json").load(dir1),

0 commit comments

Comments
 (0)