Skip to content

Commit 3ace83a

Browse files
committed
Add withNullability to Attribute and use it to change nullabilities.
1 parent df1ae53 commit 3ace83a

File tree

6 files changed

+23
-47
lines changed

6 files changed

+23
-47
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
5252
override lazy val resolved = false
5353

5454
override def newInstance = this
55+
override def withNullability(newNullability: Boolean) = this
5556
override def withQualifiers(newQualifiers: Seq[String]) = this
5657

5758
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
@@ -95,6 +96,7 @@ case class Star(
9596
override lazy val resolved = false
9697

9798
override def newInstance = this
99+
override def withNullability(newNullability: Boolean) = this
98100
override def withQualifiers(newQualifiers: Seq[String]) = this
99101

100102
def expand(input: Seq[Attribute]): Seq[NamedExpression] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)
3333

3434
type EvaluatedType = Any
3535

36-
def nullable = baseReference.nullable
37-
def dataType = baseReference.dataType
38-
def exprId = baseReference.exprId
39-
def qualifiers = baseReference.qualifiers
40-
def name = baseReference.name
36+
override def nullable = baseReference.nullable
37+
override def dataType = baseReference.dataType
38+
override def exprId = baseReference.exprId
39+
override def qualifiers = baseReference.qualifiers
40+
override def name = baseReference.name
4141

42-
def newInstance = BoundReference(ordinal, baseReference.newInstance)
43-
def withQualifiers(newQualifiers: Seq[String]) =
42+
override def newInstance = BoundReference(ordinal, baseReference.newInstance)
43+
override def withNullability(newNullability: Boolean) =
44+
BoundReference(ordinal, baseReference.withNullability(newNullability))
45+
override def withQualifiers(newQualifiers: Seq[String]) =
4446
BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
4547

4648
override def toString = s"$baseReference:$ordinal"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression {
5757
abstract class Attribute extends NamedExpression {
5858
self: Product =>
5959

60+
def withNullability(newNullability: Boolean): Attribute
6061
def withQualifiers(newQualifiers: Seq[String]): Attribute
6162

6263
def toAttribute = this
@@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
133134
/**
134135
* Returns a copy of this [[AttributeReference]] with changed nullability.
135136
*/
136-
def withNullability(newNullability: Boolean) = {
137+
override def withNullability(newNullability: Boolean) = {
137138
if (nullable == newNullability) {
138139
this
139140
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,7 @@ case class Generate(
5151
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
5252
.getOrElse(generator.output)
5353
if (join && outer) {
54-
output.map {
55-
case attr if !attr.resolved => attr
56-
case attr if !attr.nullable =>
57-
AttributeReference(
58-
attr.name, attr.dataType, nullable = true)(attr.exprId, attr.qualifiers)
59-
case attr => attr
60-
}
54+
output.map(_.withNullability(true))
6155
} else {
6256
output
6357
}
@@ -93,26 +87,17 @@ case class Join(
9387
condition: Option[Expression]) extends BinaryNode {
9488

9589
override def references = condition.map(_.references).getOrElse(Set.empty)
96-
override def output = {
97-
def nullabilize(output: Seq[Attribute]) = {
98-
output.map {
99-
case attr if !attr.resolved => attr
100-
case attr if !attr.nullable =>
101-
AttributeReference(
102-
attr.name, attr.dataType, nullable = true)(attr.exprId, attr.qualifiers)
103-
case attr => attr
104-
}
105-
}
10690

91+
override def output = {
10792
joinType match {
10893
case LeftSemi =>
10994
left.output
11095
case LeftOuter =>
111-
left.output ++ nullabilize(right.output)
96+
left.output ++ right.output.map(_.withNullability(true))
11297
case RightOuter =>
113-
nullabilize(left.output) ++ right.output
98+
left.output.map(_.withNullability(true)) ++ right.output
11499
case FullOuter =>
115-
nullabilize(left.output) ++ nullabilize(right.output)
100+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
116101
case _ =>
117102
left.output ++ right.output
118103
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ case class Generate(
4141

4242
protected def generatorOutput: Seq[Attribute] = {
4343
if (join && outer) {
44-
generator.output.map {
45-
case attr if !attr.nullable =>
46-
AttributeReference(
47-
attr.name, attr.dataType, nullable = true)(attr.exprId, attr.qualifiers)
48-
case attr => attr
49-
}
44+
generator.output.map(_.withNullability(true))
5045
} else {
5146
generator.output
5247
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -319,23 +319,14 @@ case class BroadcastNestedLoopJoin(
319319

320320
override def otherCopyArgs = sqlContext :: Nil
321321

322-
def output = {
323-
def nullabilize(output: Seq[Attribute]) = {
324-
output.map {
325-
case attr if !attr.nullable =>
326-
AttributeReference(
327-
attr.name, attr.dataType, nullable = true)(attr.exprId, attr.qualifiers)
328-
case attr => attr
329-
}
330-
}
331-
322+
override def output = {
332323
joinType match {
333324
case LeftOuter =>
334-
left.output ++ nullabilize(right.output)
325+
left.output ++ right.output.map(_.withNullability(true))
335326
case RightOuter =>
336-
nullabilize(left.output) ++ right.output
327+
left.output.map(_.withNullability(true)) ++ right.output
337328
case FullOuter =>
338-
nullabilize(left.output) ++ nullabilize(right.output)
329+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
339330
case _ =>
340331
left.output ++ right.output
341332
}

0 commit comments

Comments
 (0)