Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ object RowEncoder {
ClassTag(cls))
}

/**
* Returns an ExpressionEncoder allowing null top-level rows.
* @param exprEnc an ExpressionEncoder[Row].
* @return an ExpressionEncoder[Row] whom deserializer supports null values.
*
* @see SPARK-37829
*/
private[sql] def nullSafe(exprEnc: ExpressionEncoder[Row]): ExpressionEncoder[Row] = {
val newDeserializerInput = GetColumnByOrdinal(0, exprEnc.objSerializer.dataType)
val newDeserializer: Expression = if (exprEnc.objSerializer.nullable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry the code here is a bit confusing. We check exprEnc.objSerializer.nullable and then we construct IsNull(newDeserializerInput)? What's their connection?

If(
IsNull(newDeserializerInput),
Literal.create(null, exprEnc.objDeserializer.dataType),
exprEnc.objDeserializer)
} else {
exprEnc.objDeserializer
}
exprEnc.copy(
objDeserializer = newDeserializer
)
}

private def serializerFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
Expand Down
16 changes: 14 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1173,8 +1173,20 @@ class Dataset[T] private[sql](
joined = resolveSelfJoinCondition(joined)
}

implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
// SPARK-37829: an outer-join requires the null semantics to represent missing keys.
// As we might be running on DataFrames, we need a custom encoder that will properly
// handle null top-level Rows.
def nullSafe[V](exprEnc: ExpressionEncoder[V]): ExpressionEncoder[V] = {
if (exprEnc.clsTag.runtimeClass != classOf[Row]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit ugly.

I've tried simply wrapping CreateExternalRow with a null check and a number of tests started failing as they were assuming top-level rows couldn't be null.

Are they UT or end-to-end tests? If they are UT, we can simply update the tests because we have changed the assumption.

exprEnc
} else {
RowEncoder.nullSafe(exprEnc.asInstanceOf[ExpressionEncoder[Row]])
.asInstanceOf[ExpressionEncoder[V]]
}
}
implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(
nullSafe(this.exprEnc), nullSafe(other.exprEnc)
)

val leftResultExpr = {
if (!this.exprEnc.isSerializedAsStructForTopLevel) {
Expand Down
44 changes: 44 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.TestUtils.withListener
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
Expand Down Expand Up @@ -2147,6 +2148,49 @@ class DatasetSuite extends QueryTest
(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13))
}

test("SPARK-37829: DataFrame outer join") {
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")

val leftFieldSchema = StructType(
Seq(
StructField("a", StringType),
StructField("b", IntegerType, nullable = false)
)
)
val rightFieldSchema = StructType(
Seq(
StructField("a", StringType),
StructField("b", IntegerType, nullable = false)
)
)
val expectedSchema = StructType(
Seq(
StructField(
"_1",
leftFieldSchema,
nullable = false
),
// This is a left join, so the right output is nullable:
StructField(
"_2",
rightFieldSchema
)
)
)
assert(joined.schema === expectedSchema)

val result = joined.collect().toSet
val expected = Set(
new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
null,
new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
new GenericRowWithSchema(Array("x", 2), rightFieldSchema)
)
assert(result == expected)
}
}

case class Bar(a: Int)
Expand Down