Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*
* 1. If A and B have the same name and data type, they are merged to a field C with the same name
* and data type. C is nullable if and only if either A or B is nullable.
* 2. If A doesn't exist in `that`, it's included in the result schema.
* 3. If B doesn't exist in `this`, it's also included in the result schema.
* 2. If A doesn't exist in `that`, it's included in the result schema with nullable.
* 3. If B doesn't exist in `this`, it's also included in the result schema with nullable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

* 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be
* thrown.
*/
Expand Down Expand Up @@ -473,7 +473,7 @@ object StructType extends AbstractDataType {
nullable = leftNullable || rightNullable)
}
.orElse {
Some(leftField)
Some(leftField.copy(nullable = true))
}
.foreach(newFields += _)
}
Expand All @@ -482,7 +482,7 @@ object StructType extends AbstractDataType {
rightFields
.filterNot(f => leftMapped.get(f.name).nonEmpty)
.foreach { f =>
newFields += f
newFields += f.copy(nullable = true)
}

StructType(newFields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private ColumnarBatch columnarBatch;

/**
* Schema corresponds to columnarBatch
*/
private StructType columnarBatchSchema;

/**
* If true, this class returns batches instead of rows.
*/
Expand Down Expand Up @@ -178,6 +183,7 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns,
}
}

columnarBatchSchema = batchSchema;
columnarBatch = ColumnarBatch.allocate(batchSchema, memMode);
if (partitionColumns != null) {
int partitionIdx = sparkSchema.fields().length;
Expand Down Expand Up @@ -228,6 +234,12 @@ public boolean nextBatch() throws IOException {
for (int i = 0; i < columnReaders.length; ++i) {
if (columnReaders[i] == null) continue;
columnReaders[i].readBatch(num, columnarBatch.column(i));
StructField schema = columnarBatchSchema.fields()[i];
if (columnarBatch.column(i).anyNullsSet() && !schema.nullable()) {
throw new UnsupportedOperationException(
"Should not contain null for non-nullable " + schema.dataType() +
" schema at column index " + i);
}
}
rowsReturned += num;
columnarBatch.setNumRows(num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ case class DataSource(
HadoopFsRelation(
fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema.asNullable,
dataSchema =
if (format.isInstanceOf[ParquetFileFormat]) dataSchema else dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
caseInsensitiveOptions)(sparkSession)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ private[parquet] class ParquetRowConverter(
var i = 0
while (i < currentRow.numFields) {
fieldConverters(i).updater.end()
if (currentRow.isNullAt(i) && !catalystType(i).nullable) {
throw new UnsupportedOperationException(
"Should not contain null for non-nullable " + catalystType(i).dataType +
" schema at column index " + i)
}
i += 1
}
updater.set(currentRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class PlannerSuite extends SharedSQLContext {

withTempView("testPushed") {
val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan
assert(exp.toString.contains("PushedFilters: [IsNotNull(key), EqualTo(key,15)]"))
assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]"))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
pathToNonPartitionedTable,
userSpecifiedSchema = None,
userSpecifiedPartitionCols = partitionCols,
expectedSchema = new StructType().add("num", IntegerType).add("str", StringType),
expectedSchema = new StructType().add("num", IntegerType, false).add("str", StringType),
expectedPartitionCols = Seq.empty[String])
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
readParquetFile(path.toString) { df =>
assertResult(df.schema) {
StructType(
StructField("a", BooleanType, nullable = true) ::
StructField("b", IntegerType, nullable = true) ::
StructField("a", BooleanType, nullable = false) ::
StructField("b", IntegerType, nullable = false) ::
Nil)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import java.io.File
import java.util.Locale
import java.util.concurrent.ConcurrentLinkedQueue

import org.apache.commons.lang3.exception.ExceptionUtils
import org.scalatest.BeforeAndAfter

import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -123,6 +125,7 @@ class MessageCapturingCommitProtocol(jobId: String, path: String)
}
}

case class PointStr(x: String, y: String)

class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
Expand Down Expand Up @@ -680,6 +683,63 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}
}

private def readAndWriteWithSchema(schema: StructType,
df: DataFrame, result: Array[Row], dfNull: DataFrame): Unit = {
val fmt = "parquet"
withTempDir { dir =>
val path = new File(dir, "nonnull").getCanonicalPath
df.write.format(fmt).save(path)
val dfRead = spark.read.format(fmt).schema(schema).load(path)
checkAnswer(dfRead, result)
assert(dfRead.schema.equals(schema))

val pathNull = new File(dir, "null").getCanonicalPath
dfNull.write.format(fmt).save(pathNull)
val e = intercept[Exception] {
spark.read.format(fmt).schema(schema).load(pathNull).collect
}
assert(ExceptionUtils.getRootCause(e).isInstanceOf[UnsupportedOperationException] &&
e.getMessage.contains("Should not contain null for non-nullable"))
}
}

test("SPARK-19950: loadWithSchema") {
Seq("true", "false").foreach { vectorized =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
val dataInt = Seq(1, 2, 3)
val dfInt = sparkContext.parallelize(dataInt, 1).toDF("v")
val resultInt = dataInt.map(e => Row(e)).toArray
val schemaInt = StructType(Seq(StructField("v", IntegerType, false)))
val dfIntNull = sparkContext.parallelize(Seq[java.lang.Integer](1, null, 3), 1).toDF("v")
readAndWriteWithSchema(schemaInt, dfInt, resultInt, dfIntNull)

val dataDouble = Seq(1.1D, 2.2D, 3.3D)
val dfDouble = sparkContext.parallelize(dataDouble, 1).toDF("v")
val resultDouble = dataDouble.map(e => Row(e)).toArray
val schemaDouble = StructType(Seq(StructField("v", DoubleType, false)))
val dfDoubleNull = sparkContext.parallelize(Seq[java.lang.Double](1.1D, null, 3.3D), 1)
.toDF("v")
readAndWriteWithSchema(schemaDouble, dfDouble, resultDouble, dfDoubleNull)

val dataString = Seq("a", "b", "cd")
val dfString = sparkContext.parallelize(dataString, 1).toDF("v")
val resultString = dataString.map(e => Row(e)).toArray
val schemaString = StructType(Seq(StructField("v", StringType, false)))
val dfStringNull = sparkContext.parallelize(Seq("a", null, "cd"), 1).toDF("v")
readAndWriteWithSchema(schemaString, dfString, resultString, dfStringNull)

val dataCaseClass = Seq(PointStr("a", "b"), PointStr("c", "d"))
val dfCaseClass = sparkContext.parallelize(dataCaseClass, 1).toDF
val resultCaseClass = dataCaseClass.map(e => Row(e.x, e.y)).toArray
val schemaCaseClass = StructType(
Seq(StructField("x", StringType, false), StructField("y", StringType, false)))
val dfCaseClassNull = sparkContext.parallelize(
Seq(PointStr("a", "b"), PointStr("c", null)), 1).toDF
readAndWriteWithSchema(schemaCaseClass, dfCaseClass, resultCaseClass, dfCaseClassNull)
}
}
}

test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
checkAnswer(spark.table("old"), Row(1, "a"))

val expectedSchema = StructType(Seq(
StructField("i", IntegerType, nullable = true),
StructField("i", IntegerType, nullable = false),
StructField("j", StringType, nullable = true)))
assert(table("old").schema === expectedSchema)
}
Expand Down