Skip to content

Commit 3121e78

Browse files
yhuairxin
authored andcommitted
[SPARK-9830][SPARK-11641][SQL][FOLLOW-UP] Remove AggregateExpression1 and update toString of Exchange
https://issues.apache.org/jira/browse/SPARK-9830 This is the follow-up pr for #9556 to address davies' comments. Author: Yin Huai <[email protected]> Closes #9607 from yhuai/removeAgg1-followup.
1 parent e281b87 commit 3121e78

File tree

10 files changed

+160
-54
lines changed

10 files changed

+160
-54
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ class Analyzer(
532532
case min: Min if isDistinct =>
533533
AggregateExpression(min, Complete, isDistinct = false)
534534
// We get an aggregate function, we need to wrap it in an AggregateExpression.
535-
case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct)
535+
case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
536536
// This function is not an aggregate function, just return the resolved one.
537537
case other => other
538538
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,21 @@ trait CheckAnalysis {
110110
case Aggregate(groupingExprs, aggregateExprs, child) =>
111111
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
112112
case aggExpr: AggregateExpression =>
113-
// TODO: Is it possible that the child of a agg function is another
114-
// agg function?
115-
aggExpr.aggregateFunction.children.foreach {
116-
// This is just a sanity check, our analysis rule PullOutNondeterministic should
117-
// already pull out those nondeterministic expressions and evaluate them in
118-
// a Project node.
119-
case child if !child.deterministic =>
113+
aggExpr.aggregateFunction.children.foreach { child =>
114+
child.foreach {
115+
case agg: AggregateExpression =>
116+
failAnalysis(
117+
s"It is not allowed to use an aggregate function in the argument of " +
118+
s"another aggregate function. Please use the inner aggregate function " +
119+
s"in a sub-query.")
120+
case other => // OK
121+
}
122+
123+
if (!child.deterministic) {
120124
failAnalysis(
121125
s"nondeterministic expression ${expr.prettyString} should not " +
122126
s"appear in the arguments of an aggregate function.")
123-
case child => // OK
127+
}
124128
}
125129
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
126130
failAnalysis(
@@ -133,19 +137,33 @@ trait CheckAnalysis {
133137
case e => e.children.foreach(checkValidAggregateExpression)
134138
}
135139

140+
def checkSupportedGroupingDataType(
141+
expressionString: String,
142+
dataType: DataType): Unit = dataType match {
143+
case BinaryType =>
144+
failAnalysis(s"expression $expressionString cannot be used in " +
145+
s"grouping expression because it is in binary type or its inner field is " +
146+
s"in binary type")
147+
case a: ArrayType =>
148+
failAnalysis(s"expression $expressionString cannot be used in " +
149+
s"grouping expression because it is in array type or its inner field is " +
150+
s"in array type")
151+
case m: MapType =>
152+
failAnalysis(s"expression $expressionString cannot be used in " +
153+
s"grouping expression because it is in map type or its inner field is " +
154+
s"in map type")
155+
case s: StructType =>
156+
s.fields.foreach { f =>
157+
checkSupportedGroupingDataType(expressionString, f.dataType)
158+
}
159+
case udt: UserDefinedType[_] =>
160+
checkSupportedGroupingDataType(expressionString, udt.sqlType)
161+
case _ => // OK
162+
}
163+
136164
def checkValidGroupingExprs(expr: Expression): Unit = {
137-
expr.dataType match {
138-
case BinaryType =>
139-
failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
140-
"in grouping expression")
141-
case a: ArrayType =>
142-
failAnalysis(s"array type expression ${expr.prettyString} cannot be used " +
143-
"in grouping expression")
144-
case m: MapType =>
145-
failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
146-
"in grouping expression")
147-
case _ => // OK
148-
}
165+
checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
166+
149167
if (!expr.deterministic) {
150168
// This is just a sanity check, our analysis rule PullOutNondeterministic should
151169
// already pull out those nondeterministic expressions and evaluate them in

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ case class Average(child: Expression) extends DeclarativeAggregate {
3434
// Return data type.
3535
override def dataType: DataType = resultType
3636

37-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
37+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
3838

3939
override def checkInputDataTypes(): TypeCheckResult =
4040
TypeUtils.checkForNumericExpr(child.dataType, "function average")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
5757

5858
override def dataType: DataType = DoubleType
5959

60-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
60+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
6161

6262
override def checkInputDataTypes(): TypeCheckResult =
6363
TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
5050

5151
override def dataType: DataType = resultType
5252

53-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
53+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
5454

5555
override def checkInputDataTypes(): TypeCheckResult =
5656
TypeUtils.checkForNumericExpr(child.dataType, "function stddev")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
3232
override def dataType: DataType = resultType
3333

3434
override def inputTypes: Seq[AbstractDataType] =
35-
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
35+
Seq(TypeCollection(LongType, DoubleType, DecimalType))
3636

3737
override def checkInputDataTypes(): TypeCheckResult =
3838
TypeUtils.checkForNumericExpr(child.dataType, "function sum")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 102 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,59 @@ import org.apache.spark.sql.catalyst.plans.logical._
2323
import org.apache.spark.sql.catalyst.plans.Inner
2424
import org.apache.spark.sql.catalyst.dsl.expressions._
2525
import org.apache.spark.sql.catalyst.dsl.plans._
26+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
2627
import org.apache.spark.sql.types._
2728

29+
import scala.beans.{BeanProperty, BeanInfo}
30+
31+
@BeanInfo
32+
private[sql] case class GroupableData(@BeanProperty data: Int)
33+
34+
private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
35+
36+
override def sqlType: DataType = IntegerType
37+
38+
override def serialize(obj: Any): Int = {
39+
obj match {
40+
case groupableData: GroupableData => groupableData.data
41+
}
42+
}
43+
44+
override def deserialize(datum: Any): GroupableData = {
45+
datum match {
46+
case data: Int => GroupableData(data)
47+
}
48+
}
49+
50+
override def userClass: Class[GroupableData] = classOf[GroupableData]
51+
52+
private[spark] override def asNullable: GroupableUDT = this
53+
}
54+
55+
@BeanInfo
56+
private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
57+
58+
private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
59+
60+
override def sqlType: DataType = ArrayType(IntegerType)
61+
62+
override def serialize(obj: Any): ArrayData = {
63+
obj match {
64+
case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
65+
}
66+
}
67+
68+
override def deserialize(datum: Any): UngroupableData = {
69+
datum match {
70+
case data: Array[Int] => UngroupableData(data)
71+
}
72+
}
73+
74+
override def userClass: Class[UngroupableData] = classOf[UngroupableData]
75+
76+
private[spark] override def asNullable: UngroupableUDT = this
77+
}
78+
2879
case class TestFunction(
2980
children: Seq[Expression],
3081
inputTypes: Seq[AbstractDataType])
@@ -194,39 +245,65 @@ class AnalysisErrorSuite extends AnalysisTest {
194245
assert(error.message.contains("Conflicting attributes"))
195246
}
196247

197-
test("aggregation can't work on binary and map types") {
198-
val plan =
199-
Aggregate(
200-
AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
201-
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
202-
LocalRelation(
203-
AttributeReference("a", BinaryType)(exprId = ExprId(2)),
204-
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
248+
test("check grouping expression data types") {
249+
def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = {
250+
val plan =
251+
Aggregate(
252+
AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil,
253+
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
254+
LocalRelation(
255+
AttributeReference("a", dataType)(exprId = ExprId(2)),
256+
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
257+
258+
shouldSuccess match {
259+
case true =>
260+
assertAnalysisSuccess(plan, true)
261+
case false =>
262+
assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
263+
}
205264

206-
assertAnalysisError(plan,
207-
"binary type expression a cannot be used in grouping expression" :: Nil)
265+
}
208266

209-
val plan2 =
210-
Aggregate(
211-
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
212-
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
213-
LocalRelation(
214-
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
215-
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
267+
val supportedDataTypes = Seq(
268+
StringType,
269+
NullType, BooleanType,
270+
ByteType, ShortType, IntegerType, LongType,
271+
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
272+
DateType, TimestampType,
273+
new StructType()
274+
.add("f1", FloatType, nullable = true)
275+
.add("f2", StringType, nullable = true),
276+
new GroupableUDT())
277+
supportedDataTypes.foreach { dataType =>
278+
checkDataType(dataType, shouldSuccess = true)
279+
}
216280

217-
assertAnalysisError(plan2,
218-
"map type expression a cannot be used in grouping expression" :: Nil)
281+
val unsupportedDataTypes = Seq(
282+
BinaryType,
283+
ArrayType(IntegerType),
284+
MapType(StringType, LongType),
285+
new StructType()
286+
.add("f1", FloatType, nullable = true)
287+
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
288+
new UngroupableUDT())
289+
unsupportedDataTypes.foreach { dataType =>
290+
checkDataType(dataType, shouldSuccess = false)
291+
}
292+
}
219293

220-
val plan3 =
294+
test("we should fail analysis when we find nested aggregate functions") {
295+
val plan =
221296
Aggregate(
222-
AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil,
223-
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
297+
AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil,
298+
Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil,
224299
LocalRelation(
225-
AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
300+
AttributeReference("a", IntegerType)(exprId = ExprId(2)),
226301
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
227302

228-
assertAnalysisError(plan3,
229-
"array type expression a cannot be used in grouping expression" :: Nil)
303+
assertAnalysisError(
304+
plan,
305+
"It is not allowed to use an aggregate function in the argument of " +
306+
"another aggregate function." :: Nil)
230307
}
231308

232309
test("Join can't work on binary and map types") {

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ private[spark] object SQLConf {
474474
object Deprecated {
475475
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
476476
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
477+
val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2"
477478
}
478479
}
479480

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ case class Exchange(
4444
override def nodeName: String = {
4545
val extraInfo = coordinator match {
4646
case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated =>
47-
"Shuffle"
47+
s"(coordinator id: ${System.identityHashCode(coordinator)})"
4848
case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated =>
49-
"May shuffle"
50-
case None => "Shuffle without coordinator"
49+
s"(coordinator id: ${System.identityHashCode(coordinator)})"
50+
case None => ""
5151
}
5252

5353
val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange"
54-
s"$simpleNodeName($extraInfo)"
54+
s"${simpleNodeName}${extraInfo}"
5555
}
5656

5757
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
111111
}
112112
(keyValueOutput, runFunc)
113113

114+
case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) =>
115+
val runFunc = (sqlContext: SQLContext) => {
116+
logWarning(
117+
s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " +
118+
s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " +
119+
s"continue to be true.")
120+
Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true"))
121+
}
122+
(keyValueOutput, runFunc)
123+
114124
// Configures a single property.
115125
case Some((key, Some(value))) =>
116126
val runFunc = (sqlContext: SQLContext) => {

0 commit comments

Comments
 (0)