Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ import scala.util.parsing.input.CharArrayReader.EofCh

import org.apache.spark.sql.catalyst.plans.logical._

private[sql] object KeywordNormalizer {
def apply(str: String): String = str.toLowerCase()
}

private[sql] abstract class AbstractSparkSQLParser
extends StandardTokenParsers with PackratParsers {

Expand All @@ -42,7 +38,7 @@ private[sql] abstract class AbstractSparkSQLParser
}

protected case class Keyword(str: String) {
def normalize: String = KeywordNormalizer(str)
def normalize: String = lexical.normalizeKeyword(str)
def parser: Parser[String] = normalize
}

Expand Down Expand Up @@ -90,13 +86,16 @@ class SqlLexical extends StdLexical {
reserved ++= keywords
}

/* Normal the keyword string */
def normalizeKeyword(str: String): String = str.toLowerCase

delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
)

protected override def processIdent(name: String) = {
val token = KeywordNormalizer(name)
val token = normalizeKeyword(name)
if (reserved contains token) Keyword(token) else Identifier(name)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* Root class of SQL Parser Dialect, and we don't guarantee the binary
* compatibility for the future release, let's keep it as the internal
* interface for advanced user.
*
*/
@DeveloperApi
abstract class Dialect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An abstract interface for adding a new SQL dialect.  A `Dialect` is responsible for creating a logical plan from a string representation of a query.  Since the `LogicalPlan` interface is not a public stable API, custom dialects will likely be tied to specific Spark releases.

Explicitly annotate this as an @DeveloperAPI.

// this is the main function that will be implemented by sql parser.
def parse(sqlText: String): LogicalPlan
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ package object errors {
}
}

class DialectException(msg: String, cause: Throwable) extends Exception(msg, cause)

/**
* Wraps any exceptions that are thrown while executing `f` in a
* [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`.
Expand Down
82 changes: 68 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._
import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import com.google.common.reflect.TypeToken

Expand All @@ -32,9 +33,11 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.Dialect
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
Expand All @@ -44,6 +47,45 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}

/**
* Currently we support the default dialect named "sql", associated with the class
* [[DefaultDialect]]
*
* And we can also provide custom SQL Dialect, for example in Spark SQL CLI:
* {{{
*-- switch to "hiveql" dialect
* spark-sql>SET spark.sql.dialect=hiveql;
* spark-sql>SELECT * FROM src LIMIT 1;
*
*-- switch to "sql" dialect
* spark-sql>SET spark.sql.dialect=sql;
* spark-sql>SELECT * FROM src LIMIT 1;
*
*-- register the new SQL dialect
* spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
* spark-sql> SELECT * FROM src LIMIT 1;
*
*-- register the non-exist SQL dialect
* spark-sql> SET spark.sql.dialect=NotExistedClass;
* spark-sql> SELECT * FROM src LIMIT 1;
*
*-- Exception will be thrown and switch to dialect
*-- "sql" (for SQLContext) or
*-- "hiveql" (for HiveContext)
* }}}
*/
private[spark] class DefaultDialect extends Dialect {
@transient
protected val sqlParser = {
val catalystSqlParser = new catalyst.SqlParser
new SparkSQLParser(catalystSqlParser.parse)
}

override def parse(sqlText: String): LogicalPlan = {
sqlParser.parse(sqlText)
}
}

