diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1a9dbcae8c083..522c8ccb3e960 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1317,8 +1317,10 @@ private[spark] object Utils extends Logging { val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) + val testClassName = System.getProperty("spark.callstack.testClass") + val isSparkTestSuiteClass = (testClassName != null) && className.startsWith(testClassName) // If the class is a Spark internal class or a Scala class, then exclude. - isSparkClass || isScalaClass + (isSparkClass || isScalaClass) && !isSparkTestSuiteClass } /** @@ -1328,7 +1330,8 @@ private[spark] object Utils extends Logging { * * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): CallSite = { + def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): + CallSite = { // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD // transformation, a SparkContext function (such as parallelize), or anything else that leads diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2ec46216e1cdb..b66dfeec83b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -96,14 +96,14 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) + ExprCode(ctx.registerComment(this.toOriginString), subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val ve = doGenCode(ctx, ExprCode("", isNull, value)) if (ve.code.nonEmpty) { // Add `this` in the comment. - ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim) + ve.copy(code = s"${ctx.registerComment(this.toOriginString)}\n" + ve.code.trim) } else { ve } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 644a5b28a2151..43b0ec1880803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -50,7 +50,7 @@ object ExpressionSet { class ExpressionSet protected( protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, protected val originals: mutable.Buffer[Expression] = new ArrayBuffer) - extends Set[Expression] { + extends Set[Expression] with Serializable { protected def add(e: Expression): Unit = { if (!baseSet.contains(e.canonicalized)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6a5a3e7933eea..e5662b073123c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -35,7 +35,7 @@ trait CodegenFallback extends Expression { val idx = ctx.references.length ctx.references += this val objectTerm = ctx.freshName("obj") - val placeHolder = ctx.registerComment(this.toString) + val placeHolder = ctx.registerComment(this.toOriginString) if (nullable) { ev.copy(code = s""" $placeHolder diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d687a85c18b63..d2b8330223d69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -98,7 +98,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { case e: ParseException => throw e.withCommand(command) case e: AnalysisException => - val position = Origin(e.line, e.startPosition) + val position = Origin(None, e.line, e.startPosition) throw new ParseException(Option(command), e.message, position, position) } } @@ -150,7 +150,7 @@ case object ParseErrorListener extends BaseErrorListener { charPositionInLine: Int, msg: String, e: RecognitionException): Unit = { - val position = Origin(Some(line), Some(charPositionInLine)) + val position = Origin(None, Some(line), Some(charPositionInLine)) throw new ParseException(None, msg, position, position) } } @@ -176,7 +176,7 @@ class ParseException( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p)) => + case Origin(_, Some(l), Some(p)) => builder ++= s"(line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9619884edeafe..ad89afbad826f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -74,7 +74,7 @@ object ParserUtils { /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { - Origin(Option(token.getLine), Option(token.getCharPositionInLine)) + Origin(None, Option(token.getLine), Option(token.getCharPositionInLine)) } /** Assert if a condition holds. If it doesn't throw a parse exception. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 3ebd815dce32c..9b3181dd2cfbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -41,6 +41,7 @@ import org.apache.spark.util.Utils private class MutableInt(var i: Int) case class Origin( + var callSite: Option[String] = None, line: Option[Int] = None, startPosition: Option[Int] = None) @@ -58,15 +59,15 @@ object CurrentOrigin { def reset(): Unit = value.set(Origin()) - def setPosition(line: Int, start: Int): Unit = { + def setPosition(callSite: String, line: Int, start: Int): Unit = { value.set( - value.get.copy(line = Some(line), startPosition = Some(start))) + value.get.copy(callSite = Some(callSite), line = Some(line), startPosition = Some(start))) } def withOrigin[A](o: Origin)(f: => A): A = { + val current = get set(o) - val ret = try f finally { reset() } - reset() + val ret = try f finally { set(current) } ret } } @@ -442,6 +443,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString + def toOriginString: String = + if (this.origin.callSite.isDefined) { + this.toString + " @ " + this.origin.callSite.get + } else { + this.toString + } + /** Returns a string representation of the nodes in this tree */ def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 4005087dad05a..3d2a624ba3b30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,6 +155,16 @@ package object util { def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. The length is capped at 128 characters. + */ + def toCommentSafeString(str: String): String = { + val len = math.min(str.length, 128) + val suffix = if (str.length > len) "..." else "" + str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6a188e7e55126..83a6ae9e3fb74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.trees import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SparkContext import org.apache.spark.SparkFunSuite +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -117,7 +119,7 @@ class TreeNodeSuite extends SparkFunSuite { } test("preserves origin") { - CurrentOrigin.setPosition(1, 1) + CurrentOrigin.setPosition("TreeNodeSuite.scala:120", 1, 1) val add = Add(Literal(1), Literal(1)) CurrentOrigin.reset() @@ -125,10 +127,31 @@ class TreeNodeSuite extends SparkFunSuite { case Literal(1, _) => Literal(2) } + assert(transformed.origin.callSite.isDefined) assert(transformed.origin.line.isDefined) assert(transformed.origin.startPosition.isDefined) } + test("preserves origin thru SerDe") { + val sc = new SparkContext("local", "test") + val callSite = "TreeNodeSuite.scala:137" + val line = 1 + val startPosition = 2 + CurrentOrigin.setPosition(callSite, line, startPosition) + val add = Add(Literal(1), Literal(2)) + + val ser = sc.env.closureSerializer.newInstance() + val serBinary = ser.serialize(add) + val deadd = ser.deserialize[Expression](serBinary, Thread.currentThread.getContextClassLoader) + + assert(deadd.origin.callSite.isDefined && + deadd.origin.callSite.get == callSite) + assert(deadd.origin.line.isDefined && + deadd.origin.line.get == line) + assert(deadd.origin.startPosition.isDefined && + deadd.origin.startPosition.get == startPosition) + } + test("foreach up") { val actual = new ArrayBuffer[String]() val expected = Seq("1", "2", "3", "4", "-", "*", "+") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 713f7941beeb2..898acb2f6da51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,10 +26,12 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils.getCallSite private[sql] object Column { @@ -46,6 +48,16 @@ private[sql] object Column { case expr => usePrettyExpression(expr).sql } } + + @scala.annotation.varargs + def updateExpressionsOrigin(cols: Column*): Unit = { + // Update Expression.origin using the callSite of an operation + val callSite = org.apache.spark.util.Utils.getCallSite().shortForm + cols.map(col => col.expr.foreach(e => e.origin.callSite = Some(callSite))) + // Update CurrentOrigin for setting origin for LogicalPlan node + CurrentOrigin.set( + Origin(Some(callSite), CurrentOrigin.get.line, CurrentOrigin.get.startPosition)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 369b772d322c0..762e42217e40f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -683,6 +683,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + Column.updateExpressionsOrigin(joinExprs) // Note that in this function, we introduce a hack in the case of self-join to automatically // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. // Consider this case: df.join(df, df("key") === df("key")) @@ -967,6 +968,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { + Column.updateExpressionsOrigin(cols : _*) Project(cols.map(_.named), logicalPlan) } @@ -1111,6 +1113,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { + Column.updateExpressionsOrigin(condition) Filter(condition.expr, logicalPlan) } @@ -1173,6 +1176,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { + Column.updateExpressionsOrigin(cols : _*) RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) } @@ -1197,6 +1201,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { + Column.updateExpressionsOrigin(cols : _*) RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType) } @@ -1221,6 +1226,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def cube(cols: Column*): RelationalGroupedDataset = { + Column.updateExpressionsOrigin(cols : _*) RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType) } @@ -1419,7 +1425,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) + def agg(expr: Column, exprs: Column*): DataFrame = { + Column.updateExpressionsOrigin(exprs : _*) + groupBy().agg(expr, exprs : _*) + } /** * Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function @@ -1608,6 +1617,7 @@ class Dataset[T] private[sql]( */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + Column.updateExpressionsOrigin(input : _*) val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) @@ -1671,6 +1681,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def withColumn(colName: String, col: Column): DataFrame = { + Column.updateExpressionsOrigin(col) val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldReplace = output.exists(f => resolver(f.name, colName)) @@ -1692,6 +1703,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + Column.updateExpressionsOrigin(col) val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldReplace = output.exists(f => resolver(f.name, colName)) @@ -1782,6 +1794,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def drop(col: Column): DataFrame = { + Column.updateExpressionsOrigin(col) val expression = col match { case Column(u: UnresolvedAttribute) => queryExecution.analyzed.resolveQuoted( @@ -2218,6 +2231,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { + Column.updateExpressionsOrigin(partitionExprs : _*) RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) } @@ -2233,6 +2247,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { + Column.updateExpressionsOrigin(partitionExprs : _*) RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) } @@ -2528,6 +2543,7 @@ class Dataset[T] private[sql]( } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + Column.updateExpressionsOrigin(sortExprs : _*) val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CodegenEmbedFileLineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CodegenEmbedFileLineSuite.scala new file mode 100644 index 0000000000000..c6fee0f49fb59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CodegenEmbedFileLineSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.core.expressions.codegen + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.execution.debug.codegenString +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.ResetSystemProperties + +class CodegenEmbedFileLineSuite extends PlanTest with SharedSQLContext + with ResetSystemProperties { + import testImplicits._ + sparkConf.set("spark.sql.codegen.comments", "true") + + test("filter String") { + val df = sparkContext.parallelize(1 to 1, 1).map(i => (i, -i)).toDF("k", "v") + .filter("k > 0") + validate(df, Array(" > 0\\) @ filter at CodegenEmbedFileLineSuite.scala:34")) + } + + test("select Column") { + val df = sparkContext.parallelize(1 to 1, 1).toDF + .select($"value" + 1) + validate(df, Array(" \\+ 1\\) @ select at CodegenEmbedFileLineSuite.scala:40")) + } + + test("selectExpr String") { + val df = sparkContext.parallelize(1 to 1, 1).toDF + .selectExpr("value + 2") + validate(df, Array(" \\+ 2\\) @ selectExpr at CodegenEmbedFileLineSuite.scala:46")) + } + + test("filter Strings (two filters are combined into one plan") { + val df = sparkContext.parallelize(1 to 1, 1).map(i => (i, -i)).toDF("k", "v") + .filter("k > 0") + .filter("v > 1") + validate(df, + Array(" > 0\\) @ filter at CodegenEmbedFileLineSuite.scala:52", + " > 1\\) @ filter at CodegenEmbedFileLineSuite.scala:53"), + Array(" > 1\\) @ filter at CodegenEmbedFileLineSuite.scala:52", + " > 0\\) @ filter at CodegenEmbedFileLineSuite.scala:53")) + } + + test("selectExpr Strings") { + val df = sparkContext.parallelize(1 to 1, 1).map(i => (i, -i)).toDF("k", "v") + .selectExpr("k + 2", "v - 2") + validate(df, + Array(" \\+ 2\\) @ selectExpr at CodegenEmbedFileLineSuite.scala:63", + " - 2\\) @ selectExpr at CodegenEmbedFileLineSuite.scala:63")) + } + + test("select and selectExpr") { + val df = sparkContext.parallelize(1 to 1, 1).toDF + val df1 = df.select($"value" + 1) + val df2 = df.selectExpr("value + 2") + validate(df1, + Array(" \\+ 1\\) @ select at CodegenEmbedFileLineSuite.scala:71"), + Array(" \\+ 2\\) @ select at CodegenEmbedFileLineSuite.scala:72")) + validate(df2, + Array(" \\+ 2\\) @ selectExpr at CodegenEmbedFileLineSuite.scala:72"), + Array(" \\+ 1\\) @ selectExpr at CodegenEmbedFileLineSuite.scala:71")) + } + + test("filter and select") { + val df = sparkContext.parallelize(1 to 1, 1).toDF + val df1 = df.filter("value > 0") + val df2 = df1.select($"value" * 2) + validate(df2, + Array(" > 0\\) @ filter at CodegenEmbedFileLineSuite.scala:83", + " \\* 2\\) @ select at CodegenEmbedFileLineSuite.scala:84")) + } + + test("no transformation") { + val df = sparkContext.parallelize(1 to 1, 1).toDF + validate(df, + Array.empty, + Array("CodegenEmbedFileLineSuite.scala")) + } + + + def validate(df: DataFrame, + expected: Array[String] = Array.empty, unexpected: Array[String] = Array.empty): Unit = { + val logicalPlan = df.logicalPlan + // As LogicalPlan.resolveOperators does, + // this routine also updates CurrentOrigin by logicalPlan.origin + val cg = CurrentOrigin.withOrigin(logicalPlan.origin) { + val queryExecution = sqlContext.sessionState.executePlan(logicalPlan) + codegenString(queryExecution.executedPlan) + } + + if (cg.contains("Found 0 WholeStageCodegen subtrees")) { + return + } + + expected.foreach { string => + if (!string.r.findFirstIn(cg).isDefined) { + fail( + s""" + |=== FAIL: generated code must include: "$string" === + |$cg + """.stripMargin + ) + } + } + unexpected.foreach { string => + if (string.r.findFirstIn(cg).isDefined) { + fail( + s""" + |=== FAIL: generated code must not include: "$string" === + |$cg + """.stripMargin + ) + } + } + } + + override def beforeEach() { + super.beforeEach() + System.setProperty("spark.callstack.testClass", this.getClass.getName) + } +}