Skip to content

Commit a18be55

Browse files
committed
add ut
1 parent fa61cda commit a18be55

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.sql.QueryTest
21+
import org.apache.spark.sql.functions._
22+
import org.apache.spark.sql.internal.SQLConf
23+
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
24+
25+
class QueryFragmentSuite extends QueryTest with SQLTestUtils with SharedSQLContext {
26+
import testImplicits._
27+
28+
setupTestData()
29+
30+
test("adaptive optimization: transform sort merge join to broadcast join for inner join") {
31+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true",
32+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") {
33+
val numInputPartitions: Int = 2
34+
val df1 = sqlContext.range(0, 100000, 1, numInputPartitions)
35+
.selectExpr("id % 50 as key1", "id as value1")
36+
.groupBy("key1")
37+
.agg($"key1", count("value1") as "cnt1")
38+
val df2 = sqlContext.range(0, 100000, 1, numInputPartitions)
39+
.selectExpr("id % 50 as key2", "id as value2")
40+
.groupBy("key2")
41+
.agg($"key2", count("value2") as "cnt2")
42+
val join1 = df1.join(df2, col("key1") === col("key2"))
43+
.select(col("key1"), col("cnt1"), col("cnt2"))
44+
checkAnswer(join1,
45+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect())
46+
47+
val df3 = sqlContext.range(0, 100000, 1, numInputPartitions)
48+
.selectExpr("id as key3", "id as value3")
49+
.groupBy("key3")
50+
.agg($"key3", count("value3") as "cnt3")
51+
val join2 = df3.join(df1, col("key3") === col("key1"))
52+
.select(col("key1"), col("cnt1"), col("cnt3"))
53+
checkAnswer(join2,
54+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3").collect())
55+
}
56+
}
57+
58+
test("adaptive optimization: transform sort merge join to broadcast join for outer join") {
59+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true",
60+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") {
61+
val numInputPartitions: Int = 2
62+
val df1 = sqlContext.range(0, 100000, 1, numInputPartitions)
63+
.selectExpr("id % 50 as key1", "id as value1")
64+
.groupBy("key1")
65+
.agg($"key1", count("value1") as "cnt1")
66+
val df2 = sqlContext.range(0, 100000, 1, numInputPartitions)
67+
.selectExpr("id % 50 as key2", "id as value2")
68+
.groupBy("key2")
69+
.agg($"key2", count("value2") as "cnt2")
70+
val join1 = df1.join(df2, col("key1") === col("key2"), "left_outer")
71+
.select(col("key1"), col("cnt1"), col("cnt2"))
72+
checkAnswer(join1,
73+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect())
74+
75+
val join2 = df1.join(df2, col("key1") === col("key2"), "right_outer")
76+
.select(col("key1"), col("cnt1"), col("cnt2"))
77+
checkAnswer(join2,
78+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "2000 as cnt2").collect())
79+
80+
val df3 = sqlContext.range(0, 100000, 1, numInputPartitions)
81+
.selectExpr("id as key3", "id as value3")
82+
.groupBy("key3")
83+
.agg($"key3", count("value3") as "cnt3")
84+
val join3 = df3.join(df1, col("key3") === col("key1"), "left_outer")
85+
.select(col("key1"), col("cnt1"), col("cnt3"))
86+
checkAnswer(join3,
87+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3")
88+
.union(sqlContext.range(0, 99950).selectExpr("null as key", "null as cnt1", "1 as cnt3"))
89+
.collect())
90+
91+
val join4 = df3.join(df1, col("key3") === col("key1"), "right_outer")
92+
.select(col("key1"), col("cnt1"), col("cnt3"))
93+
checkAnswer(join4,
94+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1", "1 as cnt3").collect())
95+
}
96+
}
97+
98+
test("adaptive optimization: transform sort merge join to broadcast join for left semi join") {
99+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true",
100+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") {
101+
val numInputPartitions: Int = 2
102+
val df1 = sqlContext.range(0, 100000, 1, numInputPartitions)
103+
.selectExpr("id % 50 as key1", "id as value1")
104+
.groupBy("key1")
105+
.agg($"key1", count("value1") as "cnt1")
106+
val df2 = sqlContext.range(0, 100000, 1, numInputPartitions)
107+
.selectExpr("id % 50 as key2", "id as value2")
108+
.groupBy("key2")
109+
.agg($"key2", count("value2") as "cnt2")
110+
val join1 = df1.join(df2, col("key1") === col("key2"), "leftsemi")
111+
.select(col("key1"), col("cnt1"))
112+
113+
checkAnswer(join1,
114+
sqlContext.range(0, 50).selectExpr("id as key", "2000 as cnt1").collect())
115+
116+
val df3 = sqlContext.range(0, 100000, 1, numInputPartitions)
117+
.selectExpr("id as key3", "id as value3")
118+
.groupBy("key3")
119+
.agg($"key3", count("value3") as "cnt3")
120+
val join2 = df3.join(df1, col("key3") === col("key1"), "leftsemi")
121+
.select(col("key3"), col("cnt3"))
122+
123+
checkAnswer(join2,
124+
sqlContext.range(0, 50).selectExpr("id as key3", "1 as cnt3").collect())
125+
}
126+
}
127+
128+
test("adaptive optimization: transform sort merge join to broadcast join for left anti join") {
129+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true",
130+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") {
131+
val numInputPartitions: Int = 2
132+
val df1 = sqlContext.range(0, 100000, 1, numInputPartitions)
133+
.selectExpr("id % 100 as key1", "id as value1")
134+
.groupBy("key1")
135+
.agg($"key1", count("value1") as "cnt1")
136+
val df2 = sqlContext.range(0, 100000, 1, numInputPartitions)
137+
.selectExpr("id % 50 as key2", "id as value2")
138+
.groupBy("key2")
139+
.agg($"key2", count("value2") as "cnt2")
140+
val join1 = df1.join(df2, col("key1") === col("key2"), "leftanti")
141+
.select(col("key1"), col("cnt1"))
142+
checkAnswer(join1,
143+
sqlContext.range(50, 100).selectExpr("id as key", "1000 as cnt1").collect())
144+
145+
val df3 = sqlContext.range(0, 100000, 1, numInputPartitions)
146+
.selectExpr("id as key3", "id as value3")
147+
.groupBy("key3")
148+
.agg($"key3", count("value3") as "cnt3")
149+
val join2 = df3.join(df1, col("key3") === col("key1"), "leftanti")
150+
.select(col("key3"), col("cnt3"))
151+
152+
checkAnswer(join2,
153+
sqlContext.range(100, 100000).selectExpr("id as key3", "1 as cnt3").collect())
154+
}
155+
}
156+
157+
test("adaptive optimization: transform sort merge join to broadcast join for existence join") {
158+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION2_ENABLED.key -> "true",
159+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") {
160+
val numInputPartitions: Int = 2
161+
sqlContext.range(0, 100000, 1, numInputPartitions)
162+
.selectExpr("id % 50 as key1", "id as value1")
163+
.registerTempTable("testData")
164+
sqlContext.range(0, 100000, 1, numInputPartitions)
165+
.selectExpr("id % 50 as key2", "id as value2")
166+
.registerTempTable("testData2")
167+
val join1 = sqlContext.sql("select key1, cnt1 from " +
168+
"(select key1, count(value1) as cnt1 from testData group by key1) t1 " +
169+
"where key1 in (select distinct key2 from testData2)")
170+
checkAnswer(join1,
171+
sqlContext.range(0, 50).selectExpr("id as key1", "2000 as cnt1").collect())
172+
sqlContext.range(0, 100000, 1, numInputPartitions)
173+
.selectExpr("id as key3", "id as value3")
174+
.registerTempTable("testData3")
175+
val join2 = sqlContext.sql("select key3, value3 from testData3 " +
176+
"where key3 in (select distinct key2 from testData2)")
177+
checkAnswer(join2,
178+
sqlContext.range(0, 50).selectExpr("id as key3", "id as value3").collect())
179+
}
180+
}
181+
}

0 commit comments

Comments
 (0)