Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,18 @@ def schema(self, schema):
By specifying the schema here, the underlying data source can skip the schema
inference step, and thus speed up data loading.

:param schema: a :class:`pyspark.sql.types.StructType` object
:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give an example here to users?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

(For example ``col0 INT, col1 DOUBLE``).
"""
from pyspark.sql import SparkSession
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
spark = SparkSession.builder.getOrCreate()
jschema = spark._jsparkSession.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
if isinstance(schema, StructType):
jschema = spark._jsparkSession.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
elif isinstance(schema, basestring):
self._jreader = self._jreader.schema(schema)
else:
raise TypeError("schema should be StructType or string")
return self

@since(1.5)
Expand Down Expand Up @@ -137,7 +141,8 @@ def load(self, path=None, format=None, schema=None, **options):

:param path: optional string or a list of string for file-system backed data sources.
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param options: all other string options

>>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
Expand Down Expand Up @@ -181,7 +186,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,

:param path: string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param primitivesAsString: infers all primitive values as a string type. If None is set,
it uses the default value, ``false``.
:param prefersDecimal: infers all floating-point values as a decimal type. If the values
Expand Down Expand Up @@ -324,7 +330,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``inferSchema`` option or specify the schema explicitly using ``schema``.

:param path: string, or list of strings, for input path(s).
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
this
}

/**
* Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can
* infer the input schema automatically from data. By specifying the schema here, the underlying
* data source can skip the schema inference step, and thus speed up data loading.
*
* @since 2.3.0
*/
def schema(schemaString: String): DataFrameReader = {
this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will make PySpark API inconsistent with the Scala API

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I probably missed your point. What's the API consistency you pointed out here?
I just made the python APIs the same with the Scala ones like:

--- python
>>> from pyspark.sql.types import *
>>> fields = [StructField('a', IntegerType(), True), StructField('b', StringType(), True), StructField('c', DoubleType(), True)]
>>> schema = StructType(fields)
>>> spark.read.schema(schema).csv("/Users/maropu/Desktop/test.csv").show()
+---+----+---+
|  a|   b|  c|
+---+----+---+
|  1| aaa|0.3|
+---+----+---+

>>> spark.read.schema("a INT, b STRING, c DOUBLE").csv("/Users/maropu/Desktop/test.csv").show()
+---+----+---+
|  a|   b|  c|
+---+----+---+
|  1| aaa|0.3|
+---+----+---+

--- scala
scala> import org.apache.spark.sql.types._
scala> fields = StructField("a", IntegerType) :: StructField("b", StringType) :: StructField("c", DoubleType) :: Nil
scala> val schema = StructType(fields)
scala> spark.read.schema(schema).csv("/Users/maropu/Desktop/test.csv").show
+---+----+---+
|  a|   b|  c|
+---+----+---+
|  1| aaa|0.3|
+---+----+---+

scala> spark.read.schema("a INT, b STRING, c DOUBLE").csv("/Users/maropu/Desktop/test.csv").show
+---+----+---+
|  a|   b|  c|
+---+----+---+
|  1| aaa|0.3|
+---+----+---+

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misread the Python codes.

this
}

/**
* Adds an input option for the underlying data source.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
import testImplicits._

private val userSchema = new StructType().add("s", StringType)
private val userSchemaString = "s STRING"
private val textSchema = new StructType().add("value", StringType)
private val data = Seq("1", "2", "3")
private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
Expand Down Expand Up @@ -678,4 +679,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
assert(e.contains("User specified schema not supported with `table`"))
}
}

test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema)
testRead(spark.read.schema(userSchemaString).text(dir), data, userSchema)
testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema)
testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema)
}
}