From 8447a6d76fd57fcaad2b6b0c64441fd269070a49 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 21 Apr 2017 21:14:00 +0900 Subject: [PATCH 1/4] Specify a schema by using a DDL-formatted string --- python/pyspark/sql/readwriter.py | 19 +++++++++++-------- .../apache/spark/sql/DataFrameReader.scala | 12 ++++++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 9 +++++++++ 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 960fb882cf90..e9954539fa1e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -96,14 +96,17 @@ def schema(self, schema): By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. - :param schema: a :class:`pyspark.sql.types.StructType` object + :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string """ from pyspark.sql import SparkSession - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") spark = SparkSession.builder.getOrCreate() - jschema = spark._jsparkSession.parseDataType(schema.json()) - self._jreader = self._jreader.schema(jschema) + if isinstance(schema, StructType): + jschema = spark._jsparkSession.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + elif isinstance(schema, basestring): + self._jreader = self._jreader.schema(schema) + else: + raise TypeError("schema should be StructType") return self @since(1.5) @@ -137,7 +140,7 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string. :param options: all other string options >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, @@ -181,7 +184,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or a list of paths, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string. :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -324,7 +327,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``inferSchema`` option or specify the schema explicitly using ``schema``. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string. :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c1b32917415a..0f96e82cedf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -67,6 +67,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataFrameReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index fb15e7def6db..306aecb5bbc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -128,6 +128,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be import testImplicits._ private val userSchema = new StructType().add("s", StringType) + private val userSchemaString = "s STRING" private val textSchema = new StructType().add("value", StringType) private val data = Seq("1", "2", "3") private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath @@ -678,4 +679,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(e.contains("User specified schema not supported with `table`")) } } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchemaString).text(dir), data, userSchema) + testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema) + } } From a1a2e35afdf094150a0a2d0668e3bd69dd445fdc Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 21 Apr 2017 23:52:43 +0900 Subject: [PATCH 2/4] Fix syntax errors --- python/pyspark/sql/readwriter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e9954539fa1e..542528933db6 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -140,7 +140,8 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string. :param options: all other string options >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, @@ -184,7 +185,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or a list of paths, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or + a DDL-formatted string. :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -327,7 +329,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``inferSchema`` option or specify the schema explicitly using ``schema``. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string. :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, From 5fe5e39cb0466df96313e9ff0015ddc5fc8a957e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 22 Apr 2017 18:04:46 +0900 Subject: [PATCH 3/4] Apply comments --- python/pyspark/sql/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 542528933db6..cface8e3dcf2 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -106,7 +106,7 @@ def schema(self, schema): elif isinstance(schema, basestring): self._jreader = self._jreader.schema(schema) else: - raise TypeError("schema should be StructType") + raise TypeError("schema should be StructType or string") return self @since(1.5) From 46994fb2f22135f19ab615265cbe9da8f24aaa15 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 11 May 2017 12:56:27 +0900 Subject: [PATCH 4/4] Add an example --- python/pyspark/sql/readwriter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cface8e3dcf2..1b9844520255 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -97,6 +97,7 @@ def schema(self, schema): inference step, and thus speed up data loading. :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() @@ -141,7 +142,7 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema - or a DDL-formatted string. + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, @@ -186,7 +187,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or a list of paths, or RDD of Strings storing JSON objects. :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or - a DDL-formatted string. + a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -330,7 +331,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param path: string, or list of strings, for input path(s). :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema - or a DDL-formatted string. + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set,