diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d5c7bb722921..1cc12d902218 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -203,8 +203,17 @@ case object ParseErrorListener extends BaseErrorListener { charPositionInLine: Int, msg: String, e: RecognitionException): Unit = { - val position = Origin(Some(line), Some(charPositionInLine)) - throw new ParseException(None, msg, position, position) + val (start, stop) = offendingSymbol match { + case token: CommonToken => + val start = Origin(Some(line), Some(token.getCharPositionInLine)) + val length = token.getStopIndex - token.getStartIndex + 1 + val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) + (start, stop) + case _ => + val start = Origin(Some(line), Some(charPositionInLine)) + (start, start) + } + throw new ParseException(None, msg, start, stop) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index baaf01800b33..96a37992e8c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.SparkFunSuite * Test various parser errors. */ class ErrorParserSuite extends SparkFunSuite { - def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + def intercept(sql: String, line: Int, startPosition: Int, stopPosition: Int, + messages: String*): Unit = { val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) // Check position. @@ -30,6 +31,8 @@ class ErrorParserSuite extends SparkFunSuite { assert(e.line.get === line) assert(e.startPosition.isDefined) assert(e.startPosition.get === startPosition) + assert(e.stop.startPosition.isDefined) + assert(e.stop.startPosition.get === stopPosition) // Check messages. val error = e.getMessage @@ -39,23 +42,24 @@ class ErrorParserSuite extends SparkFunSuite { } test("no viable input") { - intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + intercept("select ((r + 1) ", 1, 16, 16, + "no viable alternative at input", "----------------^^^") } test("extraneous input") { - intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") - intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + intercept("select 1 1", 1, 9, 10, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, 13, "extraneous input", "------------^^^") } test("mismatched input") { - intercept("select * from r order by q from t", 1, 27, + intercept("select * from r order by q from t", 1, 27, 31, "mismatched input", "---------------------------^^^") - intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, 4, "mismatched input", "^^^") } test("semantic errors") { - intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, 11, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") }