diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index dc9a49e69aa5a..43206fb426897 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -22,6 +22,10 @@ license: | * Table of contents {:toc} +## Upgrading from Spark SQL 3.2 to 3.3 + + - Since Spark 3.3, Spark turns a non-nullable schema into nullable for API `DataFrameReader.schema(schema: StructType).json(jsonDataset: Dataset[String])` and `DataFrameReader.schema(schema: StructType).csv(csvDataset: Dataset[String])` when the schema is specified by the user and contains non-nullable fields. + ## Upgrading from Spark SQL 3.1 to 3.2 - Since Spark 3.2, ADD FILE/JAR/ARCHIVE commands require each path to be enclosed by `"` or `'` if the path contains whitespaces. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9d62318489828..b13a3525f0713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -449,7 +449,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val schema = userSpecifiedSchema.getOrElse { + val schema = userSpecifiedSchema.map(_.asNullable).getOrElse { TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } @@ -521,7 +521,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { None } - val schema = userSpecifiedSchema.getOrElse { + val schema = userSpecifiedSchema.map(_.asNullable).getOrElse { TextInputCSVDataSource.inferFromDataset( sparkSession, csvDataset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index fd25a79619d24..b58911fb946ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkConf, SparkException, TestUtils} -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Encoders, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.CommonFileDataSourceSuite import org.apache.spark.sql.internal.SQLConf @@ -2463,6 +2463,26 @@ abstract class CSVSuite .option("ignoreTrailingWhiteSpace", "true").load(path.getAbsolutePath).count() == 1) } } + + test("SPARK-35912: turn non-nullable schema into a nullable schema") { + val inputCSVString = """1,""" + + val schema = StructType(Seq( + StructField("c1", IntegerType, nullable = false), + StructField("c2", IntegerType, nullable = false))) + val expected = schema.asNullable + + Seq("DROPMALFORMED", "FAILFAST", "PERMISSIVE").foreach { mode => + val csv = spark.createDataset( + spark.sparkContext.parallelize(inputCSVString:: Nil))(Encoders.STRING) + val df = spark.read + .option("mode", mode) + .schema(schema) + .csv(csv) + assert(df.schema == expected) + checkAnswer(df, Row(1, null) :: Nil) + } + } } class CSVv1Suite extends CSVSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index dab1255eeab32..31df5dfabd699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2919,6 +2919,31 @@ abstract class JsonSuite } } } + + test("SPARK-35912: turn non-nullable schema into a nullable schema") { + // JSON field is missing. + val missingFieldInput = """{"c1": 1}""" + // JSON filed is null. + val nullValueInput = """{"c1": 1, "c2": null}""" + + val schema = StructType(Seq( + StructField("c1", IntegerType, nullable = false), + StructField("c2", IntegerType, nullable = false))) + val expected = schema.asNullable + + Seq(missingFieldInput, nullValueInput).foreach { jsonString => + Seq("DROPMALFORMED", "FAILFAST", "PERMISSIVE").foreach { mode => + val json = spark.createDataset( + spark.sparkContext.parallelize(jsonString:: Nil))(Encoders.STRING) + val df = spark.read + .option("mode", mode) + .schema(schema) + .json(json) + assert(df.schema == expected) + checkAnswer(df, Row(1, null) :: Nil) + } + } + } } class JsonV1Suite extends JsonSuite {