Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ object ExpressionEncoder {
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
enc.deserializer.transformUp {
val deserialized = enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
}
If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag

Expand Down
67 changes: 49 additions & 18 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -746,31 +746,62 @@ class Dataset[T] private[sql](
*/
@Experimental
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
val left = this.logicalPlan
val right = other.logicalPlan

val joined = sparkSession.sessionState.executePlan(Join(left, right, joinType =
JoinType(joinType), Some(condition.expr)))
val leftOutput = joined.analyzed.output.take(left.output.length)
val rightOutput = joined.analyzed.output.takeRight(right.output.length)
// Creates a Join node and resolve it first, to get join condition resolved, self-join resolved,
// etc.
val joined = sparkSession.sessionState.executePlan(
Join(
this.logicalPlan,
other.logicalPlan,
JoinType(joinType),
Some(condition.expr))).analyzed.asInstanceOf[Join]

// For both join side, combine all outputs into a single column and alias it with "_1" or "_2",
// to match the schema for the encoder of the join result.
// Note that we do this before joining them, to enable the join operator to return null for one
// side, in cases like outer-join.
val left = {
val combined = if (this.unresolvedTEncoder.flat) {
assert(joined.left.output.length == 1)
Alias(joined.left.output.head, "_1")()
} else {
Alias(CreateStruct(joined.left.output), "_1")()
}
Project(combined :: Nil, joined.left)
}

val leftData = this.unresolvedTEncoder match {
case e if e.flat => Alias(leftOutput.head, "_1")()
case _ => Alias(CreateStruct(leftOutput), "_1")()
val right = {
val combined = if (other.unresolvedTEncoder.flat) {
assert(joined.right.output.length == 1)
Alias(joined.right.output.head, "_2")()
} else {
Alias(CreateStruct(joined.right.output), "_2")()
}
Project(combined :: Nil, joined.right)
}
val rightData = other.unresolvedTEncoder match {
case e if e.flat => Alias(rightOutput.head, "_2")()
case _ => Alias(CreateStruct(rightOutput), "_2")()

// Rewrites the join condition to make the attribute point to correct column/field, after we
// combine the outputs of each join side.
val conditionExpr = joined.condition.get transformUp {
case a: Attribute if joined.left.outputSet.contains(a) =>
if (this.unresolvedTEncoder.flat) {
left.output.head
} else {
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
GetStructField(left.output.head, index)
}
case a: Attribute if joined.right.outputSet.contains(a) =>
if (other.unresolvedTEncoder.flat) {
right.output.head
} else {
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
GetStructField(right.output.head, index)
}
}

implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)

withTypedPlan {
Project(
leftData :: rightData :: Nil,
joined.analyzed)
}
withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
}

/**
Expand Down
23 changes: 8 additions & 15 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(1, 1), (2, 2))
}

test("joinWith, expression condition, outer join") {
val nullInteger = null.asInstanceOf[Integer]
val nullString = null.asInstanceOf[String]
val ds1 = Seq(ClassNullableData("a", 1),
ClassNullableData("c", 3)).toDS()
val ds2 = Seq(("a", new Integer(1)),
("b", new Integer(2))).toDS()

checkDataset(
ds1.joinWith(ds2, $"_1" === $"a", "outer"),
(ClassNullableData("a", 1), ("a", new Integer(1))),
(ClassNullableData("c", 3), (nullString, nullInteger)),
(ClassNullableData(nullString, nullInteger), ("b", new Integer(2))))
}

test("joinWith tuple with primitive, expression") {
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()
Expand Down Expand Up @@ -783,6 +768,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds.filter(_.b > 1).collect().toSeq
}
}

test("SPARK-15441: Dataset outer join") {
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
val result = joined.collect().toSet
assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2)))
}
}

case class Generic[T](id: T, value: Double)
Expand Down