Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import scala.language.implicitConversions
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.PackratParsers
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SqlLexical

/**
* A parser that recognizes all HiveQL constructs together with several Spark SQL specific
* extensions like CACHE TABLE and UNCACHE TABLE.
*/
private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers {

def apply(input: String): LogicalPlan = {
// Special-case out set commands since the value fields can be
// complex to handle without RegexParsers. Also this approach
// is clearer for the several possible cases of set commands.
if (input.trim.toLowerCase.startsWith("set")) {
input.trim.drop(3).split("=", 2).map(_.trim) match {
case Array("") => // "set"
SetCommand(None, None)
case Array(key) => // "set key"
SetCommand(Some(key), None)
case Array(key, value) => // "set key=value"
SetCommand(Some(key), Some(value))
}
} else if (input.trim.startsWith("!")) {
ShellCommand(input.drop(1))
} else {
phrase(query)(new lexical.Scanner(input)) match {
case Success(r, x) => r
case x => sys.error(x.toString)
}
}
}

protected case class Keyword(str: String)

protected val CACHE = Keyword("CACHE")
protected val SET = Keyword("SET")
protected val ADD = Keyword("ADD")
protected val JAR = Keyword("JAR")
protected val TABLE = Keyword("TABLE")
protected val AS = Keyword("AS")
protected val UNCACHE = Keyword("UNCACHE")
protected val FILE = Keyword("FILE")
protected val DFS = Keyword("DFS")
protected val SOURCE = Keyword("SOURCE")

protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

protected def allCaseConverse(k: String): Parser[String] =
lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _)

protected val reservedWords =
this.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)

override val lexical = new SqlLexical(reservedWords)

protected lazy val query: Parser[LogicalPlan] =
cache | uncache | addJar | addFile | dfs | source | hiveQl

protected lazy val hiveQl: Parser[LogicalPlan] =
remainingQuery ^^ {
case r => HiveQl.createPlan(r.trim())
}

/** It returns all remaining query */
protected lazy val remainingQuery: Parser[String] = new Parser[String] {
def apply(in: Input) =
Success(
in.source.subSequence(in.offset, in.source.length).toString,
in.drop(in.source.length()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: adjust indentation

    def apply(in: Input) =
      Success(
        in.source.subSequence(in.offset, in.source.length).toString,
        in.drop(in.source.length()))

}

/** It returns all query */
protected lazy val allQuery: Parser[String] = new Parser[String] {
def apply(in: Input) =
Success(in.source.toString, in.drop(in.source.length()))
}

protected lazy val cache: Parser[LogicalPlan] =
CACHE ~ TABLE ~> ident ~ opt(AS ~> hiveQl) ^^ {
case tableName ~ None => CacheCommand(tableName, true)
case tableName ~ Some(plan) =>
CacheTableAsSelectCommand(tableName, plan)
}

protected lazy val uncache: Parser[LogicalPlan] =
UNCACHE ~ TABLE ~> ident ^^ {
case tableName => CacheCommand(tableName, false)
}

protected lazy val addJar: Parser[LogicalPlan] =
ADD ~ JAR ~> remainingQuery ^^ {
case rq => AddJar(rq.trim())
}

protected lazy val addFile: Parser[LogicalPlan] =
ADD ~ FILE ~> remainingQuery ^^ {
case rq => AddFile(rq.trim())
}

protected lazy val dfs: Parser[LogicalPlan] =
DFS ~> allQuery ^^ {
case aq => NativeCommand(aq.trim())
}

protected lazy val source: Parser[LogicalPlan] =
SOURCE ~> remainingQuery ^^ {
case rq => SourceCommand(rq.trim())
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed lots of trailing spaces in this file, please remove them.

57 changes: 13 additions & 44 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ private[hive] object HiveQl {
"TOK_CREATETABLE",
"TOK_DESCTABLE"
) ++ nativeCommands

// It parses hive sql query along with with several Spark SQL specific extensions
protected val hiveSqlParser = new ExtendedHiveQlParser

/**
* A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations
Expand Down Expand Up @@ -215,40 +218,19 @@ private[hive] object HiveQl {
def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql))

/** Returns a LogicalPlan for a given HiveQL string. */
def parseSql(sql: String): LogicalPlan = {
def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql)

/** Creates LogicalPlan for a given HiveQL string. */
def createPlan(sql: String) = {
try {
if (sql.trim.toLowerCase.startsWith("set")) {
// Split in two parts since we treat the part before the first "="
// as key, and the part after as value, which may contain other "=" signs.
sql.trim.drop(3).split("=", 2).map(_.trim) match {
case Array("") => // "set"
SetCommand(None, None)
case Array(key) => // "set key"
SetCommand(Some(key), None)
case Array(key, value) => // "set key=value"
SetCommand(Some(key), Some(value))
}
} else if (sql.trim.toLowerCase.startsWith("cache table")) {
sql.trim.drop(12).trim.split(" ").toSeq match {
case Seq(tableName) =>
CacheCommand(tableName, true)
case Seq(tableName, _, select @ _*) =>
CacheTableAsSelectCommand(tableName, createPlan(select.mkString(" ").trim))
}
} else if (sql.trim.toLowerCase.startsWith("uncache table")) {
CacheCommand(sql.trim.drop(14).trim, false)
} else if (sql.trim.toLowerCase.startsWith("add jar")) {
AddJar(sql.trim.drop(8).trim)
} else if (sql.trim.toLowerCase.startsWith("add file")) {
AddFile(sql.trim.drop(9))
} else if (sql.trim.toLowerCase.startsWith("dfs")) {
val tree = getAst(sql)
if (nativeCommands contains tree.getText) {
NativeCommand(sql)
} else if (sql.trim.startsWith("source")) {
SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
} else if (sql.trim.startsWith("!")) {
ShellCommand(sql.drop(1))
} else {
createPlan(sql)
nodeToPlan(tree) match {
case NativePlaceholder => NativeCommand(sql)
case other => other
}
}
} catch {
case e: Exception => throw new ParseException(sql, e)
Expand All @@ -259,19 +241,6 @@ private[hive] object HiveQl {
""".stripMargin)
}
}

/** Creates LogicalPlan for a given HiveQL string. */
def createPlan(sql: String) = {
val tree = getAst(sql)
if (nativeCommands contains tree.getText) {
NativeCommand(sql)
} else {
nodeToPlan(tree) match {
case NativePlaceholder => NativeCommand(sql)
case other => other
}
}
}

def parseDdl(ddl: String): Seq[Attribute] = {
val tree =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,10 @@ class CachedTableSuite extends HiveComparisonTest {
}
assert(!TestHive.isCached("src"), "Table 'src' should not be cached")
}

test("'CACHE TABLE tableName AS SELECT ..'") {
TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src")
assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
TestHive.uncacheTable("testCacheTable")
}
}