/**
* The entry point for working with structured data (rows and columns) in Spark. Allows the
* creation of [[DataFrame]] objects as well as the execution of SQL queries.
Expand Down Expand Up @@ -132,17 +174,27 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer

@transient
protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))

@transient
protected[sql] val sqlParser = {
val fallback = new catalyst.SqlParser
new SparkSQLParser(fallback.parse(_))
protected[sql] val ddlParser = new DDLParser((sql: String) => { getSQLDialect().parse(sql) })

protected[sql] def getSQLDialect(): Dialect = {
try {
val clazz = Utils.classForName(dialectClassName)
clazz.newInstance().asInstanceOf[Dialect]
} catch {
case NonFatal(e) =>
// Since we didn't find the available SQL Dialect, it will fail even for SET command:
// SET spark.sql.dialect=sql; Let's reset as default dialect automatically.
val dialect = conf.dialect
// reset the sql dialect
conf.unsetConf(SQLConf.DIALECT)
// throw out the exception, and the default sql dialect will take effect for next query.
throw new DialectException(
s"""Instantiating dialect '$dialect' failed.
|Reverting to default dialect '${conf.dialect}'""".stripMargin, e)
}
}

protected[sql] def parseSql(sql: String): LogicalPlan = {
ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql))
}
protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false)

protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))

Expand All @@ -156,6 +208,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val defaultSession = createSession()

protected[sql] def dialectClassName = if (conf.dialect == "sql") {
classOf[DefaultDialect].getCanonicalName
} else {
conf.dialect
}

sparkContext.getConf.getAll.foreach {
case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
case _ =>
Expand Down Expand Up @@ -945,11 +1003,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group basic
*/
def sql(sqlText: String): DataFrame = {
if (conf.dialect == "sql") {
DataFrame(this, parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}")
}
DataFrame(this, parseSql(sqlText))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ private[sql] class DDLParser(
parseQuery: String => LogicalPlan)
extends AbstractSparkSQLParser with DataTypeParser with Logging {

def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
try {
Some(parse(input))
parse(input)
} catch {
case ddlException: DDLException => throw ddlException
case _ if !exceptionOnError => None
case _ if !exceptionOnError => parseQuery(input)
case x: Throwable => throw x
}
}
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}

import org.apache.spark.sql.types._

/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultDialect

class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
Expand Down Expand Up @@ -64,6 +69,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
}

test("SQL Dialect Switching to a new SQL parser") {
val newContext = new SQLContext(TestSQLContext.sparkContext)
newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName())
assert(newContext.getSQLDialect().getClass === classOf[MyDialect])
assert(newContext.sql("SELECT 1").collect() === Array(Row(1)))
}

test("SQL Dialect Switch to an invalid parser with alias") {
val newContext = new SQLContext(TestSQLContext.sparkContext)
newContext.sql("SET spark.sql.dialect=MyTestClass")
intercept[DialectException] {
newContext.sql("SELECT 1")
}
// test if the dialect set back to DefaultSQLDialect
assert(newContext.getSQLDialect().getClass === classOf[DefaultDialect])
}

test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
sql("SELECT a FROM testData2 SORT BY a"),
Expand Down
41 changes: 25 additions & 16 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive
import java.io.{BufferedReader, InputStreamReader, PrintStream}
import java.sql.Timestamp

import org.apache.hadoop.hive.ql.parse.VariableSubstitution
import org.apache.spark.sql.catalyst.Dialect

import scala.collection.JavaConversions._
import scala.language.implicitConversions

Expand All @@ -42,6 +45,15 @@ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNative
import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy}
import org.apache.spark.sql.types._

/**
* This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
*/
private[hive] class HiveQLDialect extends Dialect {
override def parse(sqlText: String): LogicalPlan = {
HiveQl.parseSql(sqlText)
}
}

/**
* An instance of the Spark SQL execution engine that integrates with data stored in Hive.
* Configuration for Hive is read from hive-site.xml on the classpath.
Expand Down Expand Up @@ -81,25 +93,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected[sql] def convertCTAS: Boolean =
getConf("spark.sql.hive.convertCTAS", "false").toBoolean

override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)

@transient
protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_))

override def sql(sqlText: String): DataFrame = {
val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false)
DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
protected[sql] lazy val substitutor = new VariableSubstitution()

protected[sql] override def parseSql(sql: String): LogicalPlan = {
super.parseSql(substitutor.substitute(hiveconf, sql))
}

override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)

/**
* Invalidate and refresh all the cached the metadata of the given table. For performance reasons,
* Spark SQL or the external data source library it uses might cache certain metadata about a
Expand Down Expand Up @@ -356,6 +359,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
}

override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") {
classOf[HiveQLDialect].getCanonicalName
} else {
super.dialectClassName
}

@transient
private val hivePlanner = new SparkPlanner with HiveStrategies {
val hiveContext = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")

// TODO as in unit test, conf.clear() probably be called, all of the value will be cleared.
// The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
}
}

Expand Down
Loading