Skip to content

Commit 7eafecf

Browse files
committed
Port test to SparkPlanTest
1 parent d468a88 commit 7eafecf

File tree

3 files changed

+29
-80
lines changed

3 files changed

+29
-80
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ case class UnsafeExternalSort(
269269
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
270270

271271
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
272-
assert (codegenEnabled)
272+
assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
273273
def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
274274
val ordering = newOrdering(sortOrder, child.output)
275275
val prefixComparator = new PrefixComparator {

sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala

Lines changed: 0 additions & 52 deletions
This file was deleted.

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,48 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.scalatest.Matchers
21-
22-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.sql.test.TestSQLContext
21+
import org.scalatest.BeforeAndAfterAll
2322

24-
import org.apache.spark.sql.{SQLConf, SQLContext, Row}
23+
import org.apache.spark.sql.{SQLConf, Row}
2524
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2625
import org.apache.spark.sql.catalyst.expressions._
2726
import org.apache.spark.sql.types._
28-
import org.apache.spark.sql.test.TestSQLContext
2927

30-
class UnsafeExternalSortSuite extends SparkFunSuite with Matchers {
28+
import scala.util.Random
29+
30+
class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
31+
32+
override def beforeAll(): Unit = {
33+
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
34+
}
35+
36+
override def afterAll(): Unit = {
37+
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
38+
}
3139

3240
private def createRow(values: Any*): Row = {
3341
new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
3442
}
3543

3644
test("basic sorting") {
37-
val sc = TestSQLContext.sparkContext
38-
val sqlContext = new SQLContext(sc)
39-
sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true")
4045

41-
val schema: StructType = StructType(
42-
StructField("word", StringType, nullable = false) ::
43-
StructField("number", IntegerType, nullable = false) :: Nil)
46+
val inputData = Seq(
47+
("Hello", 9),
48+
("World", 4),
49+
("Hello", 7),
50+
("Skinny", 0),
51+
("Constantinople", 9)
52+
)
53+
4454
val sortOrder: Seq[SortOrder] = Seq(
4555
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
4656
SortOrder(BoundReference(1, IntegerType, nullable = false), Descending))
47-
val rowsToSort: Seq[Row] = Seq(
48-
createRow("Hello", 9),
49-
createRow("World", 4),
50-
createRow("Hello", 7),
51-
createRow("Skinny", 0),
52-
createRow("Constantinople", 9))
53-
SparkPlan.currentContext.set(sqlContext)
54-
val input =
55-
new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1))
56-
// Treat the existing sort operators as the source-of-truth for this test
57-
val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect()
58-
val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect()
59-
val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect()
60-
assert (defaultSorted === externalSorted)
61-
assert (unsafeSorted === externalSorted)
57+
58+
checkAnswer(
59+
Random.shuffle(inputData),
60+
(input: SparkPlan) => new UnsafeExternalSort(sortOrder, global = false, input),
61+
inputData
62+
)
6263
}
6364
}

0 commit comments

Comments
 (0)