Skip to content

Commit 3b43b24

Browse files
committed
bug fix.
1 parent 00eb298 commit 3b43b24

File tree

5 files changed

+68
-45
lines changed

5 files changed

+68
-45
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ case class Aggregate2Sort(
3434
child: SparkPlan)
3535
extends UnaryNode {
3636

37+
override def canProcessUnsafeRows: Boolean = true
38+
3739
override def references: AttributeSet = {
3840
val referencesInResults =
3941
AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
@@ -72,6 +74,7 @@ case class Aggregate2Sort(
7274
if (aggregateExpressions.length == 0) {
7375
new GroupingIterator(
7476
groupingExpressions,
77+
resultExpressions,
7578
newMutableProjection,
7679
child.output,
7780
iter)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ private[sql] abstract class SortAggregationIterator(
242242
*/
243243
class GroupingIterator(
244244
groupingExpressions: Seq[NamedExpression],
245+
resultExpressions: Seq[NamedExpression],
245246
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
246247
inputAttributes: Seq[Attribute],
247248
inputIter: Iterator[InternalRow])
@@ -251,14 +252,18 @@ class GroupingIterator(
251252
newMutableProjection,
252253
inputAttributes,
253254
inputIter) {
255+
256+
private val resultProjection =
257+
newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
258+
254259
override protected def initialBufferOffset: Int = 0
255260

256261
override protected def processRow(row: InternalRow): Unit = {
257262
// Since we only do grouping, there is nothing to do at here.
258263
}
259264

260265
override protected def generateOutput(): InternalRow = {
261-
currentGroupingKey
266+
resultProjection(currentGroupingKey)
262267
}
263268
}
264269

@@ -521,7 +526,6 @@ class FinalSortAggregationIterator(
521526
nonAlgebraicAggregateFunctions(i).eval(buffer))
522527
i += 1
523528
}
524-
525529
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
526530
}
527531
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
2727

2828
object Utils {
2929
// Right now, we do not support complex types in the grouping key schema.
30-
private def groupingKeySchemaIsSupported(aggregate: Aggregate): Boolean = {
30+
private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
3131
val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
3232
case array: ArrayType => true
3333
case map: MapType => true
@@ -39,7 +39,7 @@ object Utils {
3939
}
4040

4141
private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
42-
case p: Aggregate if groupingKeySchemaIsSupported(p) =>
42+
case p: Aggregate if supportsGroupingKeySchema(p) =>
4343
val converted = p.transformExpressionsDown {
4444
case expressions.Average(child) =>
4545
aggregate.AggregateExpression2(
@@ -125,6 +125,33 @@ object Utils {
125125
case other => None
126126
}
127127

128+
private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
129+
// If the plan cannot be converted, we will do a final round check to if the original
130+
// logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
131+
// we need to throw an exception.
132+
val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
133+
expr.collect {
134+
case agg: AggregateExpression2 => agg.aggregateFunction
135+
}
136+
}.distinct
137+
if (aggregateFunction2s.nonEmpty) {
138+
// For functions implemented based on the new interface, prepare a list of function names.
139+
val invalidFunctions = {
140+
if (aggregateFunction2s.length > 1) {
141+
s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
142+
s"and ${aggregateFunction2s.head.nodeName} are"
143+
} else {
144+
s"${aggregateFunction2s.head.nodeName} is"
145+
}
146+
}
147+
val errorMessage =
148+
s"${invalidFunctions} implemented based on the new Aggregate Function " +
149+
s"interface and it cannot be used with functions implemented based on " +
150+
s"the old Aggregate Function interface."
151+
throw new AnalysisException(errorMessage)
152+
}
153+
}
154+
128155
def tryConvert(
129156
plan: LogicalPlan,
130157
useNewAggregation: Boolean,
@@ -134,26 +161,12 @@ object Utils {
134161
if (converted.isDefined) {
135162
converted
136163
} else {
137-
// If the plan cannot be converted, we will do a final round check to if the original
138-
// logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
139-
// we need to throw an exception.
140-
p match {
141-
case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.foreach { expr =>
142-
expr.foreach {
143-
case agg2: AggregateExpression2 =>
144-
// TODO: Make this errorMessage more user-friendly.
145-
val errorMessage =
146-
s"${agg2.aggregateFunction} is implemented based on new Aggregate Function " +
147-
s"interface and it cannot be used with old Aggregate Function implementaion."
148-
throw new AnalysisException(errorMessage)
149-
case other => // OK
150-
}
151-
}
152-
case other => // OK
153-
}
154-
164+
checkInvalidAggregateFunction2(p)
155165
None
156166
}
167+
case p: Aggregate =>
168+
checkInvalidAggregateFunction2(p)
169+
None
157170
case other => None
158171
}
159172

sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,6 @@ case class ScalaUDAF(
275275
override def toString: String = {
276276
s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
277277
}
278+
279+
override def nodeName: String = udaf.getClass.getSimpleName
278280
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -154,36 +154,36 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
154154
checkAnswer(
155155
sqlContext.sql(
156156
"""
157-
|SELECT DISTINCT key, value1
157+
|SELECT DISTINCT value1, key
158158
|FROM agg2
159159
""".stripMargin),
160-
Row(1, 10) ::
161-
Row(null, -60) ::
162-
Row(1, 30) ::
163-
Row(2, 1) ::
164-
Row(null, -10) ::
165-
Row(2, -1) ::
166-
Row(2, null) ::
167-
Row(null, 100) ::
168-
Row(3, null) ::
160+
Row(10, 1) ::
161+
Row(-60, null) ::
162+
Row(30, 1) ::
163+
Row(1, 2) ::
164+
Row(-10, null) ::
165+
Row(-1, 2) ::
166+
Row(null, 2) ::
167+
Row(100, null) ::
168+
Row(null, 3) ::
169169
Row(null, null) :: Nil)
170170

171171
checkAnswer(
172172
sqlContext.sql(
173173
"""
174-
|SELECT key, value1
174+
|SELECT value1, key
175175
|FROM agg2
176176
|GROUP BY key, value1
177177
""".stripMargin),
178-
Row(1, 10) ::
179-
Row(null, -60) ::
180-
Row(1, 30) ::
181-
Row(2, 1) ::
182-
Row(null, -10) ::
183-
Row(2, -1) ::
184-
Row(2, null) ::
185-
Row(null, 100) ::
186-
Row(3, null) ::
178+
Row(10, 1) ::
179+
Row(-60, null) ::
180+
Row(30, 1) ::
181+
Row(1, 2) ::
182+
Row(-10, null) ::
183+
Row(-1, 2) ::
184+
Row(null, 2) ::
185+
Row(100, null) ::
186+
Row(null, 3) ::
187187
Row(null, null) :: Nil)
188188
}
189189

@@ -427,12 +427,13 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
427427
|SELECT
428428
| key,
429429
| sum(value + 1.5 * key),
430-
| mydoublesum(value)
430+
| mydoublesum(value),
431+
| mydoubleavg(value)
431432
|FROM agg1
432433
|GROUP BY key
433434
""".stripMargin).collect()
434435
}.getMessage
435-
assert(errorMessage.contains("is implemented based on new Aggregate Function interface"))
436+
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
436437

437438
// TODO: once we support Hive UDAF in the new interface,
438439
// we can remove the following two tests.
@@ -448,7 +449,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
448449
|GROUP BY key
449450
""".stripMargin).collect()
450451
}.getMessage
451-
assert(errorMessage.contains("is implemented based on new Aggregate Function interface"))
452+
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
452453

453454
// This will fall back to the old aggregate
454455
val newAggregateOperators = sqlContext.sql(

0 commit comments

Comments
 (0)