Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def options(self, **options):
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.

:param path: optional string for file-system backed data sources.
:param path: optional string or a list of string for file-system backed data sources.
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
Expand All @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options):
... opt2=1, opt3='str')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]

>>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
... 'python/test_support/sql/people1.json'])
>>> df.dtypes
Expand All @@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options):
self.options(**options)
if path is not None:
if type(path) == list:
paths = path
gateway = self._sqlContext._sc._gateway
jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths)
return self._df(self._jreader.load(jpaths))
return self._df(
self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
else:
return self._df(self._jreader.load(path))
else:
Expand Down Expand Up @@ -175,6 +174,8 @@ def json(self, path, schema=None):
self.schema(schema)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
elif type(path) == list:
return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
return self._df(self._jreader.json(path._jrdd))
else:
Expand Down Expand Up @@ -205,16 +206,20 @@ def parquet(self, *paths):

@ignore_unicode_prefix
@since(1.6)
def text(self, path):
def text(self, paths):
"""Loads a text file and returns a [[DataFrame]] with a single string column named "text".

Each line in the text file is a new row in the resulting DataFrame.

:param paths: string, or list of strings, for input path(s).

>>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
return self._df(self._jreader.text(path))
if isinstance(paths, basestring):
paths = [paths]
return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths)))

@since(1.5)
def orc(self, path):
Expand Down
36 changes: 29 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.StringUtils

import org.apache.spark.{Logging, Partition}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.SqlParser
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.types.StructType
import org.apache.spark.{Logging, Partition}
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}

/**
* :: Experimental ::
Expand Down Expand Up @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* @since 1.4.0
*/
// TODO: Remove this one in Spark 2.0.
def load(path: String): DataFrame = {
option("path", path).load()
}
Expand All @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* @since 1.6.0
*/
def load(paths: Array[String]): DataFrame = {
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load()
}

Expand Down Expand Up @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
*
* @param path input path
* @since 1.4.0
*/
// TODO: Remove this one in Spark 2.0.
def json(path: String): DataFrame = format("json").load(path)

/**
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
* <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
* <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
* <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
* </li>
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
*
* @since 1.6.0
*/
def json(paths: String*): DataFrame = format("json").load(paths : _*)

/**
* Loads an `JavaRDD[String]` storing JSON objects (one object per record) and
* returns the result as a [[DataFrame]].
Expand Down Expand Up @@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* sqlContext.read().text("/path/to/spark/README.md")
* }}}
*
* @param path input path
* @param paths input path
* @since 1.6.0
*/
def text(path: String): DataFrame = format("text").load(path)
@scala.annotation.varargs
def text(paths: String*): DataFrame = format("text").load(paths : _*)

///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,27 @@ public void pivot() {
Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
}

public void testGenericLoad() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no @Test here, is it by means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh we should add it. can you add it in ur next pr?

DataFrame df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());

DataFrame df2 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}

@Test
public void testTextLoad() {
DataFrame df1 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());

DataFrame df2 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}
}
1 change: 1 addition & 0 deletions sql/core/src/test/resources/text-suite2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is another file for testing multi path loading.
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val dir2 = new File(dir, "dir2").getCanonicalPath
df2.write.format("json").save(dir2)

checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)),
checkAnswer(sqlContext.read.format("json").load(dir1, dir2),
Row(1, 22) :: Row(2, 23) :: Nil)

checkAnswer(sqlContext.read.format("json").load(dir1),
Expand Down