@@ -914,6 +914,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
914914 """ .stripMargin).queryExecution.executedPlan
915915 val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
916916 assert(exchanges.size == 3 )
917+
918+ val projects = planned.collect { case p : ProjectExec => p }
919+ assert(projects.exists(_.outputPartitioning match {
920+ case PartitioningCollection (Seq (HashPartitioning (Seq (k1 : AttributeReference ), _),
921+ HashPartitioning (Seq (k2 : AttributeReference ), _))) if k1.name == " t1id" =>
922+ true
923+ case _ => false
924+ }))
917925 }
918926 }
919927 }
@@ -959,7 +967,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
959967
960968 val projects = planned.collect { case p : ProjectExec => p }
961969 assert(projects.exists(_.outputPartitioning match {
962- case RangePartitioning (Seq (_@ SortOrder (ar : AttributeReference , _, _, _)), _) =>
970+ case RangePartitioning (Seq (SortOrder (ar : AttributeReference , _, _, _)), _) =>
963971 ar.name == " id1"
964972 case _ => false
965973 }))
@@ -970,23 +978,37 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
970978 " for partitioning and sortorder involving complex expressions" ) {
971979 withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
972980 withTempView(" t1" , " t2" , " t3" ) {
973- spark.range(10 ).createTempView(" t1" )
974- spark.range(20 ).createTempView(" t2" )
975- spark.range(30 ).createTempView(" t3" )
981+ spark.range(10 ).select(col( " id " ).as( " id1 " )). createTempView(" t1" )
982+ spark.range(20 ).select(col( " id " ).as( " id2 " )). createTempView(" t2" )
983+ spark.range(30 ).select(col( " id " ).as( " id3 " )). createTempView(" t3" )
976984 val planned = sql(
977985 """
978- |SELECT t3.id as t3id
986+ |SELECT t3.id3 as t3id
979987 |FROM (
980- | SELECT t1.id as t1id
988+ | SELECT t1.id1 as t1id, t2.id2 as t2id
981989 | FROM t1, t2
982- | WHERE t1.id % 10 = t2.id % 10
990+ | WHERE t1.id1 * 10 = t2.id2 * 10
983991 |) t12, t3
984- |WHERE t1id % 10 = t3.id % 10
992+ |WHERE t1id * 10 = t3.id3 * 10
985993 """ .stripMargin).queryExecution.executedPlan
986994 val sortNodes = planned.collect { case s : SortExec => s }
987995 assert(sortNodes.size == 3 )
988996 val exchangeNodes = planned.collect { case e : ShuffleExchangeExec => e }
989997 assert(exchangeNodes.size == 3 )
998+
999+ val projects = planned.collect { case p : ProjectExec => p }
1000+ assert(projects.exists(_.outputPartitioning match {
1001+ case PartitioningCollection (Seq (HashPartitioning (Seq (Multiply (ar1, _, _)), _),
1002+ HashPartitioning (Seq (Multiply (ar2, _, _)), _))) =>
1003+ Seq (ar1, ar2) match {
1004+ case Seq (ar1 : AttributeReference , ar2 : AttributeReference ) =>
1005+ ar1.name == " t1id" && ar2.name == " id2"
1006+ case _ =>
1007+ false
1008+ }
1009+ case _ => false
1010+ }))
1011+
9901012 }
9911013 }
9921014 }
@@ -1025,6 +1047,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
10251047 """ .stripMargin).queryExecution.executedPlan
10261048 val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
10271049 assert(exchanges.size == 2 )
1050+
1051+ val projects = planned.collect { case p : ProjectExec => p }
1052+ assert(projects.exists(_.outputPartitioning match {
1053+ case PartitioningCollection (Seq (HashPartitioning (Seq (k1 : AttributeReference ), _),
1054+ HashPartitioning (Seq (k2 : AttributeReference ), _))) =>
1055+ k1.name == " t1id" && k2.name == " t2id"
1056+ case _ => false
1057+ }))
10281058 }
10291059 }
10301060 }
0 commit comments