Skip to content

Commit d2ad5c5

Browse files
committed
Refactor putting SQLContext into SparkPlan. Fix ordering, other test cases.
1 parent be2cd6b commit d2ad5c5

File tree

14 files changed

+104
-95
lines changed

14 files changed

+104
-95
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.codegen
1919

20+
import com.typesafe.scalalogging.slf4j.Logging
2021
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.types.{StringType, NumericType}
2123

2224
/**
2325
* Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
2426
* [[Expression Expressions]].
2527
*/
26-
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
28+
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
2729
import scala.reflect.runtime.{universe => ru}
2830
import scala.reflect.runtime.universe._
2931

@@ -40,6 +42,22 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
4042
val evalA = expressionEvaluator(order.child)
4143
val evalB = expressionEvaluator(order.child)
4244

45+
val compare = order.child.dataType match {
46+
case _: NumericType =>
47+
q"""
48+
val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
49+
if(comp != 0) {
50+
return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
51+
}
52+
"""
53+
case StringType =>
54+
if (order.direction == Ascending) {
55+
q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
56+
} else {
57+
q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
58+
}
59+
}
60+
4361
q"""
4462
i = $a
4563
..${evalA.code}
@@ -52,9 +70,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
5270
} else if (${evalB.nullTerm}) {
5371
return ${if (order.direction == Ascending) q"1" else q"-1"}
5472
} else {
55-
i = a
56-
val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
57-
if(comp != 0) return comp.toInt
73+
$compare
5874
}
5975
"""
6076
}
@@ -76,6 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
7692
}
7793
new $orderingName()
7894
"""
95+
logger.debug(s"Generated Ordering: $code")
7996
toolBox.eval(code).asInstanceOf[Ordering[Row]]
8097
}
8198
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,18 +304,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
304304
@transient
305305
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
306306
val batches =
307-
Batch("Add exchange", Once, AddExchange(self)) ::
308-
Batch("CodeGen", Once, TurnOnCodeGen) :: Nil
309-
}
310-
311-
protected object TurnOnCodeGen extends Rule[SparkPlan] {
312-
def apply(plan: SparkPlan): SparkPlan = {
313-
if (self.codegenEnabled) {
314-
plan.foreach(p => println(p.simpleString))
315-
plan.foreach(_._codegenEnabled = true)
316-
}
317-
plan
318-
}
307+
Batch("Add exchange", Once, AddExchange(self)) :: Nil
319308
}
320309

321310
/**
@@ -330,7 +319,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
330319
lazy val analyzed = analyzer(logical)
331320
lazy val optimizedPlan = optimizer(analyzed)
332321
// TODO: Don't just pick the first one...
333-
lazy val sparkPlan = planner(optimizedPlan).next()
322+
lazy val sparkPlan = {
323+
SparkPlan.currentContext.set(self)
324+
planner(optimizedPlan).next()
325+
}
334326
// executedPlan should not be used to initialize any SparkPlan. It should be
335327
// only used for execution.
336328
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ case class Aggregate(
4242
partial: Boolean,
4343
groupingExpressions: Seq[Expression],
4444
aggregateExpressions: Seq[NamedExpression],
45-
child: SparkPlan)(@transient sqlContext: SQLContext)
45+
child: SparkPlan)
4646
extends UnaryNode {
4747

4848
override def requiredChildDistribution =
@@ -56,8 +56,6 @@ case class Aggregate(
5656
}
5757
}
5858

59-
override def otherCopyArgs = sqlContext :: Nil
60-
6159
// HACK: Generators don't correctly preserve their output through serializations so we grab
6260
// out child's output attributes statically here.
6361
private[this] val childOutput = child.output

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ case class Generate(
5151
if (join) child.output ++ generatorOutput else generatorOutput
5252

5353
/** Codegenned rows are not serializable... */
54-
override def codegenEnabled = false
54+
override val codegenEnabled = false
5555

5656
override def execute() = {
57+
val boundGenerator = BindReferences.bindReference(generator, child.output)
58+
5759
if (join) {
5860
child.execute().mapPartitions { iter =>
5961
val nullValues = Seq.fill(generator.output.size)(Literal(null))
@@ -66,7 +68,7 @@ case class Generate(
6668
val joinedRow = new JoinedRow
6769

6870
iter.flatMap {row =>
69-
val outputRows = generator.eval(row)
71+
val outputRows = boundGenerator.eval(row)
7072
if (outer && outputRows.isEmpty) {
7173
outerProjection(row) :: Nil
7274
} else {
@@ -75,7 +77,7 @@ case class Generate(
7577
}
7678
}
7779
} else {
78-
child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
80+
child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
7981
}
8082
}
8183
}

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@ case class GeneratedAggregate(
4646
partial: Boolean,
4747
groupingExpressions: Seq[Expression],
4848
aggregateExpressions: Seq[NamedExpression],
49-
child: SparkPlan)(@transient sqlContext: SQLContext)
49+
child: SparkPlan)
5050
extends UnaryNode {
5151

52-
println(s"new $codegenEnabled")
53-
5452
override def requiredChildDistribution =
5553
if (partial) {
5654
UnspecifiedDistribution :: Nil
@@ -62,12 +60,9 @@ case class GeneratedAggregate(
6260
}
6361
}
6462

65-
override def otherCopyArgs = sqlContext :: Nil
66-
6763
override def output = aggregateExpressions.map(_.toAttribute)
6864

6965
override def execute() = {
70-
println(s"codegen: $codegenEnabled")
7166
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
7267
a.collect { case agg: AggregateExpression => agg}
7368
}
@@ -160,7 +155,6 @@ case class GeneratedAggregate(
160155
// TODO: Codegening anything other than the updateProjection is probably over kill.
161156
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
162157
var currentRow: Row = null
163-
println(codegenEnabled)
164158

165159
while (iter.hasNext) {
166160
currentRow = iter.next()
@@ -172,7 +166,6 @@ case class GeneratedAggregate(
172166
} else {
173167
val buffers = new java.util.HashMap[Row, MutableRow]()
174168

175-
println(codegenEnabled)
176169
var currentRow: Row = null
177170
while (iter.hasNext) {
178171
currentRow = iter.next()

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.Logging
2122
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.sql.{SQLContext, Logging, Row}
23+
import org.apache.spark.sql.{SQLContext, Row}
2324
import org.apache.spark.sql.catalyst.trees
2425
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2526
import org.apache.spark.sql.catalyst.expressions._
@@ -28,17 +29,35 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
2829
import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
2930
import org.apache.spark.sql.catalyst.plans.physical._
3031

32+
33+
object SparkPlan {
34+
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
35+
}
36+
3137
/**
3238
* :: DeveloperApi ::
3339
*/
3440
@DeveloperApi
35-
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
41+
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
3642
self: Product =>
3743

38-
def codegenEnabled = _codegenEnabled
44+
/**
45+
* A handle to the SQL Context that was used to create this plan. Since many operators need
46+
* access to the sqlContext for RDD operations or configuration this field is automatically
47+
* populated by the query planning infrastructure.
48+
*/
49+
@transient
50+
protected val sqlContext = SparkPlan.currentContext.get()
3951

40-
/** Will be set to true during planning if code generation should be used for this operator. */
41-
private[sql] var _codegenEnabled = false
52+
protected def sparkContext = sqlContext.sparkContext
53+
54+
def logger = log
55+
56+
val codegenEnabled: Boolean = if(sqlContext != null) {
57+
sqlContext.codegenEnabled
58+
} else {
59+
false
60+
}
4261

4362
// TODO: Move to `DistributedPlan`
4463
/** Specifies how data is partitioned across different nodes in the cluster. */
@@ -57,16 +76,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
5776
*/
5877
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
5978

60-
def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection =
79+
protected def newProjection(
80+
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
81+
log.debug(
82+
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
6183
if (codegenEnabled) {
6284
GenerateProjection(expressions, inputSchema)
6385
} else {
6486
new InterpretedProjection(expressions, inputSchema)
6587
}
88+
}
6689

67-
def newMutableProjection(
90+
protected def newMutableProjection(
6891
expressions: Seq[Expression],
6992
inputSchema: Seq[Attribute]): () => MutableProjection = {
93+
log.debug(
94+
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
7095
if(codegenEnabled) {
7196
GenerateMutableProjection(expressions, inputSchema)
7297
} else {
@@ -75,15 +100,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
75100
}
76101

77102

78-
def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
103+
protected def newPredicate(
104+
expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
79105
if (codegenEnabled) {
80106
GeneratePredicate(expression, inputSchema)
81107
} else {
82108
InterpretedPredicate(expression, inputSchema)
83109
}
84110
}
85111

86-
def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
112+
protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
87113
if (codegenEnabled) {
88114
GenerateOrdering(order, inputSchema)
89115
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
3939
// no predicate can be evaluated by matching hash keys
4040
case logical.Join(left, right, LeftSemi, condition) =>
4141
execution.LeftSemiJoinBNL(
42-
planLater(left), planLater(right), condition)(sqlContext) :: Nil
42+
planLater(left), planLater(right), condition) :: Nil
4343
case _ => Nil
4444
}
4545
}
@@ -58,7 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
5858
condition: Option[Expression],
5959
side: BuildSide) = {
6060
val broadcastHashJoin = execution.BroadcastHashJoin(
61-
leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
61+
leftKeys, rightKeys, side, planLater(left), planLater(right))
6262
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
6363
}
6464

@@ -118,7 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
118118
partial = true,
119119
groupingExpressions,
120120
partialComputation,
121-
planLater(child))(sqlContext))(sqlContext) :: Nil
121+
planLater(child))) :: Nil
122122

123123
// Cases where some aggregate can not be codegened
124124
case PartialAggregation(
@@ -135,7 +135,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
135135
partial = true,
136136
groupingExpressions,
137137
partialComputation,
138-
planLater(child))(sqlContext))(sqlContext) :: Nil
138+
planLater(child))) :: Nil
139139

140140
case _ => Nil
141141
}
@@ -153,7 +153,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
153153
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
154154
case logical.Join(left, right, joinType, condition) =>
155155
execution.BroadcastNestedLoopJoin(
156-
planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
156+
planLater(left), planLater(right), joinType, condition) :: Nil
157157
case _ => Nil
158158
}
159159
}
@@ -175,7 +175,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
175175
object TakeOrdered extends Strategy {
176176
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
177177
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
178-
execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
178+
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
179179
case _ => Nil
180180
}
181181
}
@@ -187,9 +187,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
187187
val relation =
188188
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
189189
// Note: overwrite=false because otherwise the metadata we just created will be deleted
190-
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
190+
InsertIntoParquetTable(relation, planLater(child), overwrite=false) :: Nil
191191
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
192-
InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
192+
InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
193193
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
194194
val prunePushedDownFilters =
195195
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
@@ -218,7 +218,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
218218
projectList,
219219
filters,
220220
prunePushedDownFilters,
221-
ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
221+
ParquetTableScan(_, relation, filters)) :: Nil
222222

223223
case _ => Nil
224224
}
@@ -243,7 +243,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
243243
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
244244
case logical.Distinct(child) =>
245245
execution.Aggregate(
246-
partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
246+
partial = false, child.output, child.output, planLater(child)) :: Nil
247247
case logical.Sort(sortExprs, child) =>
248248
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
249249
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
@@ -256,17 +256,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
256256
case logical.Filter(condition, child) =>
257257
execution.Filter(condition, planLater(child)) :: Nil
258258
case logical.Aggregate(group, agg, child) =>
259-
execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
259+
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
260260
case logical.Sample(fraction, withReplacement, seed, child) =>
261261
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
262262
case logical.LocalRelation(output, data) =>
263263
ExistingRdd(
264264
output,
265265
ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
266266
case logical.Limit(IntegerLiteral(limit), child) =>
267-
execution.Limit(limit, planLater(child))(sqlContext) :: Nil
267+
execution.Limit(limit, planLater(child)) :: Nil
268268
case Unions(unionChildren) =>
269-
execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
269+
execution.Union(unionChildren.map(planLater)) :: Nil
270270
case logical.Except(left,right) =>
271271
execution.Except(planLater(left),planLater(right)) :: Nil
272272
case logical.Intersect(left, right) =>

0 commit comments

Comments
 (0)