From 94d99ae46c9bc7375242c23bd99b1ad442a84efb Mon Sep 17 00:00:00 2001 From: --global Date: Wed, 19 Apr 2023 22:05:04 +0800 Subject: [PATCH] [SPARK-37829][SQL] Dataframe.joinWith outer-join should return a null value for unmatched row MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? When doing an outer join with joinWith on DataFrames, unmatched rows return Row objects with null fields instead of a single null value. This is not a expected behavior, and it's a regression introduced in [this commit](https://github.com/apache/spark/commit/cd92f25be5a221e0d4618925f7bc9dfd3bb8cb59). This pull request aims to fix the regression, note this is not a full rollback of the commit, do not add back "schema" variable. ``` case class ClassData(a: String, b: Int) val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF left.joinWith(right, left("b") === right("b"), "left_outer").collect ``` ``` Wrong results (current behavior): Array(([a,1],[null,null]), ([b,2],[x,2])) Correct results: Array(([a,1],null), ([b,2],[x,2])) ``` ### Why are the changes needed? We need to address the regression mentioned above. It results in unexpected behavior changes in the Dataframe joinWith API between versions 2.4.8 and 3.0.0+. This could potentially cause data correctness issues for users who expect the old behavior when using Spark 3.0.0+. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit test (use the same test in previous [closed pull request](https://github.com/apache/spark/pull/35140), credit to Clément de Groc) Run sql-core and sql-catalyst submodules locally with ./build/mvn clean package -pl sql/core,sql/catalyst Closes #40755 from kings129/encoder_bug_fix. Authored-by: --global Signed-off-by: Wenchen Fan --- .../catalyst/encoders/ExpressionEncoder.scala | 19 +++++--- .../org/apache/spark/sql/DatasetSuite.scala | 45 +++++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c97dfe1970c1a..e6477e48fe967 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -110,22 +110,29 @@ object ExpressionEncoder { } val newSerializer = CreateStruct(serializers) + def nullSafe(input: Expression, result: Expression): Expression = { + If(IsNull(input), Literal.create(null, result.dataType), result) + } + val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) - val deserializers = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct assert(getColExprs.size == 1, "object deserializer should have only one " + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") val input = GetStructField(newDeserializerInput, index) - enc.objDeserializer.transformUp { + val childDeserializer = enc.objDeserializer.transformUp { case GetColumnByOrdinal(0, _) => input } - } - val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false) - def nullSafe(input: Expression, result: Expression): Expression = { - If(IsNull(input), Literal.create(null, result.dataType), result) + if (enc.objSerializer.nullable) { + nullSafe(input, childDeserializer) + } else { + childDeserializer + } } + val newDeserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( nullSafe(newSerializerInput, newSerializer), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f5e736621ebbe..43322b6dc9725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.TestUtils.withListener import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} @@ -2152,6 +2153,50 @@ class DatasetSuite extends QueryTest assert(parquetFiles.size === 10) } } + + test("SPARK-37829: DataFrame outer join") { + // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + + val leftFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val rightFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val expectedSchema = StructType( + Seq( + StructField( + "_1", + leftFieldSchema, + nullable = false + ), + // This is a left join, so the right output is nullable: + StructField( + "_2", + rightFieldSchema + ) + ) + ) + assert(joined.schema === expectedSchema) + + val result = joined.collect().toSet + val expected = Set( + new GenericRowWithSchema(Array("a", 1), leftFieldSchema) -> + null, + new GenericRowWithSchema(Array("b", 2), leftFieldSchema) -> + new GenericRowWithSchema(Array("x", 2), rightFieldSchema) + ) + assert(result == expected) + } } case class Bar(a: Int)