Skip to content

Commit 1b2ff9f

Browse files
committed
Merge remote-tracking branch 'upstream/master' into SPARK-2375
2 parents 10794eb + 9d5ecf8 commit 1b2ff9f

File tree

8 files changed

+80
-29
lines changed

8 files changed

+80
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
5252
override lazy val resolved = false
5353

5454
override def newInstance = this
55+
override def withNullability(newNullability: Boolean) = this
5556
override def withQualifiers(newQualifiers: Seq[String]) = this
5657

5758
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
@@ -95,6 +96,7 @@ case class Star(
9596
override lazy val resolved = false
9697

9798
override def newInstance = this
99+
override def withNullability(newNullability: Boolean) = this
98100
override def withQualifiers(newQualifiers: Seq[String]) = this
99101

100102
def expand(input: Seq[Attribute]): Seq[NamedExpression] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)
3333

3434
type EvaluatedType = Any
3535

36-
def nullable = baseReference.nullable
37-
def dataType = baseReference.dataType
38-
def exprId = baseReference.exprId
39-
def qualifiers = baseReference.qualifiers
40-
def name = baseReference.name
36+
override def nullable = baseReference.nullable
37+
override def dataType = baseReference.dataType
38+
override def exprId = baseReference.exprId
39+
override def qualifiers = baseReference.qualifiers
40+
override def name = baseReference.name
4141

42-
def newInstance = BoundReference(ordinal, baseReference.newInstance)
43-
def withQualifiers(newQualifiers: Seq[String]) =
42+
override def newInstance = BoundReference(ordinal, baseReference.newInstance)
43+
override def withNullability(newNullability: Boolean) =
44+
BoundReference(ordinal, baseReference.withNullability(newNullability))
45+
override def withQualifiers(newQualifiers: Seq[String]) =
4446
BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
4547

4648
override def toString = s"$baseReference:$ordinal"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression {
5757
abstract class Attribute extends NamedExpression {
5858
self: Product =>
5959

60+
def withNullability(newNullability: Boolean): Attribute
6061
def withQualifiers(newQualifiers: Seq[String]): Attribute
6162

6263
def toAttribute = this
@@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
133134
/**
134135
* Returns a copy of this [[AttributeReference]] with changed nullability.
135136
*/
136-
def withNullability(newNullability: Boolean) = {
137+
override def withNullability(newNullability: Boolean) = {
137138
if (nullable == newNullability) {
138139
this
139140
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
5252
* - Inserting Projections beneath the following operators:
5353
* - Aggregate
5454
* - Project <- Join
55+
* - LeftSemiJoin
5556
* - Collapse adjacent projections, performing alias substitution.
5657
*/
5758
object ColumnPruning extends Rule[LogicalPlan] {
@@ -62,19 +63,22 @@ object ColumnPruning extends Rule[LogicalPlan] {
6263

6364
// Eliminate unneeded attributes from either side of a Join.
6465
case Project(projectList, Join(left, right, joinType, condition)) =>
65-
// Collect the list of off references required either above or to evaluate the condition.
66+
// Collect the list of all references required either above or to evaluate the condition.
6667
val allReferences: Set[Attribute] =
6768
projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty)
6869

6970
/** Applies a projection only when the child is producing unnecessary attributes */
70-
def prunedChild(c: LogicalPlan) =
71-
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
72-
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
73-
} else {
74-
c
75-
}
71+
def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
7672

77-
Project(projectList, Join(prunedChild(left), prunedChild(right), joinType, condition))
73+
Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
74+
75+
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
76+
case Join(left, right, LeftSemi, condition) =>
77+
// Collect the list of all references required to evaluate the condition.
78+
val allReferences: Set[Attribute] =
79+
condition.map(_.references).getOrElse(Set.empty)
80+
81+
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
7882

7983
// Combine adjacent Projects.
8084
case Project(projectList1, Project(projectList2, child)) =>
@@ -97,6 +101,14 @@ object ColumnPruning extends Rule[LogicalPlan] {
97101
// Eliminate no-op Projects
98102
case Project(projectList, child) if child.output == projectList => child
99103
}
104+
105+
/** Applies a projection only when the child is producing unnecessary attributes */
106+
private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) =
107+
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
108+
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
109+
} else {
110+
c
111+
}
100112
}
101113

102114
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
21+
import org.apache.spark.sql.catalyst.plans._
2222
import org.apache.spark.sql.catalyst.types._
2323

2424
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@@ -46,10 +46,16 @@ case class Generate(
4646
child: LogicalPlan)
4747
extends UnaryNode {
4848

49-
protected def generatorOutput: Seq[Attribute] =
50-
alias
49+
protected def generatorOutput: Seq[Attribute] = {
50+
val output = alias
5151
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
5252
.getOrElse(generator.output)
53+
if (join && outer) {
54+
output.map(_.withNullability(true))
55+
} else {
56+
output
57+
}
58+
}
5359

5460
override def output =
5561
if (join) child.output ++ generatorOutput else generatorOutput
@@ -81,11 +87,20 @@ case class Join(
8187
condition: Option[Expression]) extends BinaryNode {
8288

8389
override def references = condition.map(_.references).getOrElse(Set.empty)
84-
override def output = joinType match {
85-
case LeftSemi =>
86-
left.output
87-
case _ =>
88-
left.output ++ right.output
90+
91+
override def output = {
92+
joinType match {
93+
case LeftSemi =>
94+
left.output
95+
case LeftOuter =>
96+
left.output ++ right.output.map(_.withNullability(true))
97+
case RightOuter =>
98+
left.output.map(_.withNullability(true)) ++ right.output
99+
case FullOuter =>
100+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
101+
case _ =>
102+
left.output ++ right.output
103+
}
89104
}
90105
}
91106

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ case class Aggregate(
8383
case a: AggregateExpression =>
8484
ComputedAggregate(
8585
a,
86-
BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
87-
AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
86+
BindReferences.bindReference(a, childOutput),
87+
AttributeReference(s"aggResult:$a", a.dataType, a.nullable)())
8888
}
8989
}.toArray
9090

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21-
import org.apache.spark.sql.catalyst.expressions.{Generator, JoinedRow, Literal, Projection}
21+
import org.apache.spark.sql.catalyst.expressions._
2222

2323
/**
2424
* :: DeveloperApi ::
@@ -39,8 +39,16 @@ case class Generate(
3939
child: SparkPlan)
4040
extends UnaryNode {
4141

42+
protected def generatorOutput: Seq[Attribute] = {
43+
if (join && outer) {
44+
generator.output.map(_.withNullability(true))
45+
} else {
46+
generator.output
47+
}
48+
}
49+
4250
override def output =
43-
if (join) child.output ++ generator.output else generator.output
51+
if (join) child.output ++ generatorOutput else generatorOutput
4452

4553
override def execute() = {
4654
if (join) {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,18 @@ case class BroadcastNestedLoopJoin(
319319

320320
override def otherCopyArgs = sqlContext :: Nil
321321

322-
def output = left.output ++ right.output
322+
override def output = {
323+
joinType match {
324+
case LeftOuter =>
325+
left.output ++ right.output.map(_.withNullability(true))
326+
case RightOuter =>
327+
left.output.map(_.withNullability(true)) ++ right.output
328+
case FullOuter =>
329+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
330+
case _ =>
331+
left.output ++ right.output
332+
}
333+
}
323334

324335
/** The Streamed Relation */
325336
def left = streamed

0 commit comments

Comments
 (0)