|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution |
19 | 19 |
|
20 | | -import org.apache.spark.sql.{DataFrame, QueryTest} |
| 20 | +import org.apache.spark.sql.{DataFrame, QueryTest, Row} |
21 | 21 | import org.apache.spark.sql.internal.SQLConf |
22 | 22 | import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} |
| 23 | +import org.apache.spark.sql.types.StructType |
23 | 24 | import org.apache.spark.util.Utils |
24 | 25 |
|
25 | 26 | class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession with SQLTestUtils { |
| 27 | + import testImplicits._ |
26 | 28 |
|
27 | | - private def assertProjectExecCount(df: DataFrame, expected: Integer): Unit = { |
| 29 | + private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = { |
28 | 30 | withClue(df.queryExecution) { |
29 | 31 | val plan = df.queryExecution.executedPlan |
30 | 32 | val actual = plan.collectWithSubqueries { case p: ProjectExec => p }.size |
31 | 33 | assert(actual == expected) |
32 | 34 | } |
33 | 35 | } |
34 | 36 |
|
35 | | - private def assertProjectExec(query: String, enabled: Integer, disabled: Integer): Unit = { |
| 37 | + private def assertProjectExec(query: String, enabled: Int, disabled: Int): Unit = { |
36 | 38 | val df = sql(query) |
37 | 39 | assertProjectExecCount(df, enabled) |
38 | 40 | val result = df.collect() |
@@ -120,9 +122,13 @@ class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession wit |
120 | 122 | } |
121 | 123 |
|
122 | 124 | test("subquery") { |
123 | | - testData |
124 | | - val query = "select key, value from testData where key in " + |
125 | | - "(select sum(a) from testView where a > 5 group by key)" |
126 | | - assertProjectExec(query, 0, 1) |
| 125 | + withTempView("testData") { |
| 126 | + val data = spark.sparkContext.parallelize((1 to 100).map(i => Row(i, i.toString))) |
| 127 | + val schema = new StructType().add("key", "int").add("value", "string") |
| 128 | + spark.createDataFrame(data, schema).createOrReplaceTempView("testData") |
| 129 | + val query = "select key, value from testData where key in " + |
| 130 | + "(select sum(a) from testView where a > 5 group by key)" |
| 131 | + assertProjectExec(query, 0, 1) |
| 132 | + } |
127 | 133 | } |
128 | 134 | } |
0 commit comments