diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e8f0d7ec77035..2e75f0c8a1827 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,7 +109,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes @@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options): self.options(**options) if path is not None: if type(path) == list: - paths = path - gateway = self._sqlContext._sc._gateway - jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) - return self._df(self._jreader.load(jpaths)) + return self._df( + self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load(path)) else: @@ -175,6 +174,8 @@ def json(self, path, schema=None): self.schema(schema) if isinstance(path, basestring): return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): return self._df(self._jreader.json(path._jrdd)) else: @@ -205,16 +206,20 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, path): + def text(self, paths): """Loads a text file and returns a [[DataFrame]] with a single string column named "text". Each line in the text file is a new row in the resulting DataFrame. + :param paths: string, or list of strings, for input path(s). + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ - return self._df(self._jreader.text(path)) + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) @since(1.5) def orc(self, path): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dcb3737b70fbf..3ed1e55adec6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.SqlParser import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def load(path: String): DataFrame = { option("path", path).load() } @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def load(paths: Array[String]): DataFrame = { + @scala.annotation.varargs + def load(paths: String*): DataFrame = { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { *