diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e83202d9e5ee3..920c1922c522c 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2907,6 +2907,12 @@ ], "sqlState" : "42613" }, + "INVALID_REGEXP_REPLACE" : { + "message" : [ + "Could not perform regexp_replace for source = \"\", pattern = \"\", replacement = \"\" and position = ." + ], + "sqlState" : "22023" + }, "INVALID_SAVE_MODE" : { "message" : [ "The specified save mode is invalid. Valid save modes include \"append\", \"overwrite\", \"ignore\", \"error\", \"errorifexists\", and \"default\"." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 970397c76a1cd..52460533efbe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -22,6 +22,7 @@ import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException} import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal import org.apache.commons.text.StringEscapeUtils @@ -700,7 +701,13 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio m.region(position, source.length) result.delete(0, result.length()) while (m.find) { - m.appendReplacement(result, lastReplacement) + try { + m.appendReplacement(result, lastReplacement) + } catch { + case NonFatal(e) => + throw QueryExecutionErrors.invalidRegexpReplaceError(s.toString, + p.toString, r.toString, i.asInstanceOf[Int], e) + } } m.appendTail(result) UTF8String.fromString(result.toString) @@ -748,7 +755,16 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio $matcher.region($position, $source.length()); while ($matcher.find()) { - $matcher.appendReplacement($termResult, $termLastReplacement); + try { + $matcher.appendReplacement($termResult, $termLastReplacement); + } catch (Throwable e) { + if (scala.util.control.NonFatal.apply(e)) { + throw QueryExecutionErrors.invalidRegexpReplaceError($source, $regexp.toString(), + $rep.toString(), $pos, e); + } else { + throw e; + } + } } $matcher.appendTail($termResult); ${ev.value} = UTF8String.fromString($termResult.toString()); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 4a23e9766fc5d..aebdf1160d808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -362,6 +362,24 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "groupIndex" -> groupIndex.toString())) } + def invalidRegexpReplaceError( + source: String, + pattern: String, + replacement: String, + position: Int, + cause: Throwable): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_REGEXP_REPLACE", + messageParameters = Map( + "source" -> source, + "pattern" -> pattern, + "replacement" -> replacement, + "position" -> position.toString + ), + cause = cause + ) + } + def invalidUrlError(url: UTF8String, e: URISyntaxException): SparkIllegalArgumentException = { new SparkIllegalArgumentException( errorClass = "INVALID_URL", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ec240d71b851f..ca47073f4ae4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.execution.FormattedMode +import org.apache.spark.sql.execution.{FormattedMode, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -1356,4 +1356,34 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { } } } + + test("RegExpReplace throws the right exception when replace fails on a particular row") { + val tableName = "regexpReplaceException" + withTable(tableName) { + sql(s"CREATE TABLE IF NOT EXISTS $tableName(s STRING)") + sql(s"INSERT INTO $tableName VALUES('first last')") + Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + val query = s"SELECT regexp_replace(s, '(?[a-zA-Z]+) (?[a-zA-Z]+)', " + + s"'$$3 $$1') FROM $tableName" + val df = sql(query) + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec] == (codegenMode == "CODEGEN_ONLY")) + val exception = intercept[SparkRuntimeException] { + df.collect() + } + checkError( + exception = exception, + condition = "INVALID_REGEXP_REPLACE", + parameters = Map( + "source" -> "first last", + "pattern" -> "(?[a-zA-Z]+) (?[a-zA-Z]+)", + "replacement" -> "$3 $1", + "position" -> "1") + ) + assert(exception.getCause.getMessage.contains("No group 3")) + } + } + } + } }