@@ -21,9 +21,13 @@ import scala.util.control.NonFatal
2121import scala .reflect .runtime .universe .TypeTag
2222
2323import org .apache .spark .SparkFunSuite
24+
25+ import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
26+ import org .apache .spark .sql .catalyst .expressions .BoundReference
27+ import org .apache .spark .sql .catalyst .util ._
28+
2429import org .apache .spark .sql .test .TestSQLContext
2530import org .apache .spark .sql .{Row , DataFrame }
26- import org .apache .spark .sql .catalyst .util ._
2731
2832/**
2933 * Base class for writing tests for individual physical operators. For an example of how this
@@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite {
4852 }
4953 }
5054
55+ /**
56+ * Runs the plan and makes sure the answer matches the expected result.
57+ * @param input the input data to be used.
58+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
59+ * the physical operator that's being tested.
60+ * @param expectedAnswer the expected result in a [[Seq ]] of [[Product ]]s.
61+ */
62+ protected def checkAnswer [A <: Product : TypeTag ](
63+ input : DataFrame ,
64+ planFunction : SparkPlan => SparkPlan ,
65+ expectedAnswer : Seq [A ]): Unit = {
66+ val expectedRows = expectedAnswer.map(Row .fromTuple)
67+ SparkPlanTest .checkAnswer(input, planFunction, expectedRows) match {
68+ case Some (errorMessage) => fail(errorMessage)
69+ case None =>
70+ }
71+ }
72+
5173 /**
5274 * Runs the plan and makes sure the answer matches the expected result.
5375 * @param input the input data to be used.
@@ -87,6 +109,23 @@ object SparkPlanTest {
87109
88110 val outputPlan = planFunction(input.queryExecution.sparkPlan)
89111
112+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
113+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
114+ val resolvedPlan = outputPlan transform {
115+ case plan : SparkPlan =>
116+ val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
117+ case (a, i) =>
118+ (a.name, BoundReference (i, a.dataType, a.nullable))
119+ }.toMap
120+
121+ plan.transformExpressions {
122+ case UnresolvedAttribute (Seq (u)) =>
123+ inputMap.get(u).getOrElse {
124+ sys.error(s " Invalid Test: Cannot resolve $u given input ${inputMap}" )
125+ }
126+ }
127+ }
128+
90129 def prepareAnswer (answer : Seq [Row ]): Seq [Row ] = {
91130 // Converts data to types that we can do equality comparison using Scala collections.
92131 // For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -105,7 +144,7 @@ object SparkPlanTest {
105144 }
106145
107146 val sparkAnswer : Seq [Row ] = try {
108- outputPlan .executeCollect().toSeq
147+ resolvedPlan .executeCollect().toSeq
109148 } catch {
110149 case NonFatal (e) =>
111150 val errorMessage =
0 commit comments