diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index faa165c298d0..8f7583c48fca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -97,22 +97,29 @@ object ExpressionEncoder { } val newSerializer = CreateStruct(serializers) + def nullSafe(input: Expression, result: Expression): Expression = { + If(IsNull(input), Literal.create(null, result.dataType), result) + } + val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) - val deserializers = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct assert(getColExprs.size == 1, "object deserializer should have only one " + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") val input = GetStructField(newDeserializerInput, index) - enc.objDeserializer.transformUp { + val childDeserializer = enc.objDeserializer.transformUp { case GetColumnByOrdinal(0, _) => input } - } - val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false) - def nullSafe(input: Expression, result: Expression): Expression = { - If(IsNull(input), Literal.create(null, result.dataType), result) + if (enc.objSerializer.nullable) { + nullSafe(input, childDeserializer) + } else { + childDeserializer + } } + val newDeserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( nullSafe(newSerializerInput, newSerializer), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4aca7c8a5a66..75cee4078197 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} @@ -2429,6 +2430,50 @@ class DatasetSuite extends QueryTest assert(parquetFiles.size === 10) } } + + test("SPARK-37829: DataFrame outer join") { + // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + + val leftFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val rightFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val expectedSchema = StructType( + Seq( + StructField( + "_1", + leftFieldSchema, + nullable = false + ), + // This is a left join, so the right output is nullable: + StructField( + "_2", + rightFieldSchema + ) + ) + ) + assert(joined.schema === expectedSchema) + + val result = joined.collect().toSet + val expected = Set( + new GenericRowWithSchema(Array("a", 1), leftFieldSchema) -> + null, + new GenericRowWithSchema(Array("b", 2), leftFieldSchema) -> + new GenericRowWithSchema(Array("x", 2), rightFieldSchema) + ) + assert(result == expected) + } } class DatasetLargeResultCollectingSuite extends QueryTest