diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 8e30349f50f0..1eee09e60ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._ import com.google.common.util.concurrent.AtomicLongMap import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.util.Utils object RuleExecutor { protected val timeMap = AtomicLongMap.create[String]() @@ -46,15 +48,24 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** * An execution strategy for rules that indicates the maximum number of executions. If the * execution reaches fix point (i.e. converge) before maxIterations, it will stop. + * If throwsExceptionUponMaxIterations is equal to true, it will issue an exception + * TreeNodeException. */ - abstract class Strategy { def maxIterations: Int } + abstract class Strategy { + def maxIterations: Int + def throwsExceptionUponMaxIterations: Boolean + } /** A strategy that only runs once. */ - case object Once extends Strategy { val maxIterations = 1 } + case object Once extends Strategy { + override val maxIterations = 1 + override val throwsExceptionUponMaxIterations = false + } /** A strategy that runs until fix point or maxIterations times, whichever comes first. */ - case class FixedPoint(maxIterations: Int) extends Strategy - + case class FixedPoint(maxIterations: Int) extends Strategy { + override val throwsExceptionUponMaxIterations: Boolean = Utils.isTesting + } /** A batch of rules. */ protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) @@ -98,7 +109,12 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + val msg = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" + if (batch.strategy.throwsExceptionUponMaxIterations) { + throw new TreeNodeException(curPlan, msg, null) + } else { + logTrace(msg) + } } continue = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index a7de7b052bdc..d14e24acdf74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { @@ -46,9 +48,13 @@ class RuleExecutorSuite extends SparkFunSuite { test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { + System.setProperty("spark.testing", "true") val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } - assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) + val message = intercept[TreeNodeException[LogicalPlan]] { + ToFixedPoint.execute(Literal(100)) + }.getMessage + assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } }