Skip to content

Commit 1e6e666

Browse files
committed
[SPARK-7462] By default retain group by columns in aggregate
1 parent 22ab70e commit 1e6e666

File tree

5 files changed

+203
-145
lines changed

5 files changed

+203
-145
lines changed

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,15 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
158158
case expr: NamedExpression => expr
159159
case expr: Expression => Alias(expr, expr.prettyString)()
160160
}
161-
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
161+
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
162+
val retainedExprs = groupingExprs.map {
163+
case expr: NamedExpression => expr
164+
case expr: Expression => Alias(expr, expr.prettyString)()
165+
}
166+
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
167+
} else {
168+
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
169+
}
162170
}
163171

164172
/**

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ private[spark] object SQLConf {
7171
// See SPARK-6231.
7272
val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity"
7373

74+
// Whether to retain group by columns or not in GroupedData.agg.
75+
val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns"
76+
7477
val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
7578

7679
val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
@@ -233,6 +236,9 @@ private[sql] class SQLConf extends Serializable {
233236

234237
private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
235238
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean
239+
240+
private[spark] def dataFrameRetainGroupColumns: Boolean =
241+
getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean
236242

237243
/** ********************** SQLConf functionality methods ************ */
238244

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.TestData._
21+
import org.apache.spark.sql.functions._
22+
import org.apache.spark.sql.test.TestSQLContext
23+
import org.apache.spark.sql.test.TestSQLContext.implicits._
24+
import org.apache.spark.sql.types.DecimalType
25+
26+
27+
class DataFrameAggregateSuite extends QueryTest {
28+
29+
test("groupBy") {
30+
checkAnswer(
31+
testData2.groupBy("a").agg(sum($"b")),
32+
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
33+
)
34+
checkAnswer(
35+
testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)),
36+
Row(9)
37+
)
38+
checkAnswer(
39+
testData2.groupBy("a").agg(count("*")),
40+
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
41+
)
42+
checkAnswer(
43+
testData2.groupBy("a").agg(Map("*" -> "count")),
44+
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
45+
)
46+
checkAnswer(
47+
testData2.groupBy("a").agg(Map("b" -> "sum")),
48+
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
49+
)
50+
51+
val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
52+
.toDF("key", "value1", "value2", "rest")
53+
54+
checkAnswer(
55+
df1.groupBy("key").min(),
56+
df1.groupBy("key").min("value1", "value2").collect()
57+
)
58+
checkAnswer(
59+
df1.groupBy("key").min("value2"),
60+
Seq(Row("a", 0), Row("b", 4))
61+
)
62+
}
63+
64+
test("spark.sql.retainGroupColumns config") {
65+
checkAnswer(
66+
testData2.groupBy("a").agg(sum($"b")),
67+
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
68+
)
69+
70+
TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false")
71+
checkAnswer(
72+
testData2.groupBy("a").agg(sum($"b")),
73+
Seq(Row(3), Row(3), Row(3))
74+
)
75+
TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true")
76+
}
77+
78+
test("agg without groups") {
79+
checkAnswer(
80+
testData2.agg(sum('b)),
81+
Row(9)
82+
)
83+
}
84+
85+
test("average") {
86+
checkAnswer(
87+
testData2.agg(avg('a)),
88+
Row(2.0))
89+
90+
checkAnswer(
91+
testData2.agg(avg('a), sumDistinct('a)), // non-partial
92+
Row(2.0, 6.0) :: Nil)
93+
94+
checkAnswer(
95+
decimalData.agg(avg('a)),
96+
Row(new java.math.BigDecimal(2.0)))
97+
checkAnswer(
98+
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
99+
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
100+
101+
checkAnswer(
102+
decimalData.agg(avg('a cast DecimalType(10, 2))),
103+
Row(new java.math.BigDecimal(2.0)))
104+
// non-partial
105+
checkAnswer(
106+
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
107+
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
108+
}
109+
110+
test("null average") {
111+
checkAnswer(
112+
testData3.agg(avg('b)),
113+
Row(2.0))
114+
115+
checkAnswer(
116+
testData3.agg(avg('b), countDistinct('b)),
117+
Row(2.0, 1))
118+
119+
checkAnswer(
120+
testData3.agg(avg('b), sumDistinct('b)), // non-partial
121+
Row(2.0, 2.0))
122+
}
123+
124+
test("zero average") {
125+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
126+
checkAnswer(
127+
emptyTableData.agg(avg('a)),
128+
Row(null))
129+
130+
checkAnswer(
131+
emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
132+
Row(null, null))
133+
}
134+
135+
test("count") {
136+
assert(testData2.count() === testData2.map(_ => 1).count())
137+
138+
checkAnswer(
139+
testData2.agg(count('a), sumDistinct('a)), // non-partial
140+
Row(6, 6.0))
141+
}
142+
143+
test("null count") {
144+
checkAnswer(
145+
testData3.groupBy('a).agg(count('b)),
146+
Seq(Row(1,0), Row(2, 1))
147+
)
148+
149+
checkAnswer(
150+
testData3.groupBy('a).agg(count('a + 'b)),
151+
Seq(Row(1,0), Row(2, 1))
152+
)
153+
154+
checkAnswer(
155+
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
156+
Row(2, 1, 2, 2, 1)
157+
)
158+
159+
checkAnswer(
160+
testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
161+
Row(1, 1, 2)
162+
)
163+
}
164+
165+
test("zero count") {
166+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
167+
assert(emptyTableData.count() === 0)
168+
169+
checkAnswer(
170+
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
171+
Row(0, null))
172+
}
173+
174+
test("zero sum") {
175+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
176+
checkAnswer(
177+
emptyTableData.agg(sum('a)),
178+
Row(null))
179+
}
180+
181+
test("zero sum distinct") {
182+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
183+
checkAnswer(
184+
emptyTableData.agg(sumDistinct('a)),
185+
Row(null))
186+
}
187+
188+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 0 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import scala.language.postfixOps
2222
import org.apache.spark.sql.functions._
2323
import org.apache.spark.sql.types._
2424
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
25-
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
2625
import org.apache.spark.sql.test.TestSQLContext.implicits._
2726

2827

@@ -165,48 +164,6 @@ class DataFrameSuite extends QueryTest {
165164
testData.select('key).collect().toSeq)
166165
}
167166

168-
test("groupBy") {
169-
checkAnswer(
170-
testData2.groupBy("a").agg($"a", sum($"b")),
171-
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
172-
)
173-
checkAnswer(
174-
testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
175-
Row(9)
176-
)
177-
checkAnswer(
178-
testData2.groupBy("a").agg(col("a"), count("*")),
179-
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
180-
)
181-
checkAnswer(
182-
testData2.groupBy("a").agg(Map("*" -> "count")),
183-
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
184-
)
185-
checkAnswer(
186-
testData2.groupBy("a").agg(Map("b" -> "sum")),
187-
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
188-
)
189-
190-
val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
191-
.toDF("key", "value1", "value2", "rest")
192-
193-
checkAnswer(
194-
df1.groupBy("key").min(),
195-
df1.groupBy("key").min("value1", "value2").collect()
196-
)
197-
checkAnswer(
198-
df1.groupBy("key").min("value2"),
199-
Seq(Row("a", 0), Row("b", 4))
200-
)
201-
}
202-
203-
test("agg without groups") {
204-
checkAnswer(
205-
testData2.agg(sum('b)),
206-
Row(9)
207-
)
208-
}
209-
210167
test("convert $\"attribute name\" into unresolved attribute") {
211168
checkAnswer(
212169
testData.where($"key" === lit(1)).select($"value"),
@@ -303,105 +260,6 @@ class DataFrameSuite extends QueryTest {
303260
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
304261
}
305262

306-
test("average") {
307-
checkAnswer(
308-
testData2.agg(avg('a)),
309-
Row(2.0))
310-
311-
checkAnswer(
312-
testData2.agg(avg('a), sumDistinct('a)), // non-partial
313-
Row(2.0, 6.0) :: Nil)
314-
315-
checkAnswer(
316-
decimalData.agg(avg('a)),
317-
Row(new java.math.BigDecimal(2.0)))
318-
checkAnswer(
319-
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
320-
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
321-
322-
checkAnswer(
323-
decimalData.agg(avg('a cast DecimalType(10, 2))),
324-
Row(new java.math.BigDecimal(2.0)))
325-
// non-partial
326-
checkAnswer(
327-
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
328-
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
329-
}
330-
331-
test("null average") {
332-
checkAnswer(
333-
testData3.agg(avg('b)),
334-
Row(2.0))
335-
336-
checkAnswer(
337-
testData3.agg(avg('b), countDistinct('b)),
338-
Row(2.0, 1))
339-
340-
checkAnswer(
341-
testData3.agg(avg('b), sumDistinct('b)), // non-partial
342-
Row(2.0, 2.0))
343-
}
344-
345-
test("zero average") {
346-
checkAnswer(
347-
emptyTableData.agg(avg('a)),
348-
Row(null))
349-
350-
checkAnswer(
351-
emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
352-
Row(null, null))
353-
}
354-
355-
test("count") {
356-
assert(testData2.count() === testData2.map(_ => 1).count())
357-
358-
checkAnswer(
359-
testData2.agg(count('a), sumDistinct('a)), // non-partial
360-
Row(6, 6.0))
361-
}
362-
363-
test("null count") {
364-
checkAnswer(
365-
testData3.groupBy('a).agg('a, count('b)),
366-
Seq(Row(1,0), Row(2, 1))
367-
)
368-
369-
checkAnswer(
370-
testData3.groupBy('a).agg('a, count('a + 'b)),
371-
Seq(Row(1,0), Row(2, 1))
372-
)
373-
374-
checkAnswer(
375-
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
376-
Row(2, 1, 2, 2, 1)
377-
)
378-
379-
checkAnswer(
380-
testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
381-
Row(1, 1, 2)
382-
)
383-
}
384-
385-
test("zero count") {
386-
assert(emptyTableData.count() === 0)
387-
388-
checkAnswer(
389-
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
390-
Row(0, null))
391-
}
392-
393-
test("zero sum") {
394-
checkAnswer(
395-
emptyTableData.agg(sum('a)),
396-
Row(null))
397-
}
398-
399-
test("zero sum distinct") {
400-
checkAnswer(
401-
emptyTableData.agg(sumDistinct('a)),
402-
Row(null))
403-
}
404-
405263
test("except") {
406264
checkAnswer(
407265
lowerCaseData.except(upperCaseData),

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ object TestData {
8686
TestData3(2, Some(2)) :: Nil).toDF()
8787
testData3.registerTempTable("testData3")
8888

89-
val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
90-
9189
case class UpperCaseData(N: Int, L: String)
9290
val upperCaseData =
9391
TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)