diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 54006e20a3eb6..4c4e99a7f2ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -370,8 +370,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * * 1. If A and B have the same name and data type, they are merged to a field C with the same name * and data type. C is nullable if and only if either A or B is nullable. - * 2. If A doesn't exist in `that`, it's included in the result schema. - * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 2. If A doesn't exist in `that`, it's included in the result schema with nullable. + * 3. If B doesn't exist in `this`, it's also included in the result schema with nullable. * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be * thrown. */ @@ -473,7 +473,7 @@ object StructType extends AbstractDataType { nullable = leftNullable || rightNullable) } .orElse { - Some(leftField) + Some(leftField.copy(nullable = true)) } .foreach(newFields += _) } @@ -482,7 +482,7 @@ object StructType extends AbstractDataType { rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach { f => - newFields += f + newFields += f.copy(nullable = true) } StructType(newFields) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f2291..4b242ec3febf9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -90,6 +90,11 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private ColumnarBatch columnarBatch; + /** + * Schema corresponds to columnarBatch + */ + private StructType columnarBatchSchema; + /** * If true, this class returns batches instead of rows. */ @@ -178,6 +183,7 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } } + columnarBatchSchema = batchSchema; columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; @@ -228,6 +234,12 @@ public boolean nextBatch() throws IOException { for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnarBatch.column(i)); + StructField schema = columnarBatchSchema.fields()[i]; + if (columnarBatch.column(i).anyNullsSet() && !schema.nullable()) { + throw new UnsupportedOperationException( + "Should not contain null for non-nullable " + schema.dataType() + + " schema at column index " + i); + } } rowsReturned += num; columnarBatch.setNumRows(num); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 9fce29b06b9d8..842929f240007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -374,7 +374,8 @@ case class DataSource( HadoopFsRelation( fileCatalog, partitionSchema = partitionSchema, - dataSchema = dataSchema.asNullable, + dataSchema = + if (format.isInstanceOf[ParquetFileFormat]) dataSchema else dataSchema.asNullable, bucketSpec = bucketSpec, format, caseInsensitiveOptions)(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 32e6c60cd9766..504984fc2a9e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -190,6 +190,11 @@ private[parquet] class ParquetRowConverter( var i = 0 while (i < currentRow.numFields) { fieldConverters(i).updater.end() + if (currentRow.isNullAt(i) && !catalystType(i).nullable) { + throw new UnsupportedOperationException( + "Should not contain null for non-nullable " + catalystType(i).dataType + + " schema at column index " + i) + } i += 1 } updater.set(currentRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d02c8ffe33f0f..69f2aab33eae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -159,7 +159,7 @@ class PlannerSuite extends SharedSQLContext { withTempView("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan - assert(exp.toString.contains("PushedFilters: [IsNotNull(key), EqualTo(key,15)]")) + assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e4dd077715d0f..41af50744b95a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -404,7 +404,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { pathToNonPartitionedTable, userSpecifiedSchema = None, userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = new StructType().add("num", IntegerType, false).add("str", StringType), expectedPartitionCols = Seq.empty[String]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 94a2f9a00b3f3..fcfa2cecc209a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -458,8 +458,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { readParquetFile(path.toString) { df => assertResult(df.schema) { StructType( - StructField("a", BooleanType, nullable = true) :: - StructField("b", IntegerType, nullable = true) :: + StructField("a", BooleanType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 306aecb5bbc86..2b0d019371a84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -21,12 +21,14 @@ import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue +import org.apache.commons.lang3.exception.ExceptionUtils import org.scalatest.BeforeAndAfter import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -123,6 +125,7 @@ class MessageCapturingCommitProtocol(jobId: String, path: String) } } +case class PointStr(x: String, y: String) class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -680,6 +683,63 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } + private def readAndWriteWithSchema(schema: StructType, + df: DataFrame, result: Array[Row], dfNull: DataFrame): Unit = { + val fmt = "parquet" + withTempDir { dir => + val path = new File(dir, "nonnull").getCanonicalPath + df.write.format(fmt).save(path) + val dfRead = spark.read.format(fmt).schema(schema).load(path) + checkAnswer(dfRead, result) + assert(dfRead.schema.equals(schema)) + + val pathNull = new File(dir, "null").getCanonicalPath + dfNull.write.format(fmt).save(pathNull) + val e = intercept[Exception] { + spark.read.format(fmt).schema(schema).load(pathNull).collect + } + assert(ExceptionUtils.getRootCause(e).isInstanceOf[UnsupportedOperationException] && + e.getMessage.contains("Should not contain null for non-nullable")) + } + } + + test("SPARK-19950: loadWithSchema") { + Seq("true", "false").foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + val dataInt = Seq(1, 2, 3) + val dfInt = sparkContext.parallelize(dataInt, 1).toDF("v") + val resultInt = dataInt.map(e => Row(e)).toArray + val schemaInt = StructType(Seq(StructField("v", IntegerType, false))) + val dfIntNull = sparkContext.parallelize(Seq[java.lang.Integer](1, null, 3), 1).toDF("v") + readAndWriteWithSchema(schemaInt, dfInt, resultInt, dfIntNull) + + val dataDouble = Seq(1.1D, 2.2D, 3.3D) + val dfDouble = sparkContext.parallelize(dataDouble, 1).toDF("v") + val resultDouble = dataDouble.map(e => Row(e)).toArray + val schemaDouble = StructType(Seq(StructField("v", DoubleType, false))) + val dfDoubleNull = sparkContext.parallelize(Seq[java.lang.Double](1.1D, null, 3.3D), 1) + .toDF("v") + readAndWriteWithSchema(schemaDouble, dfDouble, resultDouble, dfDoubleNull) + + val dataString = Seq("a", "b", "cd") + val dfString = sparkContext.parallelize(dataString, 1).toDF("v") + val resultString = dataString.map(e => Row(e)).toArray + val schemaString = StructType(Seq(StructField("v", StringType, false))) + val dfStringNull = sparkContext.parallelize(Seq("a", null, "cd"), 1).toDF("v") + readAndWriteWithSchema(schemaString, dfString, resultString, dfStringNull) + + val dataCaseClass = Seq(PointStr("a", "b"), PointStr("c", "d")) + val dfCaseClass = sparkContext.parallelize(dataCaseClass, 1).toDF + val resultCaseClass = dataCaseClass.map(e => Row(e.x, e.y)).toArray + val schemaCaseClass = StructType( + Seq(StructField("x", StringType, false), StructField("y", StringType, false))) + val dfCaseClassNull = sparkContext.parallelize( + Seq(PointStr("a", "b"), PointStr("c", null)), 1).toDF + readAndWriteWithSchema(schemaCaseClass, dfCaseClass, resultCaseClass, dfCaseClassNull) + } + } + } + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c785aca985820..6146432317235 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1374,7 +1374,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv checkAnswer(spark.table("old"), Row(1, "a")) val expectedSchema = StructType(Seq( - StructField("i", IntegerType, nullable = true), + StructField("i", IntegerType, nullable = false), StructField("j", StringType, nullable = true))) assert(table("old").schema === expectedSchema) }