diff --git a/docs/sql-data-sources-parquet.md b/docs/sql-data-sources-parquet.md index 4fed3eaf83e5d..dcd2936518465 100644 --- a/docs/sql-data-sources-parquet.md +++ b/docs/sql-data-sources-parquet.md @@ -9,7 +9,7 @@ displayTitle: Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +of the original data. When reading Parquet files, all columns are automatically converted to be nullable for compatibility reasons. ### Loading Data Programmatically diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 237872585e11d..e45ab19aadbfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -23,6 +23,13 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.parquet.schema.PrimitiveType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.Type.Repetition import org.scalatest.BeforeAndAfter import org.apache.spark.SparkContext @@ -31,6 +38,7 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -522,11 +530,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Seq("json", "orc", "parquet", "csv").foreach { format => val schema = StructType( StructField("cl1", IntegerType, nullable = false).withComment("test") :: - StructField("cl2", IntegerType, nullable = true) :: - StructField("cl3", IntegerType, nullable = true) :: Nil) + StructField("cl2", IntegerType, nullable = true) :: + StructField("cl3", IntegerType, nullable = true) :: Nil) val row = Row(3, null, 4) val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + // if we write and then read, the read will enforce schema to be nullable val tableName = "tab" withTable(tableName) { df.write.format(format).mode("overwrite").saveAsTable(tableName) @@ -536,12 +545,41 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Row("cl1", "test") :: Nil) // Verify the schema val expectedFields = schema.fields.map(f => f.copy(nullable = true)) - assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) + assert(spark.table(tableName).schema === schema.copy(fields = expectedFields)) } } } } + test("parquet - column nullability -- write only") { + val schema = StructType( + StructField("cl1", IntegerType, nullable = false) :: + StructField("cl2", IntegerType, nullable = true) :: Nil) + val row = Row(3, 4) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), new Configuration()) + val f = ParquetFileReader.open(hadoopInputFile) + val parquetSchema = f.getFileMetaData.getSchema.getColumns.asScala + .map(_.getPrimitiveType) + f.close() + + // the write keeps nullable info from the schema + val expectedParquetSchema = Seq( + new PrimitiveType(Repetition.REQUIRED, PrimitiveTypeName.INT32, "cl1"), + new PrimitiveType(Repetition.OPTIONAL, PrimitiveTypeName.INT32, "cl2") + ) + + assert (expectedParquetSchema === parquetSchema) + } + + } + test("SPARK-17230: write out results of decimal calculation") { val df = spark.range(99, 101) .selectExpr("id", "cast(id as long) * cast('1.0' as decimal(38, 18)) as num")