Skip to content

Commit fbd6787

Browse files
committed
Handle cases missed before
1 parent 777c5b4 commit fbd6787

File tree

6 files changed

+28
-30
lines changed

6 files changed

+28
-30
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ case class ExpandExec(
168168
// Part 2: switch/case statements
169169
val cases = projections.zipWithIndex.map { case (exprs, row) =>
170170
var updateCode = ""
171+
val attributeSeq: AttributeSeq = child.output
171172
for (col <- exprs.indices) {
172173
if (!sameOutput(col)) {
173-
val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
174+
val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
174175
updateCode +=
175176
s"""
176177
|${ev.code}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ abstract class AggregationIterator(
7777
val expressionsLength = expressions.length
7878
val functions = new Array[AggregateFunction](expressionsLength)
7979
var i = 0
80+
val inputAttributeSeq: AttributeSeq = inputAttributes
8081
while (i < expressionsLength) {
8182
val func = expressions(i).aggregateFunction
8283
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
@@ -86,7 +87,7 @@ abstract class AggregationIterator(
8687
// this function is Partial or Complete because we will call eval of this
8788
// function's children in the update method of this aggregate function.
8889
// Those eval calls require BoundReferences to work.
89-
BindReferences.bindReference(func, inputAttributes)
90+
BindReferences.bindReference(func, inputAttributeSeq)
9091
case _ =>
9192
// We only need to set inputBufferOffset for aggregate functions with mode
9293
// PartialMerge and Final.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,12 @@ case class HashAggregateExec(
200200
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
201201
// evaluate aggregate results
202202
ctx.currentVars = bufVars
203-
val aggResults = functions.map(_.evaluateExpression).map { e =>
204-
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
205-
}
203+
val aggResults = bindReferences(functions.map(_.evaluateExpression),
204+
aggregateBufferAttributes).map(_.genCode(ctx))
206205
val evaluateAggResults = evaluateVariables(aggResults)
207206
// evaluate result expressions
208207
ctx.currentVars = aggResults
209-
val resultVars = resultExpressions.map { e =>
210-
BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
211-
}
208+
val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
212209
(resultVars, s"""
213210
|$evaluateAggResults
214211
|${evaluateVariables(resultVars)}
@@ -457,16 +454,14 @@ case class HashAggregateExec(
457454
val evaluateBufferVars = evaluateVariables(bufferVars)
458455
// evaluate the aggregation result
459456
ctx.currentVars = bufferVars
460-
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
461-
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
462-
}
457+
val aggResults = bindReferences(declFunctions.map(_.evaluateExpression),
458+
aggregateBufferAttributes).map(_.genCode(ctx))
463459
val evaluateAggResults = evaluateVariables(aggResults)
464460
// generate the final result
465461
ctx.currentVars = keyVars ++ aggResults
466462
val inputAttrs = groupingAttributes ++ aggregateAttributes
467-
val resultVars = resultExpressions.map { e =>
468-
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
469-
}
463+
val resultVars = bindReferences[Expression](resultExpressions,
464+
inputAttrs).map(_.genCode(ctx))
470465
s"""
471466
$evaluateKeyVars
472467
$evaluateBufferVars
@@ -495,9 +490,8 @@ case class HashAggregateExec(
495490

496491
ctx.currentVars = keyVars ++ resultBufferVars
497492
val inputAttrs = resultExpressions.map(_.toAttribute)
498-
val resultVars = resultExpressions.map { e =>
499-
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
500-
}
493+
val resultVars = bindReferences[Expression](resultExpressions,
494+
inputAttrs).map(_.genCode(ctx))
501495
s"""
502496
$evaluateKeyVars
503497
$evaluateResultBufferVars
@@ -507,9 +501,8 @@ case class HashAggregateExec(
507501
// generate result based on grouping key
508502
ctx.INPUT_ROW = keyTerm
509503
ctx.currentVars = null
510-
val eval = resultExpressions.map{ e =>
511-
BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
512-
}
504+
val eval = bindReferences[Expression](resultExpressions,
505+
groupingAttributes).map(_.genCode(ctx))
513506
consume(ctx, eval)
514507
}
515508
ctx.addNewFunction(funcName,
@@ -731,9 +724,9 @@ case class HashAggregateExec(
731724
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
732725
// create grouping key
733726
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
734-
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
727+
ctx, bindReferences[Expression](groupingExpressions, child.output))
735728
val fastRowKeys = ctx.generateExpressions(
736-
groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
729+
bindReferences[Expression](groupingExpressions, child.output))
737730
val unsafeRowKeys = unsafeRowKeyCode.value
738731
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
739732
val fastRowBuffer = ctx.freshName("fastAggBuffer")

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
2424
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2728
import org.apache.spark.sql.catalyst.expressions.codegen._
2829
import org.apache.spark.sql.catalyst.plans.physical._
2930
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
5657
}
5758

5859
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
59-
val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
60+
val exprs = bindReferences[Expression](projectList, child.output)
6061
val resultVars = exprs.map(_.genCode(ctx))
6162
// Evaluation of non-deterministic expressions can't be deferred.
6263
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3434
import org.apache.spark.sql.catalyst.catalog.BucketSpec
3535
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3636
import org.apache.spark.sql.catalyst.expressions._
37+
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
3738
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
3839
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
3940
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
@@ -145,9 +146,9 @@ object FileFormatWriter extends Logging {
145146
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
146147
// the physical plan may have different attribute ids due to optimizer removing some
147148
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
148-
val orderingExpr = requiredOrdering
149-
.map(SortOrder(_, Ascending))
150-
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
149+
val orderingExpr = bindReferences(
150+
requiredOrdering.map(SortOrder(_, Ascending)),
151+
outputSpec.outputColumns)
151152
SortExec(
152153
orderingExpr,
153154
global = false,

sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2425
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
2526
import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
2627

@@ -89,9 +90,9 @@ private[window] final class OffsetWindowFunctionFrame(
8990
private[this] val projection = {
9091
// Collect the expressions and bind them.
9192
val inputAttrs = inputSchema.map(_.withNullability(true))
92-
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
93-
BindReferences.bindReference(e.input, inputAttrs)
94-
}
93+
val boundExpressions = bindReferences(
94+
Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map(_.input),
95+
inputAttrs)
9596

9697
// Create the projection.
9798
newMutableProjection(boundExpressions, Nil).target(target)
@@ -100,7 +101,7 @@ private[window] final class OffsetWindowFunctionFrame(
100101
/** Create the projection used when the offset row DOES NOT exists. */
101102
private[this] val fillDefaultValue = {
102103
// Collect the expressions and bind them.
103-
val inputAttrs = inputSchema.map(_.withNullability(true))
104+
val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
104105
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
105106
if (e.default == null || e.default.foldable && e.default.eval() == null) {
106107
// The default value is null.

0 commit comments

Comments
 (0)