Skip to content

Commit fec10f0

Browse files
committed
[SPARK-9085][SQL] Remove LeafNode, UnaryNode, BinaryNode from TreeNode.
This builds on apache#7433 but also removes LeafNode/UnaryNode. These are slightly more complicated to remove. I had to change some abstract classes to traits in order for it to work. The problem with LeafNode/UnaryNode is that they are often mixed in at the end of an Expression, and then the toString function actually gets resolved to the ones defined in TreeNode, rather than in Expression. Author: Reynold Xin <[email protected]> Closes apache#7434 from rxin/remove-binary-unary-leaf-node and squashes the following commits: 9e8a4de [Reynold Xin] Generator should not be foldable. 3135a8b [Reynold Xin] SortOrder should not be foldable. 9c589cf [Reynold Xin] Fixed one more test case... 2225331 [Reynold Xin] Aggregate expressions should not be foldable. 16b5c90 [Reynold Xin] [SPARK-9085][SQL] Remove LeafNode, UnaryNode, BinaryNode from TreeNode.
1 parent 43dac2c commit fec10f0

File tree

11 files changed

+69
-66
lines changed

11 files changed

+69
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ case class UnresolvedRelation(
5050
/**
5151
* Holds the name of an attribute that has yet to be resolved.
5252
*/
53-
case class UnresolvedAttribute(nameParts: Seq[String])
54-
extends Attribute with trees.LeafNode[Expression] {
53+
case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute {
5554

5655
def name: String =
5756
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
@@ -96,7 +95,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
9695
* Represents all of the input attributes to a given relational operator, for example in
9796
* "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
9897
*/
99-
trait Star extends NamedExpression with trees.LeafNode[Expression] {
98+
abstract class Star extends LeafExpression with NamedExpression {
10099
self: Product =>
101100

102101
override def name: String = throw new UnresolvedException(this, "name")
@@ -151,7 +150,7 @@ case class UnresolvedStar(table: Option[String]) extends Star {
151150
* @param names the names to be associated with each output of computing [[child]].
152151
*/
153152
case class MultiAlias(child: Expression, names: Seq[String])
154-
extends NamedExpression with trees.UnaryNode[Expression] {
153+
extends UnaryExpression with NamedExpression {
155154

156155
override def name: String = throw new UnresolvedException(this, "name")
157156

@@ -210,8 +209,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
210209
/**
211210
* Holds the expression that has yet to be aliased.
212211
*/
213-
case class UnresolvedAlias(child: Expression) extends NamedExpression
214-
with trees.UnaryNode[Expression] {
212+
case class UnresolvedAlias(child: Expression) extends UnaryExpression with NamedExpression {
215213

216214
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
217215
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
3030
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
3131
*/
3232
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
33-
extends NamedExpression with trees.LeafNode[Expression] {
33+
extends LeafExpression with NamedExpression {
3434

3535
override def toString: String = s"input[$ordinal]"
3636

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
2222
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
23-
import org.apache.spark.sql.catalyst.trees
2423
import org.apache.spark.sql.catalyst.trees.TreeNode
2524
import org.apache.spark.sql.types._
2625

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ case object Descending extends SortDirection
3030
* An expression that can be used to sort a tuple. This class extends expression primarily so that
3131
* transformations over expression will descend into its child.
3232
*/
33-
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
34-
with trees.UnaryNode[Expression] {
33+
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
34+
35+
/** Sort order is not foldable because we don't have an eval for it. */
36+
override def foldable: Boolean = false
3537

3638
override def dataType: DataType = child.dataType
3739
override def nullable: Boolean = child.nullable

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@ package org.apache.spark.sql.catalyst.expressions
2020
import com.clearspring.analytics.stream.cardinality.HyperLogLog
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.trees
2423
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2524
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2625
import org.apache.spark.sql.catalyst.util.TypeUtils
2726
import org.apache.spark.sql.types._
2827
import org.apache.spark.util.collection.OpenHashSet
2928

30-
abstract class AggregateExpression extends Expression {
29+
trait AggregateExpression extends Expression {
3130
self: Product =>
3231

32+
/**
33+
* Aggregate expressions should not be foldable.
34+
*/
35+
override def foldable: Boolean = false
36+
3337
/**
3438
* Creates a new instance that can be used to compute this aggregate expression for a group
3539
* of input rows/
@@ -60,7 +64,7 @@ case class SplitEvaluation(
6064
* An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
6165
* These partial evaluations can then be combined to compute the actual answer.
6266
*/
63-
abstract class PartialAggregate extends AggregateExpression {
67+
trait PartialAggregate extends AggregateExpression {
6468
self: Product =>
6569

6670
/**
@@ -74,7 +78,7 @@ abstract class PartialAggregate extends AggregateExpression {
7478
* [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
7579
*/
7680
abstract class AggregateFunction
77-
extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
81+
extends LeafExpression with AggregateExpression with Serializable {
7882
self: Product =>
7983

8084
/** Base should return the generic aggregate expression that this function is computing */
@@ -91,7 +95,7 @@ abstract class AggregateFunction
9195
}
9296
}
9397

94-
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
98+
case class Min(child: Expression) extends UnaryExpression with PartialAggregate {
9599

96100
override def nullable: Boolean = true
97101
override def dataType: DataType = child.dataType
@@ -124,7 +128,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
124128
override def eval(input: InternalRow): Any = currentMin.value
125129
}
126130

127-
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
131+
case class Max(child: Expression) extends UnaryExpression with PartialAggregate {
128132

129133
override def nullable: Boolean = true
130134
override def dataType: DataType = child.dataType
@@ -157,7 +161,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
157161
override def eval(input: InternalRow): Any = currentMax.value
158162
}
159163

160-
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
164+
case class Count(child: Expression) extends UnaryExpression with PartialAggregate {
161165

162166
override def nullable: Boolean = false
163167
override def dataType: LongType.type = LongType
@@ -310,7 +314,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
310314
}
311315

312316
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
313-
extends AggregateExpression with trees.UnaryNode[Expression] {
317+
extends UnaryExpression with AggregateExpression {
314318

315319
override def nullable: Boolean = false
316320
override def dataType: DataType = HyperLogLogUDT
@@ -340,7 +344,7 @@ case class ApproxCountDistinctPartitionFunction(
340344
}
341345

342346
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
343-
extends AggregateExpression with trees.UnaryNode[Expression] {
347+
extends UnaryExpression with AggregateExpression {
344348

345349
override def nullable: Boolean = false
346350
override def dataType: LongType.type = LongType
@@ -368,7 +372,7 @@ case class ApproxCountDistinctMergeFunction(
368372
}
369373

370374
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
371-
extends PartialAggregate with trees.UnaryNode[Expression] {
375+
extends UnaryExpression with PartialAggregate {
372376

373377
override def nullable: Boolean = false
374378
override def dataType: LongType.type = LongType
@@ -386,7 +390,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
386390
override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
387391
}
388392

389-
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
393+
case class Average(child: Expression) extends UnaryExpression with PartialAggregate {
390394

391395
override def prettyName: String = "avg"
392396

@@ -479,7 +483,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
479483
}
480484
}
481485

482-
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
486+
case class Sum(child: Expression) extends UnaryExpression with PartialAggregate {
483487

484488
override def nullable: Boolean = true
485489

@@ -606,8 +610,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
606610
}
607611
}
608612

609-
case class SumDistinct(child: Expression)
610-
extends PartialAggregate with trees.UnaryNode[Expression] {
613+
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate {
611614

612615
def this() = this(null)
613616
override def nullable: Boolean = true
@@ -701,7 +704,7 @@ case class CombineSetsAndSumFunction(
701704
}
702705
}
703706

704-
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
707+
case class First(child: Expression) extends UnaryExpression with PartialAggregate {
705708
override def nullable: Boolean = true
706709
override def dataType: DataType = child.dataType
707710
override def toString: String = s"FIRST($child)"
@@ -729,7 +732,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
729732
override def eval(input: InternalRow): Any = result
730733
}
731734

732-
case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
735+
case class Last(child: Expression) extends UnaryExpression with PartialAggregate {
733736
override def references: AttributeSet = child.references
734737
override def nullable: Boolean = true
735738
override def dataType: DataType = child.dataType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,14 @@ import org.apache.spark.sql.types._
4040
* requested. The attributes produced by this function will be automatically copied anytime rules
4141
* result in changes to the Generator or its children.
4242
*/
43-
abstract class Generator extends Expression {
44-
self: Product =>
43+
trait Generator extends Expression { self: Product =>
4544

4645
// TODO ideally we should return the type of ArrayType(StructType),
4746
// however, we don't keep the output field names in the Generator.
4847
override def dataType: DataType = throw new UnsupportedOperationException
4948

49+
override def foldable: Boolean = false
50+
5051
override def nullable: Boolean = false
5152

5253
/**
@@ -99,8 +100,9 @@ case class UserDefinedGenerator(
99100
/**
100101
* Given an input array produces a sequence of rows for each value in the array.
101102
*/
102-
case class Explode(child: Expression)
103-
extends Generator with trees.UnaryNode[Expression] {
103+
case class Explode(child: Expression) extends UnaryExpression with Generator {
104+
105+
override def children: Seq[Expression] = child :: Nil
104106

105107
override def checkInputDataTypes(): TypeCheckResult = {
106108
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
@@ -127,6 +129,4 @@ case class Explode(child: Expression)
127129
else inputMap.map { case (k, v) => InternalRow(k, v) }
128130
}
129131
}
130-
131-
override def toString: String = s"explode($child)"
132132
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@ object NamedExpression {
3737
*/
3838
case class ExprId(id: Long)
3939

40-
abstract class NamedExpression extends Expression {
41-
self: Product =>
40+
/**
41+
* An [[Expression]] that is named.
42+
*/
43+
trait NamedExpression extends Expression { self: Product =>
44+
45+
/** We should never fold named expressions in order to not remove the alias. */
46+
override def foldable: Boolean = false
4247

4348
def name: String
4449
def exprId: ExprId
@@ -78,8 +83,7 @@ abstract class NamedExpression extends Expression {
7883
}
7984
}
8085

81-
abstract class Attribute extends NamedExpression {
82-
self: Product =>
86+
abstract class Attribute extends LeafExpression with NamedExpression { self: Product =>
8387

8488
override def references: AttributeSet = AttributeSet(this)
8589

@@ -110,7 +114,7 @@ case class Alias(child: Expression, name: String)(
110114
val exprId: ExprId = NamedExpression.newExprId,
111115
val qualifiers: Seq[String] = Nil,
112116
val explicitMetadata: Option[Metadata] = None)
113-
extends NamedExpression with trees.UnaryNode[Expression] {
117+
extends UnaryExpression with NamedExpression {
114118

115119
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
116120
override lazy val resolved =
@@ -172,7 +176,8 @@ case class AttributeReference(
172176
nullable: Boolean = true,
173177
override val metadata: Metadata = Metadata.empty)(
174178
val exprId: ExprId = NamedExpression.newExprId,
175-
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
179+
val qualifiers: Seq[String] = Nil)
180+
extends Attribute {
176181

177182
/**
178183
* Returns true iff the expression id is the same for both attributes.
@@ -242,7 +247,7 @@ case class AttributeReference(
242247
* A place holder used when printing expressions without debugging information such as the
243248
* expression id or the unresolved indicator.
244249
*/
245-
case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
250+
case class PrettyAttribute(name: String) extends Attribute {
246251

247252
override def toString: String = name
248253

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
342342
case l: Literal => l
343343

344344
// Fold expressions that are foldable.
345-
case e if e.foldable => Literal.create(e.eval(null), e.dataType)
345+
case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
346346

347347
// Fold "literal in (item1, item2, ..., literal, ...)" into true directly.
348348
case In(Literal(v, _), list) if list.exists {
@@ -361,7 +361,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
361361
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
362362
case q: LogicalPlan => q transformExpressionsDown {
363363
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
364-
val hSet = list.map(e => e.eval(null))
364+
val hSet = list.map(e => e.eval(EmptyRow))
365365
InSet(v, HashSet() ++ hSet)
366366
}
367367
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.QueryPlan
2525
import org.apache.spark.sql.catalyst.trees.TreeNode
26-
import org.apache.spark.sql.catalyst.trees
2726

2827

2928
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
@@ -277,15 +276,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
277276
/**
278277
* A logical plan node with no children.
279278
*/
280-
abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
279+
abstract class LeafNode extends LogicalPlan {
281280
self: Product =>
281+
282+
override def children: Seq[LogicalPlan] = Nil
282283
}
283284

284285
/**
285286
* A logical plan node with single child.
286287
*/
287-
abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] {
288+
abstract class UnaryNode extends LogicalPlan {
288289
self: Product =>
290+
291+
def child: LogicalPlan
292+
293+
override def children: Seq[LogicalPlan] = child :: Nil
289294
}
290295

291296
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,3 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
452452
s"$nodeName(${args.mkString(",")})"
453453
}
454454
}
455-
456-
457-
/**
458-
* A [[TreeNode]] with no children.
459-
*/
460-
trait LeafNode[BaseType <: TreeNode[BaseType]] {
461-
def children: Seq[BaseType] = Nil
462-
}
463-
464-
/**
465-
* A [[TreeNode]] with a single [[child]].
466-
*/
467-
trait UnaryNode[BaseType <: TreeNode[BaseType]] {
468-
def child: BaseType
469-
def children: Seq[BaseType] = child :: Nil
470-
}

0 commit comments

Comments
 (0)