|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution |
19 | 19 |
|
20 | | -import org.scalatest.Matchers |
21 | | - |
22 | | -import org.apache.spark.SparkFunSuite |
| 20 | +import org.apache.spark.sql.test.TestSQLContext |
| 21 | +import org.scalatest.BeforeAndAfterAll |
23 | 22 |
|
24 | | -import org.apache.spark.sql.{SQLConf, SQLContext, Row} |
| 23 | +import org.apache.spark.sql.{SQLConf, Row} |
25 | 24 | import org.apache.spark.sql.catalyst.CatalystTypeConverters |
26 | 25 | import org.apache.spark.sql.catalyst.expressions._ |
27 | 26 | import org.apache.spark.sql.types._ |
28 | | -import org.apache.spark.sql.test.TestSQLContext |
29 | 27 |
|
30 | | -class UnsafeExternalSortSuite extends SparkFunSuite with Matchers { |
| 28 | +import scala.util.Random |
| 29 | + |
| 30 | +class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { |
| 31 | + |
| 32 | + override def beforeAll(): Unit = { |
| 33 | + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) |
| 34 | + } |
| 35 | + |
| 36 | + override def afterAll(): Unit = { |
| 37 | + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) |
| 38 | + } |
31 | 39 |
|
32 | 40 | private def createRow(values: Any*): Row = { |
33 | 41 | new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) |
34 | 42 | } |
35 | 43 |
|
36 | 44 | test("basic sorting") { |
37 | | - val sc = TestSQLContext.sparkContext |
38 | | - val sqlContext = new SQLContext(sc) |
39 | | - sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true") |
40 | 45 |
|
41 | | - val schema: StructType = StructType( |
42 | | - StructField("word", StringType, nullable = false) :: |
43 | | - StructField("number", IntegerType, nullable = false) :: Nil) |
| 46 | + val inputData = Seq( |
| 47 | + ("Hello", 9), |
| 48 | + ("World", 4), |
| 49 | + ("Hello", 7), |
| 50 | + ("Skinny", 0), |
| 51 | + ("Constantinople", 9) |
| 52 | + ) |
| 53 | + |
44 | 54 | val sortOrder: Seq[SortOrder] = Seq( |
45 | 55 | SortOrder(BoundReference(0, StringType, nullable = false), Ascending), |
46 | 56 | SortOrder(BoundReference(1, IntegerType, nullable = false), Descending)) |
47 | | - val rowsToSort: Seq[Row] = Seq( |
48 | | - createRow("Hello", 9), |
49 | | - createRow("World", 4), |
50 | | - createRow("Hello", 7), |
51 | | - createRow("Skinny", 0), |
52 | | - createRow("Constantinople", 9)) |
53 | | - SparkPlan.currentContext.set(sqlContext) |
54 | | - val input = |
55 | | - new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1)) |
56 | | - // Treat the existing sort operators as the source-of-truth for this test |
57 | | - val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect() |
58 | | - val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect() |
59 | | - val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect() |
60 | | - assert (defaultSorted === externalSorted) |
61 | | - assert (unsafeSorted === externalSorted) |
| 57 | + |
| 58 | + checkAnswer( |
| 59 | + Random.shuffle(inputData), |
| 60 | + (input: SparkPlan) => new UnsafeExternalSort(sortOrder, global = false, input), |
| 61 | + inputData |
| 62 | + ) |
62 | 63 | } |
63 | 64 | } |
0 commit comments