From 78ba03d7612992e334d4dc047a88e2211506adda Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 3 Sep 2024 12:29:55 +0200 Subject: [PATCH 1/9] Introduce LEAVE and ITERATE statements --- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 2 + .../sql/catalyst/parser/SqlBaseParser.g4 | 10 +++ .../sql/catalyst/parser/AstBuilder.scala | 46 ++++++++++++- .../parser/SqlScriptingLogicalOperators.scala | 4 ++ .../scripting/SqlScriptingExecutionNode.scala | 67 +++++++++++++++++-- .../scripting/SqlScriptingInterpreter.scala | 17 +++-- 6 files changed, 133 insertions(+), 13 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index acfc0011f5d0..6793cb46852b 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -276,6 +276,7 @@ INTO: 'INTO'; INVOKER: 'INVOKER'; IS: 'IS'; ITEMS: 'ITEMS'; +ITERATE: 'ITERATE'; JOIN: 'JOIN'; KEYS: 'KEYS'; LANGUAGE: 'LANGUAGE'; @@ -283,6 +284,7 @@ LAST: 'LAST'; LATERAL: 'LATERAL'; LAZY: 'LAZY'; LEADING: 'LEADING'; +LEAVE: 'LEAVE'; LEFT: 'LEFT'; LIKE: 'LIKE'; ILIKE: 'ILIKE'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 5b8805821b04..1a6ba243ac7b 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -65,6 +65,8 @@ compoundStatement | beginEndCompoundBlock | ifElseStatement | whileStatement + | leaveStatement + | iterateStatement ; setStatementWithOptionalVarKeyword @@ -83,6 +85,14 @@ ifElseStatement (ELSE elseBody=compoundBody)? END IF ; +leaveStatement + : LEAVE multipartIdentifier + ; + +iterateStatement + : ITERATE multipartIdentifier + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b0922542c562..8667261fa68f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} -import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} @@ -261,6 +261,50 @@ class AstBuilder extends DataTypeAstBuilder WhileStatement(condition, body, Some(labelText)) } + private def leaveOrIterateContextHasLabel( + ctx: RuleContext, label: String, isLeave: Boolean): Boolean = { + ctx match { + case c: BeginEndCompoundBlockContext + if isLeave && + Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + => true + case c: WhileStatementContext + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + => true + case _ => false + } + } + + override def visitLeaveStatement(ctx: LeaveStatementContext): LeaveStatement = { + val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) + var parentCtx = ctx.parent + + while (Option(parentCtx).isDefined) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = true)) { + return LeaveStatement(labelText) + } + parentCtx = parentCtx.parent + } + + throw SparkException.internalError("No matching block (with same label) found!") + } + + override def visitIterateStatement(ctx: IterateStatementContext): IterateStatement = { + val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) + var parentCtx = ctx.parent + + while (Option(parentCtx).isDefined) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = false)) { + return IterateStatement(labelText) + } + parentCtx = parentCtx.parent + } + + throw SparkException.internalError("No matching block (with same label) found!") + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { Option(ctx.statement().asInstanceOf[ParserRuleContext]) .orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 4a5259f09a8a..6b7ffabb625b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -89,3 +89,7 @@ case class WhileStatement( condition: SingleStatement, body: CompoundBody, label: Option[String]) extends CompoundPlanStatement + +case class LeaveStatement(label: String) extends CompoundPlanStatement + +case class IterateStatement(label: String) extends CompoundPlanStatement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 7085366c3b7a..837d24702743 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -141,11 +141,14 @@ class SingleStatementExec( * @param collection * Collection of child execution nodes. */ -abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) +abstract class CompoundNestedStatementIteratorExec( + collection: Seq[CompoundStatementExec], + label: Option[String] = None) extends NonLeafStatementExec { private var localIterator = collection.iterator private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + private var stopIteration = false private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { @@ -157,7 +160,7 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState case _ => throw SparkException.internalError( "Unknown statement type encountered during SQL script interpretation.") } - localIterator.hasNext || childHasNext + !stopIteration && (localIterator.hasNext || childHasNext) } @scala.annotation.tailrec @@ -165,12 +168,21 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState curr match { case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") + case Some(leaveStatement: LeaveStatementExec) => + handleLeaveStatement(leaveStatement) + curr = None + leaveStatement case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { - body.getTreeIterator.next() + body.getTreeIterator.next() match { + case leaveStatement: LeaveStatementExec => + handleLeaveStatement(leaveStatement) + leaveStatement + case other => other + } } else { curr = if (localIterator.hasNext) Some(localIterator.next()) else None next() @@ -187,6 +199,19 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState collection.foreach(_.reset()) localIterator = collection.iterator curr = if (localIterator.hasNext) Some(localIterator.next()) else None + stopIteration = false + } + + private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = { + if (!leaveStatement.hasBeenMatched) { + // Stop the iteration. + stopIteration = true + + // TODO: Variable cleanup (once we add SQL script execution logic). + + // Check if label has been matched. + leaveStatement.hasBeenMatched = label.isDefined && label.get.equals(leaveStatement.label) + } } } @@ -195,8 +220,8 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * @param statements * Executable nodes for nested statements within the CompoundBody. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec]) - extends CompoundNestedStatementIteratorExec(statements) +class CompoundBodyExec(statements: Seq[CompoundStatementExec], label: Option[String] = None) + extends CompoundNestedStatementIteratorExec(statements, label) /** * Executable node for IfElseStatement. @@ -282,6 +307,7 @@ class IfElseStatementExec( class WhileStatementExec( condition: SingleStatementExec, body: CompoundBodyExec, + label: Option[String], session: SparkSession) extends NonLeafStatementExec { private object WhileState extends Enumeration { @@ -308,6 +334,25 @@ class WhileStatementExec( condition case WhileState.Body => val retStmt = body.getTreeIterator.next() + + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + curr = None + return retStmt + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } + state = WhileState.Condition + curr = Some(condition) + condition.reset() + return retStmt + case _ => + } + if (!body.getTreeIterator.hasNext) { state = WhileState.Condition curr = Some(condition) @@ -326,3 +371,13 @@ class WhileStatementExec( body.reset() } } + +class LeaveStatementExec(val label: String) extends LeafStatementExec { + var hasBeenMatched: Boolean = false + override def reset(): Unit = hasBeenMatched = false +} + +class IterateStatementExec(val label: String) extends LeafStatementExec { + var hasBeenMatched: Boolean = false + override def reset(): Unit = hasBeenMatched = false +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 08b4f9728628..8a5a9774d42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, IfElseStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -71,9 +71,9 @@ case class SqlScriptingInterpreter() { private def transformTreeIntoExecutable( node: CompoundPlanStatement, session: SparkSession): CompoundStatementExec = node match { - case body: CompoundBody => + case CompoundBody(collection, label) => // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. - val variables = body.collection.flatMap { + val variables = collection.flatMap { case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) case _ => None } @@ -82,7 +82,8 @@ case class SqlScriptingInterpreter() { .map(new SingleStatementExec(_, Origin(), isInternal = true)) .reverse new CompoundBodyExec( - body.collection.map(st => transformTreeIntoExecutable(st, session)) ++ dropVariables) + collection.map(st => transformTreeIntoExecutable(st, session)) ++ dropVariables, + label) case IfElseStatement(conditions, conditionalBodies, elseBody) => val conditionsExec = conditions.map(condition => new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)) @@ -92,12 +93,16 @@ case class SqlScriptingInterpreter() { transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) new IfElseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) - case WhileStatement(condition, body, _) => + case WhileStatement(condition, body, label) => val conditionExec = new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false) val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] - new WhileStatementExec(conditionExec, bodyExec, session) + new WhileStatementExec(conditionExec, bodyExec, label, session) + case leaveStatement: LeaveStatement => + new LeaveStatementExec(leaveStatement.label) + case iterateStatement: IterateStatement => + new IterateStatementExec(iterateStatement.label) case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, From 32bfb68da7517b6eb502275e354a17048b5eda14 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 3 Sep 2024 16:15:22 +0200 Subject: [PATCH 2/9] Add missing stuff for new keywords --- docs/sql-ref-ansi-compliance.md | 2 ++ .../org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 | 4 ++++ .../scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala | 2 +- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 2 +- .../test/resources/sql-tests/results/ansi/keywords.sql.out | 2 ++ .../src/test/resources/sql-tests/results/keywords.sql.out | 2 ++ .../spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala | 2 +- .../hive/thriftserver/ThriftServerWithSparkContextSuite.scala | 2 +- 8 files changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index f5e1ddfd3c57..0ac19e2ae943 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -556,6 +556,7 @@ Below is a list of all the keywords in Spark SQL. |INVOKER|non-reserved|non-reserved|non-reserved| |IS|reserved|non-reserved|reserved| |ITEMS|non-reserved|non-reserved|non-reserved| +|ITERATE|non-reserved|non-reserved|non-reserved| |JOIN|reserved|strict-non-reserved|reserved| |KEYS|non-reserved|non-reserved|non-reserved| |LANGUAGE|non-reserved|non-reserved|reserved| @@ -563,6 +564,7 @@ Below is a list of all the keywords in Spark SQL. |LATERAL|reserved|strict-non-reserved|reserved| |LAZY|non-reserved|non-reserved|non-reserved| |LEADING|reserved|non-reserved|reserved| +|LEAVE|non-reserved|non-reserved|non-reserved| |LEFT|reserved|strict-non-reserved|reserved| |LIKE|non-reserved|non-reserved|reserved| |ILIKE|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 1a6ba243ac7b..6a23bd394c8c 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1588,10 +1588,12 @@ ansiNonReserved | INTERVAL | INVOKER | ITEMS + | ITERATE | KEYS | LANGUAGE | LAST | LAZY + | LEAVE | LIKE | ILIKE | LIMIT @@ -1937,11 +1939,13 @@ nonReserved | INVOKER | IS | ITEMS + | ITERATE | KEYS | LANGUAGE | LAST | LAZY | LEADING + | LEAVE | LIKE | LONG | ILIKE diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala index 9977dcd83d6a..3b93c470478c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -200,7 +200,7 @@ class SQLKeywordSuite extends SQLKeywordUtils { withTempDir { dir => val tmpFile = new File(dir, "tmp") val is = Thread.currentThread().getContextClassLoader - .getResourceAsStream("ansi-sql-2016-reserved-keywords.txt") + .getResourceAsStream("ansi-sql-2016-reserved- ke.txt") Files.copy(is, tmpFile.toPath) val reservedKeywordsInSql2016 = Files.readAllLines(tmpFile.toPath) .asScala.filterNot(_.startsWith("--")).map(_.trim).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 837d24702743..d59a34a2353a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 5735e5eef68e..b2f3fdda74db 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -163,6 +163,7 @@ INTO true INVOKER false IS true ITEMS false +ITERATE false JOIN true KEYS false LANGUAGE false @@ -170,6 +171,7 @@ LAST false LATERAL true LAZY false LEADING true +LEAVE false LEFT true LIKE false LIMIT false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index ca48e851e717..ce9fd580b2ff 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -163,6 +163,7 @@ INTO false INVOKER false IS false ITEMS false +ITERATE false JOIN false KEYS false LANGUAGE false @@ -170,6 +171,7 @@ LAST false LATERAL false LAZY false LEADING false +LEAVE false LEFT false LIKE false LIMIT false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 5c36f9e19e6d..7d00fb22c613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -55,7 +55,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case class TestWhile( condition: TestWhileCondition, body: CompoundBodyExec) - extends WhileStatementExec(condition, body, spark) { + extends WhileStatementExec(condition, body, None, spark) { private var callCount: Int = 0 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 7005f0e951b2..2e3457dab09b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From 61d4237cb677a76a1d2d17175080e36baffbdf95 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 3 Sep 2024 16:57:16 +0200 Subject: [PATCH 3/9] Add proper error messages --- .../resources/error/error-conditions.json | 6 +++ .../sql/catalyst/parser/AstBuilder.scala | 44 ++++++++++--------- .../spark/sql/errors/SqlScriptingErrors.scala | 13 ++++++ .../parser/SqlScriptingParserSuite.scala | 1 - 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e2725a98a63b..bfca55620cde 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2489,6 +2489,12 @@ ], "sqlState" : "F0000" }, + "INVALID_LABEL_USAGE_IN_STATEMENT" : { + "message" : [ + "The label used in the statement does not belong to any surrounding block." + ], + "sqlState" : "42K0L" + }, "INVALID_LAMBDA_FUNCTION_CALL" : { "message" : [ "Invalid lambda function call." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8667261fa68f..77a35df16e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -277,33 +277,37 @@ class AstBuilder extends DataTypeAstBuilder } } - override def visitLeaveStatement(ctx: LeaveStatementContext): LeaveStatement = { - val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) - var parentCtx = ctx.parent + override def visitLeaveStatement(ctx: LeaveStatementContext): LeaveStatement = + withOrigin(ctx) { + val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) + var parentCtx = ctx.parent - while (Option(parentCtx).isDefined) { - if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = true)) { - return LeaveStatement(labelText) + while (Option(parentCtx).isDefined) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = true)) { + return LeaveStatement(labelText) + } + parentCtx = parentCtx.parent } - parentCtx = parentCtx.parent - } - throw SparkException.internalError("No matching block (with same label) found!") - } + throw SqlScriptingErrors.invalidLabelUsageInStatement( + CurrentOrigin.get, labelText, "LEAVE") + } - override def visitIterateStatement(ctx: IterateStatementContext): IterateStatement = { - val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) - var parentCtx = ctx.parent + override def visitIterateStatement(ctx: IterateStatementContext): IterateStatement = + withOrigin(ctx) { + val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) + var parentCtx = ctx.parent - while (Option(parentCtx).isDefined) { - if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = false)) { - return IterateStatement(labelText) + while (Option(parentCtx).isDefined) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = false)) { + return IterateStatement(labelText) + } + parentCtx = parentCtx.parent } - parentCtx = parentCtx.parent - } - throw SparkException.internalError("No matching block (with same label) found!") - } + throw SqlScriptingErrors.invalidLabelUsageInStatement( + CurrentOrigin.get, labelText, "ITERATE") + } override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { Option(ctx.statement().asInstanceOf[ParserRuleContext]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 61661b1d32f3..022fe388bc15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -84,4 +84,17 @@ private[sql] object SqlScriptingErrors { cause = null, messageParameters = Map("invalidStatement" -> toSQLStmt(stmt))) } + + def invalidLabelUsageInStatement( + origin: Origin, + labelName: String, + statementType: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + cause = null, + messageParameters = Map( + "labelName" -> toSQLStmt(labelName), + "statementType" -> statementType)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 5fc3ade408bd..031dc9b5bae4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -666,7 +666,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 42") assert(whileStmt.label.contains("lbl")) - } // Helper methods From 3a0941668591c20d593f067c0b3e6b5651a80c21 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 3 Sep 2024 17:06:41 +0200 Subject: [PATCH 4/9] Fix accidental edit --- .../scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala index 3b93c470478c..9977dcd83d6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -200,7 +200,7 @@ class SQLKeywordSuite extends SQLKeywordUtils { withTempDir { dir => val tmpFile = new File(dir, "tmp") val is = Thread.currentThread().getContextClassLoader - .getResourceAsStream("ansi-sql-2016-reserved- ke.txt") + .getResourceAsStream("ansi-sql-2016-reserved-keywords.txt") Files.copy(is, tmpFile.toPath) val reservedKeywordsInSql2016 = Files.readAllLines(tmpFile.toPath) .asScala.filterNot(_.startsWith("--")).map(_.trim).toSet From b5b6b257d6ba134e7912c3a59cc54f97d7bf2593 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 3 Sep 2024 23:23:05 +0200 Subject: [PATCH 5/9] Add tests --- .../parser/SqlScriptingParserSuite.scala | 178 ++++++++++++++++++ .../SqlScriptingExecutionNodeSuite.scala | 107 ++++++++++- .../SqlScriptingInterpreterSuite.scala | 143 ++++++++++++++ 3 files changed, 426 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 031dc9b5bae4..441a20d47fd9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -668,6 +668,184 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("leave compound block") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE lbl; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 2) + assert(tree.collection.head.isInstanceOf[SingleStatement]) + assert(tree.collection(1).isInstanceOf[LeaveStatement]) + } + + test("leave while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 2) + + assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(whileStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(whileStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test ("iterate compound block - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE lbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + parameters = Map("labelName" -> "LBL", "statementType" -> "ITERATE")) + } + + test("iterate while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 2) + + assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(whileStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(whileStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + + test("leave with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE")) + } + + test("iterate with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE")) + } + + test("leave outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 1) + + val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[WhileStatement] + assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.condition.getText == "2 = 2") + + assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedWhileStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(nestedWhileStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test("iterate outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 1) + + val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[WhileStatement] + assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.condition.getText == "2 = 2") + + assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedWhileStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(nestedWhileStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 7d00fb22c613..99692339597f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -35,6 +35,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi override def reset(): Unit = () } + case class TestLeaveStatement(labelText: String) extends LeaveStatementExec(labelText) + + case class TestIterateStatement(labelText: String) extends IterateStatementExec(labelText) + case class TestIfElseCondition(condVal: Boolean, description: String) extends SingleStatementExec( parsedPlan = Project(Seq(Alias(Literal(condVal), description)()), OneRowRelation()), @@ -54,8 +58,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case class TestWhile( condition: TestWhileCondition, - body: CompoundBodyExec) - extends WhileStatementExec(condition, body, None, spark) { + body: CompoundBodyExec, + label: Option[String] = None) + extends WhileStatementExec(condition, body, label, spark) { private var callCount: Int = 0 @@ -77,6 +82,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case TestLeafStatement(testVal) => testVal case TestIfElseCondition(_, description) => description case TestWhileCondition(_, _, description) => description + case TestLeaveStatement(label) => label + case TestIterateStatement(label) => label case _ => fail("Unexpected statement type") } @@ -314,4 +321,100 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "con2", "body1", "con2", "con1")) } + test("leave compound block") { + val iter = new CompoundBodyExec( + statements = Seq( + TestLeafStatement("one"), + TestLeaveStatement("lbl") + ), + label = Some("lbl") + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("one", "lbl")) + } + + test("leave while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + TestLeaveStatement("lbl")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body1", "lbl")) + } + + test("iterate while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + TestIterateStatement("lbl"), + TestLeafStatement("body2")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body1", "lbl", "con1", "body1", "lbl", "con1")) + } + + test("leave outer loop from nested while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestWhile( + condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + TestLeaveStatement("lbl")) + ), + label = Some("lbl2") + ) + )), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2", "body1", "lbl")) + } + + test("iterate outer loop from nested while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestWhile( + condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + TestIterateStatement("lbl"), + TestLeafStatement("body2")) + ), + label = Some("lbl2") + ) + )), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "con1", "con2", "body1", "lbl", + "con1", "con2", "body1", "lbl", + "con1")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 592516de84c1..7e5419f6fc17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parseScript import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.test.SharedSparkSession @@ -536,4 +537,146 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } } + + test("leave compound block") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE lbl; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate compound block - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE lbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + parameters = Map("labelName" -> "LBL", "statementType" -> "ITERATE")) + } + + test("iterate while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE")) + } + + test("iterate with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE")) + } + + test("leave outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x= 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } } From 0e8fa23d423db85f88a2e05d0930fa940a50f441 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 4 Sep 2024 20:39:32 +0200 Subject: [PATCH 6/9] Remove test exec nodes for LEAVE and ITERATE statements --- .../SqlScriptingExecutionNodeSuite.scala | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 99692339597f..97a21c505fdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -35,10 +35,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi override def reset(): Unit = () } - case class TestLeaveStatement(labelText: String) extends LeaveStatementExec(labelText) - - case class TestIterateStatement(labelText: String) extends IterateStatementExec(labelText) - case class TestIfElseCondition(condVal: Boolean, description: String) extends SingleStatementExec( parsedPlan = Project(Seq(Alias(Literal(condVal), description)()), OneRowRelation()), @@ -82,8 +78,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case TestLeafStatement(testVal) => testVal case TestIfElseCondition(_, description) => description case TestWhileCondition(_, _, description) => description - case TestLeaveStatement(label) => label - case TestIterateStatement(label) => label + case leaveStmt: LeaveStatementExec => leaveStmt.label + case iterateStmt: IterateStatementExec => iterateStmt.label case _ => fail("Unexpected statement type") } @@ -325,7 +321,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val iter = new CompoundBodyExec( statements = Seq( TestLeafStatement("one"), - TestLeaveStatement("lbl") + new LeaveStatementExec("lbl") ), label = Some("lbl") ).getTreeIterator @@ -340,7 +336,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), - TestLeaveStatement("lbl")) + new LeaveStatementExec("lbl")) ), label = Some("lbl") ) @@ -357,7 +353,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), - TestIterateStatement("lbl"), + new IterateStatementExec("lbl"), TestLeafStatement("body2")) ), label = Some("lbl") @@ -378,7 +374,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), - TestLeaveStatement("lbl")) + new LeaveStatementExec("lbl")) ), label = Some("lbl2") ) @@ -401,7 +397,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), - TestIterateStatement("lbl"), + new IterateStatementExec("lbl"), TestLeafStatement("body2")) ), label = Some("lbl2") From 22e49fa31931007b3a1a05af80bba877ba81c945 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 4 Sep 2024 20:52:59 +0200 Subject: [PATCH 7/9] Add error class for usage of ITERATE statement for compound blocks --- .../main/resources/error/error-conditions.json | 6 ++++++ .../spark/sql/catalyst/parser/AstBuilder.scala | 16 +++++++++------- .../spark/sql/errors/SqlScriptingErrors.scala | 10 ++++++++++ .../parser/SqlScriptingParserSuite.scala | 4 ++-- .../scripting/SqlScriptingInterpreterSuite.scala | 4 ++-- 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index bfca55620cde..92f99c2e95ce 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2453,6 +2453,12 @@ }, "sqlState" : "42K0K" }, + "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND" : { + "message" : [ + "The label used in the ITERATE statement cannot belong to compound (BEGIN...END) body." + ], + "sqlState" : "42K0L" + }, "INVALID_JOIN_TYPE_FOR_JOINWITH" : { "message" : [ "Invalid join type in joinWith: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 77a35df16e4f..ede15a397ac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -262,13 +262,15 @@ class AstBuilder extends DataTypeAstBuilder } private def leaveOrIterateContextHasLabel( - ctx: RuleContext, label: String, isLeave: Boolean): Boolean = { + ctx: RuleContext, label: String, isIterate: Boolean): Boolean = { ctx match { case c: BeginEndCompoundBlockContext - if isLeave && - Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => + if (isIterate) { + throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label) + } + true case c: WhileStatementContext if Option(c.beginLabel()).isDefined && c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) @@ -283,7 +285,7 @@ class AstBuilder extends DataTypeAstBuilder var parentCtx = ctx.parent while (Option(parentCtx).isDefined) { - if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = true)) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isIterate = false)) { return LeaveStatement(labelText) } parentCtx = parentCtx.parent @@ -299,7 +301,7 @@ class AstBuilder extends DataTypeAstBuilder var parentCtx = ctx.parent while (Option(parentCtx).isDefined) { - if (leaveOrIterateContextHasLabel(parentCtx, labelText, isLeave = false)) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isIterate = true)) { return IterateStatement(labelText) } parentCtx = parentCtx.parent diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 022fe388bc15..f053179f3e24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -97,4 +97,14 @@ private[sql] object SqlScriptingErrors { "labelName" -> toSQLStmt(labelName), "statementType" -> statementType)) } + + def invalidIterateLabelUsageForCompound( + origin: Origin, + labelName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND", + cause = null, + messageParameters = Map("labelName" -> toSQLStmt(labelName))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 441a20d47fd9..c2bc698731fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -719,8 +719,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", - parameters = Map("labelName" -> "LBL", "statementType" -> "ITERATE")) + errorClass = "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND", + parameters = Map("labelName" -> "LBL")) } test("iterate while loop") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 7e5419f6fc17..e9510a950d38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -577,8 +577,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", - parameters = Map("labelName" -> "LBL", "statementType" -> "ITERATE")) + errorClass = "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND", + parameters = Map("labelName" -> "LBL")) } test("iterate while loop") { From 23bb4139662810a1eafcb33716d5269b707ae428 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 4 Sep 2024 23:03:21 +0200 Subject: [PATCH 8/9] Add missing comments --- .../parser/SqlScriptingLogicalOperators.scala | 14 +++++++ .../scripting/SqlScriptingExecutionNode.scala | 39 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 6b7ffabb625b..dbb29a71323e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -90,6 +90,20 @@ case class WhileStatement( body: CompoundBody, label: Option[String]) extends CompoundPlanStatement +/** + * Logical operator for LEAVE statement. + * The statement can be used both for compounds or any kind of loops. + * When used, the corresponding body/loop execution is skipped and the execution continues + * with the next statement after the body/loop. + * @param label Label of the compound or loop to leave. + */ case class LeaveStatement(label: String) extends CompoundPlanStatement +/** + * Logical operator for ITERATE statement. + * The statement can be used only for loops. + * When used, the rest of the loop is skipped and the loop execution continues + * with the next iteration. + * @param label Label of the loop to iterate. + */ case class IterateStatement(label: String) extends CompoundPlanStatement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index d59a34a2353a..c2e6abf184b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -140,6 +140,8 @@ class SingleStatementExec( * Implements recursive iterator logic over all child execution nodes. * @param collection * Collection of child execution nodes. + * @param label + * Label set by user or None otherwise. */ abstract class CompoundNestedStatementIteratorExec( collection: Seq[CompoundStatementExec], @@ -148,6 +150,8 @@ abstract class CompoundNestedStatementIteratorExec( private var localIterator = collection.iterator private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + + /** Used to stop the iteration in cases when LEAVE statement is encountered. */ private var stopIteration = false private lazy val treeIterator: Iterator[CompoundStatementExec] = @@ -202,6 +206,7 @@ abstract class CompoundNestedStatementIteratorExec( stopIteration = false } + /** Actions to do when LEAVE statement is encountered to stop the execution of this compound. */ private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = { if (!leaveStatement.hasBeenMatched) { // Stop the iteration. @@ -219,6 +224,8 @@ abstract class CompoundNestedStatementIteratorExec( * Executable node for CompoundBody. * @param statements * Executable nodes for nested statements within the CompoundBody. + * @param label + * Label set by user to CompoundBody or None otherwise. */ class CompoundBodyExec(statements: Seq[CompoundStatementExec], label: Option[String] = None) extends CompoundNestedStatementIteratorExec(statements, label) @@ -302,6 +309,7 @@ class IfElseStatementExec( * Executable node for WhileStatement. * @param condition Executable node for the condition. * @param body Executable node for the body. + * @param label Label set to WhileStatement by user or None otherwise. * @param session Spark session that SQL script is executed within. */ class WhileStatementExec( @@ -335,6 +343,7 @@ class WhileStatementExec( case WhileState.Body => val retStmt = body.getTreeIterator.next() + // Handle LEAVE or ITERATE statement if it has been encountered. retStmt match { case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => if (label.contains(leaveStatementExec.label)) { @@ -372,12 +381,42 @@ class WhileStatementExec( } } +/** + * Executable node for LeaveStatement. + * @param label Label of the compound or loop to leave. + */ class LeaveStatementExec(val label: String) extends LeafStatementExec { + /** + * Label specified in the LEAVE statement might not belong to the immediate surrounding compound, + * but to the any surrounding compound. + * Iteration logic is recursive, i.e. when iterating through the compound, if another + * compound is encountered, next() will be called to iterate its body. The same logic + * is applied to any other compound down the traversal tree. + * In such cases, when LEAVE statement is encountered (as the leaf of the traversal tree), + * it will be propagated upwards and the logic will try to match it to the labels of + * surrounding compounds. + * Once the match is found, this flag is set to true to indicate that search should be stopped. + */ var hasBeenMatched: Boolean = false override def reset(): Unit = hasBeenMatched = false } +/** + * Executable node for ITERATE statement. + * @param label Label of the loop to iterate. + */ class IterateStatementExec(val label: String) extends LeafStatementExec { + /** + * Label specified in the ITERATE statement might not belong to the immediate compound, + * but to the any surrounding compound. + * Iteration logic is recursive, i.e. when iterating through the compound, if another + * compound is encountered, next() will be called to iterate its body. The same logic + * is applied to any other compound down the tree. + * In such cases, when ITERATE statement is encountered (as the leaf of the traversal tree), + * it will be propagated upwards and the logic will try to match it to the labels of + * surrounding compounds. + * Once the match is found, this flag is set to true to indicate that search should be stopped. + */ var hasBeenMatched: Boolean = false override def reset(): Unit = hasBeenMatched = false } From ca986325dcca7ca5e0bbd712c9650e29f22dc9cf Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 5 Sep 2024 09:48:06 +0200 Subject: [PATCH 9/9] Group error classes --- .../resources/error/error-conditions.json | 22 ++++++++++++------- .../sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../spark/sql/errors/SqlScriptingErrors.scala | 6 ++--- .../parser/SqlScriptingParserSuite.scala | 6 ++--- .../SqlScriptingInterpreterSuite.scala | 6 ++--- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 92f99c2e95ce..7ec5709da9a4 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2453,12 +2453,6 @@ }, "sqlState" : "42K0K" }, - "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND" : { - "message" : [ - "The label used in the ITERATE statement cannot belong to compound (BEGIN...END) body." - ], - "sqlState" : "42K0L" - }, "INVALID_JOIN_TYPE_FOR_JOINWITH" : { "message" : [ "Invalid join type in joinWith: ." @@ -2495,10 +2489,22 @@ ], "sqlState" : "F0000" }, - "INVALID_LABEL_USAGE_IN_STATEMENT" : { + "INVALID_LABEL_USAGE" : { "message" : [ - "The label used in the statement does not belong to any surrounding block." + "The usage of the label is invalid." ], + "subClass" : { + "DOES_NOT_EXIST" : { + "message" : [ + "Label was used in the statement, but the label does not belong to any surrounding block." + ] + }, + "ITERATE_IN_COMPOUND" : { + "message" : [ + "ITERATE statement cannot be used with a label that belongs to a compound (BEGIN...END) body." + ] + } + }, "sqlState" : "42K0L" }, "INVALID_LAMBDA_FUNCTION_CALL" : { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ede15a397ac1..f4638920af3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -291,7 +291,7 @@ class AstBuilder extends DataTypeAstBuilder parentCtx = parentCtx.parent } - throw SqlScriptingErrors.invalidLabelUsageInStatement( + throw SqlScriptingErrors.labelDoesNotExist( CurrentOrigin.get, labelText, "LEAVE") } @@ -307,7 +307,7 @@ class AstBuilder extends DataTypeAstBuilder parentCtx = parentCtx.parent } - throw SqlScriptingErrors.invalidLabelUsageInStatement( + throw SqlScriptingErrors.labelDoesNotExist( CurrentOrigin.get, labelText, "ITERATE") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index f053179f3e24..591d2e3e53d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -85,13 +85,13 @@ private[sql] object SqlScriptingErrors { messageParameters = Map("invalidStatement" -> toSQLStmt(stmt))) } - def invalidLabelUsageInStatement( + def labelDoesNotExist( origin: Origin, labelName: String, statementType: String): Throwable = { new SqlScriptingException( origin = origin, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", cause = null, messageParameters = Map( "labelName" -> toSQLStmt(labelName), @@ -103,7 +103,7 @@ private[sql] object SqlScriptingErrors { labelName: String): Throwable = { new SqlScriptingException( origin = origin, - errorClass = "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND", + errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND", cause = null, messageParameters = Map("labelName" -> toSQLStmt(labelName))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index c2bc698731fb..465c2d408f26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -719,7 +719,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND", + errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND", parameters = Map("labelName" -> "LBL")) } @@ -761,7 +761,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE")) } @@ -776,7 +776,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e9510a950d38..5568f85fc476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -577,7 +577,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_ITERATE_LABEL_USAGE_FOR_COMPOUND", + errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND", parameters = Map("labelName" -> "LBL")) } @@ -616,7 +616,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE")) } @@ -631,7 +631,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { exception = intercept[SqlScriptingException] { parseScript(sqlScriptText) }, - errorClass = "INVALID_LABEL_USAGE_IN_STATEMENT", + errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE")) }