Skip to content

Commit c66874a

Browse files
committed
add more assertions in tests
1 parent f4fd12e commit c66874a

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ trait AliasAwareOutputOrdering extends AliasAwareOutputExpression {
6565

6666
final override def outputOrdering: Seq[SortOrder] = {
6767
if (hasAlias) {
68-
orderingExpressions.map { sortOrder =>
69-
normalizeExpression(sortOrder).asInstanceOf[SortOrder]
70-
}
68+
orderingExpressions.map(normalizeExpression(_).asInstanceOf[SortOrder])
7169
} else {
7270
orderingExpressions
7371
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
914914
""".stripMargin).queryExecution.executedPlan
915915
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
916916
assert(exchanges.size == 3)
917+
918+
val projects = planned.collect { case p: ProjectExec => p }
919+
assert(projects.exists(_.outputPartitioning match {
920+
case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _),
921+
HashPartitioning(Seq(k2: AttributeReference), _))) if k1.name == "t1id" =>
922+
true
923+
case _ => false
924+
}))
917925
}
918926
}
919927
}
@@ -959,7 +967,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
959967

960968
val projects = planned.collect { case p: ProjectExec => p }
961969
assert(projects.exists(_.outputPartitioning match {
962-
case RangePartitioning(Seq(_@SortOrder(ar: AttributeReference, _, _, _)), _) =>
970+
case RangePartitioning(Seq(SortOrder(ar: AttributeReference, _, _, _)), _) =>
963971
ar.name == "id1"
964972
case _ => false
965973
}))
@@ -970,23 +978,37 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
970978
"for partitioning and sortorder involving complex expressions") {
971979
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
972980
withTempView("t1", "t2", "t3") {
973-
spark.range(10).createTempView("t1")
974-
spark.range(20).createTempView("t2")
975-
spark.range(30).createTempView("t3")
981+
spark.range(10).select(col("id").as("id1")).createTempView("t1")
982+
spark.range(20).select(col("id").as("id2")).createTempView("t2")
983+
spark.range(30).select(col("id").as("id3")).createTempView("t3")
976984
val planned = sql(
977985
"""
978-
|SELECT t3.id as t3id
986+
|SELECT t3.id3 as t3id
979987
|FROM (
980-
| SELECT t1.id as t1id
988+
| SELECT t1.id1 as t1id, t2.id2 as t2id
981989
| FROM t1, t2
982-
| WHERE t1.id % 10 = t2.id % 10
990+
| WHERE t1.id1 * 10 = t2.id2 * 10
983991
|) t12, t3
984-
|WHERE t1id % 10 = t3.id % 10
992+
|WHERE t1id * 10 = t3.id3 * 10
985993
""".stripMargin).queryExecution.executedPlan
986994
val sortNodes = planned.collect { case s: SortExec => s }
987995
assert(sortNodes.size == 3)
988996
val exchangeNodes = planned.collect { case e: ShuffleExchangeExec => e }
989997
assert(exchangeNodes.size == 3)
998+
999+
val projects = planned.collect { case p: ProjectExec => p }
1000+
assert(projects.exists(_.outputPartitioning match {
1001+
case PartitioningCollection(Seq(HashPartitioning(Seq(Multiply(ar1, _, _)), _),
1002+
HashPartitioning(Seq(Multiply(ar2, _, _)), _))) =>
1003+
Seq(ar1, ar2) match {
1004+
case Seq(ar1: AttributeReference, ar2: AttributeReference) =>
1005+
ar1.name == "t1id" && ar2.name == "id2"
1006+
case _ =>
1007+
false
1008+
}
1009+
case _ => false
1010+
}))
1011+
9901012
}
9911013
}
9921014
}
@@ -1025,6 +1047,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
10251047
""".stripMargin).queryExecution.executedPlan
10261048
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
10271049
assert(exchanges.size == 2)
1050+
1051+
val projects = planned.collect { case p: ProjectExec => p }
1052+
assert(projects.exists(_.outputPartitioning match {
1053+
case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _),
1054+
HashPartitioning(Seq(k2: AttributeReference), _))) =>
1055+
k1.name == "t1id" && k2.name == "t2id"
1056+
case _ => false
1057+
}))
10281058
}
10291059
}
10301060
}

0 commit comments

Comments
 (0)