diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 056d435eecd23..6ed822dc70d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -179,23 +179,38 @@ private[sql] object SparkSqlSerializer2 { /** * Check if rows with the given schema can be serialized with ShuffleSerializer. + * Right now, we do not support a schema having complex types or UDTs, or all data types + * of fields are NullTypes. */ def support(schema: Array[DataType]): Boolean = { if (schema == null) return true + var allNullTypes = true var i = 0 while (i < schema.length) { schema(i) match { - case udt: UserDefinedType[_] => return false - case array: ArrayType => return false - case map: MapType => return false - case struct: StructType => return false + case NullType => // Do nothing + case udt: UserDefinedType[_] => + allNullTypes = false + return false + case array: ArrayType => + allNullTypes = false + return false + case map: MapType => + allNullTypes = false + return false + case struct: StructType => + allNullTypes = false + return false case _ => + allNullTypes = false } i += 1 } - return true + // If types of fields are all NullTypes, we return false. + // Otherwise, we return true. + return !allNullTypes } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 8631e247c6c05..71f6b26bcd01a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -42,7 +42,6 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { } checkSupported(null, isSupported = true) - checkSupported(NullType, isSupported = true) checkSupported(BooleanType, isSupported = true) checkSupported(ByteType, isSupported = true) checkSupported(ShortType, isSupported = true) @@ -57,6 +56,8 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { checkSupported(DecimalType(10, 5), isSupported = true) checkSupported(DecimalType.Unlimited, isSupported = true) + // If NullType is the only data type in the schema, we do not support it. + checkSupported(NullType, isSupported = false) // For now, ArrayType, MapType, and StructType are not supported. checkSupported(ArrayType(DoubleType, true), isSupported = false) checkSupported(ArrayType(StringType, false), isSupported = false) @@ -170,6 +171,23 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) } + + test("types of fields are all NullTypes") { + // Test range partitioning code path. + val nulls = ctx.sql(s"SELECT null as a, null as b, null as c") + val df = nulls.unionAll(nulls).sort("a") + checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + checkAnswer( + df, + Row(null, null, null) :: Row(null, null, null) :: Nil) + + // Test hash partitioning code path. + val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle") + checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + checkAnswer( + oneRow, + Row(null, null, null)) + } } /** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */