Skip to content

Commit 05875ba

Browse files
committed
Skip calling encoder.shift many times in update and do not need to do shift when delta is zero.
1 parent 62b7f30 commit 05875ba

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,13 @@ case class ExpressionEncoder[T](
339339
* Returns a new encoder with input columns shifted by `delta` ordinals
340340
*/
341341
def shift(delta: Int): ExpressionEncoder[T] = {
342-
copy(deserializer = deserializer transform {
343-
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
344-
})
342+
if (delta == 0) {
343+
this
344+
} else {
345+
copy(deserializer = deserializer transform {
346+
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
347+
})
348+
}
345349
}
346350

347351
protected val attrs = serializer.flatMap(_.collect {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ case class TypedAggregateExpression(
8585
.resolve(aggBufferAttributes, OuterScopes.outerScopes)
8686
.bind(aggBufferAttributes)
8787

88+
val bEncoderForMutableAggBuffer = bEncoder.shift(mutableAggBufferOffset)
89+
val bEncoderForInputAggBuffer = bEncoder.shift(inputAggBufferOffset)
90+
8891
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
8992
// in the superclass because that will lead to initialization ordering issues.
9093
override val inputAggBufferAttributes: Seq[AttributeReference] =
@@ -118,24 +121,24 @@ case class TypedAggregateExpression(
118121

119122
override def update(buffer: MutableRow, input: InternalRow): Unit = {
120123
val inputA = boundA.fromRow(input)
121-
val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
124+
val currentB = bEncoderForMutableAggBuffer.fromRow(buffer)
122125
val merged = aggregator.reduce(currentB, inputA)
123126
val returned = bEncoder.toRow(merged)
124127

125128
updateBuffer(buffer, returned)
126129
}
127130

128131
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
129-
val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
130-
val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
132+
val b1 = bEncoderForMutableAggBuffer.fromRow(buffer1)
133+
val b2 = bEncoderForInputAggBuffer.fromRow(buffer2)
131134
val merged = aggregator.merge(b1, b2)
132135
val returned = bEncoder.toRow(merged)
133136

134137
updateBuffer(buffer1, returned)
135138
}
136139

137140
override def eval(buffer: InternalRow): Any = {
138-
val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
141+
val b = bEncoderForMutableAggBuffer.fromRow(buffer)
139142
val result = cEncoder.toRow(aggregator.finish(b))
140143
dataType match {
141144
case _: StructType => result

0 commit comments

Comments
 (0)