From ce6295bf51d4171c6556ceb56cb31402085ef4b9 Mon Sep 17 00:00:00 2001 From: Clement de Groc Date: Fri, 7 Jan 2022 13:27:27 +0100 Subject: [PATCH 1/2] [SPARK-37829][SQL][TESTS] Add a test demonstrating joinWith outer join issue --- .../org/apache/spark/sql/DatasetSuite.scala | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2ce0754a5d1e..1b48dd0c0122 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.TestUtils.withListener import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} @@ -2147,6 +2148,49 @@ class DatasetSuite extends QueryTest (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) } + + test("SPARK-37829: DataFrame outer join") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + + val leftFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val rightFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val expectedSchema = StructType( + Seq( + StructField( + "_1", + leftFieldSchema, + nullable = false + ), + // This is a left join, so the right output is nullable: + StructField( + "_2", + rightFieldSchema + ) + ) + ) + assert(joined.schema === expectedSchema) + + val result = joined.collect().toSet + val expected = Set( + new GenericRowWithSchema(Array("a", 1), leftFieldSchema) -> + null, + new GenericRowWithSchema(Array("b", 2), leftFieldSchema) -> + new GenericRowWithSchema(Array("x", 2), rightFieldSchema) + ) + assert(result == expected) + } } case class Bar(a: Int) From 3831f769b588f4deb49fe8a4fdad78f35cc995d5 Mon Sep 17 00:00:00 2001 From: Clement de Groc Date: Mon, 17 Jan 2022 11:49:00 +0100 Subject: [PATCH 2/2] [SPARK-37829][SQL] Add if(isnull ...) check for DataFrame.joinWith Wrap tuple fields deserializers in null checks when calling on DataFrames as top-level rows are not nullable and won't propagate null values. --- .../sql/catalyst/encoders/RowEncoder.scala | 22 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 16 ++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d34d9531c3f3..36e31fcc9352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -77,6 +77,28 @@ object RowEncoder { ClassTag(cls)) } + /** + * Returns an ExpressionEncoder allowing null top-level rows. + * @param exprEnc an ExpressionEncoder[Row]. + * @return an ExpressionEncoder[Row] whom deserializer supports null values. + * + * @see SPARK-37829 + */ + private[sql] def nullSafe(exprEnc: ExpressionEncoder[Row]): ExpressionEncoder[Row] = { + val newDeserializerInput = GetColumnByOrdinal(0, exprEnc.objSerializer.dataType) + val newDeserializer: Expression = if (exprEnc.objSerializer.nullable) { + If( + IsNull(newDeserializerInput), + Literal.create(null, exprEnc.objDeserializer.dataType), + exprEnc.objDeserializer) + } else { + exprEnc.objDeserializer + } + exprEnc.copy( + objDeserializer = newDeserializer + ) + } + private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9dd38d850e32..6de24e6b987e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1173,8 +1173,20 @@ class Dataset[T] private[sql]( joined = resolveSelfJoinCondition(joined) } - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + // SPARK-37829: an outer-join requires the null semantics to represent missing keys. + // As we might be running on DataFrames, we need a custom encoder that will properly + // handle null top-level Rows. + def nullSafe[V](exprEnc: ExpressionEncoder[V]): ExpressionEncoder[V] = { + if (exprEnc.clsTag.runtimeClass != classOf[Row]) { + exprEnc + } else { + RowEncoder.nullSafe(exprEnc.asInstanceOf[ExpressionEncoder[Row]]) + .asInstanceOf[ExpressionEncoder[V]] + } + } + implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple( + nullSafe(this.exprEnc), nullSafe(other.exprEnc) + ) val leftResultExpr = { if (!this.exprEnc.isSerializedAsStructForTopLevel) {