Skip to content

Commit ce5aba3

Browse files
hvanhovellyhuai
authored andcommitted
[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up (3)
This PR is a 2nd follow-up for [SPARK-9241](https://issues.apache.org/jira/browse/SPARK-9241). It contains the following improvements: * Fix for a potential bug in distinct child expression and attribute alignment. * Improved handling of duplicate distinct child expressions. * Added test for distinct UDAF with multiple children. cc yhuai Author: Herman van Hovell <[email protected]> Closes #9566 from hvanhovell/SPARK-9241-followup-2. (cherry picked from commit 21c562f) Signed-off-by: Yin Huai <[email protected]>
1 parent ff7d869 commit ce5aba3

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
151151
}
152152

153153
// Setup unique distinct aggregate children.
154-
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
155-
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
156-
val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
154+
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
155+
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
156+
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
157157

158158
// Setup expand & aggregate operators for distinct aggregate expressions.
159+
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
159160
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
160161
case ((group, expressions), i) =>
161162
val id = Literal(i + 1)
@@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
170171
val operators = expressions.map { e =>
171172
val af = e.aggregateFunction
172173
val naf = patchAggregateFunctionChildren(af) { x =>
173-
evalWithinGroup(id, distinctAggChildAttrMap(x))
174+
evalWithinGroup(id, distinctAggChildAttrLookup(x))
174175
}
175176
(e, e.copy(aggregateFunction = naf, isDistinct = false))
176177
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
6666
}
6767
}
6868

69+
class LongProductSum extends UserDefinedAggregateFunction {
70+
def inputSchema: StructType = new StructType()
71+
.add("a", LongType)
72+
.add("b", LongType)
73+
74+
def bufferSchema: StructType = new StructType()
75+
.add("product", LongType)
76+
77+
def dataType: DataType = LongType
78+
79+
def deterministic: Boolean = true
80+
81+
def initialize(buffer: MutableAggregationBuffer): Unit = {
82+
buffer(0) = 0L
83+
}
84+
85+
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
86+
if (!(input.isNullAt(0) || input.isNullAt(1))) {
87+
buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1)
88+
}
89+
}
90+
91+
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
92+
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
93+
}
94+
95+
def evaluate(buffer: Row): Any =
96+
buffer.getLong(0)
97+
}
98+
6999
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
70100
import testImplicits._
71101

@@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
110140
// Register UDAFs
111141
sqlContext.udf.register("mydoublesum", new MyDoubleSum)
112142
sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
143+
sqlContext.udf.register("longProductSum", new LongProductSum)
113144
}
114145

115146
override def afterAll(): Unit = {
@@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
545576
| count(distinct value2),
546577
| sum(distinct value2),
547578
| count(distinct value1, value2),
579+
| longProductSum(distinct value1, value2),
548580
| count(value1),
549581
| sum(value1),
550582
| count(value2),
551583
| sum(value2),
584+
| longProductSum(value1, value2),
552585
| count(*),
553586
| count(1)
554587
|FROM agg2
555588
|GROUP BY key
556589
""".stripMargin),
557-
Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
558-
Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
559-
Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
560-
Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
590+
Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
591+
Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
592+
Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
593+
Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
561594
}
562595

563596
test("test count") {

0 commit comments

Comments
 (0)