Skip to content

Commit c2b7533

Browse files
committed
Fix Exchange and initial code gen attempt.
1 parent aa7120e commit c2b7533

File tree

5 files changed

+50
-28
lines changed

5 files changed

+50
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Broadcast.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,26 @@ import org.apache.spark.broadcast
2323
import org.apache.spark.rdd.{EmptyRDD, RDD}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.Attribute
26+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
27+
import org.apache.spark.sql.execution.metric.SQLMetrics
2628
import org.apache.spark.util.ThreadUtils
2729

2830
/**
2931
* A broadcaster collects transforms and broadcasts the result of an underlying spark plan.
3032
*
3133
* TODO whole stage codegen.
3234
*/
33-
case class Broadcast(f: Iterable[InternalRow] => Any, child: SparkPlan) extends UnaryNode {
35+
case class Broadcast(
36+
f: Iterable[InternalRow] => Any,
37+
child: SparkPlan)
38+
extends UnaryNode with CodegenSupport {
39+
3440
override def output: Seq[Attribute] = child.output
3541

42+
override private[sql] lazy val metrics = Map(
43+
"numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")
44+
)
45+
3646
val timeout: Duration = {
3747
val timeoutValue = sqlContext.conf.broadcastTimeout
3848
if (timeoutValue < 0) {
@@ -66,7 +76,11 @@ case class Broadcast(f: Iterable[InternalRow] => Any, child: SparkPlan) extends
6676
Await.result(future, timeout)
6777
}
6878

79+
override def upstream(): RDD[InternalRow] = {
80+
child.asInstanceOf[CodegenSupport].upstream()
81+
}
6982

83+
override def doProduce(ctx: CodegenContext): String = ""
7084

7185
override protected def doPrepare(): Unit = {
7286
// Materialize the relation.
@@ -86,9 +100,9 @@ case class Broadcast(f: Iterable[InternalRow] => Any, child: SparkPlan) extends
86100
object Broadcast {
87101
def broadcastRelation[T](plan: SparkPlan): broadcast.Broadcast[T] = plan match {
88102
case builder: Broadcast => builder.broadcastRelation
89-
case _ => sys.error("The given plan is not a Broadcaster")
103+
case _ => sys.error("The given plan is not a Broadcast")
90104
}
91105

92106
private[execution] val executionContext = ExecutionContext.fromExecutorService(
93107
ThreadUtils.newDaemonCachedThreadPool("build-broadcast", 128))
94-
}
108+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,19 +385,24 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
385385
// Ensure that the operator's children satisfy their output distribution requirements:
386386
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
387387
distribution match {
388-
case BroadcastDistribution(f) =>
389-
Broadcast(f, child)
390388
case _ if child.outputPartitioning.satisfies(distribution) =>
391389
child
390+
case BroadcastDistribution(f) =>
391+
Broadcast(f, child)
392392
case _ =>
393393
Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
394394
}
395395
}
396396

397397
// If the operator has multiple children and specifies child output distributions (e.g. join),
398398
// then the children's output partitionings must be compatible:
399+
def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
400+
case UnspecifiedDistribution => false
401+
case BroadcastDistribution(_) => false
402+
case _ => true
403+
}
399404
if (children.length > 1
400-
&& requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
405+
&& requiredChildDistributions.exists(requireCompatiblePartitioning)
401406
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
402407

403408
// First check if the existing partitions of the children all match. This means they are
@@ -434,8 +439,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
434439

435440
children.zip(requiredChildDistributions).map {
436441
case (child, distribution) => {
437-
val targetPartitioning =
438-
createPartitioning(distribution, numPartitions)
442+
val targetPartitioning = createPartitioning(distribution, numPartitions)
439443
if (child.outputPartitioning.guarantees(targetPartitioning)) {
440444
child
441445
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,19 @@ case class BroadcastHashJoin(
4242
right: SparkPlan)
4343
extends BinaryNode with HashJoin with CodegenSupport {
4444

45+
val streamSideName = buildSide match {
46+
case BuildLeft => "right"
47+
case BuildRight => "left"
48+
}
49+
4550
override private[sql] lazy val metrics = Map(
46-
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
47-
"numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
51+
"numStreamRows" -> SQLMetrics.createLongMetric(sparkContext, s"number of $streamSideName rows"),
4852
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
4953

5054
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
5155

5256
override def requiredChildDistribution: Seq[Distribution] = buildSide match {
53-
case BuildLeft => longMetric("numLeftRows")
57+
case BuildLeft =>
5458
BroadcastDistribution(buildRelation) :: UnspecifiedDistribution :: Nil
5559
case BuildRight =>
5660
UnspecifiedDistribution :: BroadcastDistribution(buildRelation) :: Nil
@@ -61,10 +65,7 @@ case class BroadcastHashJoin(
6165
}
6266

6367
protected override def doExecute(): RDD[InternalRow] = {
64-
val numStreamedRows = buildSide match {
65-
case BuildLeft => longMetric("numRightRows")
66-
case BuildRight => longMetric("numLeftRows")
67-
}
68+
val numStreamedRows = longMetric("numStreamRows")
6869
val numOutputRows = longMetric("numOutputRows")
6970

7071
val broadcastRelation = Broadcast.broadcastRelation[HashedRelation](buildPlan)

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,18 @@ case class BroadcastHashOuterJoin(
4040
left: SparkPlan,
4141
right: SparkPlan) extends BinaryNode with HashOuterJoin {
4242

43+
private[this] def failOnWrongJoinType(jt: JoinType): Nothing = {
44+
throw new IllegalArgumentException(s"HashOuterJoin should not take $jt as the JoinType")
45+
}
46+
47+
val streamSideName = joinType match {
48+
case RightOuter => "right"
49+
case LeftOuter => "left"
50+
case jt => failOnWrongJoinType(jt)
51+
}
52+
4353
override private[sql] lazy val metrics = Map(
44-
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
45-
"numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
54+
"numStreamRows" -> SQLMetrics.createLongMetric(sparkContext, s"number of $streamSideName rows"),
4655
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
4756

4857
override def requiredChildDistribution: Seq[Distribution] = joinType match {
@@ -51,7 +60,7 @@ case class BroadcastHashOuterJoin(
5160
case LeftOuter =>
5261
UnspecifiedDistribution :: BroadcastDistribution(buildRelation) :: Nil
5362
case x =>
54-
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
63+
failOnWrongJoinType(x)
5564
}
5665

5766
private val buildRelation: Iterable[InternalRow] => HashedRelation = { input =>
@@ -61,13 +70,7 @@ case class BroadcastHashOuterJoin(
6170
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
6271

6372
override def doExecute(): RDD[InternalRow] = {
64-
val numStreamedRows = joinType match {
65-
case RightOuter => longMetric("numRightRows")
66-
case LeftOuter => longMetric("numLeftRows")
67-
case x =>
68-
throw new IllegalArgumentException(
69-
s"HashOuterJoin should not take $x as the JoinType")
70-
}
73+
val numStreamedRows = longMetric("numStreamRows")
7174
val numOutputRows = longMetric("numOutputRows")
7275

7376
val broadcastRelation = Broadcast.broadcastRelation[UnsafeHashedRelation](buildPlan)
@@ -102,8 +105,7 @@ case class BroadcastHashOuterJoin(
102105
})
103106

104107
case x =>
105-
throw new IllegalArgumentException(
106-
s"BroadcastHashOuterJoin should not take $x as the JoinType")
108+
failOnWrongJoinType(x)
107109
}
108110
}
109111
}

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
4343
val conf = new SparkConf()
4444
.setMaster("local-cluster[2,1,1024]")
4545
.setAppName("testing")
46+
.set("spark.sql.codegen.wholeStage", "false")
4647
val sc = new SparkContext(conf)
4748
sqlContext = new SQLContext(sc)
4849
}
@@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
6263
// Comparison at the end is for broadcast left semi join
6364
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
6465
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
65-
val plan = df3.queryExecution.sparkPlan
66+
val plan = df3.queryExecution.executedPlan
6667
assert(plan.collect { case p: T => p }.size === 1)
6768
plan.executeCollect()
6869
}

0 commit comments

Comments
 (0)