From b2ea6aa502118bc63dbef5d477f81f5519e622a1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 01:43:08 +0800 Subject: [PATCH 1/9] Makes BoundReference respect nullability --- .../catalyst/expressions/BoundAttribute.scala | 27 ++++- .../codegen/GenerateProjection.scala | 4 +- .../apache/spark/sql/types/StructType.scala | 8 +- .../org/apache/spark/sql/DatasetSuite.scala | 106 ++++++++++++++++++ 4 files changed, 135 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ff1f28ddbbf3..4728277cd4ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -69,10 +69,29 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + + if (nullable) { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value}; + if (!${ev.isNull}) { + ${ev.value} = ($value); + } else { + throw new RuntimeException( + "Null value appeared in non-nullable field: " + + "ordinal=$ordinal, dataType=${dataType.simpleString}. " + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + ); + } + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index f229f2000d8e..9355a9f37b6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -42,7 +42,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) - // Make Mutablility optional... + // Make Mutability optional... protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { @@ -65,7 +65,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { """ }.mkString("\n") - val getCases = (0 until expressions.size).map { i => + val getCases = expressions.indices.map { i => s"case $i: return c$i;" }.mkString("\n") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9778df271ddd..2b6a200dcbfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -358,10 +358,10 @@ object StructType extends AbstractDataType { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } .orElse(Some(leftField)) .foreach(newFields += _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8f8db318261d..2d2a3b253880 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -23,6 +23,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class DatasetSuite extends QueryTest with SharedSQLContext { @@ -489,12 +490,117 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } + + def testNonNullable[T: Encoder](name: String, schema: StructType, rows: Row*): Unit = { + test(s"non-nullable field - $name") { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + val ds = sqlContext.createDataFrame(rowRDD, schema).as[T] + val message = intercept[RuntimeException](ds.collect()).getMessage + assert(message.contains("Null value appeared in non-nullable field")) + } + } + + testNonNullable[ClassData]( + "scala.Int", + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), + Row("hello", 1: Integer), + Row("world", null) + ) + + testNonNullable[NestedClassData]( + "struct", + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = false) + )), + Row(Row("hello", 1: Integer)), + Row(null) + ) + + ignore("non-nullable field in nested struct") { + testNonNullable[NestedClassData]( + "non-nullable field in nested struct", + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = false) + )), + Row(Row("hello", 1: Integer)), + Row(Row("hello", null)) + ) + } + + testNonNullable[NestedNonNullableArray]( + "array", + StructType(Seq( + StructField("a", ArrayType(StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + ))), nullable = false) + )), + Row(Seq(Row("hello", 1))), + Row(null) + ) + + ignore("non-nullable element in nested array") { + testNonNullable[NestedNonNullableArray]( + "non-nullable element in nested array", + StructType(Seq( + StructField("a", ArrayType(StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), containsNull = false), nullable = false) + )), + Row(Seq(Row("hello", 1), null)) + ) + } + + testNonNullable[NestedNonNullableMap]( + "map", + StructType(Seq( + StructField("a", MapType( + IntegerType, + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )) + ), nullable = false) + )), + Row(Map(1 -> Row("hello", 1))), + Row(null) + ) + + ignore("non-nullable value in nested map") { + testNonNullable[NestedNonNullableMap]( + "non-nullable value in nested map", + StructType(Seq( + StructField("a", MapType( + IntegerType, + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )) + ), nullable = false) + )), + Row(Map(1 -> Row("hello", 1), 2 -> null)) + ) + } } case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) +case class NestedClassData(a: ClassData) +case class NestedNonNullableArray(a: Array[ClassData]) +case class NestedNonNullableMap(a: scala.collection.Map[Int, ClassData]) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. From 7b1600ce8268c517b7c2f9246af6da6b1973b7ef Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 18:41:12 +0800 Subject: [PATCH 2/9] Makes CentralMomentAgg nullable --- .../sql/catalyst/expressions/aggregate/CentralMomentAgg.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index d07d4c338cdf..30f602227b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType From b84eb6a52d7e036ce183d547720bce39d7fc0fcd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 19:15:17 +0800 Subject: [PATCH 3/9] Fixes SPARK-12336 --- .../org/apache/spark/sql/DataFrame.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd4826677..a4e9083a3b60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -498,28 +498,28 @@ class DataFrame private[sql]( def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. - val joined = sqlContext.executePlan( + val innerJoined = sqlContext.executePlan( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] // Project only one of the join columns. - val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) + val joinedCols = AttributeSet(usingColumns.map(col => withPlan(innerJoined.right).resolve(col))) + val condition = usingColumns.map { col => catalyst.expressions.EqualTo( - withPlan(joined.left).resolve(col), - withPlan(joined.right).resolve(col)) + withPlan(innerJoined.left).resolve(col), + withPlan(innerJoined.right).resolve(col)) }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => catalyst.expressions.And(cond, eqTo) } withPlan { - Project( - joined.output.filterNot(joinedCols.contains(_)), - Join( - joined.left, - joined.right, - joinType = JoinType(joinType), - condition) - ) + val joined = Join( + innerJoined.left, + innerJoined.right, + joinType = JoinType(joinType), + condition) + + Project(joined.output.filterNot(joinedCols.contains(_)), joined) } } From 02ad8703ff3ce36851d4ce8d027d24a3eece6e78 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 19:21:06 +0800 Subject: [PATCH 4/9] Comments for SPARK-12336 --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index a4e9083a3b60..1cf21b22ea46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -502,6 +502,10 @@ class DataFrame private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] // Project only one of the join columns. + // + // SPARK-12336: For outer joins, attributes of at least one child plan output will be forced to + // be nullable. An `AttributeSet` is necessary so that we are not affected by different + // nullability values. val joinedCols = AttributeSet(usingColumns.map(col => withPlan(innerJoined.right).resolve(col))) val condition = usingColumns.map { col => From aa968e663a8ce618a68035be7c7c5b2294ed5360 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 20:48:43 +0800 Subject: [PATCH 5/9] Addresses comments --- .../org/apache/spark/sql/DataFrame.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1cf21b22ea46..053bb5da12aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -498,32 +498,29 @@ class DataFrame private[sql]( def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. - val innerJoined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + val joined = sqlContext.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None) + ).analyzed.asInstanceOf[Join] // Project only one of the join columns. // // SPARK-12336: For outer joins, attributes of at least one child plan output will be forced to // be nullable. An `AttributeSet` is necessary so that we are not affected by different // nullability values. - val joinedCols = AttributeSet(usingColumns.map(col => withPlan(innerJoined.right).resolve(col))) + val joinedCols = AttributeSet(usingColumns.map(col => withPlan(joined.right).resolve(col))) val condition = usingColumns.map { col => catalyst.expressions.EqualTo( - withPlan(innerJoined.left).resolve(col), - withPlan(innerJoined.right).resolve(col)) + withPlan(joined.left).resolve(col), + withPlan(joined.right).resolve(col)) }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => catalyst.expressions.And(cond, eqTo) } withPlan { - val joined = Join( - innerJoined.left, - innerJoined.right, - joinType = JoinType(joinType), - condition) - - Project(joined.output.filterNot(joinedCols.contains(_)), joined) + Project( + joined.output.filterNot(joinedCols.contains(_)), + joined.copy(condition = condition)) } } From d84478ef201fc60564330a56c10353262547473f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 21:00:15 +0800 Subject: [PATCH 6/9] Fixes SPARK-12341 --- .../scala/org/apache/spark/sql/execution/datasources/ddl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e7deeff13dc4..e759c011e75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -42,7 +42,7 @@ case class DescribeCommand( new MetadataBuilder().putString("comment", "name of the column").build())(), AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = false, + AttributeReference("comment", StringType, nullable = true, new MetadataBuilder().putString("comment", "comment of the column").build())() ) } From 05c36e5cd1ec2c82a5d9fdf2c5750bc5a591b2d7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 15 Dec 2015 21:10:44 +0800 Subject: [PATCH 7/9] Fixes SPARK-12342 --- .../apache/spark/sql/catalyst/expressions/aggregate/Corr.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 00d7436b710d..d25f3335ffd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +41,7 @@ case class Corr( override def children: Seq[Expression] = Seq(left, right) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType From d540b86c307666ff9bb73e8ab436fbd366243183 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 16 Dec 2015 16:57:43 +0800 Subject: [PATCH 8/9] Reverts BoundReference changes --- .../catalyst/expressions/BoundAttribute.scala | 27 +---- .../org/apache/spark/sql/DatasetSuite.scala | 106 ------------------ 2 files changed, 4 insertions(+), 129 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4728277cd4ae..ff1f28ddbbf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -69,29 +69,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - - if (nullable) { - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ - } else { - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value}; - if (!${ev.isNull}) { - ${ev.value} = ($value); - } else { - throw new RuntimeException( - "Null value appeared in non-nullable field: " + - "ordinal=$ordinal, dataType=${dataType.simpleString}. " + - "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + - "please try to use scala.Option[_] or other nullable types " + - "(e.g. java.lang.Integer instead of int/scala.Int)." - ); - } - """ - } + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2d2a3b253880..8f8db318261d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -23,7 +23,6 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ class DatasetSuite extends QueryTest with SharedSQLContext { @@ -490,117 +489,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } - - def testNonNullable[T: Encoder](name: String, schema: StructType, rows: Row*): Unit = { - test(s"non-nullable field - $name") { - val rowRDD = sqlContext.sparkContext.parallelize(rows) - val ds = sqlContext.createDataFrame(rowRDD, schema).as[T] - val message = intercept[RuntimeException](ds.collect()).getMessage - assert(message.contains("Null value appeared in non-nullable field")) - } - } - - testNonNullable[ClassData]( - "scala.Int", - StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )), - Row("hello", 1: Integer), - Row("world", null) - ) - - testNonNullable[NestedClassData]( - "struct", - StructType(Seq( - StructField("a", StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )), nullable = false) - )), - Row(Row("hello", 1: Integer)), - Row(null) - ) - - ignore("non-nullable field in nested struct") { - testNonNullable[NestedClassData]( - "non-nullable field in nested struct", - StructType(Seq( - StructField("a", StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )), nullable = false) - )), - Row(Row("hello", 1: Integer)), - Row(Row("hello", null)) - ) - } - - testNonNullable[NestedNonNullableArray]( - "array", - StructType(Seq( - StructField("a", ArrayType(StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - ))), nullable = false) - )), - Row(Seq(Row("hello", 1))), - Row(null) - ) - - ignore("non-nullable element in nested array") { - testNonNullable[NestedNonNullableArray]( - "non-nullable element in nested array", - StructType(Seq( - StructField("a", ArrayType(StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )), containsNull = false), nullable = false) - )), - Row(Seq(Row("hello", 1), null)) - ) - } - - testNonNullable[NestedNonNullableMap]( - "map", - StructType(Seq( - StructField("a", MapType( - IntegerType, - StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )) - ), nullable = false) - )), - Row(Map(1 -> Row("hello", 1))), - Row(null) - ) - - ignore("non-nullable value in nested map") { - testNonNullable[NestedNonNullableMap]( - "non-nullable value in nested map", - StructType(Seq( - StructField("a", MapType( - IntegerType, - StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) - )) - ), nullable = false) - )), - Row(Map(1 -> Row("hello", 1), 2 -> null)) - ) - } } case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) -case class NestedClassData(a: ClassData) -case class NestedNonNullableArray(a: Array[ClassData]) -case class NestedNonNullableMap(a: scala.collection.Map[Int, ClassData]) - /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. From 7e35e377303d038c8fd26c72fd062e7498ea49ea Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 16 Dec 2015 17:44:34 +0800 Subject: [PATCH 9/9] Test case for SPARK-12336 --- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c70397f9853a..8a4f5d84a36d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -43,15 +43,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } test("join - join using multiple columns and specifying join type") { - val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df1 = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + val join1 = df1.join(df2, Seq("int", "str"), "left") + assert(join1.schema.map(_.nullable) === Seq(false, false, true, true)) checkAnswer( - df.join(df2, Seq("int", "str"), "left"), + join1, Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + val join2 = df1.join(df2, Seq("int", "str"), "right") + assert(join2.schema.map(_.nullable) === Seq(true, true, true, false)) checkAnswer( - df.join(df2, Seq("int", "str"), "right"), + join2, Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) }