diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c97dfe1970c1a..94bd1c9b7da30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -87,9 +87,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => + StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable) + }) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true) val serializers = encoders.zipWithIndex.map { case (enc, index) => val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + @@ -97,39 +101,42 @@ object ExpressionEncoder { val originalInputObject = boundRefs.head val newInputObject = Invoke( - newSerializerInput, + BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", originalInputObject.dataType, returnNullable = originalInputObject.nullable) val newSerializer = enc.objSerializer.transformUp { - case BoundReference(0, _, _) => newInputObject + case b: BoundReference => newInputObject } Alias(newSerializer, s"_${index + 1}")() } - val newSerializer = CreateStruct(serializers) - val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) - val deserializers = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct assert(getColExprs.size == 1, "object deserializer should have only one " + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") - val input = GetStructField(newDeserializerInput, index) - enc.objDeserializer.transformUp { + val input = GetStructField(GetColumnByOrdinal(0, schema), index) + val newDeserializer = enc.objDeserializer.transformUp { case GetColumnByOrdinal(0, _) => input } + if (schema(index).nullable) { + If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer) + } else { + newDeserializer + } } - val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false) - def nullSafe(input: Expression, result: Expression): Expression = { - If(IsNull(input), Literal.create(null, result.dataType), result) - } + val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)), + Literal.create(null, schema), CreateStruct(serializers)) + val deserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( - nullSafe(newSerializerInput, newSerializer), - nullSafe(newDeserializerInput, newDeserializer), + serializer, + deserializer, ClassTag(cls)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2ce0754a5d1e7..2723122f25c17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.TestUtils.withListener import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} @@ -2147,6 +2148,50 @@ class DatasetSuite extends QueryTest (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) } + + test("SPARK-37829: DataFrame outer join") { + // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + + val leftFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val rightFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val expectedSchema = StructType( + Seq( + StructField( + "_1", + leftFieldSchema, + nullable = false + ), + // This is a left join, so the right output is nullable: + StructField( + "_2", + rightFieldSchema + ) + ) + ) + assert(joined.schema === expectedSchema) + + val result = joined.collect().toSet + val expected = Set( + new GenericRowWithSchema(Array("a", 1), leftFieldSchema) -> + null, + new GenericRowWithSchema(Array("b", 2), leftFieldSchema) -> + new GenericRowWithSchema(Array("x", 2), rightFieldSchema) + ) + assert(result == expected) + } } case class Bar(a: Int)