From f2433c2b6769bb7fcaf40b6e74a1eac3589aae82 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 16 Jun 2016 23:08:14 +0800 Subject: [PATCH 01/26] init commit --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 + .../spark/sql/execution/SparkSqlParser.scala | 33 +++++++ .../spark/sql/execution/command/macros.scala | 94 +++++++++++++++++++ .../sql/execution/command/DDLSuite.scala | 22 +++++ 4 files changed, 152 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 044f91038876a..e45ca79b25459 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -97,6 +97,9 @@ statement | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | CREATE TEMPORARY MACRO macroName=identifier + '('(columns=colTypeList)?')' expression #createMacro + | DROP TEMPORARY MACRO (IF EXISTS)? macroName=identifier #dropMacro | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? (LIKE? pattern=STRING)? #showTables diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a0508ad6019bd..839d8f667faee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -26,6 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} @@ -589,6 +590,38 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.TEMPORARY != null) } + /** + * Create a [[CreateMacroCommand]] command. + * + * For example: + * {{{ + * CREATE TEMPORARY MACRO macro_name([col_name col_type, ...]) expression; + * }}} + */ + override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) { + val arguments = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns).map { col => + AttributeReference(col.name, CatalystSqlParser.parseDataType(col.dataType))() + } + val e = expression(ctx.expression) + CreateMacroCommand( + ctx.macroName.getText, + MacroFunctionWrapper(arguments, e)) + } + + /** + * Create a [[DropMacroCommand]] command. + * + * For example: + * {{{ + * DROP TEMPORARY MACRO [IF EXISTS] macro_name; + * }}} + */ + override def visitDropMacro(ctx: DropMacroContext): LogicalPlan = withOrigin(ctx) { + DropMacroCommand( + ctx.macroName.getText, + ctx.EXISTS != null) + } + /** * Create a [[DropTableCommand]] command. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala new file mode 100644 index 0000000000000..792d97cf64c6b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ + +/** + * This class provides arguments and body expression of the macro. + */ +case class MacroFunctionWrapper(arguments: Seq[AttributeReference], body: Expression) + +/** + * The DDL command that creates a macro. + * To create a temporary macro, the syntax of using this command in SQL is: + * {{{ + * CREATE TEMPORARY MACRO macro_name([col_name col_type, ...]) expression; + * }}} + */ +case class CreateMacroCommand(macroName: String, macroFunction: MacroFunctionWrapper) + extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val inputSet = AttributeSet(macroFunction.arguments) + val colNames = macroFunction.arguments.map(_.name) + val colToIndex: Map[String, Int] = colNames.zipWithIndex.toMap + macroFunction.body.transformUp { + case u @ UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + colToIndex.get(nameParts.head).getOrElse( + throw new AnalysisException(s"Cannot create temporary macro '$macroName', " + + s"cannot resolve: [${u}] given input columns: [${inputSet}]")) + u + case _: SubqueryExpression => + throw new AnalysisException(s"Cannot create temporary macro '$macroName', " + + s"cannot support subquery for macro.") + } + + val macroInfo = macroFunction.arguments.mkString(",") + "->" + macroFunction.body.toString + val info = new ExpressionInfo(macroInfo, macroName) + val builder = (children: Seq[Expression]) => { + if (children.size != colNames.size) { + throw new AnalysisException(s"actual number of arguments: ${children.size} != " + + s"expected number of arguments: ${colNames.size} for Macro $macroName") + } + macroFunction.body.transformUp { + case u @ UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + colToIndex.get(nameParts.head).map(children(_)).getOrElse( + throw new AnalysisException(s"Macro '$macroInfo' cannot resolve '$u' " + + s"given input expressions: [${children.mkString(",")}]")) + } + } + catalog.createTempFunction(macroName, info, builder, ignoreIfExists = false) + Seq.empty[Row] + } +} + +/** + * The DDL command that drops a macro. + * ifExists: returns an error if the macro doesn't exist, unless this is true. + * {{{ + * DROP TEMPORARY MACRO [IF EXISTS] macro_name; + * }}} + */ +case class DropMacroCommand(macroName: String, ifExists: Boolean) + extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + if (FunctionRegistry.builtin.functionExists(macroName)) { + throw new AnalysisException(s"Cannot drop native function '$macroName'") + } + catalog.dropTempFunction(macroName, ifExists) + Seq.empty[Row] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e15fcf4326be2..c4636bf26dd7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1309,4 +1309,26 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("TRUNCATE TABLE my_tab PARTITION (age=10)") } + test("create/drop temporary macro") { + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x + y") + } + sql("CREATE TEMPORARY MACRO fixed_number() 42") + checkAnswer(sql("SELECT fixed_number()"), Row(42)) + sql("CREATE TEMPORARY MACRO string_len_plus_two(x string) length(x) + 2") + checkAnswer(sql("SELECT string_len_plus_two('abc')"), Row(5)) + sql("CREATE TEMPORARY MACRO simple_add(x int, y int) x + y") + checkAnswer(sql("SELECT simple_add(1, 2)"), Row(3)) + intercept[AnalysisException] { + sql(s"SELECT simple_add(1)") + } + sql("DROP TEMPORARY MACRO fixed_number") + intercept[AnalysisException] { + sql(s"DROP TEMPORARY MACRO abs") + } + intercept[AnalysisException] { + sql("DROP TEMPORARY MACRO SOME_MACRO") + } + sql("DROP TEMPORARY MACRO IF EXISTS SOME_MACRO") + } } From 0b93636c941fd3093ba9b93e49e75211aa077c90 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Fri, 17 Jun 2016 02:00:05 +0800 Subject: [PATCH 02/26] fix unit test --- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index e0f6ccf04dd33..6bc84a700fc44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1205,13 +1205,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} } - - test("create/drop macro commands are not supported") { - assertUnsupportedFeature { - sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") - } - assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } - } } // for SPARK-2180 test From 301e9508d860283ea8d12f91f52d04d9d697b000 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sat, 18 Jun 2016 18:43:39 +0800 Subject: [PATCH 03/26] update --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../catalyst/analysis/FunctionRegistry.scala | 33 +++++++++++++ .../analysis/NoSuchItemException.scala | 3 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 19 +++++++ .../spark/sql/execution/SparkSqlParser.scala | 28 ++++++++--- .../spark/sql/execution/command/macros.scala | 49 +++++-------------- .../sql/execution/command/DDLSuite.scala | 6 +++ 7 files changed, 96 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e45ca79b25459..9771ef5d7fd63 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -98,7 +98,7 @@ statement (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction | CREATE TEMPORARY MACRO macroName=identifier - '('(columns=colTypeList)?')' expression #createMacro + '(' colTypeList? ')' expression #createMacro | DROP TEMPORARY MACRO (IF EXISTS)? macroName=identifier #dropMacro | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 42a8faa412a34..0608e392a98ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.HashSet import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -59,6 +60,10 @@ trait FunctionRegistry { /** Checks if a function with a given name exists. */ def functionExists(name: String): Boolean = lookupFunction(name).isDefined + def registerMacro(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit + + def dropMacro(name: String): Boolean + /** Clear all registered functions. */ def clear(): Unit @@ -69,6 +74,8 @@ class SimpleFunctionRegistry extends FunctionRegistry { protected val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) + val macros = new HashSet[String] + override def registerFunction( name: String, info: ExpressionInfo, @@ -101,8 +108,26 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } + override def registerMacro( + name: String, + info: ExpressionInfo, + builder: FunctionBuilder): Unit = synchronized { + functionBuilders.put(name, (info, builder)) + macros += name.toLowerCase() + } + + override def dropMacro(name: String): Boolean = synchronized { + if (macros.contains(name.toLowerCase)) { + macros -= name.toLowerCase + functionBuilders.remove(name).isDefined + } else { + false + } + } + override def clear(): Unit = synchronized { functionBuilders.clear() + macros.clear() } def copy(): SimpleFunctionRegistry = synchronized { @@ -144,6 +169,14 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } + override def registerMacro(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit = { + throw new UnsupportedOperationException + } + + override def dropMacro(name: String): Boolean = { + throw new UnsupportedOperationException + } + override def clear(): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 8febdcaee829b..dc7ffb69e3a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -52,3 +52,6 @@ class NoSuchPartitionsException(db: String, table: String, specs: Seq[TableParti class NoSuchTempFunctionException(func: String) extends AnalysisException(s"Temporary function '$func' not found") + +class NoSuchTempMacroException(func: String) + extends AnalysisException(s"Temporary macro '$func' not found") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 1ec1bb1baf23b..e085d1b9c1302 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -758,6 +758,25 @@ class SessionCatalog( } } + /** Create a temporary macro. */ + def createTempMacro( + name: String, + info: ExpressionInfo, + funcDefinition: FunctionBuilder, + ignoreIfExists: Boolean): Unit = { + if (functionRegistry.functionExists(name) && !ignoreIfExists) { + throw new TempFunctionAlreadyExistsException(name) + } + functionRegistry.registerMacro(name, info, funcDefinition) + } + + /** Drop a temporary macro. */ + def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { + if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { + throw new NoSuchTempMacroException(name) + } + } + protected def failFunctionLookup(name: String): Nothing = { throw new NoSuchFunctionException(db = currentDb, func = name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 839d8f667faee..9ace36ea73bf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -25,15 +25,16 @@ import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructField} /** * Concrete parser for Spark SQL statements. @@ -599,13 +600,28 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) { - val arguments = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns).map { col => - AttributeReference(col.name, CatalystSqlParser.parseDataType(col.dataType))() + val arguments = Option(ctx.colTypeList).map(visitColTypeList(_)) + .getOrElse(Seq.empty[StructField]).map { col => + AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() } + val colToIndex: Map[String, Int] = arguments.map(_.name).zipWithIndex.toMap + if (colToIndex.size != arguments.size) { + throw operationNotAllowed( + s"Cannot support duplicate colNames for CREATE TEMPORARY MACRO ", ctx) + } + val macroFunction = expression(ctx.expression).transformUp { + case u: UnresolvedAttribute => + val index = colToIndex.get(u.name).getOrElse( + throw new ParseException( + s"Cannot find colName: [${u}] for CREATE TEMPORARY MACRO", ctx)) + BoundReference(index, arguments(index).dataType, arguments(index).nullable) + case _: SubqueryExpression => + throw operationNotAllowed(s"Cannot support Subquery for CREATE TEMPORARY MACRO", ctx) } - val e = expression(ctx.expression) + CreateMacroCommand( ctx.macroName.getText, - MacroFunctionWrapper(arguments, e)) + arguments, + macroFunction) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 792d97cf64c6b..9be5c54c0bffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -18,14 +18,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ -/** - * This class provides arguments and body expression of the macro. - */ -case class MacroFunctionWrapper(arguments: Seq[AttributeReference], body: Expression) - /** * The DDL command that creates a macro. * To create a temporary macro, the syntax of using this command in SQL is: @@ -33,42 +27,26 @@ case class MacroFunctionWrapper(arguments: Seq[AttributeReference], body: Expres * CREATE TEMPORARY MACRO macro_name([col_name col_type, ...]) expression; * }}} */ -case class CreateMacroCommand(macroName: String, macroFunction: MacroFunctionWrapper) +case class CreateMacroCommand( + macroName: String, + columns: Seq[AttributeReference], + macroFunction: Expression) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val inputSet = AttributeSet(macroFunction.arguments) - val colNames = macroFunction.arguments.map(_.name) - val colToIndex: Map[String, Int] = colNames.zipWithIndex.toMap - macroFunction.body.transformUp { - case u @ UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - colToIndex.get(nameParts.head).getOrElse( - throw new AnalysisException(s"Cannot create temporary macro '$macroName', " + - s"cannot resolve: [${u}] given input columns: [${inputSet}]")) - u - case _: SubqueryExpression => - throw new AnalysisException(s"Cannot create temporary macro '$macroName', " + - s"cannot support subquery for macro.") - } - - val macroInfo = macroFunction.arguments.mkString(",") + "->" + macroFunction.body.toString + val macroInfo = columns.mkString(",") + " -> " + macroFunction.toString val info = new ExpressionInfo(macroInfo, macroName) val builder = (children: Seq[Expression]) => { - if (children.size != colNames.size) { - throw new AnalysisException(s"actual number of arguments: ${children.size} != " + - s"expected number of arguments: ${colNames.size} for Macro $macroName") + if (children.size != columns.size) { + throw new AnalysisException(s"Actual number of columns: ${children.size} != " + + s"expected number of columns: ${columns.size} for Macro $macroName") } - macroFunction.body.transformUp { - case u @ UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - colToIndex.get(nameParts.head).map(children(_)).getOrElse( - throw new AnalysisException(s"Macro '$macroInfo' cannot resolve '$u' " + - s"given input expressions: [${children.mkString(",")}]")) + macroFunction.transformUp { + case b: BoundReference => children(b.ordinal) } } - catalog.createTempFunction(macroName, info, builder, ignoreIfExists = false) + catalog.createTempMacro(macroName, info, builder, ignoreIfExists = false) Seq.empty[Row] } } @@ -85,10 +63,7 @@ case class DropMacroCommand(macroName: String, ifExists: Boolean) override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (FunctionRegistry.builtin.functionExists(macroName)) { - throw new AnalysisException(s"Cannot drop native function '$macroName'") - } - catalog.dropTempFunction(macroName, ifExists) + catalog.dropTempMacro(macroName, ifExists) Seq.empty[Row] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index c4636bf26dd7e..8301dc2ab5671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1313,6 +1313,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { intercept[AnalysisException] { sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x + y") } + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y") + } + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2 from t2) ") + } sql("CREATE TEMPORARY MACRO fixed_number() 42") checkAnswer(sql("SELECT fixed_number()"), Row(42)) sql("CREATE TEMPORARY MACRO string_len_plus_two(x string) length(x) + 2") From 808a5fa509392d8e2b909020f4a711c4dc2437b5 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sat, 18 Jun 2016 18:50:33 +0800 Subject: [PATCH 04/26] update createTempMacro --- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 5 ++--- .../org/apache/spark/sql/execution/command/macros.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index e085d1b9c1302..9b80839f457fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -762,9 +762,8 @@ class SessionCatalog( def createTempMacro( name: String, info: ExpressionInfo, - funcDefinition: FunctionBuilder, - ignoreIfExists: Boolean): Unit = { - if (functionRegistry.functionExists(name) && !ignoreIfExists) { + funcDefinition: FunctionBuilder): Unit = { + if (functionRegistry.functionExists(name)) { throw new TempFunctionAlreadyExistsException(name) } functionRegistry.registerMacro(name, info, funcDefinition) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 9be5c54c0bffb..344a93d1327ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -46,7 +46,7 @@ case class CreateMacroCommand( case b: BoundReference => children(b.ordinal) } } - catalog.createTempMacro(macroName, info, builder, ignoreIfExists = false) + catalog.createTempMacro(macroName, info, builder) Seq.empty[Row] } } From f4ed3bc13cbc629d055ef74a127cf212217cd589 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Mon, 20 Jun 2016 16:03:57 +0800 Subject: [PATCH 05/26] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 34 ++++++++------ .../spark/sql/execution/SparkSqlParser.scala | 23 ++------- .../spark/sql/execution/command/macros.scala | 47 +++++++++++++++++-- 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a5755616329ab..59f505eb6f547 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -733,6 +733,24 @@ class Analyzer( } } + protected[sql] def resolveFunction(func: UnresolvedFunction) = { + catalog.lookupFunction(func.name, func.children) match { + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if func.isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if func.isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => AggregateExpression(agg, Complete, func.isDistinct) + // This function is not an aggregate function, just return the resolved one. + case other => other + } + } + /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the @@ -916,21 +934,7 @@ class Analyzer( } case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { - catalog.lookupFunction(funcId, children) match { - // DISTINCT is not meaningful for a Max or a Min. - case max: Max if isDistinct => - AggregateExpression(max, Complete, isDistinct = false) - case min: Min if isDistinct => - AggregateExpression(min, Complete, isDistinct = false) - // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within - // the context of a Window clause. They do not need to be wrapped in an - // AggregateExpression. - case wf: AggregateWindowFunction => wf - // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) - // This function is not an aggregate function, just return the resolved one. - case other => other - } + resolveFunction(u) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 9ace36ea73bf7..272bc924dfe48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -25,7 +25,6 @@ import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ @@ -601,27 +600,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) { val arguments = Option(ctx.colTypeList).map(visitColTypeList(_)) - .getOrElse(Seq.empty[StructField]).map { col => - AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() } - val colToIndex: Map[String, Int] = arguments.map(_.name).zipWithIndex.toMap - if (colToIndex.size != arguments.size) { - throw operationNotAllowed( - s"Cannot support duplicate colNames for CREATE TEMPORARY MACRO ", ctx) - } - val macroFunction = expression(ctx.expression).transformUp { - case u: UnresolvedAttribute => - val index = colToIndex.get(u.name).getOrElse( - throw new ParseException( - s"Cannot find colName: [${u}] for CREATE TEMPORARY MACRO", ctx)) - BoundReference(index, arguments(index).dataType, arguments(index).nullable) - case _: SubqueryExpression => - throw operationNotAllowed(s"Cannot support Subquery for CREATE TEMPORARY MACRO", ctx) - } - + .getOrElse(Seq.empty[StructField]) + val e = expression(ctx.expression) CreateMacroCommand( ctx.macroName.getText, - arguments, - macroFunction) + MacroFunctionWrapper(arguments, e)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 344a93d1327ea..2b26b5904cedb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -18,7 +18,14 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructField + +/** + * This class provides arguments and body expression of the macro function. + */ +case class MacroFunctionWrapper(columns: Seq[StructField], macroFunction: Expression) /** * The DDL command that creates a macro. @@ -29,13 +36,46 @@ import org.apache.spark.sql.catalyst.expressions._ */ case class CreateMacroCommand( macroName: String, - columns: Seq[AttributeReference], - macroFunction: Expression) + funcWrapper: MacroFunctionWrapper) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val macroInfo = columns.mkString(",") + " -> " + macroFunction.toString + val columns = funcWrapper.columns.map { col => + AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() } + val colToIndex: Map[String, Int] = columns.map(_.name).zipWithIndex.toMap + if (colToIndex.size != columns.size) { + throw new AnalysisException(s"Cannot support duplicate colNames " + + s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}") + } + val macroFunction = funcWrapper.macroFunction.transformDown { + case u: UnresolvedAttribute => + val index = colToIndex.get(u.name).getOrElse( + throw new AnalysisException(s"Cannot find colName: ${u} " + + s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}")) + BoundReference(index, columns(index).dataType, columns(index).nullable) + case u: UnresolvedFunction => + sparkSession.sessionState.analyzer.resolveFunction(u) + case s: SubqueryExpression => + throw new AnalysisException(s"Cannot support Subquery: ${s} " + + s"for CREATE TEMPORARY MACRO $macroName") + case u: UnresolvedGenerator => + throw new AnalysisException(s"Cannot support Generator: ${u} " + + s"for CREATE TEMPORARY MACRO $macroName") + } + if (!macroFunction.resolved) { + if (macroFunction.checkInputDataTypes().isFailure) { + macroFunction.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + throw new AnalysisException(s"Cannot resolve '${macroFunction.sql}' " + + s"for CREATE TEMPORARY MACRO $macroName, due to data type mismatch: $message") + } + } else { + throw new AnalysisException(s"Cannot resolve '${macroFunction.sql}' " + + s"for CREATE TEMPORARY MACRO $macroName") + } + } + val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString val info = new ExpressionInfo(macroInfo, macroName) val builder = (children: Seq[Expression]) => { if (children.size != columns.size) { @@ -43,6 +83,7 @@ case class CreateMacroCommand( s"expected number of columns: ${columns.size} for Macro $macroName") } macroFunction.transformUp { + // Skip to validate the input type because Analyzer will check it after ResolveFunctions. case b: BoundReference => children(b.ordinal) } } From af0136de2931aa390c4c83229622b53769952de3 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Mon, 20 Jun 2016 16:13:13 +0800 Subject: [PATCH 06/26] update --- .../scala/org/apache/spark/sql/execution/command/macros.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 2b26b5904cedb..94e52201fd3d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -48,7 +48,7 @@ case class CreateMacroCommand( throw new AnalysisException(s"Cannot support duplicate colNames " + s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}") } - val macroFunction = funcWrapper.macroFunction.transformDown { + val macroFunction = funcWrapper.macroFunction.transform { case u: UnresolvedAttribute => val index = colToIndex.get(u.name).getOrElse( throw new AnalysisException(s"Cannot find colName: ${u} " + @@ -82,7 +82,7 @@ case class CreateMacroCommand( throw new AnalysisException(s"Actual number of columns: ${children.size} != " + s"expected number of columns: ${columns.size} for Macro $macroName") } - macroFunction.transformUp { + macroFunction.transform { // Skip to validate the input type because Analyzer will check it after ResolveFunctions. case b: BoundReference => children(b.ordinal) } From 9fe1881ffe2810ec445f0560a7920c167c22b2d7 Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Fri, 11 Nov 2016 02:25:54 +0800 Subject: [PATCH 07/26] update code --- .../sql/catalyst/analysis/Analyzer.scala | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 600e384f61add..b85c9cb8fd96f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -766,24 +766,6 @@ class Analyzer( } } - protected[sql] def resolveFunction(func: UnresolvedFunction) = { - catalog.lookupFunction(func.name, func.children) match { - // DISTINCT is not meaningful for a Max or a Min. - case max: Max if func.isDistinct => - AggregateExpression(max, Complete, isDistinct = false) - case min: Min if func.isDistinct => - AggregateExpression(min, Complete, isDistinct = false) - // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within - // the context of a Window clause. They do not need to be wrapped in an - // AggregateExpression. - case wf: AggregateWindowFunction => wf - // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => AggregateExpression(agg, Complete, func.isDistinct) - // This function is not an aggregate function, just return the resolved one. - case other => other - } - } - /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the @@ -969,7 +951,21 @@ class Analyzer( } case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { - resolveFunction(u) + catalog.lookupFunction(funcId, children) match { + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => AggregateExpression(agg, Complete, func.isDistinct) + // This function is not an aggregate function, just return the resolved one. + case other => other + } } } } From b8ffdc9d9f021e4fcec396a7bff5703a6e3ed521 Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Fri, 11 Nov 2016 02:31:44 +0800 Subject: [PATCH 08/26] fix function --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../scala/org/apache/spark/sql/execution/command/macros.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b85c9cb8fd96f..dd68d60d3e839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -962,7 +962,7 @@ class Analyzer( // AggregateExpression. case wf: AggregateWindowFunction => wf // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => AggregateExpression(agg, Complete, func.isDistinct) + case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index dabef6b58e1ee..bca381e86c400 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -55,7 +55,7 @@ case class CreateMacroCommand( s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}")) BoundReference(index, columns(index).dataType, columns(index).nullable) case u: UnresolvedFunction => - sparkSession.sessionState.analyzer.resolveFunction(u) + sparkSession.sessionState.catalog.lookupFunction(u.name, u.children) case s: SubqueryExpression => throw new AnalysisException(s"Cannot support Subquery: ${s} " + s"for CREATE TEMPORARY MACRO $macroName") From fb8b57a4d46f6856dc2c883c6e995c248dda6a3b Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Fri, 11 Nov 2016 09:26:54 +0800 Subject: [PATCH 09/26] update comment --- .../scala/org/apache/spark/sql/execution/command/macros.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index bca381e86c400..a2dcd3e2d3138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -85,7 +85,7 @@ case class CreateMacroCommand( s"expected number of columns: ${columns.size} for Macro $macroName") } macroFunction.transform { - // Skip to validate the input type because Analyzer will check it after ResolveFunctions. + // Skip to validate the input type because check it beforepupd. case b: BoundReference => children(b.ordinal) } } From e895a9c7b89d2a53f6747f1e7fa08f8e97b80ed4 Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Fri, 11 Nov 2016 09:27:34 +0800 Subject: [PATCH 10/26] update comments --- .../scala/org/apache/spark/sql/execution/command/macros.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index a2dcd3e2d3138..0a67aa94d1244 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -85,7 +85,7 @@ case class CreateMacroCommand( s"expected number of columns: ${columns.size} for Macro $macroName") } macroFunction.transform { - // Skip to validate the input type because check it beforepupd. + // Skip to validate the input type because check it before. case b: BoundReference => children(b.ordinal) } } From 651b485e2461df014a5c29780d1d2f6d4c1695e1 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sat, 27 May 2017 23:54:48 +0800 Subject: [PATCH 11/26] Merge branch 'master' of https://github.com/apache/spark into macro --- .../analysis/AlreadyExistException.scala | 3 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 18 ++++++++++ .../spark/sql/execution/SparkSqlParser.scala | 33 ++++++++++++++++++- .../sql/execution/command/DDLSuite.scala | 29 ++++++++++++++++ .../spark/sql/hive/HiveStrategies.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 7 ---- 6 files changed, 83 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 57f7a80bedc6c..16ceac098e597 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -44,3 +44,6 @@ class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[Tabl class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") + +class TempMacroAlreadyExistsException(func: String) + extends AnalysisException(s"Temp macro '$func' already exists") \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f6653d384fe1d..e144a6afd444c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1020,6 +1020,24 @@ class SessionCatalog( } } + /** Create a temporary macro. */ + def createTempMacro( + name: String, + info: ExpressionInfo, + funcDefinition: FunctionBuilder): Unit = { + if (functionRegistry.functionExists(name)) { + throw new TempMacroAlreadyExistsException(name) + } + functionRegistry.registerMacro(name, info, funcDefinition) + } + + /** Drop a temporary macro. */ + def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { + if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { + throw new NoSuchTempMacroException(name) + } + } + /** * Retrieve the metadata of a metastore function. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3c58c6e1b6780..b3dc50bc8ad9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** * Concrete parser for Spark SQL statements. @@ -715,6 +715,37 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx.TEMPORARY != null) } + /** + * Create a [[CreateMacroCommand]] command. + * + * For example: + * {{{ + * CREATE TEMPORARY MACRO macro_name([col_name col_type, ...]) expression; + * }}} + */ + override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) { + val arguments = Option(ctx.colTypeList).map(visitColTypeList(_)) + .getOrElse(Seq.empty[StructField]) + val e = expression(ctx.expression) + CreateMacroCommand( + ctx.macroName.getText, + MacroFunctionWrapper(arguments, e)) + } + + /** + * Create a [[DropMacroCommand]] command. + * + * For example: + * {{{ + * DROP TEMPORARY MACRO [IF EXISTS] macro_name; + * }}} + */ + override def visitDropMacro(ctx: DropMacroContext): LogicalPlan = withOrigin(ctx) { + DropMacroCommand( + ctx.macroName.getText, + ctx.EXISTS != null) + } + /** * Create a [[DropTableCommand]] command. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e4dd077715d0f..8d7b7836f7069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1516,6 +1516,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ) } + test("create/drop temporary macro") { + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x + y") + } + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y") + } + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2 from t2) ") + } + sql("CREATE TEMPORARY MACRO fixed_number() 42") + checkAnswer(sql("SELECT fixed_number()"), Row(42)) + sql("CREATE TEMPORARY MACRO string_len_plus_two(x string) length(x) + 2") + checkAnswer(sql("SELECT string_len_plus_two('abc')"), Row(5)) + sql("CREATE TEMPORARY MACRO simple_add(x int, y int) x + y") + checkAnswer(sql("SELECT simple_add(1, 2)"), Row(3)) + intercept[AnalysisException] { + sql(s"SELECT simple_add(1)") + } + sql("DROP TEMPORARY MACRO fixed_number") + intercept[AnalysisException] { + sql(s"DROP TEMPORARY MACRO abs") + } + intercept[AnalysisException] { + sql("DROP TEMPORARY MACRO SOME_MACRO") + } + sql("DROP TEMPORARY MACRO IF EXISTS SOME_MACRO") + } + test("create a data source table without schema") { import testImplicits._ withTempPath { tempDir => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 662fc80661513..4f83a834db4a3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -237,7 +237,7 @@ private[hive] trait HiveStrategies { !predicate.references.isEmpty && predicate.references.subsetOf(partitionKeyIds) } - + FunctionRegistry.scala pruneFilterProject( projectList, otherPredicates, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index cf33760360724..2de83cff5ef17 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1153,13 +1153,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} } - test("create/drop macro commands are not supported") { - assertUnsupportedFeature { - sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") - } - assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } - } - test("dynamic partitioning is allowed when hive.exec.dynamic.partition.mode is nonstrict") { val modeConfKey = "hive.exec.dynamic.partition.mode" withTable("with_parts") { From 314913df4a0345957c718622ab1924901a895b90 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:07:25 +0800 Subject: [PATCH 12/26] Merge branch 'macro' of https://github.com/lianhuiwang/spark into macro --- .../spark/sql/catalyst/parser/SqlBase.g4 | 1 - .../sql/catalyst/catalog/SessionCatalog.scala | 460 +++-- .../spark/sql/execution/SparkSqlParser.scala | 288 +++- .../sql/execution/command/DDLSuite.scala | 1522 +++++++++++------ .../spark/sql/hive/HiveStrategies.scala | 1 - 5 files changed, 1534 insertions(+), 738 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9ef669cd33096..a56d74ab0a5f5 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -132,7 +132,6 @@ statement | CREATE TEMPORARY MACRO macroName=identifier '(' colTypeList? ')' expression #createMacro | DROP TEMPORARY MACRO (IF EXISTS)? macroName=identifier #dropMacro - | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 56ccae08dbb4f..cdcc1c112d8a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,22 +17,28 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI +import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.util.{Failure, Success, Try} +import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -46,32 +52,37 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, - functionResourceLoader: FunctionResourceLoader, - functionRegistry: FunctionRegistry, - conf: CatalystConf, - hadoopConf: Configuration) extends Logging { + val externalCatalog: ExternalCatalog, + globalTempViewManager: GlobalTempViewManager, + functionRegistry: FunctionRegistry, + conf: SQLConf, + hadoopConf: Configuration, + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends Logging { import SessionCatalog._ import CatalogTypes.TablePartitionSpec // For testing only. def this( - externalCatalog: ExternalCatalog, - functionRegistry: FunctionRegistry, - conf: CatalystConf) { + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: SQLConf) { this( externalCatalog, new GlobalTempViewManager("global_temp"), - DummyFunctionResourceLoader, functionRegistry, conf, - new Configuration()) + new Configuration(), + new CatalystSqlParser(conf), + DummyFunctionResourceLoader) } // For testing only. def this(externalCatalog: ExternalCatalog) { - this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) + this( + externalCatalog, + new SimpleFunctionRegistry, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } /** List of temporary tables, mapping from table name to their logical plan. */ @@ -83,27 +94,43 @@ class SessionCatalog( // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") - protected var currentDb = { - val defaultName = DEFAULT_DATABASE - val defaultDbDefinition = - CatalogDatabase(defaultName, "default database", conf.warehousePath, Map()) - // Initialize default database if it doesn't already exist - createDatabase(defaultDbDefinition, ignoreIfExists = true) - formatDatabaseName(defaultName) + protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) + + /** + * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), + * i.e. if this name only contains characters, numbers, and _. + * + * This method is intended to have the same behavior of + * org.apache.hadoop.hive.metastore.MetaStoreUtils.validateName. + */ + private def validateName(name: String): Unit = { + val validNameFormat = "([\\w_]+)".r + if (!validNameFormat.pattern.matcher(name).matches()) { + throw new AnalysisException(s"`$name` is not a valid name for tables/databases. " + + "Valid names only contain alphabet characters, numbers and _.") + } } /** * Format table name, taking into account case sensitivity. */ protected[this] def formatTableName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** * Format database name, taking into account case sensitivity. */ protected[this] def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) + } + + /** + * A cache of qualified table names to table relation plans. + */ + val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { + val cacheSize = conf.tableRelationCacheSize + CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() } /** @@ -112,10 +139,10 @@ class SessionCatalog( * does not contain a scheme, this path will not be changed after the default * FileSystem is changed. */ - private def makeQualifiedPath(path: String): Path = { + private def makeQualifiedPath(path: URI): URI = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(hadoopConf) - fs.makeQualified(hadoopPath) + fs.makeQualified(hadoopPath).toUri } private def requireDbExists(db: String): Unit = { @@ -137,6 +164,20 @@ class SessionCatalog( throw new TableAlreadyExistsException(db = db, table = name.table) } } + + private def checkDuplication(fields: Seq[StructField]): Unit = { + val columnNames = if (conf.caseSensitiveAnalysis) { + fields.map(_.name) + } else { + fields.map(_.name.toLowerCase) + } + if (columnNames.distinct.length != columnNames.length) { + val duplicateColumns = columnNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => x + } + throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") + } + } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -150,7 +191,8 @@ class SessionCatalog( s"${globalTempViewManager.database} is a system preserved database, " + "you cannot create a database with this name.") } - val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString + validateName(dbName) + val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri) externalCatalog.createDatabase( dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) @@ -208,9 +250,9 @@ class SessionCatalog( * Get the path for creating a non-default database when database location is not provided * by users. */ - def getDefaultDBPath(db: String): String = { + def getDefaultDBPath(db: String): URI = { val database = formatDatabaseName(db) - new Path(new Path(conf.warehousePath), database + ".db").toString + new Path(new Path(conf.warehousePath), database + ".db").toUri } // ---------------------------------------------------------------------------- @@ -233,7 +275,20 @@ class SessionCatalog( def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) - val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + validateName(table) + + val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined + && !tableDefinition.storage.locationUri.get.isAbsolute) { + // make the location of the table qualified. + val qualifiedTableLocation = + makeQualifiedPath(tableDefinition.storage.locationUri.get) + tableDefinition.copy( + storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), + identifier = TableIdentifier(table, Some(db))) + } else { + tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + } + requireDbExists(db) externalCatalog.createTable(newTableDefinition, ignoreIfExists) } @@ -257,6 +312,47 @@ class SessionCatalog( externalCatalog.alterTable(newTableDefinition) } + /** + * Alter the schema of a table identified by the provided table identifier. The new schema + * should still contain the existing bucket columns and partition columns used by the table. This + * method will also update any Spark SQL-related parameters stored as Hive table properties (such + * as the schema itself). + * + * @param identifier TableIdentifier + * @param newSchema Updated schema to be used for the table (must contain existing partition and + * bucket columns, and partition columns need to be at the end) + */ + def alterTableSchema( + identifier: TableIdentifier, + newSchema: StructType): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + checkDuplication(newSchema) + + val catalogTable = externalCatalog.getTable(db, table) + val oldSchema = catalogTable.schema + + // not supporting dropping columns yet + val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _)) + if (nonExistentColumnNames.nonEmpty) { + throw new AnalysisException( + s""" + |Some existing schema fields (${nonExistentColumnNames.mkString("[", ",", "]")}) are + |not present in the new schema. We don't support dropping columns yet. + """.stripMargin) + } + + // assuming the newSchema has all partition columns at the end as required + externalCatalog.alterTableSchema(db, table, newSchema) + } + + private def columnNameResolved(schema: StructType, colName: String): Boolean = { + schema.fields.map(_.name).exists(conf.resolver(_, colName)) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. @@ -298,15 +394,15 @@ class SessionCatalog( * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. */ def loadTable( - name: TableIdentifier, - loadPath: String, - isOverwrite: Boolean, - holdDDLTime: Boolean): Unit = { + name: TableIdentifier, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime) + externalCatalog.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) } /** @@ -315,25 +411,26 @@ class SessionCatalog( * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. */ def loadPartition( - name: TableIdentifier, - loadPath: String, - partition: TablePartitionSpec, - isOverwrite: Boolean, - holdDDLTime: Boolean, - inheritTableSpecs: Boolean): Unit = { + name: TableIdentifier, + loadPath: String, + spec: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) + requireNonEmptyValueInPartitionSpec(Seq(spec)) externalCatalog.loadPartition( - db, table, loadPath, partition, isOverwrite, holdDDLTime, inheritTableSpecs) + db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) } - def defaultTablePath(tableIdent: TableIdentifier): String = { + def defaultTablePath(tableIdent: TableIdentifier): URI = { val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase)) val dbLocation = getDatabaseMetadata(dbName).locationUri - new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString + new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toUri } // ---------------------------------------------- @@ -344,9 +441,9 @@ class SessionCatalog( * Create a local temporary view. */ def createTempView( - name: String, - tableDefinition: LogicalPlan, - overrideIfExists: Boolean): Unit = synchronized { + name: String, + tableDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) if (tempTables.contains(table) && !overrideIfExists) { throw new TempTableAlreadyExistsException(name) @@ -358,9 +455,9 @@ class SessionCatalog( * Create a global temporary view. */ def createGlobalTempView( - name: String, - viewDefinition: LogicalPlan, - overrideIfExists: Boolean): Unit = { + name: String, + viewDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = { globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists) } @@ -369,8 +466,8 @@ class SessionCatalog( * temp view is matched and altered, false otherwise. */ def alterTempViewDefinition( - name: TableIdentifier, - viewDefinition: LogicalPlan): Boolean = synchronized { + name: TableIdentifier, + viewDefinition: LogicalPlan): Boolean = synchronized { val viewName = formatTableName(name.table) if (name.database.isEmpty) { if (tempTables.contains(viewName)) { @@ -481,6 +578,7 @@ class SessionCatalog( if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) + validateName(newTableName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -507,9 +605,9 @@ class SessionCatalog( * the same name, then, if that does not exist, drop the table from the current database. */ def dropTable( - name: TableIdentifier, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = synchronized { + name: TableIdentifier, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) if (db == globalTempViewManager.database) { @@ -543,26 +641,41 @@ class SessionCatalog( * Note that, the global temp view database is also valid here, this will return the global temp * view matching the given name. * - * If the relation is a view, the relation will be wrapped in a [[SubqueryAlias]] which will - * track the name of the view. + * If the relation is a view, we generate a [[View]] operator from the view description, and + * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. + * + * @param name The name of the table/view that we look up. */ - def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { + def lookupRelation(name: TableIdentifier): LogicalPlan = { synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) - val relationAlias = alias.getOrElse(table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(relationAlias, viewDef, Some(name)) + SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) - val view = Option(metadata.tableType).collect { - case CatalogTableType.VIEW => name + if (metadata.tableType == CatalogTableType.VIEW) { + val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) + // The relation is a view, so we wrap the relation by: + // 1. Add a [[View]] operator over the relation to keep track of the view desc; + // 2. Wrap the logical plan in a [[SubqueryAlias]] which tracks the name of the view. + val child = View( + desc = metadata, + output = metadata.schema.toAttributes, + child = parser.parsePlan(viewText)) + SubqueryAlias(table, child) + } else { + val tableRelation = CatalogRelation( + metadata, + // we assume all the columns are nullable. + metadata.dataSchema.asNullable.toAttributes, + metadata.partitionSchema.asNullable.toAttributes) + SubqueryAlias(table, tableRelation) } - SubqueryAlias(relationAlias, SimpleCatalogRelation(db, metadata), view) } else { - SubqueryAlias(relationAlias, tempTables(table), Option(name)) + SubqueryAlias(table, tempTables(table)) } } } @@ -621,15 +734,22 @@ class SessionCatalog( /** * Refresh the cache entry for a metastore table, if any. */ - def refreshTable(name: TableIdentifier): Unit = { + def refreshTable(name: TableIdentifier): Unit = synchronized { + val dbName = formatDatabaseName(name.database.getOrElse(currentDb)) + val tableName = formatTableName(name.table) + // Go through temporary tables and invalidate them. - // If the database is defined, this is definitely not a temp table. + // If the database is defined, this may be a global temporary view. // If the database is not defined, there is a good chance this is a temp table. if (name.database.isEmpty) { - tempTables.get(formatTableName(name.table)).foreach(_.refresh()) - } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { - globalTempViewManager.get(formatTableName(name.table)).foreach(_.refresh()) + tempTables.get(tableName).foreach(_.refresh()) + } else if (dbName == globalTempViewManager.database) { + globalTempViewManager.get(tableName).foreach(_.refresh()) } + + // Also invalidate the table relation cache. + val qualifiedTableName = QualifiedTableName(dbName, tableName) + tableRelationCache.invalidate(qualifiedTableName) } /** @@ -657,14 +777,15 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def createPartitions( - tableName: TableIdentifier, - parts: Seq[CatalogTablePartition], - ignoreIfExists: Boolean): Unit = { + tableName: TableIdentifier, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(parts.map(_.spec)) externalCatalog.createPartitions(db, table, parts, ignoreIfExists) } @@ -673,16 +794,18 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def dropPartitions( - tableName: TableIdentifier, - specs: Seq[TablePartitionSpec], - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) requirePartialMatchedPartitionSpec(specs, getTableMetadata(tableName)) - externalCatalog.dropPartitions(db, table, specs, ignoreIfNotExists, purge) + requireNonEmptyValueInPartitionSpec(specs) + externalCatalog.dropPartitions(db, table, specs, ignoreIfNotExists, purge, retainData) } /** @@ -692,9 +815,9 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def renamePartitions( - tableName: TableIdentifier, - specs: Seq[TablePartitionSpec], - newSpecs: Seq[TablePartitionSpec]): Unit = { + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { val tableMetadata = getTableMetadata(tableName) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) @@ -702,6 +825,8 @@ class SessionCatalog( requireTableExists(TableIdentifier(table, Option(db))) requireExactMatchedPartitionSpec(specs, tableMetadata) requireExactMatchedPartitionSpec(newSpecs, tableMetadata) + requireNonEmptyValueInPartitionSpec(specs) + requireNonEmptyValueInPartitionSpec(newSpecs) externalCatalog.renamePartitions(db, table, specs, newSpecs) } @@ -720,6 +845,7 @@ class SessionCatalog( requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(parts.map(_.spec)) externalCatalog.alterPartitions(db, table, parts) } @@ -733,9 +859,31 @@ class SessionCatalog( requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) requireExactMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(Seq(spec)) externalCatalog.getPartition(db, table, spec) } + /** + * List the names of all partitions that belong to the specified table, assuming it exists. + * + * A partial partition spec may optionally be provided to filter the partitions returned. + * For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'), + * then a partial spec of (a='1') will return the first two only. + */ + def listPartitionNames( + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + partialSpec.foreach { spec => + requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(Seq(spec)) + } + externalCatalog.listPartitionNames(db, table, partialSpec) + } + /** * List the metadata of all partitions that belong to the specified table, assuming it exists. * @@ -744,12 +892,16 @@ class SessionCatalog( * then a partial spec of (a='1') will return the first two only. */ def listPartitions( - tableName: TableIdentifier, - partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + partialSpec.foreach { spec => + requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(Seq(spec)) + } externalCatalog.listPartitions(db, table, partialSpec) } @@ -758,13 +910,26 @@ class SessionCatalog( * satisfy the given partition-pruning predicate expressions. */ def listPartitionsByFilter( - tableName: TableIdentifier, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + tableName: TableIdentifier, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) - externalCatalog.listPartitionsByFilter(db, table, predicates) + externalCatalog.listPartitionsByFilter(db, table, predicates, conf.sessionLocalTimeZone) + } + + /** + * Verify if the input partition spec has any empty value. + */ + private def requireNonEmptyValueInPartitionSpec(specs: Seq[TablePartitionSpec]): Unit = { + specs.foreach { s => + if (s.values.exists(_.isEmpty)) { + val spec = s.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") + throw new AnalysisException( + s"Partition spec is invalid. The spec ($spec) contains an empty partition column value") + } + } } /** @@ -772,8 +937,8 @@ class SessionCatalog( * The columns must be the same but the orders could be different. */ private def requireExactMatchedPartitionSpec( - specs: Seq[TablePartitionSpec], - table: CatalogTable): Unit = { + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { val defined = table.partitionColumnNames.sorted specs.foreach { s => if (s.keys.toSeq.sorted != defined) { @@ -790,8 +955,8 @@ class SessionCatalog( * That is, the columns of partition spec should be part of the defined partition spec. */ private def requirePartialMatchedPartitionSpec( - specs: Seq[TablePartitionSpec], - table: CatalogTable): Unit = { + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { val defined = table.partitionColumnNames specs.foreach { s => if (!s.keys.forall(defined.contains)) { @@ -855,24 +1020,6 @@ class SessionCatalog( } } - /** Create a temporary macro. */ - def createTempMacro( - name: String, - info: ExpressionInfo, - funcDefinition: FunctionBuilder): Unit = { - if (functionRegistry.functionExists(name)) { - throw new TempFunctionAlreadyExistsException(name) - } - functionRegistry.registerMacro(name, info, funcDefinition) - } - - /** Drop a temporary macro. */ - def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { - throw new NoSuchTempMacroException(name) - } - } - /** * Retrieve the metadata of a metastore function. * @@ -904,7 +1051,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -918,18 +1065,20 @@ class SessionCatalog( } /** - * Create a temporary function. - * This assumes no database is specified in `funcDefinition`. + * Registers a temporary or permanent function into a session-specific [[FunctionRegistry]] */ - def createTempFunction( - name: String, - info: ExpressionInfo, - funcDefinition: FunctionBuilder, - ignoreIfExists: Boolean): Unit = { - if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { - throw new TempFunctionAlreadyExistsException(name) + def registerFunction( + funcDefinition: CatalogFunction, + ignoreIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { + val func = funcDefinition.identifier + if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + throw new AnalysisException(s"Function $func already exists") } - functionRegistry.registerFunction(name, info, funcDefinition) + val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) + val builder = + functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionRegistry.registerFunction(func.unquotedString, info, builder) } /** @@ -941,22 +1090,37 @@ class SessionCatalog( } } + /** Create a temporary macro. */ + def createTempMacro( + name: String, + info: ExpressionInfo, + funcDefinition: FunctionBuilder): Unit = { + if (functionRegistry.functionExists(name)) { + throw new TempMacroAlreadyExistsException(name) + } + functionRegistry.registerMacro(name, info, funcDefinition) + } + + /** Drop a temporary macro. */ + def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { + if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { + throw new NoSuchTempMacroException(name) + } + } + /** * Returns whether it is a temporary function. If not existed, returns false. */ def isTemporaryFunction(name: FunctionIdentifier): Boolean = { // copied from HiveSessionCatalog - val hiveFunctions = Seq( - "hash", - "histogram_numeric", - "percentile") + val hiveFunctions = Seq("histogram_numeric") // A temporary function is a function that has been registered in functionRegistry // without a database name, and is neither a built-in function nor a Hive function name.database.isEmpty && functionRegistry.functionExists(name.funcName) && !FunctionRegistry.builtin.functionExists(name.funcName) && - !hiveFunctions.contains(name.funcName.toLowerCase) + !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } protected def failFunctionLookup(name: String): Nothing = { @@ -1001,8 +1165,8 @@ class SessionCatalog( * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ def lookupFunction( - name: FunctionIdentifier, - children: Seq[Expression]): Expression = synchronized { + name: FunctionIdentifier, + children: Seq[Expression]): Expression = synchronized { // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). @@ -1037,12 +1201,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - val info = new ExpressionInfo( - catalogFunction.className, - qualifiedName.database.orNull, - qualifiedName.funcName) - val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) - createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName.unquotedString, children) } @@ -1062,15 +1221,25 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { val dbName = formatDatabaseName(db) requireDbExists(dbName) - val dbFunctions = externalCatalog.listFunctions(dbName, pattern) - .map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) - .map { f => FunctionIdentifier(f) } + val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => + FunctionIdentifier(f, Some(dbName)) } + val loadedFunctions = + StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f => + // In functionRegistry, function names are stored as an unquoted format. + Try(parser.parseFunctionIdentifier(f)) match { + case Success(e) => e + case Failure(_) => + // The names of some built-in functions are not parsable by our parser, e.g., % + FunctionIdentifier(f) + } + } val functions = dbFunctions ++ loadedFunctions + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. functions.map { case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") case f => (f, "USER") - } + }.distinct } @@ -1086,6 +1255,7 @@ class SessionCatalog( */ def reset(): Unit = synchronized { setCurrentDatabase(DEFAULT_DATABASE) + externalCatalog.setCurrentDatabase(DEFAULT_DATABASE) listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) } @@ -1099,9 +1269,10 @@ class SessionCatalog( dropTempFunction(func.funcName, ignoreIfNotExists = false) } } - tempTables.clear() + clearTempTables() globalTempViewManager.clear() functionRegistry.clear() + tableRelationCache.invalidateAll() // restore built-in functions FunctionRegistry.builtin.listFunction().foreach { f => val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) @@ -1112,4 +1283,17 @@ class SessionCatalog( } } -} + /** + * Copy the current state of the catalog to another catalog. + * + * This function is synchronized on this [[SessionCatalog]] (the source) to make sure the copied + * state is consistent. The target [[SessionCatalog]] is not synchronized, and should not be + * because the target [[SessionCatalog]] should not be published at this point. The caller must + * synchronize on the target if this assumption does not hold. + */ + private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { + target.currentDb = currentDb + // copy over temporary tables + tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 9522954f5f9eb..18c5d4ac5b2b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -17,21 +17,24 @@ package org.apache.spark.sql.execution +import java.util.Locale + import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -49,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { +class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ /** @@ -82,7 +85,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitResetConfiguration( - ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { + ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { ResetCommand } @@ -102,7 +105,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") } if (ctx.identifier != null) { - if (ctx.identifier.getText.toLowerCase != "noscan") { + if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) @@ -132,7 +135,26 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { ShowTablesCommand( Option(ctx.db).map(_.getText), - Option(ctx.pattern).map(string)) + Option(ctx.pattern).map(string), + isExtended = false, + partitionSpec = None) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + * Example SQL : + * {{{ + * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards' + * [PARTITION(partition_spec)]; + * }}} + */ + override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) { + val partitionSpec = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) + ShowTablesCommand( + Option(ctx.db).map(_.getText), + Option(ctx.pattern).map(string), + isExtended = true, + partitionSpec = partitionSpec) } /** @@ -156,7 +178,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitShowTblProperties( - ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { ShowTablePropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), Option(ctx.key).map(visitTablePropertyKey)) @@ -233,7 +255,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create an [[UncacheTableCommand]] logical plan. */ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { - UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier), ctx.EXISTS != null) } /** @@ -262,7 +284,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (statement == null) { null // This is enough since ParseException will raise later. } else if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = ctx.EXTENDED != null, codegen = ctx.CODEGEN != null) + ExplainCommand( + logicalPlan = statement, + extended = ctx.EXTENDED != null, + codegen = ctx.CODEGEN != null, + cost = ctx.COST != null) } else { ExplainCommand(OneRowRelation) } @@ -298,8 +324,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { DescribeTableCommand( visitTableIdentifier(ctx.tableIdentifier), partitionSpec, - ctx.EXTENDED != null, - ctx.FORMATTED != null) + ctx.EXTENDED != null || ctx.FORMATTED != null) } } @@ -312,7 +337,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Validate a create table statement and return the [[TableIdentifier]]. */ override def visitCreateTableHeader( - ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { val temporary = ctx.TEMPORARY != null val ifNotExists = ctx.EXISTS != null if (temporary && ifNotExists) { @@ -322,18 +347,30 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a [[CreateTable]] logical plan. + * Create a table, returning a [[CreateTable]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * USING table_provider + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [AS select_statement]; + * }}} */ - override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } - val options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText - if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) { - throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING") - } val schema = Option(ctx.colTypeList()).map(createSchema) val partitionColumnNames = Option(ctx.partitionColumnNames) @@ -341,10 +378,17 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .getOrElse(Array.empty[String]) val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) - // TODO: this may be wrong for non file-based data source like JDBC, which should be external - // even there is no `path` in options. We should consider allow the EXTERNAL keyword. + val location = Option(ctx.locationSpec).map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) - val tableType = if (storage.locationUri.isDefined) { + + if (location.isDefined && storage.locationUri.isDefined) { + throw new ParseException( + "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + + "you can only specify one of them.", ctx) + } + val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI(_))) + + val tableType = if (customLocation.isDefined) { CatalogTableType.EXTERNAL } else { CatalogTableType.MANAGED @@ -353,12 +397,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val tableDesc = CatalogTable( identifier = table, tableType = tableType, - storage = storage, + storage = storage.copy(locationUri = customLocation), schema = schema.getOrElse(new StructType), provider = Some(provider), partitionColumnNames = partitionColumnNames, - bucketSpec = bucketSpec - ) + bucketSpec = bucketSpec, + comment = Option(ctx.comment).map(string)) // Determine the storage mode. val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists @@ -371,6 +415,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) } + // Don't allow explicit specification of schema for CTAS + if (schema.nonEmpty) { + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + } CreateTable(tableDesc, mode, Some(query)) } else { if (temp) { @@ -380,7 +430,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") - CreateTempViewUsing(table, schema, replace = true, global = false, provider, options) + // Unlike CREATE TEMPORARY VIEW USING, CREATE TEMPORARY TABLE USING does not support + // IF NOT EXISTS. Users are not allowed to replace the existing temp table. + CreateTempViewUsing(table, schema, replace = false, global = false, provider, options) } else { CreateTable(tableDesc, mode, None) } @@ -391,7 +443,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Creates a [[CreateTempViewUsing]] logical plan. */ override def visitCreateTempViewUsing( - ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { + ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { CreateTempViewUsing( tableIdent = visitTableIdentifier(ctx.tableIdentifier()), userSpecifiedSchema = Option(ctx.colTypeList()).map(createSchema), @@ -453,7 +505,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. */ override def visitTablePropertyList( - ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { val properties = ctx.tableProperty.asScala.map { property => val key = visitTablePropertyKey(property.key) val value = visitTablePropertyValue(property.value) @@ -513,7 +565,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } else if (value.STRING != null) { string(value.STRING) } else if (value.booleanValue != null) { - value.getText.toLowerCase + value.getText.toLowerCase(Locale.ROOT) } else { value.getText } @@ -546,7 +598,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitSetDatabaseProperties( - ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterDatabasePropertiesCommand( ctx.identifier.getText, visitPropertyKeyValues(ctx.tablePropertyList)) @@ -597,7 +649,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { import ctx._ - val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase) match { + val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase(Locale.ROOT)) match { case None | Some("all") => (true, true) case Some("system") => (false, true) case Some("user") => (true, false) @@ -627,7 +679,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { val resources = ctx.resource.asScala.map { resource => - val resourceType = resource.identifier.getText.toLowerCase + val resourceType = resource.identifier.getText.toLowerCase(Locale.ROOT) resourceType match { case "jar" | "file" | "archive" => FunctionResource(FunctionResourceType.fromString(resourceType), string(resource.STRING)) @@ -721,6 +773,22 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.VIEW != null) } + /** + * Create a [[AlterTableAddColumnsCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} + */ + override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddColumnsCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitColTypeList(ctx.columns) + ) + } + /** * Create an [[AlterTableSetPropertiesCommand]] command. * @@ -731,7 +799,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitSetTableProperties( - ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableSetPropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), visitPropertyKeyValues(ctx.tablePropertyList), @@ -748,7 +816,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitUnsetTableProperties( - ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableUnsetPropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), visitPropertyKeys(ctx.tablePropertyList), @@ -787,7 +855,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * is associated with physical tables */ override def visitAddTablePartition( - ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { + ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { if (ctx.VIEW != null) { operationNotAllowed("ALTER VIEW ... ADD PARTITION", ctx) } @@ -818,7 +886,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitRenameTablePartition( - ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { + ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { AlterTableRenamePartitionCommand( visitTableIdentifier(ctx.tableIdentifier), visitNonOptionalPartitionSpec(ctx.from), @@ -838,15 +906,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * is associated with physical tables */ override def visitDropTablePartitions( - ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { + ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { if (ctx.VIEW != null) { operationNotAllowed("ALTER VIEW ... DROP PARTITION", ctx) } AlterTableDropPartitionCommand( visitTableIdentifier(ctx.tableIdentifier), ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), - ctx.EXISTS != null, - ctx.PURGE != null) + ifExists = ctx.EXISTS != null, + purge = ctx.PURGE != null, + retainData = false) } /** @@ -858,7 +927,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitRecoverPartitions( - ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { + ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { AlterTableRecoverPartitionsCommand(visitTableIdentifier(ctx.tableIdentifier)) } @@ -877,6 +946,33 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { visitLocationSpec(ctx.locationSpec)) } + /** + * Create a [[AlterTableChangeColumnCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION partition_spec] + * CHANGE [COLUMN] column_old_name column_new_name column_dataType [COMMENT column_comment] + * [FIRST | AFTER column_name]; + * }}} + */ + override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) { + if (ctx.partitionSpec != null) { + operationNotAllowed("ALTER TABLE table PARTITION partition_spec CHANGE COLUMN", ctx) + } + + if (ctx.colPosition != null) { + operationNotAllowed( + "ALTER TABLE table [PARTITION partition_spec] CHANGE COLUMN ... FIRST | AFTER otherCol", + ctx) + } + + AlterTableChangeColumnCommand( + tableName = visitTableIdentifier(ctx.tableIdentifier), + columnName = ctx.identifier.getText, + newColumn = visitColType(ctx.colType)) + } + /** * Create location string. */ @@ -896,7 +992,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .flatMap(_.orderedIdentifier.asScala) .map { orderedIdCtx => Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase != "asc") { + if (dir.toLowerCase(Locale.ROOT) != "asc") { operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) } } @@ -909,7 +1005,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Convert a nested constants list into a sequence of string sequences. */ override def visitNestedConstantList( - ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { + ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { ctx.constantList.asScala.map(visitConstantList) } @@ -949,13 +1045,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val mayebePaths = remainder(ctx.identifier).trim ctx.op.getType match { case SqlBaseParser.ADD => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "file" => AddFileCommand(mayebePaths) case "jar" => AddJarCommand(mayebePaths) case other => operationNotAllowed(s"ADD with resource type '$other'", ctx) } case SqlBaseParser.LIST => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "files" | "file" => if (mayebePaths.length > 0) { ListFilesCommand(mayebePaths.split("\\s+")) @@ -975,10 +1071,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a table, returning a [[CreateTable]] logical plan. + * Create a Hive serde table, returning a [[CreateTable]] logical plan. * - * This is not used to create datasource tables, which is handled through - * "CREATE TABLE ... USING ...". + * This is a legacy syntax for Hive compatibility, we recommend users to use the Spark SQL + * CREATE TABLE syntax to create Hive serde table, e.g. "CREATE TABLE ... USING hive ..." * * Note: several features are currently not supported - temporary tables, bucketing, * skewed columns and storage handlers (STORED BY). @@ -996,7 +1092,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * [AS select_statement]; * }}} */ - override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) // TODO: implement temporary tables if (temp) { @@ -1007,33 +1103,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.skewSpec != null) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } - if (ctx.bucketSpec != null) { - operationNotAllowed("CREATE TABLE ... CLUSTERED BY", ctx) - } - val comment = Option(ctx.STRING).map(string) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) + val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly val schema = StructType(dataCols ++ partitionCols) // Storage format - val defaultStorage: CatalogStorageFormat = { - val defaultStorageType = conf.getConfString("hive.default.fileformat", "textfile") - val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType) - CatalogStorageFormat( - locationUri = None, - inputFormat = defaultHiveSerde.flatMap(_.inputFormat) - .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), - outputFormat = defaultHiveSerde.flatMap(_.outputFormat) - .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - serde = defaultHiveSerde.flatMap(_.serde), - compressed = false, - properties = Map()) - } + val defaultStorage = HiveSerDe.getDefaultStorage(conf) validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) @@ -1044,8 +1126,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) } + + val locUri = location.map(CatalogUtils.stringToURI(_)) val storage = CatalogStorageFormat( - locationUri = location, + locationUri = locUri, inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), @@ -1065,10 +1149,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { tableType = tableType, storage = storage, schema = schema, + bucketSpec = bucketSpec, provider = Some(DDLUtils.HIVE_PROVIDER), partitionColumnNames = partitionCols.map(_.name), properties = properties, - comment = comment) + comment = Option(ctx.comment).map(string)) val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists @@ -1083,7 +1168,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "CTAS statement." operationNotAllowed(errorMessage, ctx) } - // Just use whatever is projected in the select statement as our schema + + // Don't allow explicit specification of schema for CTAS. if (schema.nonEmpty) { operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", @@ -1095,7 +1181,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. val newTableDesc = tableDesc.copy( - storage = CatalogStorageFormat.empty.copy(locationUri = location), + storage = CatalogStorageFormat.empty.copy(locationUri = locUri), provider = Some(conf.defaultDataSourceName)) CreateTable(newTableDesc, mode, Some(q)) } else { @@ -1111,13 +1197,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * For example: * {{{ * CREATE TABLE [IF NOT EXISTS] [db_name.]table_name - * LIKE [other_db_name.]existing_table_name + * LIKE [other_db_name.]existing_table_name [locationSpec] * }}} */ override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { val targetTable = visitTableIdentifier(ctx.target) val sourceTable = visitTableIdentifier(ctx.source) - CreateTableLikeCommand(targetTable, sourceTable, ctx.EXISTS != null) + val location = Option(ctx.locationSpec).map(visitLocationSpec) + CreateTableLikeCommand(targetTable, sourceTable, location, ctx.EXISTS != null) } /** @@ -1126,7 +1213,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Format: STORED AS ... */ override def visitCreateFileFormat( - ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { (ctx.fileFormat, ctx.storageHandler) match { // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format case (c: TableFileFormatContext, null) => @@ -1145,7 +1232,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a [[CatalogStorageFormat]]. */ override def visitTableFileFormat( - ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { CatalogStorageFormat.empty.copy( inputFormat = Option(string(ctx.inFmt)), outputFormat = Option(string(ctx.outFmt))) @@ -1155,7 +1242,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. */ override def visitGenericFileFormat( - ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { val source = ctx.identifier.getText HiveSerDe.sourceToSerDe(source) match { case Some(s) => @@ -1197,7 +1284,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create SERDE row format name and properties pair. */ override def visitRowFormatSerde( - ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { import ctx._ CatalogStorageFormat.empty.copy( serde = Option(string(name)), @@ -1208,7 +1295,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a delimited row format properties object. */ override def visitRowFormatDelimited( - ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { // Collect the entries if any. def entry(key: String, value: Token): Seq[(String, String)] = { Option(value).toSeq.map(x => key -> string(x)) @@ -1242,16 +1329,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... */ private def validateRowFormatFileFormat( - rowFormatCtx: RowFormatContext, - createFileFormatCtx: CreateFileFormatContext, - parentCtx: ParserRuleContext): Unit = { + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { if (rowFormatCtx == null || createFileFormatCtx == null) { return } (rowFormatCtx, createFileFormatCtx.fileFormat) match { case (_, ffTable: TableFileFormatContext) => // OK case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case ("sequencefile" | "textfile" | "rcfile") => // OK case fmt => operationNotAllowed( @@ -1259,7 +1346,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { parentCtx) } case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case "textfile" => // OK case fmt => operationNotAllowed( s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) @@ -1291,6 +1378,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) @@ -1337,12 +1433,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a [[ScriptInputOutputSchema]]. */ override protected def withScriptIOSchema( - ctx: QuerySpecificationContext, - inRowFormat: RowFormatContext, - recordWriter: Token, - outRowFormat: RowFormatContext, - recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = { + ctx: QuerySpecificationContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { if (recordWriter != null || recordReader != null) { // TODO: what does this message mean? throw new ParseException( @@ -1352,9 +1448,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // Decode and input/output format. type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) def format( - fmt: RowFormatContext, - configKey: String, - defaultConfigValue: String): Format = fmt match { + fmt: RowFormatContext, + configKey: String, + defaultConfigValue: String): Format = fmt match { case c: RowFormatDelimitedContext => // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema // expects a seq of pairs in which the old parsers' token names are used as keys. @@ -1408,4 +1504,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { reader, writer, schemaLess) } -} + + /** + * Create a clause for DISTRIBUTE BY. + */ + override protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + RepartitionByExpression(expressions, query, conf.numShufflePartitions) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4b4d6ecbde2f7..56f3f8ee12ee6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -18,36 +18,129 @@ package org.apache.spark.sql.execution.command import java.io.File +import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogStorageFormat} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { - private val escapedIdentifier = "`(.+)`".r +class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test spark.sessionState.catalog.reset() } finally { - Utils.deleteRecursively(new File("spark-warehouse")) + Utils.deleteRecursively(new File(spark.sessionState.conf.warehousePath)) super.afterEach() } } + protected override def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { + val storage = + CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL, + storage = storage, + schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") + .add("a", "int") + .add("b", "int"), + provider = Some("parquet"), + partitionColumnNames = Seq("a", "b"), + createTime = 0L, + tracksPartitionsInCatalog = true) + } + + test("create a managed Hive source table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + val tabName = "tbl" + withTable(tabName) { + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $tabName (i INT, j STRING)") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE")) + } + } + + test("create an external Hive source table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + withTempDir { tempDir => + val tabName = "tbl" + withTable(tabName) { + val e = intercept[AnalysisException] { + sql( + s""" + |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |LOCATION '${tempDir.toURI}' + """.stripMargin) + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE")) + } + } + } + + test("Create Hive Table As Select") { + import testImplicits._ + withTable("t", "t1") { + var e = intercept[AnalysisException] { + sql("CREATE TABLE t SELECT 1 as a, 1 as b") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) + + spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + e = intercept[AnalysisException] { + sql("CREATE TABLE t SELECT a, b from t1") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) + } + } + +} + +abstract class DDLSuite extends QueryTest with SQLTestUtils { + + protected def isUsingHiveMetastore: Boolean = { + spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" + } + + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable + + private val escapedIdentifier = "`(.+)`".r + + protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table + + private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { + props.filterNot(p => Seq("serialization.format", "path").contains(p._1)) + } + + private def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = { + assert(normalizeCatalogTable(actual) == normalizeCatalogTable(expected)) + } + /** * Strip backticks, if any, from the string. */ @@ -62,7 +155,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql(query) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { @@ -71,47 +164,72 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + CatalogDatabase( + name, "", CatalogUtils.stringToURI(spark.sessionState.conf.warehousePath), Map()), ignoreIfExists = false) } - private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { - val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = None, - outputFormat = None, - serde = None, - compressed = false, - properties = Map()) - CatalogTable( - identifier = name, - tableType = CatalogTableType.EXTERNAL, - storage = storage, - schema = new StructType() - .add("col1", "int") - .add("col2", "string") - .add("a", "int") - .add("b", "int"), - provider = Some("hive"), - partitionColumnNames = Seq("a", "b"), - createTime = 0L, - tracksPartitionsInCatalog = true) - } - - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( - catalog: SessionCatalog, - spec: TablePartitionSpec, - tableName: TableIdentifier): Unit = { + catalog: SessionCatalog, + spec: TablePartitionSpec, + tableName: TableIdentifier): Unit = { val part = CatalogTablePartition( spec, CatalogStorageFormat(None, None, None, None, false, Map())) catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } + private def getDBPath(dbName: String): URI = { + val warehousePath = makeQualifiedPath(spark.sessionState.conf.warehousePath) + new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri + } + + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -125,76 +243,25 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert("file" === pathInCatalog.getScheme) val expectedPath = new Path(path).toUri assert(expectedPath.getPath === pathInCatalog.getPath) - - withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { - sql(s"CREATE DATABASE db2") - val pathInCatalog2 = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri - assert("file" === pathInCatalog2.getScheme) - val expectedPath2 = new Path(spark.sessionState.conf.warehousePath + "/" + "db2.db").toUri - assert(expectedPath2.getPath === pathInCatalog2.getPath) - } - sql("DROP DATABASE db1") - sql("DROP DATABASE db2") - } - } - - private def makeQualifiedPath(path: String): String = { - // copy-paste from SessionCatalog - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) - fs.makeQualified(hadoopPath).toString - } - - test("Create/Drop Database") { - withTempDir { tmpDir => - val path = tmpDir.getCanonicalPath - withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { - val catalog = spark.sessionState.catalog - val databaseNames = Seq("db1", "`database`") - - databaseNames.foreach { dbName => - try { - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - - sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") - assert(db1 == CatalogDatabase( - dbNameWithoutBackTicks, - "", - expectedLocation, - Map.empty)) - sql(s"DROP DATABASE $dbName CASCADE") - assert(!catalog.databaseExists(dbNameWithoutBackTicks)) - } finally { - catalog.reset() - } - } - } } } test("Create Database using Default Warehouse Path") { - withSQLConf(SQLConf.WAREHOUSE_PATH.key -> "") { - // Will use the default location if and only if we unset the conf - spark.conf.unset(SQLConf.WAREHOUSE_PATH.key) - val catalog = spark.sessionState.catalog - val dbName = "db1" - try { - sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabaseMetadata(dbName) - val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbName.db") - assert(db1 == CatalogDatabase( - dbName, - "", - expectedLocation, - Map.empty)) - sql(s"DROP DATABASE $dbName CASCADE") - assert(!catalog.databaseExists(dbName)) - } finally { - catalog.reset() - } + val catalog = spark.sessionState.catalog + val dbName = "db1" + try { + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabaseMetadata(dbName) + assert(db1 == CatalogDatabase( + dbName, + "", + getDBPath(dbName), + Map.empty)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.databaseExists(dbName)) + } finally { + catalog.reset() } } @@ -224,41 +291,37 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("Create Database - database already exists") { - withTempDir { tmpDir => - val path = tmpDir.getCanonicalPath - withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { - val catalog = spark.sessionState.catalog - val databaseNames = Seq("db1", "`database`") + val catalog = spark.sessionState.catalog + val databaseNames = Seq("db1", "`database`") - databaseNames.foreach { dbName => - try { - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") - assert(db1 == CatalogDatabase( - dbNameWithoutBackTicks, - "", - expectedLocation, - Map.empty)) - - intercept[DatabaseAlreadyExistsException] { - sql(s"CREATE DATABASE $dbName") - } - } finally { - catalog.reset() - } - } + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + getDBPath(dbNameWithoutBackTicks), + Map.empty)) + + // TODO: HiveExternalCatalog should throw DatabaseAlreadyExistsException + val e = intercept[AnalysisException] { + sql(s"CREATE DATABASE $dbName") + }.getMessage + assert(e.contains(s"already exists")) + } finally { + catalog.reset() } } } private def checkSchemaInCreatedDataSourceTable( - path: File, - userSpecifiedSchema: Option[String], - userSpecifiedPartitionCols: Option[String], - expectedSchema: StructType, - expectedPartitionCols: Seq[String]): Unit = { + path: File, + userSpecifiedSchema: Option[String], + userSpecifiedPartitionCols: Option[String], + expectedSchema: StructType, + expectedPartitionCols: Seq[String]): Unit = { val tabName = "tab1" withTable(tabName) { val partitionClause = @@ -322,7 +385,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { pathToPartitionedTable, userSpecifiedSchema = Option("num int, str string"), userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = new StructType().add("str", StringType).add("num", IntegerType), expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String])) } } @@ -360,7 +423,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { pathToNonPartitionedTable, userSpecifiedSchema = Option("num int, str string"), userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = if (partitionCols.isDefined) { + // we skipped inference, so the partition col is ordered at the end + new StructType().add("str", StringType).add("num", IntegerType) + } else { + // no inferred partitioning, so schema is in original order + new StructType().add("num", IntegerType).add("str", StringType) + }, expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String])) } } @@ -384,7 +453,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json PARTITIONED BY (c)") } - assert(e.message == "partition column c is not defined in table `tbl`, " + + assert(e.message == "partition column c is not defined in table tbl, " + "defined table columns are: a, b") } @@ -392,7 +461,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json CLUSTERED BY (c) INTO 4 BUCKETS") } - assert(e.message == "bucket column c is not defined in table `tbl`, " + + assert(e.message == "bucket column c is not defined in table tbl, " + "defined table columns are: a, b") } @@ -459,61 +528,43 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("desc table for parquet data source table using in-memory catalog") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - val tabName = "tab1" - withTable(tabName) { - sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") - - checkAnswer( - sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") - ) - } - } - test("Alter/Describe Database") { - withTempDir { tmpDir => - val path = tmpDir.getCanonicalPath - withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { - val catalog = spark.sessionState.catalog - val databaseNames = Seq("db1", "`database`") + val catalog = spark.sessionState.catalog + val databaseNames = Seq("db1", "`database`") - databaseNames.foreach { dbName => - try { - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - val location = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + val location = getDBPath(dbNameWithoutBackTicks) - sql(s"CREATE DATABASE $dbName") + sql(s"CREATE DATABASE $dbName") - checkAnswer( - sql(s"DESCRIBE DATABASE EXTENDED $dbName"), - Row("Database Name", dbNameWithoutBackTicks) :: - Row("Description", "") :: - Row("Location", location) :: - Row("Properties", "") :: Nil) + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", CatalogUtils.URIToString(location)) :: + Row("Properties", "") :: Nil) - sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") - checkAnswer( - sql(s"DESCRIBE DATABASE EXTENDED $dbName"), - Row("Database Name", dbNameWithoutBackTicks) :: - Row("Description", "") :: - Row("Location", location) :: - Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", CatalogUtils.URIToString(location)) :: + Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) - sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") - checkAnswer( - sql(s"DESCRIBE DATABASE EXTENDED $dbName"), - Row("Database Name", dbNameWithoutBackTicks) :: - Row("Description", "") :: - Row("Location", location) :: - Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) - } finally { - catalog.reset() - } - } + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", CatalogUtils.URIToString(location)) :: + Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } finally { + catalog.reset() } } } @@ -528,7 +579,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) + // TODO: Unify the exception. + if (isUsingHiveMetastore) { + assert(message.contains(s"NoSuchObjectException: $dbNameWithoutBackTicks")) + } else { + assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) + } message = intercept[AnalysisException] { sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") @@ -557,7 +613,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName RESTRICT") }.getMessage - assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) + assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) + catalog.dropTable(tableIdent1, ignoreIfNotExists = false, purge = false) @@ -588,7 +645,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) val expectedTable = generateTable(catalog, expectedTableIdent) - assert(catalog.getTableMetadata(tableIdent1) === expectedTable) + checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1)) } test("create table in a specific db") { @@ -597,7 +654,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) val expectedTable = generateTable(catalog, tableIdent1) - assert(catalog.getTableMetadata(tableIdent1) === expectedTable) + checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1)) } test("create table using") { @@ -618,7 +675,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val table = catalog.getTableMetadata(TableIdentifier("tbl")) assert(table.tableType == CatalogTableType.MANAGED) assert(table.provider == Some("parquet")) - assert(table.schema == new StructType().add("a", IntegerType).add("b", IntegerType)) + // a is ordered last since it is a user-specified partitioning column + assert(table.schema == new StructType().add("b", IntegerType).add("a", IntegerType)) assert(table.partitionColumnNames == Seq("a")) } } @@ -637,21 +695,28 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create temporary view using") { - val csvFile = - Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString - withView("testview") { - sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + - "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + - s"OPTIONS (PATH '$csvFile')") - - checkAnswer( - sql("select c1, c2 from testview order by c1 limit 1"), + // when we test the HiveCatalogedDDLSuite, it will failed because the csvFile path above + // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it + // to a temp file by withResourceTempPath + withResourceTempPath("test-data/cars.csv") { tmpFile => + withView("testview") { + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + + s"OPTIONS (PATH '${tmpFile.toURI}')") + + checkAnswer( + sql("select c1, c2 from testview order by c1 limit 1"), Row("1997", "Ford") :: Nil) - // Fails if creating a new view with the same name - intercept[TempTableAlreadyExistsException] { - sql(s"CREATE TEMPORARY VIEW testview USING " + - s"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat OPTIONS (PATH '$csvFile')") + // Fails if creating a new view with the same name + intercept[TempTableAlreadyExistsException] { + sql( + s""" + |CREATE TEMPORARY VIEW testview + |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat + |OPTIONS (PATH '${tmpFile.toURI}') + """.stripMargin) + } } } } @@ -777,46 +842,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - test("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - test("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -841,14 +866,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -861,7 +878,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testRecoverPartitions() { + protected def testRecoverPartitions() { val catalog = spark.sessionState.catalog // table to alter does not exist intercept[AnalysisException] { @@ -875,7 +892,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) @@ -900,8 +917,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE tab1 RECOVER PARTITIONS") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) - assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") - assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + if (!isUsingHiveMetastore) { + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } else { + // After ALTER TABLE, the statistics of the first partition is removed by Hive megastore + assert(catalog.getPartition(tableIdent, part1).parameters.get("numFiles").isEmpty) + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } } finally { fs.delete(root, true) } @@ -911,73 +934,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("show tables") { - withTempView("show1a", "show2b") { - sql( - """ - |CREATE TEMPORARY TABLE show1a - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - | - |) - """.stripMargin) - sql( - """ - |CREATE TEMPORARY TABLE show2b - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - checkAnswer( - sql("SHOW TABLES IN default 'show1*'"), - Row("", "show1a", true) :: Nil) - - checkAnswer( - sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("", "show1a", true) :: - Row("", "show2b", true) :: Nil) - - checkAnswer( - sql("SHOW TABLES 'show1*|show2*'"), - Row("", "show1a", true) :: - Row("", "show2b", true) :: Nil) - - assert( - sql("SHOW TABLES").count() >= 2) - assert( - sql("SHOW TABLES IN default").count() >= 2) - } - } - - test("show databases") { - sql("CREATE DATABASE showdb2B") - sql("CREATE DATABASE showdb1A") + test("show databases") { + sql("CREATE DATABASE showdb2B") + sql("CREATE DATABASE showdb1A") // check the result as well as its order checkDataset(sql("SHOW DATABASES"), Row("default"), Row("showdb1a"), Row("showdb2b")) @@ -1000,11 +964,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Nil) } - test("drop table - temporary table") { + test("drop view - temporary view") { val catalog = spark.sessionState.catalog sql( """ - |CREATE TEMPORARY TABLE tab1 + |CREATE TEMPORARY VIEW tab1 |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', @@ -1013,26 +977,18 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |) """.stripMargin) assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) - sql("DROP TABLE tab1") + sql("DROP VIEW tab1") assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - - private def testDropTable(isDatasourceTable: Boolean): Unit = { + protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1056,23 +1012,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - } - - private def testSetProperties(isDatasourceTable: Boolean): Unit = { + protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { - catalog.getTableMetadata(tableIdent).properties + if (isUsingHiveMetastore) { + normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties + } else { + catalog.getTableMetadata(tableIdent).properties + } } assert(getProps.isEmpty) // set table properties @@ -1088,16 +1041,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { - catalog.getTableMetadata(tableIdent).properties + if (isUsingHiveMetastore) { + normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties + } else { + catalog.getTableMetadata(tableIdent).properties + } } // unset table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan', 'x' = 'y')") @@ -1121,49 +1078,46 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(getProps == Map("x" -> "y")) } - private def testSetLocation(isDatasourceTable: Boolean): Unit = { + protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) - assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) + assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) + assert( + normalizeSerdeProp(catalog.getPartition(tableIdent, partSpec).storage.properties).isEmpty) + // Verify that the location is set to the expected string - def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + def verifyLocation(expected: URI, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } .getOrElse { catalog.getTableMetadata(tableIdent).storage } - if (isDatasourceTable) { - if (spec.isDefined) { - assert(storageFormat.properties.isEmpty) - assert(storageFormat.locationUri === Some(expected)) - } else { - assert(storageFormat.locationUri === Some(expected)) - } - } else { - assert(storageFormat.locationUri === Some(expected)) - } + // TODO(gatorsmile): fix the bug in alter table set location. + // if (isUsingHiveMetastore) { + // assert(storageFormat.properties.get("path") === expected) + // } + assert(storageFormat.locationUri === Some(expected)) } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") - verifyLocation("/path/to/your/lovely/heart") + verifyLocation(new URI("/path/to/your/lovely/heart")) // set table partition location sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") - verifyLocation("/path/to/part/ways", Some(partSpec)) + verifyLocation(new URI("/path/to/part/ways"), Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") - verifyLocation("/swanky/steak/place") + verifyLocation(new URI("/swanky/steak/place")) // set table partition location without explicitly specifying database sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") - verifyLocation("vienna", Some(partSpec)) + verifyLocation(new URI("vienna"), Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") @@ -1174,16 +1128,33 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testSetSerde(isDatasourceTable: Boolean): Unit = { + protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) + def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { + val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties + if (isUsingHiveMetastore) { + assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps) + } else { + assert(serdeProp == expectedSerdeProps) + } + } + if (isUsingHiveMetastore) { + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) + } else { + assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) } - assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) + checkSerdeProps(Map.empty[String, String]) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -1196,45 +1167,61 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e1.getMessage.contains("datasource")) assert(e2.getMessage.contains("datasource")) } else { - sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") - assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + val newSerde = "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + sql(s"ALTER TABLE dbx.tab1 SET SERDE '$newSerde'") + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(newSerde)) + checkSerdeProps(Map.empty[String, String]) + val serde2 = "org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe" + sql(s"ALTER TABLE dbx.tab1 SET SERDE '$serde2' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") - assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop")) - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "v", "kay" -> "vee")) + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(serde2)) + checkSerdeProps(Map("k" -> "v", "kay" -> "vee")) } // set serde properties only sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "vvv", "kay" -> "vee")) + checkSerdeProps(Map("k" -> "vvv", "kay" -> "vee")) // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "vvv", "kay" -> "veee")) + checkSerdeProps(Map("k" -> "vvv", "kay" -> "veee")) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") } } - private def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) + def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { + val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties + if (isUsingHiveMetastore) { + assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps) + } else { + assert(serdeProp == expectedSerdeProps) + } } - assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) - assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty) + if (isUsingHiveMetastore) { + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) + } else { + assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) + } + checkPartitionSerdeProps(Map.empty[String, String]) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -1249,26 +1236,23 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } else { sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.jadoop'") assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty) + checkPartitionSerdeProps(Map.empty[String, String]) sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.madoop")) - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "v", "kay" -> "vee")) + checkPartitionSerdeProps(Map("k" -> "v", "kay" -> "vee")) } // set serde properties only maybeWrapException(isDatasourceTable) { sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) " + "SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "vvv", "kay" -> "vee")) + checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "vee")) } // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { sql("ALTER TABLE tab1 PARTITION (a=1, b=2) SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "vvv", "kay" -> "veee")) + checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "veee")) } // table to alter does not exist intercept[AnalysisException] { @@ -1276,7 +1260,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1285,20 +1272,25 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) - assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) + val partitionLocation = if (isUsingHiveMetastore) { + val tableLocation = catalog.getTableMetadata(tableIdent).storage.locationUri + assert(tableLocation.isDefined) + makeQualifiedPath(new Path(tableLocation.get.toString, "paris").toString) + } else { + new URI("paris") + } + + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(partitionLocation)) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") @@ -1327,33 +1319,35 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(part1, part2, part3, part4, part5)) } - private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") val part3 = Map("a" -> "3", "b" -> "7") val part4 = Map("a" -> "4", "b" -> "8") + val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) createTablePartition(catalog, part4, tableIdent) + createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3, part4)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + Set(part1, part2, part3, part4, part5)) // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part5)) // drop partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='2', b ='6')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part5)) // table to alter does not exist intercept[AnalysisException] { @@ -1367,28 +1361,32 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // partition to drop does not exist when using IF EXISTS sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part5)) // partition spec in DROP PARTITION should be case insensitive by default sql("ALTER TABLE tab1 DROP PARTITION (A='1', B='5')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part5)) + + // use int literal as partition value for int type partition column + sql("ALTER TABLE tab1 DROP PARTITION (a=9, b=9)") assert(catalog.listPartitions(tableIdent).isEmpty) } - private def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1418,6 +1416,26 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } + protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } + val catalog = spark.sessionState.catalog + val resolver = spark.sessionState.conf.resolver + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent, isDatasourceTable) + def getMetadata(colName: String): Metadata = { + val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => + resolver(field.name, colName) + } + column.map(_.metadata).getOrElse(Metadata.empty) + } + // Ensure that change column will preserve other metadata fields. + sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") + assert(getMetadata("col1").getString("key") == "value") + } + test("drop build-in function") { Seq("true", "false").foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { @@ -1479,8 +1497,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) " + - "- Returns the concatenation of `str1`, `str2`, ..., `strN`.") :: Nil + Row("Usage: concat(str1, str2, ..., strN) - " + + "Returns the concatenation of str1, str2, ..., strN.") :: Nil ) // extended mode checkAnswer( @@ -1525,54 +1543,19 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DROP TEMPORARY MACRO SOME_MACRO") } sql("DROP TEMPORARY MACRO IF EXISTS SOME_MACRO") - sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0d / (1.0d + EXP(-x))") - checkAnswer(sql("SELECT SIGMOID(1.0)"), Row(0.7310585786300049)) - sql("DROP TEMPORARY MACRO SIGMOID") } - test("select/insert into the managed table") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - val tabName = "tbl" - withTable(tabName) { - sql(s"CREATE TABLE $tabName (i INT, j STRING)") - val catalogTable = - spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) - assert(catalogTable.tableType == CatalogTableType.MANAGED) - - var message = intercept[AnalysisException] { - sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1, 'a'") - }.getMessage - assert(message.contains("Hive support is required to insert into the following tables")) - message = intercept[AnalysisException] { - sql(s"SELECT * FROM $tabName") - }.getMessage - assert(message.contains("Hive support is required to select over the following tables")) - } - } + test("create a data source table without schema") { + import testImplicits._ + withTempPath { tempDir => + withTable("tab1", "tab2") { + (("a", "b") :: Nil).toDF().write.json(tempDir.getCanonicalPath) - test("select/insert into external table") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - withTempDir { tempDir => - val tabName = "tbl" - withTable(tabName) { - sql( - s""" - |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) - |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |LOCATION '$tempDir' - """.stripMargin) - val catalogTable = - spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) - assert(catalogTable.tableType == CatalogTableType.EXTERNAL) + val e = intercept[AnalysisException] { sql("CREATE TABLE tab1 USING json") }.getMessage + assert(e.contains("Unable to infer schema for JSON. It must be specified manually")) - var message = intercept[AnalysisException] { - sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1, 'a'") - }.getMessage - assert(message.contains("Hive support is required to insert into the following tables")) - message = intercept[AnalysisException] { - sql(s"SELECT * FROM $tabName") - }.getMessage - assert(message.contains("Hive support is required to select over the following tables")) + sql(s"CREATE TABLE tab2 using json location '${tempDir.toURI}'") + checkAnswer(spark.table("tab2"), Row("a", "b")) } } } @@ -1600,22 +1583,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("Create Hive Table As Select") { - import testImplicits._ - withTable("t", "t1") { - var e = intercept[AnalysisException] { - sql("CREATE TABLE t SELECT 1 as a, 1 as b") - }.getMessage - assert(e.contains("Hive support is required to use CREATE Hive TABLE AS SELECT")) - - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") - e = intercept[AnalysisException] { - sql("CREATE TABLE t SELECT a, b from t1") - }.getMessage - assert(e.contains("Hive support is required to use CREATE Hive TABLE AS SELECT")) - } - } - test("Create Data Source Table As Select") { import testImplicits._ withTable("t", "t1", "t2") { @@ -1629,17 +1596,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { - sql("CREATE TABLE t (a INT, b INT)") + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { + sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { - Seq("true", "false").foreach { caseSensitive => + val caseSensitiveOptions = if (isUsingHiveMetastore) Seq("false") else Seq("true", "false") + caseSensitiveOptions.foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { var message = intercept[AnalysisException] { sql("DROP DATABASE default") @@ -1741,22 +1711,39 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + test("block creating duplicate temp table") { + withView("t_temp") { + sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") + val e = intercept[TempTableAlreadyExistsException] { + sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") + }.getMessage + assert(e.contains("Temporary table 't_temp' already exists")) + } + } + test("truncate table - external table, temporary table, view (not allowed)") { import testImplicits._ - val path = Utils.createTempDir().getAbsolutePath - (1 to 10).map { i => (i, i) }.toDF("a", "b").createTempView("my_temp_tab") - sql(s"CREATE EXTERNAL TABLE my_ext_tab LOCATION '$path'") - sql(s"CREATE VIEW my_view AS SELECT 1") - intercept[NoSuchTableException] { - sql("TRUNCATE TABLE my_temp_tab") + withTempPath { tempDir => + withTable("my_ext_tab") { + (("a", "b") :: Nil).toDF().write.parquet(tempDir.getCanonicalPath) + (1 to 10).map { i => (i, i) }.toDF("a", "b").createTempView("my_temp_tab") + sql(s"CREATE TABLE my_ext_tab using parquet LOCATION '${tempDir.toURI}'") + sql(s"CREATE VIEW my_view AS SELECT 1") + intercept[NoSuchTableException] { + sql("TRUNCATE TABLE my_temp_tab") + } + assertUnsupported("TRUNCATE TABLE my_ext_tab") + assertUnsupported("TRUNCATE TABLE my_view") + } } - assertUnsupported("TRUNCATE TABLE my_ext_tab") - assertUnsupported("TRUNCATE TABLE my_view") } test("truncate table - non-partitioned table (not allowed)") { - sql("CREATE TABLE my_tab (age INT, name STRING)") - assertUnsupported("TRUNCATE TABLE my_tab PARTITION (age=10)") + withTable("my_tab") { + sql("CREATE TABLE my_tab (age INT, name STRING) using parquet") + sql("INSERT INTO my_tab values (10, 'a')") + assertUnsupported("TRUNCATE TABLE my_tab PARTITION (age=10)") + } } test("SPARK-16034 Partition columns should match when appending to existing data source tables") { @@ -1824,7 +1811,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { withTable(tabName) { sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") val message = intercept[AnalysisException] { - sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase}") + sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase(Locale.ROOT)}") }.getMessage assert(message.contains("SHOW COLUMNS with conflicting databases")) } @@ -1838,4 +1825,525 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val rows: Seq[Row] = df.toLocalIterator().asScala.toSeq assert(rows.length > 0) } -} + + test("SET LOCATION for managed table") { + withTable("tbl") { + withTempDir { dir => + sql("CREATE TABLE tbl(i INT) USING parquet") + sql("INSERT INTO tbl SELECT 1") + checkAnswer(spark.table("tbl"), Row(1)) + val defaultTablePath = spark.sessionState.catalog + .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } + } + } + } + + test("insert data to a data source table which has a non-existing location should succeed") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, b int) + |USING parquet + |OPTIONS(path "${dir.toURI}") + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + dir.delete + assert(!dir.exists) + spark.sql("INSERT INTO TABLE t SELECT 'c', 1") + assert(dir.exists) + checkAnswer(spark.table("t"), Row("c", 1) :: Nil) + + Utils.deleteRecursively(dir) + assert(!dir.exists) + spark.sql("INSERT OVERWRITE TABLE t SELECT 'c', 1") + assert(dir.exists) + checkAnswer(spark.table("t"), Row("c", 1) :: Nil) + + val newDirFile = new File(dir, "x") + val newDir = newDirFile.toURI + spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + + val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table1.location == newDir) + assert(!newDirFile.exists) + + spark.sql("INSERT INTO TABLE t SELECT 'c', 1") + assert(newDirFile.exists) + checkAnswer(spark.table("t"), Row("c", 1) :: Nil) + } + } + } + + test("insert into a data source table with a non-existing partition location should succeed") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a int, b int, c int, d int) + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION "${dir.toURI}" + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") + checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) + + val partLoc = new File(s"${dir.getAbsolutePath}/a=1") + Utils.deleteRecursively(partLoc) + assert(!partLoc.exists()) + // insert overwrite into a partition which location has been deleted. + spark.sql("INSERT OVERWRITE TABLE t PARTITION(a=1, b=2) SELECT 7, 8") + assert(partLoc.exists()) + checkAnswer(spark.table("t"), Row(7, 8, 1, 2) :: Nil) + } + } + } + + test("read data from a data source table which has a non-existing location should succeed") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, b int) + |USING parquet + |OPTIONS(path "${dir.toURI}") + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + dir.delete() + checkAnswer(spark.table("t"), Nil) + + val newDirFile = new File(dir, "x") + val newDir = newDirFile.toURI + spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") + + val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table1.location == newDir) + assert(!newDirFile.exists()) + checkAnswer(spark.table("t"), Nil) + } + } + } + + test("read data from a data source table with non-existing partition location should succeed") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a int, b int, c int, d int) + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION "${dir.toURI}" + """.stripMargin) + spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") + checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) + + // select from a partition which location has been deleted. + Utils.deleteRecursively(dir) + assert(!dir.exists()) + spark.sql("REFRESH TABLE t") + checkAnswer(spark.sql("select * from t where a=1 and b=2"), Nil) + } + } + } + + test("create datasource table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '${dir.toURI}'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql( + s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '${dir.toURI}'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existing" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '${dir.toURI}' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '${dir.toURI}' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"data source table:partition column name containing $specialChars") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING parquet + |PARTITIONED BY(`$specialChars`) + |LOCATION '${dir.toURI}' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().nonEmpty) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for datasource table") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + + withTable("t", "t1") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$escapedLoc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().nonEmpty) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$escapedLoc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().nonEmpty) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + + if (!Utils.isWindows) { + // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows. + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().nonEmpty) + checkAnswer( + spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for database") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + + withDatabase ("tmpdb") { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$escapedLoc'") + spark.sql("USE tmpdb") + + import testImplicits._ + Seq(1).toDF("a").write.saveAsTable("t") + val tblloc = new File(loc, "t") + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(tblloc.getAbsolutePath)) + assert(tblloc.listFiles().nonEmpty) + } + } + } + } + } + + test("the qualified path of a datasource table is stored in the catalog") { + withTable("t", "t1") { + withTempDir { dir => + assert(!dir.getAbsolutePath.startsWith("file:/")) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\") + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$escapedDir' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location.toString.startsWith("file:/")) + } + + withTempDir { dir => + assert(!dir.getAbsolutePath.startsWith("file:/")) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\") + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$escapedDir' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location.toString.startsWith("file:/")) + } + } + } + + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + } + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - partitioned - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + } + + test("alter datasource table add columns - text format not supported") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING text") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + }.getMessage + assert(e.contains("ALTER ADD COLUMNS does not support datasource table with type")) + } + } + + test("alter table add columns -- not support temp view") { + withTempView("tmp_v") { + sql("CREATE TEMPORARY VIEW tmp_v AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns -- not support view") { + withView("v1") { + sql("CREATE VIEW v1 AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns with existing column name") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + if (!caseSensitive) { + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } else { + if (isUsingHiveMetastore) { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("HiveException")) + } else { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema + .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) + } + } + } + } + } + + test(s"basic DDL using locale tr - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withLocale("tr") { + val dbName = "DaTaBaSe_I" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + + val tabName = "tAb_I" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET") + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1") + checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil) + } + } + } + } + } + } +} \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 4f83a834db4a3..703a74278e155 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -237,7 +237,6 @@ private[hive] trait HiveStrategies { !predicate.references.isEmpty && predicate.references.subsetOf(partitionKeyIds) } - FunctionRegistry.scala pruneFilterProject( projectList, otherPredicates, From 3d05e4f3509d32fa85618bfb475b648261a0694f Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:23:55 +0800 Subject: [PATCH 13/26] reformat code. --- .../analysis/AlreadyExistException.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 1115 +++++++++++------ .../spark/sql/execution/SparkSqlParser.scala | 68 +- .../sql/execution/command/DDLSuite.scala | 52 +- .../spark/sql/hive/HiveStrategies.scala | 1 + 5 files changed, 814 insertions(+), 424 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 16ceac098e597..eecdcf6ffa781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -46,4 +46,4 @@ class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") class TempMacroAlreadyExistsException(func: String) - extends AnalysisException(s"Temp macro '$func' already exists") \ No newline at end of file + extends AnalysisException(s"Temp macro '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dd68d60d3e839..cf6f3939a24d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,18 +21,20 @@ import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -41,13 +43,45 @@ import org.apache.spark.sql.types._ * to resolve attribute references. */ object SimpleAnalyzer extends Analyzer( - new SessionCatalog( - new InMemoryCatalog, - EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) { - override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} - }, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) { + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} + }, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) + +/** + * Provides a way to keep state during the analysis, this enables us to decouple the concerns + * of analysis environment from the catalog. + * + * Note this is thread local. + * + * @param defaultDatabase The default database used in the view resolution, this overrules the + * current catalog database. + * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the + * depth of nested views. + */ +case class AnalysisContext( + defaultDatabase: Option[String] = None, + nestedViewDepth: Int = 0) + +object AnalysisContext { + private val value = new ThreadLocal[AnalysisContext]() { + override def initialValue: AnalysisContext = AnalysisContext() + } + + def get: AnalysisContext = value.get() + private def set(context: AnalysisContext): Unit = value.set(context) + + def withAnalysisContext[A](database: Option[String])(f: => A): A = { + val originContext = value.get() + val context = AnalysisContext(defaultDatabase = database, + nestedViewDepth = originContext.nestedViewDepth + 1) + set(context) + try f finally { set(originContext) } + } +} /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and @@ -56,11 +90,11 @@ object SimpleAnalyzer extends Analyzer( */ class Analyzer( catalog: SessionCatalog, - conf: CatalystConf, + conf: SQLConf, maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis { - def this(catalog: SessionCatalog, conf: CatalystConf) = { + def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } @@ -73,7 +107,19 @@ class Analyzer( */ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Override to provide rules to do post-hoc resolution. Note that these rules will be executed + * in an individual batch. This batch is to run right after the normal resolution batch and + * execute its rules in one pass. + */ + val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + lazy val batches: Seq[Batch] = Seq( + Batch("Hints", fixedPoint, + new ResolveHints.ResolveBroadcastHints(conf), + ResolveHints.RemoveAllHints), + Batch("Simple Sanity Check", Once, + LookupFunctions), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -81,46 +127,54 @@ class Analyzer( new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: - ResolveRelations :: - ResolveReferences :: - ResolveCreateNamedStruct :: - ResolveDeserializer :: - ResolveNewInstance :: - ResolveUpCast :: - ResolveGroupingAnalytics :: - ResolvePivot :: - ResolveOrdinalInOrderByAndGroupBy :: - ResolveMissingReferences :: - ExtractGenerator :: - ResolveGenerate :: - ResolveFunctions :: - ResolveAliases :: - ResolveSubquery :: - ResolveWindowOrder :: - ResolveWindowFrame :: - ResolveNaturalAndUsingJoin :: - ExtractWindowExpressions :: - GlobalAggregates :: - ResolveAggregateFunctions :: - TimeWindowing :: - ResolveInlineTables :: - TypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*), + ResolveRelations :: + ResolveReferences :: + ResolveCreateNamedStruct :: + ResolveDeserializer :: + ResolveNewInstance :: + ResolveUpCast :: + ResolveGroupingAnalytics :: + ResolvePivot :: + ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: + ResolveMissingReferences :: + ExtractGenerator :: + ResolveGenerate :: + ResolveFunctions :: + ResolveAliases :: + ResolveSubquery :: + ResolveWindowOrder :: + ResolveWindowFrame :: + ResolveNaturalAndUsingJoin :: + ExtractWindowExpressions :: + GlobalAggregates :: + ResolveAggregateFunctions :: + TimeWindowing :: + ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: + TypeCoercion.typeCoercionRules ++ + extendedResolutionRules : _*), + Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), + Batch("View", Once, + AliasViewChild(conf)), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), + Batch("Subquery", Once, + UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases) + CleanupAliases, + EliminateBarriers) ) /** * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -132,12 +186,8 @@ class Analyzer( def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { plan transformDown { case u : UnresolvedRelation => - val substituted = cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) - .map(_._2).map { relation => - val withAlias = u.alias.map(SubqueryAlias(_, relation, None)) - withAlias.getOrElse(relation) - } - substituted.getOrElse(u) + cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) + .map(_._2).getOrElse(u) case other => // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { @@ -152,7 +202,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -178,9 +228,10 @@ class Analyzer( expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) - case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() + case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() case e: ExtractValue => Alias(e, toPrettySQL(e))() case e if optGenAliasFunc.isDefined => Alias(child, optGenAliasFunc.get.apply(e))() @@ -193,7 +244,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -437,14 +488,16 @@ class Analyzer( case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { + val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") if (singleAgg) { - value.toString + stringValue } else { val suffix = aggregate match { case n: NamedExpression => n.name - case _ => aggregate.sql + case _ => toPrettySQL(aggregate) } - value + "_" + suffix + stringValue + "_" + suffix } } if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { @@ -461,20 +514,21 @@ class Analyzer( val pivotAggs = namedAggExps.map { a => Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) .toAggregateExpression() - , "__pivot_" + a.sql)() + , "__pivot_" + a.sql)() } - val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg) + val groupByExprsAttr = groupByExprs.map(_.toAttribute) + val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) val pivotAggAttribute = pivotAggs.map(_.toAttribute) val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - Project(groupByExprs ++ pivotOutputs, secondAgg) + Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { @@ -509,32 +563,102 @@ class Analyzer( * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - private def lookupTableFromCatalog(u: UnresolvedRelation): LogicalPlan = { + + // If the unresolved relation is running directly on files, we just return the original + // UnresolvedRelation, the plan will get resolved later. Else we look up the table from catalog + // and change the default database name(in AnalysisContext) if it is a view. + // We usually look up a table from the default database if the table identifier has an empty + // database part, for a view the default database should be the currentDb when the view was + // created. When the case comes to resolving a nested view, the view may have different default + // database with that the referenced view has, so we need to use + // `AnalysisContext.defaultDatabase` to track the current default database. + // When the relation we resolve is a view, we fetch the view.desc(which is a CatalogTable), and + // then set the value of `CatalogTable.viewDefaultDatabase` to + // `AnalysisContext.defaultDatabase`, we look up the relations that the view references using + // the default database. + // For example: + // |- view1 (defaultDatabase = db1) + // |- operator + // |- table2 (defaultDatabase = db1) + // |- view2 (defaultDatabase = db2) + // |- view3 (defaultDatabase = db3) + // |- view4 (defaultDatabase = db4) + // In this case, the view `view1` is a nested view, it directly references `table2`, `view2` + // and `view4`, the view `view2` references `view3`. On resolving the table, we look up the + // relations `table2`, `view2`, `view4` using the default database `db1`, and look up the + // relation `view3` using the default database `db2`. + // + // Note this is compatible with the views defined by older versions of Spark(before 2.2), which + // have empty defaultDatabase and all the relations in viewText have database part defined. + def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match { + case u: UnresolvedRelation if !isRunningDirectlyOnFiles(u.tableIdentifier) => + val defaultDatabase = AnalysisContext.get.defaultDatabase + val relation = lookupTableFromCatalog(u, defaultDatabase) + resolveRelation(relation) + // The view's child should be a logical plan parsed from the `desc.viewText`, the variable + // `viewText` should be defined, or else we throw an error on the generation of the View + // operator. + case view @ View(desc, _, child) if !child.resolved => + // Resolve all the UnresolvedRelations and Views in the child. + val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) { + if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { + view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + + "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + + "aroud this.") + } + execute(child) + } + view.copy(child = newChild) + case p @ SubqueryAlias(_, view: View) => + val newChild = resolveRelation(view) + p.copy(child = newChild) + case _ => plan + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => + EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { + case v: View => + u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") + case other => i.copy(table = other) + } + case u: UnresolvedRelation => resolveRelation(u) + } + + // Look up the table with the given name from catalog. The database we used is decided by the + // precedence: + // 1. Use the database part of the table identifier, if it is defined; + // 2. Use defaultDatabase, if it is defined(In this case, no temporary objects can be used, + // and the default database is only used to look up a view); + // 3. Use the currentDb of the SessionCatalog. + private def lookupTableFromCatalog( + u: UnresolvedRelation, + defaultDatabase: Option[String] = None): LogicalPlan = { + val tableIdentWithDb = u.tableIdentifier.copy( + database = u.tableIdentifier.database.orElse(defaultDatabase)) try { - catalog.lookupRelation(u.tableIdentifier, u.alias) + catalog.lookupRelation(tableIdentWithDb) } catch { case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${u.tableName}") + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + // If the database is defined and that database is not found, throw an AnalysisException. + // Note that if the database is not defined, it is possible we are looking up a temp view. + case e: NoSuchDatabaseException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + + s"database ${e.db} doesn't exsits.") } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) - case u: UnresolvedRelation => - val table = u.tableIdentifier - if (table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && - (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))) { - // If the database part is specified, and we support running SQL directly on files, and - // it's not a temporary view, and the table does not exist, then let's just return the - // original UnresolvedRelation. It is possible we are matching a query like "select * - // from parquet.`/path/to/query`". The plan will get resolved later. - // Note that we are testing (!db_exists || !table_exists) because the catalog throws - // an exception from tableExists if the database does not exist. - u - } else { - lookupTableFromCatalog(u) - } + // If the database part is specified, and we support running SQL directly on files, and + // it's not a temporary view, and the table does not exist, then let's just return the + // original UnresolvedRelation. It is possible we are matching a query like "select * + // from parquet.`/path/to/query`". The plan will get resolved in the rule `ResolveDataSource`. + // Note that we are testing (!db_exists || !table_exists) because the catalog throws + // an exception from tableExists if the database does not exist. + private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = { + table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && + (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table)) } } @@ -547,7 +671,9 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { + // Remove analysis barrier if any. + val right = EliminateBarriers(oriRight) val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -590,7 +716,7 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - right + oriRight case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -598,14 +724,73 @@ class Analyzer( } transformUp { case other => other transformExpressions { case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - newRight + AnalysisBarrier(newRight) + } + } + + private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + } + + /** + * The outer plan may have been de-duplicated and the function below updates the + * outer references to refer to the de-duplicated attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. + */ + private def dedupOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(dedupAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) + } } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -663,11 +848,10 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -690,8 +874,8 @@ class Analyzer( * Build a project list for Project/Aggregate and expand the star if possible */ private def buildExpandedProjectList( - exprs: Seq[NamedExpression], - child: LogicalPlan): Seq[NamedExpression] = { + exprs: Seq[NamedExpression], + child: LogicalPlan): Seq[NamedExpression] = { exprs.flatMap { // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") case s: Star => s.expand(child, resolver) @@ -766,30 +950,30 @@ class Analyzer( } } - /** - * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by - * clauses. This rule is to convert ordinal positions to the corresponding expressions in the - * select list. This support is introduced in Spark 2.0. - * - * - When the sort references or group by expressions are not integer but foldable expressions, - * just ignore them. - * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position - * numbers too. - * - * Before the release of Spark 2.0, the literals in order/sort by and group by clauses - * have no effect on the results. - */ + /** + * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by + * clauses. This rule is to convert ordinal positions to the corresponding expressions in the + * select list. This support is introduced in Spark 2.0. + * + * - When the sort references or group by expressions are not integer but foldable expressions, + * just ignore them. + * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position + * numbers too. + * + * Before the release of Spark 2.0, the literals in order/sort by and group by clauses + * have no effect on the results. + */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering) + SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + @@ -801,17 +985,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + @@ -822,6 +1000,41 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + // This is a strict check though, we put this to apply the rule only if the expression is not + // resolvable by child. + private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = { + !child.output.exists(a => resolver(a.name, attrName)) + } + + private def mayResolveAttrByAggregateExprs( + exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = { + exprs.map { _.transform { + case u: UnresolvedAttribute if notResolvableByChild(u.name, child) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + }} + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => + agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) + + case gs @ GroupingSets(selectedGroups, groups, child, aggs) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + gs.copy( + selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)), + groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child)) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -831,11 +1044,13 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if child.resolved => + case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved => + val child = EliminateBarriers(orgChild) try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -856,7 +1071,8 @@ class Analyzer( case ae: AnalysisException => s } - case f @ Filter(cond, child) if child.resolved => + case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved => + val child = EliminateBarriers(orgChild) try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -883,7 +1099,7 @@ class Analyzer( */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { - return plan + return AnalysisBarrier(plan) } plan match { case p: Project => @@ -932,11 +1148,30 @@ class Analyzer( } } + /** + * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the + * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It + * only performs simple existence check according to the function identifier to quickly identify + * undefined functions without triggering relation resolution, which may incur potentially + * expensive partition/schema discovery process in some cases. + * + * @see [[ResolveFunctions]] + * @see https://issues.apache.org/jira/browse/SPARK-19737 + */ + object LookupFunctions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case f: UnresolvedFunction if !catalog.functionExists(f.name) => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1001,159 +1236,212 @@ class Analyzer( } /** - * Pull out all (outer) correlated predicates from a given subquery. This method removes the - * correlated predicates from subquery [[Filter]]s and adds the references of these predicates - * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to - * be able to evaluate the predicates at the top level. - * - * This method returns the rewritten subquery and correlated predicates. + * Validates to make sure the outer references appearing inside the subquery + * are legal. This function also returns the list of expressions + * that contain outer references. These outer references would be kept as children + * of subquery expressions by the caller of this function. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] - - /** Make sure a plans' subtree does not contain a tagged predicate. */ - def failOnOuterReferenceInSubTree(p: LogicalPlan, msg: String): Unit = { - if (p.collect(predicateMap).nonEmpty) { - failAnalysis(s"Accessing outer query column is not allowed in $msg: $p") + private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { + val outerReferences = ArrayBuffer.empty[Expression] + + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => } } - /** Helper function for locating outer references. */ - def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (hasOuterReferences(p)) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") + } } - /** Make sure a plans' expressions do not contain a tagged predicate. */ - def failOnOuterReference(p: LogicalPlan): Unit = { - if (p.expressions.exists(containsOuter)) { + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { failAnalysis( - s"Correlated predicates are not supported outside of WHERE/HAVING clauses: $p") + "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + + s"clauses:\n$p") } } - /** Determine which correlated predicate references are missing from this plan. */ - def missingReferences(p: LogicalPlan): AttributeSet = { - val localPredicateReferences = p.collect(predicateMap) - .flatten - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - localPredicateReferences -- p.outputSet + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } } - // Simplify the predicates before pulling them out. - val transformed = BooleanSimplification(sub) transformUp { - case f @ Filter(cond, child) => + var foundNonEqualCorrelatedPred : Boolean = false + + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { + + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Project, Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case s: Sort => + failOnInvalidOuterReference(s) + case r: RepartitionByExpression => + failOnInvalidOuterReference(r) + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. + case f: Filter => // Find all predicates with an outer reference. - val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) - - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> xs - newFilter - case xs => - predicateMap += child -> xs - child - } - case p @ Project(expressions, child) => - failOnOuterReference(p) - val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) - } else { - p + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) + + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true } - case a @ Aggregate(grouping, expressions, child) => - failOnOuterReference(a) - val referencesToAdd = missingReferences(a) - if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) - } else { - a + + failOnInvalidOuterReference(f) + // The aggregate expressions are treated in a special way by getOuterReferences. If the + // aggregate expression contains only outer reference attributes then the entire aggregate + // expression is isolated as an OuterReference. + // i.e min(OuterReference(b)) => OuterReference(min(b)) + outerReferences ++= getOuterReferences(correlated) + + // Project cannot host any correlated expressions + // but can be anywhere in a correlated subquery. + case p: Project => + failOnInvalidOuterReference(p) + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. + case a: Aggregate => + failOnInvalidOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnInvalidOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) } - case j @ Join(left, _, RightOuter, _) => - failOnOuterReference(j) - failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") - j - case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] => - failOnOuterReference(j) - failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN") - j - case u: Union => - failOnOuterReferenceInSubTree(u, "a UNION") - u - case s: SetOperation => - failOnOuterReferenceInSubTree(s.right, "an INTERSECT/EXCEPT") - s - case e: Expand => - failOnOuterReferenceInSubTree(e, "an EXPAND") - e - case l : LocalLimit => - failOnOuterReferenceInSubTree(l, "a LIMIT") - l - // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) - // and we are walking bottom up, we will fail on LocalLimit before - // reaching GlobalLimit. - // The code below is just a safety net. - case g : GlobalLimit => - failOnOuterReferenceInSubTree(g, "a LIMIT") - g - case s : Sample => - failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") - s - case p => - failOnOuterReference(p) - p - } - (transformed, predicateMap.values.flatten.toSeq) - } - /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. - */ - private def rewriteSubQuery( - sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - - // Make sure the inner and the outer query attributes do not collide. - val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = basePlan.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = basePlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, basePlan) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (basePlan, baseConditions) + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case g: Generate if g.join => + failOnInvalidOuterReference(g) + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) } - // Remove outer references from the correlated predicates. We wait with extracting - // these until collisions between the inner and outer query attributes have been - // solved. - val conditions = deDuplicatedConditions.map(_.transform { - case OuterReference(ref) => ref - }) - (plan, conditions) + outerReferences } /** - * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * - * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + * Outer references from the correlated predicates are updated as children of + * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, @@ -1176,7 +1464,8 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: Pull out the predicates if the plan is resolved. + // Step 2: If the subquery plan is fully resolved, pull the outer references and record + // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. @@ -1184,44 +1473,44 @@ class Analyzer( failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } - // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + // Validate the outer reference and record the outer references as children of + // subquery expression. + f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } } /** - * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS - * expressions into PredicateSubquery expression once the are resolved. + * Resolves the subquery. Apart of resolving the subquery and outer references (if any) + * in the subquery plan, the children of subquery expression are updated to record the + * outer references. This is needed to make sure + * (1) The column(s) referred from the outer query are not pruned from the plan during + * optimization. + * (2) Any aggregate expression(s) that reference outer attributes are pushed down to + * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { - case s @ ScalarSubquery(sub, conditions, exprId) - if sub.resolved && conditions.isEmpty && sub.output.size != 1 => - failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) - case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + case e @ Exists(sub, _, exprId) if !sub.resolved => + resolveSubQuery(e, plans)(Exists(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. - val expressions = e match { + val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => - // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) - PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) - } + val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + In(value, Seq(expr)) } } /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1236,7 +1525,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1262,10 +1551,12 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => + apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved => + aggregate @ Aggregate(grouping, originalAggExprs, child)) + if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { @@ -1292,8 +1583,8 @@ class Analyzer( alias.toAttribute // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && - !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !aggregate.output.exists(_.semanticEquals(e)) => + !ResolveGroupingAnalytics.hasGroupingFunction(e) && + !aggregate.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne @@ -1322,6 +1613,8 @@ class Analyzer( case ae: AnalysisException => filter } + case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1418,16 +1711,23 @@ class Analyzer( case _ => expr } - /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { - def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) + /** + * Extracts a [[Generator]] expression, any names assigned by aliases to the outputs + * and the outer flag. The outer flag is used when joining the generator output. + * @param e the [[Expression]] + * @return (the [[Generator]], seq of output names, outer flag) + */ + def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match { + case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true)) + case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some(g, names, true) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names, false) case _ => None } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1443,7 +1743,7 @@ class Analyzer( var resolvedGenerator: Generate = null val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names) if generator.childrenResolved => + case AliasedGenerator(generator, names, outer) if generator.childrenResolved => // It's a sanity check, this should not happen as the previous case will throw // exception earlier. assert(resolvedGenerator == null, "More than one generator found in SELECT.") @@ -1452,7 +1752,7 @@ class Analyzer( Generate( generator, join = projectList.size > 1, // Only join if there are other expressions in SELECT. - outer = false, + outer = outer, qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) @@ -1485,7 +1785,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1496,8 +1796,8 @@ class Analyzer( * names is empty names are assigned from field names in generator. */ private[analysis] def makeGeneratorOutput( - generator: Generator, - names: Seq[String]): Seq[Attribute] = { + generator: Generator, + names: Seq[String]): Seq[Attribute] = { val elementAttrs = generator.elementSchema.toAttributes if (names.length == elementAttrs.length) { @@ -1509,8 +1809,8 @@ class Analyzer( } else { failAnalysis( "The number of aliases supplied in the AS clause does not match the number of columns " + - s"output by the UDTF expected ${elementAttrs.size} aliases but got " + - s"${names.mkString(",")} ") + s"output by the UDTF expected ${elementAttrs.size} aliases but got " + + s"${names.mkString(",")} ") } } } @@ -1651,8 +1951,8 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - ae @ AggregateExpression(function, _, _, _), - spec: WindowSpecDefinition) => + ae @ AggregateExpression(function, _, _, _), + spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] val newAgg = ae.copy(aggregateFunction = newFunction) @@ -1679,8 +1979,8 @@ class Analyzer( * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ private def addWindow( - expressionsWithWindowFunctions: Seq[NamedExpression], - child: LogicalPlan): LogicalPlan = { + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions // and put those extracted WindowExpressions to extractedWindowExprBuffer. // This step is needed because it is possible that an expression contains multiple @@ -1751,8 +2051,8 @@ class Analyzer( // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && - hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + hasWindowFunction(aggregateExprs) && + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) @@ -1769,7 +2069,7 @@ class Analyzer( // Aggregate without Having clause. case a @ Aggregate(groupingExprs, aggregateExprs, child) if hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) @@ -1802,33 +2102,42 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f + case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => + val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) + val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) + a.transformExpressions { case e => + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + }.copy(child = newChild) + // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => - val leafNondeterministic = expr.collect { - case n: Nondeterministic => n - } - leafNondeterministic.map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")(isGenerated = true) - } - new TreeNodeRef(e) -> ne - } - }.toMap + val nondeterToAttr = getNondeterToAttr(p.expressions) val newPlan = p.transformExpressions { case e => - nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) } - val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } + + private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { + exprs.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { case n: Nondeterministic => n } + leafNondeterministic.distinct.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")(isGenerated = true) + } + e -> ne + } + }.toMap + } } /** @@ -1838,12 +2147,12 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { - case udf @ ScalaUDF(func, _, inputs, _) => + case udf @ ScalaUDF(func, _, inputs, _, _, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) @@ -1903,18 +2212,10 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) - if left.resolved && right.resolved && j.duplicateResolved => - // Resolve the column names referenced in using clause from both the legs of join. - val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, resolver)) - val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, resolver)) - if ((lCols.length == usingCols.length) && (rCols.length == usingCols.length)) { - val joinNames = lCols.map(exp => exp.name) - commonNaturalJoinProcessing(left, right, joinType, joinNames, None) - } else { - j - } + if left.resolved && right.resolved && j.duplicateResolved => + commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) @@ -1929,18 +2230,16 @@ class Analyzer( joinNames: Seq[String], condition: Option[Expression]) = { val leftKeys = joinNames.map { keyName => - val joinColumn = left.output.find(attr => resolver(attr.name, keyName)) - assert( - joinColumn.isDefined, - s"$keyName should exist in ${left.output.map(_.name).mkString(",")}") - joinColumn.get + left.output.find(attr => resolver(attr.name, keyName)).getOrElse { + throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + + s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]") + } } val rightKeys = joinNames.map { keyName => - val joinColumn = right.output.find(attr => resolver(attr.name, keyName)) - assert( - joinColumn.isDefined, - s"$keyName should exist in ${right.output.map(_.name).mkString(",")}") - joinColumn.get + right.output.find(attr => resolver(attr.name, keyName)).getOrElse { + throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " + + s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]") + } } val joinPairs = leftKeys.zip(rightKeys) @@ -1978,7 +2277,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -1993,8 +2292,21 @@ class Analyzer( validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpression( deserializer, LocalRelation(inputs), throws = true) - validateNestedTupleFields(resolved) - resolved + val result = resolved transformDown { + case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => + inputData.dataType match { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + expr + case other => + throw new AnalysisException("need an array field but got " + other.simpleString) + } + } + validateNestedTupleFields(result) + result } } @@ -2051,7 +2363,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2074,39 +2386,29 @@ class Analyzer( */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + val fromStr = from match { + case l: LambdaVariable => "array element" + case e => e.sql + } + throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } - private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) - toPrecedence > 0 && fromPrecedence > toPrecedence - } - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u - case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => - fail(child, to, walkedTypePath) - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => - fail(child, to, walkedTypePath) - case (from, to) if illegalNumericPrecedence(from, to) => - fail(child, to, walkedTypePath) - case (TimestampType, DateType) => - fail(child, DateType, walkedTypePath) - case (StringType, to: NumericType) => - fail(child, to, walkedTypePath) - case _ => Cast(child, dataType.asNullable) - } + case UpCast(child, dataType, walkedTypePath) + if Cast.mayTruncate(child.dataType, dataType) => + fail(child, dataType, walkedTypePath) + + case UpCast(child, dataType, walkedTypePath) => Cast(child, dataType.asNullable) } } } @@ -2118,7 +2420,7 @@ class Analyzer( */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child, _) => child + case SubqueryAlias(_, child) => child } } @@ -2149,7 +2451,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2178,6 +2480,23 @@ object CleanupAliases extends Rule[LogicalPlan] { } } +/** Remove the barrier nodes of analysis */ +object EliminateBarriers extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case AnalysisBarrier(child) => child + } +} + +/** + * Ignore event time watermark in batch query, which is only supported in Structured Streaming. + * TODO: add this rule into analyzer rule list. + */ +object EliminateEventTimeWatermark extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case EventTimeWatermark(_, _, child) if !child.isStreaming => child + } +} + /** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and @@ -2217,7 +2536,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2225,29 +2544,35 @@ object TimeWindowing extends Rule[LogicalPlan] { // Only support a single window expression for now if (windowExpressions.size == 1 && - windowExpressions.head.timeColumn.resolved && - windowExpressions.head.checkInputDataTypes().isSuccess) { + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { val window = windowExpressions.head - val windowAttr = AttributeReference("window", window.dataType)() + + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + val windowAttr = + AttributeReference("window", window.dataType, metadata = metadata)() val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt val windows = Seq.tabulate(maxNumOverlapping + 1) { i => val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / window.slideDuration) val windowStart = (windowId + i - maxNumOverlapping) * - window.slideDuration + window.startTime + window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( Literal(WINDOW_START) :: windowStart :: - Literal(WINDOW_END) :: windowEnd :: Nil) + Literal(WINDOW_END) :: windowEnd :: Nil) } val projections = windows.map(_ +: p.children.head.output) val filterExpr = window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + window.timeColumn < windowAttr.getField(WINDOW_END) val expandedPlan = Filter(filterExpr, @@ -2260,7 +2585,7 @@ object TimeWindowing extends Rule[LogicalPlan] { substitutedPlan.withNewChildren(expandedPlan :: Nil) } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are not currently not supported.") + "of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } @@ -2282,3 +2607,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * The aggregate expressions from subquery referencing outer query block are pushed + * down to the outer query block for evaluation. This rule below updates such outer references + * as AttributeReference referring attributes from the parent/outer query block. + * + * For example (SQL): + * {{{ + * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) + * }}} + * Plan before the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < min(outer(b#227))) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + * Plan after the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < outer(min(b#227)#249)) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + */ +object UpdateOuterReferences extends Rule[LogicalPlan] { + private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } + + private def updateOuterReferenceInSubquery( + plan: LogicalPlan, + refExprs: Seq[Expression]): LogicalPlan = { + plan transformAllExpressions { case e => + val outerAlias = + refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) + outerAlias match { + case Some(a: Alias) => OuterReference(a.toAttribute) + case _ => e + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case f @ Filter(_, a: Aggregate) if f.resolved => + f transformExpressions { + case s: SubqueryExpression if s.children.nonEmpty => + // Collect the aliases from output of aggregate. + val outerAliases = a.aggregateExpressions collect { case a: Alias => a } + // Update the subquery plan to record the OuterReference to point to outer query plan. + s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) + } + } + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 18c5d4ac5b2b0..d26bd785970ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -85,7 +85,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitResetConfiguration( - ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { + ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { ResetCommand } @@ -178,7 +178,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitShowTblProperties( - ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { ShowTablePropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), Option(ctx.key).map(visitTablePropertyKey)) @@ -337,7 +337,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Validate a create table statement and return the [[TableIdentifier]]. */ override def visitCreateTableHeader( - ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { val temporary = ctx.TEMPORARY != null val ifNotExists = ctx.EXISTS != null if (temporary && ifNotExists) { @@ -443,7 +443,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Creates a [[CreateTempViewUsing]] logical plan. */ override def visitCreateTempViewUsing( - ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { + ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { CreateTempViewUsing( tableIdent = visitTableIdentifier(ctx.tableIdentifier()), userSpecifiedSchema = Option(ctx.colTypeList()).map(createSchema), @@ -505,7 +505,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. */ override def visitTablePropertyList( - ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { val properties = ctx.tableProperty.asScala.map { property => val key = visitTablePropertyKey(property.key) val value = visitTablePropertyValue(property.value) @@ -598,7 +598,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitSetDatabaseProperties( - ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterDatabasePropertiesCommand( ctx.identifier.getText, visitPropertyKeyValues(ctx.tablePropertyList)) @@ -799,7 +799,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitSetTableProperties( - ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableSetPropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), visitPropertyKeyValues(ctx.tablePropertyList), @@ -816,7 +816,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitUnsetTableProperties( - ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableUnsetPropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), visitPropertyKeys(ctx.tablePropertyList), @@ -855,7 +855,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * is associated with physical tables */ override def visitAddTablePartition( - ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { + ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { if (ctx.VIEW != null) { operationNotAllowed("ALTER VIEW ... ADD PARTITION", ctx) } @@ -886,7 +886,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitRenameTablePartition( - ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { + ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { AlterTableRenamePartitionCommand( visitTableIdentifier(ctx.tableIdentifier), visitNonOptionalPartitionSpec(ctx.from), @@ -906,7 +906,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * is associated with physical tables */ override def visitDropTablePartitions( - ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { + ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { if (ctx.VIEW != null) { operationNotAllowed("ALTER VIEW ... DROP PARTITION", ctx) } @@ -927,7 +927,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitRecoverPartitions( - ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { + ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { AlterTableRecoverPartitionsCommand(visitTableIdentifier(ctx.tableIdentifier)) } @@ -1005,7 +1005,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Convert a nested constants list into a sequence of string sequences. */ override def visitNestedConstantList( - ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { + ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { ctx.constantList.asScala.map(visitConstantList) } @@ -1213,7 +1213,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Format: STORED AS ... */ override def visitCreateFileFormat( - ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { (ctx.fileFormat, ctx.storageHandler) match { // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format case (c: TableFileFormatContext, null) => @@ -1232,7 +1232,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create a [[CatalogStorageFormat]]. */ override def visitTableFileFormat( - ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { CatalogStorageFormat.empty.copy( inputFormat = Option(string(ctx.inFmt)), outputFormat = Option(string(ctx.outFmt))) @@ -1242,7 +1242,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. */ override def visitGenericFileFormat( - ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { val source = ctx.identifier.getText HiveSerDe.sourceToSerDe(source) match { case Some(s) => @@ -1284,7 +1284,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create SERDE row format name and properties pair. */ override def visitRowFormatSerde( - ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { import ctx._ CatalogStorageFormat.empty.copy( serde = Option(string(name)), @@ -1295,7 +1295,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create a delimited row format properties object. */ override def visitRowFormatDelimited( - ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { + ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { // Collect the entries if any. def entry(key: String, value: Token): Seq[(String, String)] = { Option(value).toSeq.map(x => key -> string(x)) @@ -1329,9 +1329,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... */ private def validateRowFormatFileFormat( - rowFormatCtx: RowFormatContext, - createFileFormatCtx: CreateFileFormatContext, - parentCtx: ParserRuleContext): Unit = { + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { if (rowFormatCtx == null || createFileFormatCtx == null) { return } @@ -1433,12 +1433,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create a [[ScriptInputOutputSchema]]. */ override protected def withScriptIOSchema( - ctx: QuerySpecificationContext, - inRowFormat: RowFormatContext, - recordWriter: Token, - outRowFormat: RowFormatContext, - recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = { + ctx: QuerySpecificationContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { if (recordWriter != null || recordReader != null) { // TODO: what does this message mean? throw new ParseException( @@ -1448,9 +1448,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Decode and input/output format. type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) def format( - fmt: RowFormatContext, - configKey: String, - defaultConfigValue: String): Format = fmt match { + fmt: RowFormatContext, + configKey: String, + defaultConfigValue: String): Format = fmt match { case c: RowFormatDelimitedContext => // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema // expects a seq of pairs in which the old parsers' token names are used as keys. @@ -1509,9 +1509,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create a clause for DISTRIBUTE BY. */ override protected def withRepartitionByExpression( - ctx: QueryOrganizationContext, - expressions: Seq[Expression], - query: LogicalPlan): LogicalPlan = { + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { RepartitionByExpression(expressions, query, conf.numShufflePartitions) } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 56f3f8ee12ee6..6a978a2e88c7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -48,9 +48,9 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo } protected override def generateTable( - catalog: SessionCatalog, - name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable = { + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() @@ -125,9 +125,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def generateTable( - catalog: SessionCatalog, - name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -170,16 +170,16 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } private def createTable( - catalog: SessionCatalog, - name: TableIdentifier, - isDataSource: Boolean = true): Unit = { - catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( - catalog: SessionCatalog, - spec: TablePartitionSpec, - tableName: TableIdentifier): Unit = { + catalog: SessionCatalog, + spec: TablePartitionSpec, + tableName: TableIdentifier): Unit = { val part = CatalogTablePartition( spec, CatalogStorageFormat(None, None, None, None, false, Map())) catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) @@ -317,11 +317,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } private def checkSchemaInCreatedDataSourceTable( - path: File, - userSpecifiedSchema: Option[String], - userSpecifiedPartitionCols: Option[String], - expectedSchema: StructType, - expectedPartitionCols: Seq[String]): Unit = { + path: File, + userSpecifiedSchema: Option[String], + userSpecifiedPartitionCols: Option[String], + expectedSchema: StructType, + expectedPartitionCols: Seq[String]): Unit = { val tabName = "tab1" withTable(tabName) { val partitionClause = @@ -968,13 +968,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val catalog = spark.sessionState.catalog sql( """ - |CREATE TEMPORARY VIEW tab1 - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) + |CREATE TEMPORARY VIEW tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) """.stripMargin) assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) sql("DROP VIEW tab1") @@ -1498,7 +1498,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + "Returns the concatenation of str1, str2, ..., strN.") :: Nil ) // extended mode checkAnswer( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 703a74278e155..bb636d623248d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -237,6 +237,7 @@ private[hive] trait HiveStrategies { !predicate.references.isEmpty && predicate.references.subsetOf(partitionKeyIds) } + pruneFilterProject( projectList, otherPredicates, From 22d8b1acbba87149add028c7f9f053f40ee55bb0 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:40:20 +0800 Subject: [PATCH 14/26] reformat code. --- .../sql/catalyst/analysis/Analyzer.scala | 112 ++++++++-------- .../sql/catalyst/catalog/SessionCatalog.scala | 120 +++++++++--------- .../sql/execution/command/DDLSuite.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- 4 files changed, 119 insertions(+), 119 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cf6f3939a24d1..16cce2693a1da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -127,33 +127,33 @@ class Analyzer( new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: - ResolveRelations :: - ResolveReferences :: - ResolveCreateNamedStruct :: - ResolveDeserializer :: - ResolveNewInstance :: - ResolveUpCast :: - ResolveGroupingAnalytics :: - ResolvePivot :: - ResolveOrdinalInOrderByAndGroupBy :: - ResolveAggAliasInGroupBy :: - ResolveMissingReferences :: - ExtractGenerator :: - ResolveGenerate :: - ResolveFunctions :: - ResolveAliases :: - ResolveSubquery :: - ResolveWindowOrder :: - ResolveWindowFrame :: - ResolveNaturalAndUsingJoin :: - ExtractWindowExpressions :: - GlobalAggregates :: - ResolveAggregateFunctions :: - TimeWindowing :: - ResolveInlineTables(conf) :: - ResolveTimeZone(conf) :: - TypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*), + ResolveRelations :: + ResolveReferences :: + ResolveCreateNamedStruct :: + ResolveDeserializer :: + ResolveNewInstance :: + ResolveUpCast :: + ResolveGroupingAnalytics :: + ResolvePivot :: + ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: + ResolveMissingReferences :: + ExtractGenerator :: + ResolveGenerate :: + ResolveFunctions :: + ResolveAliases :: + ResolveSubquery :: + ResolveWindowOrder :: + ResolveWindowFrame :: + ResolveNaturalAndUsingJoin :: + ExtractWindowExpressions :: + GlobalAggregates :: + ResolveAggregateFunctions :: + TimeWindowing :: + ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: + TypeCoercion.typeCoercionRules ++ + extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, AliasViewChild(conf)), @@ -514,7 +514,7 @@ class Analyzer( val pivotAggs = namedAggExps.map { a => Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) .toAggregateExpression() - , "__pivot_" + a.sql)() + , "__pivot_" + a.sql)() } val groupByExprsAttr = groupByExprs.map(_.toAttribute) val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) @@ -633,8 +633,8 @@ class Analyzer( // and the default database is only used to look up a view); // 3. Use the currentDb of the SessionCatalog. private def lookupTableFromCatalog( - u: UnresolvedRelation, - defaultDatabase: Option[String] = None): LogicalPlan = { + u: UnresolvedRelation, + defaultDatabase: Option[String] = None): LogicalPlan = { val tableIdentWithDb = u.tableIdentifier.copy( database = u.tableIdentifier.database.orElse(defaultDatabase)) try { @@ -874,8 +874,8 @@ class Analyzer( * Build a project list for Project/Aggregate and expand the star if possible */ private def buildExpandedProjectList( - exprs: Seq[NamedExpression], - child: LogicalPlan): Seq[NamedExpression] = { + exprs: Seq[NamedExpression], + child: LogicalPlan): Seq[NamedExpression] = { exprs.flatMap { // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") case s: Star => s.expand(child, resolver) @@ -1022,13 +1022,13 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(!_.resolved) => + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) case gs @ GroupingSets(selectedGroups, groups, child, aggs) - if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => gs.copy( selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)), groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child)) @@ -1555,7 +1555,7 @@ class Analyzer( case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) + aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1583,8 +1583,8 @@ class Analyzer( alias.toAttribute // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && - !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !aggregate.output.exists(_.semanticEquals(e)) => + !ResolveGroupingAnalytics.hasGroupingFunction(e) && + !aggregate.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne @@ -1796,8 +1796,8 @@ class Analyzer( * names is empty names are assigned from field names in generator. */ private[analysis] def makeGeneratorOutput( - generator: Generator, - names: Seq[String]): Seq[Attribute] = { + generator: Generator, + names: Seq[String]): Seq[Attribute] = { val elementAttrs = generator.elementSchema.toAttributes if (names.length == elementAttrs.length) { @@ -1809,8 +1809,8 @@ class Analyzer( } else { failAnalysis( "The number of aliases supplied in the AS clause does not match the number of columns " + - s"output by the UDTF expected ${elementAttrs.size} aliases but got " + - s"${names.mkString(",")} ") + s"output by the UDTF expected ${elementAttrs.size} aliases but got " + + s"${names.mkString(",")} ") } } } @@ -1951,8 +1951,8 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - ae @ AggregateExpression(function, _, _, _), - spec: WindowSpecDefinition) => + ae @ AggregateExpression(function, _, _, _), + spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] val newAgg = ae.copy(aggregateFunction = newFunction) @@ -1979,8 +1979,8 @@ class Analyzer( * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ private def addWindow( - expressionsWithWindowFunctions: Seq[NamedExpression], - child: LogicalPlan): LogicalPlan = { + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions // and put those extracted WindowExpressions to extractedWindowExprBuffer. // This step is needed because it is possible that an expression contains multiple @@ -2051,8 +2051,8 @@ class Analyzer( // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && - hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + hasWindowFunction(aggregateExprs) && + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) @@ -2069,7 +2069,7 @@ class Analyzer( // Aggregate without Having clause. case a @ Aggregate(groupingExprs, aggregateExprs, child) if hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) @@ -2214,7 +2214,7 @@ class Analyzer( object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) - if left.resolved && right.resolved && j.duplicateResolved => + if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => // find common column names from both sides @@ -2544,8 +2544,8 @@ object TimeWindowing extends Rule[LogicalPlan] { // Only support a single window expression for now if (windowExpressions.size == 1 && - windowExpressions.head.timeColumn.resolved && - windowExpressions.head.checkInputDataTypes().isSuccess) { + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { val window = windowExpressions.head val metadata = window.timeColumn match { @@ -2560,19 +2560,19 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / window.slideDuration) val windowStart = (windowId + i - maxNumOverlapping) * - window.slideDuration + window.startTime + window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( Literal(WINDOW_START) :: windowStart :: - Literal(WINDOW_END) :: windowEnd :: Nil) + Literal(WINDOW_END) :: windowEnd :: Nil) } val projections = windows.map(_ +: p.children.head.output) val filterExpr = window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + window.timeColumn < windowAttr.getField(WINDOW_END) val expandedPlan = Filter(filterExpr, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index cdcc1c112d8a4..f2b9c8ab5f2cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -52,21 +52,21 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - val externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, - functionRegistry: FunctionRegistry, - conf: SQLConf, - hadoopConf: Configuration, - parser: ParserInterface, - functionResourceLoader: FunctionResourceLoader) extends Logging { + val externalCatalog: ExternalCatalog, + globalTempViewManager: GlobalTempViewManager, + functionRegistry: FunctionRegistry, + conf: SQLConf, + hadoopConf: Configuration, + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends Logging { import SessionCatalog._ import CatalogTypes.TablePartitionSpec // For testing only. def this( - externalCatalog: ExternalCatalog, - functionRegistry: FunctionRegistry, - conf: SQLConf) { + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: SQLConf) { this( externalCatalog, new GlobalTempViewManager("global_temp"), @@ -323,8 +323,8 @@ class SessionCatalog( * bucket columns, and partition columns need to be at the end) */ def alterTableSchema( - identifier: TableIdentifier, - newSchema: StructType): Unit = { + identifier: TableIdentifier, + newSchema: StructType): Unit = { val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -394,10 +394,10 @@ class SessionCatalog( * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. */ def loadTable( - name: TableIdentifier, - loadPath: String, - isOverwrite: Boolean, - isSrcLocal: Boolean): Unit = { + name: TableIdentifier, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) @@ -411,12 +411,12 @@ class SessionCatalog( * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. */ def loadPartition( - name: TableIdentifier, - loadPath: String, - spec: TablePartitionSpec, - isOverwrite: Boolean, - inheritTableSpecs: Boolean, - isSrcLocal: Boolean): Unit = { + name: TableIdentifier, + loadPath: String, + spec: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) @@ -441,9 +441,9 @@ class SessionCatalog( * Create a local temporary view. */ def createTempView( - name: String, - tableDefinition: LogicalPlan, - overrideIfExists: Boolean): Unit = synchronized { + name: String, + tableDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) if (tempTables.contains(table) && !overrideIfExists) { throw new TempTableAlreadyExistsException(name) @@ -455,9 +455,9 @@ class SessionCatalog( * Create a global temporary view. */ def createGlobalTempView( - name: String, - viewDefinition: LogicalPlan, - overrideIfExists: Boolean): Unit = { + name: String, + viewDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = { globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists) } @@ -466,8 +466,8 @@ class SessionCatalog( * temp view is matched and altered, false otherwise. */ def alterTempViewDefinition( - name: TableIdentifier, - viewDefinition: LogicalPlan): Boolean = synchronized { + name: TableIdentifier, + viewDefinition: LogicalPlan): Boolean = synchronized { val viewName = formatTableName(name.table) if (name.database.isEmpty) { if (tempTables.contains(viewName)) { @@ -605,9 +605,9 @@ class SessionCatalog( * the same name, then, if that does not exist, drop the table from the current database. */ def dropTable( - name: TableIdentifier, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = synchronized { + name: TableIdentifier, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) if (db == globalTempViewManager.database) { @@ -777,9 +777,9 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def createPartitions( - tableName: TableIdentifier, - parts: Seq[CatalogTablePartition], - ignoreIfExists: Boolean): Unit = { + tableName: TableIdentifier, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) @@ -794,11 +794,11 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def dropPartitions( - tableName: TableIdentifier, - specs: Seq[TablePartitionSpec], - ignoreIfNotExists: Boolean, - purge: Boolean, - retainData: Boolean): Unit = { + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) @@ -815,9 +815,9 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def renamePartitions( - tableName: TableIdentifier, - specs: Seq[TablePartitionSpec], - newSpecs: Seq[TablePartitionSpec]): Unit = { + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { val tableMetadata = getTableMetadata(tableName) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) @@ -871,8 +871,8 @@ class SessionCatalog( * then a partial spec of (a='1') will return the first two only. */ def listPartitionNames( - tableName: TableIdentifier, - partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) @@ -892,8 +892,8 @@ class SessionCatalog( * then a partial spec of (a='1') will return the first two only. */ def listPartitions( - tableName: TableIdentifier, - partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) @@ -910,8 +910,8 @@ class SessionCatalog( * satisfy the given partition-pruning predicate expressions. */ def listPartitionsByFilter( - tableName: TableIdentifier, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + tableName: TableIdentifier, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) @@ -937,8 +937,8 @@ class SessionCatalog( * The columns must be the same but the orders could be different. */ private def requireExactMatchedPartitionSpec( - specs: Seq[TablePartitionSpec], - table: CatalogTable): Unit = { + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { val defined = table.partitionColumnNames.sorted specs.foreach { s => if (s.keys.toSeq.sorted != defined) { @@ -955,8 +955,8 @@ class SessionCatalog( * That is, the columns of partition spec should be part of the defined partition spec. */ private def requirePartialMatchedPartitionSpec( - specs: Seq[TablePartitionSpec], - table: CatalogTable): Unit = { + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { val defined = table.partitionColumnNames specs.foreach { s => if (!s.keys.forall(defined.contains)) { @@ -1068,9 +1068,9 @@ class SessionCatalog( * Registers a temporary or permanent function into a session-specific [[FunctionRegistry]] */ def registerFunction( - funcDefinition: CatalogFunction, - ignoreIfExists: Boolean, - functionBuilder: Option[FunctionBuilder] = None): Unit = { + funcDefinition: CatalogFunction, + ignoreIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { throw new AnalysisException(s"Function $func already exists") @@ -1165,8 +1165,8 @@ class SessionCatalog( * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ def lookupFunction( - name: FunctionIdentifier, - children: Seq[Expression]): Expression = synchronized { + name: FunctionIdentifier, + children: Seq[Expression]): Expression = synchronized { // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). @@ -1296,4 +1296,4 @@ class SessionCatalog( // copy over temporary tables tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6a978a2e88c7b..8d7b7836f7069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -173,7 +173,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { catalog: SessionCatalog, name: TableIdentifier, isDataSource: Boolean = true): Unit = { - catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -2346,4 +2346,4 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } -} \ No newline at end of file +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index bb636d623248d..662fc80661513 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -237,7 +237,7 @@ private[hive] trait HiveStrategies { !predicate.references.isEmpty && predicate.references.subsetOf(partitionKeyIds) } - + pruneFilterProject( projectList, otherPredicates, From d91f6335cb73b4ef0d46d3fb0eec77abfd9b452c Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:43:37 +0800 Subject: [PATCH 15/26] reformat code. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 16cce2693a1da..ee688b8b654bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -874,8 +874,8 @@ class Analyzer( * Build a project list for Project/Aggregate and expand the star if possible */ private def buildExpandedProjectList( - exprs: Seq[NamedExpression], - child: LogicalPlan): Seq[NamedExpression] = { + exprs: Seq[NamedExpression], + child: LogicalPlan): Seq[NamedExpression] = { exprs.flatMap { // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") case s: Star => s.expand(child, resolver) @@ -1022,13 +1022,13 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(!_.resolved) => + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) case gs @ GroupingSets(selectedGroups, groups, child, aggs) - if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => gs.copy( selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)), groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child)) @@ -1556,7 +1556,7 @@ class Analyzer( apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved => + if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { @@ -2667,7 +2667,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { val outerAliases = a.aggregateExpressions collect { case a: Alias => a } // Update the subquery plan to record the OuterReference to point to outer query plan. s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) - } + } } } } \ No newline at end of file From 1eb23c75b0be7b93980b44f0a9fbaab6a489996e Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:45:28 +0800 Subject: [PATCH 16/26] reformat code. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ee688b8b654bd..85cf8ddbaacf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2670,4 +2670,4 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } } } -} \ No newline at end of file +} From ad851098de14a105846f9f060d0e3a3b26df266a Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:51:41 +0800 Subject: [PATCH 17/26] remove type check for macro as same with hive. --- .../apache/spark/sql/execution/command/macros.scala | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 0a67aa94d1244..26c8cf3858586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -63,19 +63,6 @@ case class CreateMacroCommand( throw new AnalysisException(s"Cannot support Generator: ${u} " + s"for CREATE TEMPORARY MACRO $macroName") } - macroFunction.transformUp { - case e: Expression if !e.resolved => - if (e.checkInputDataTypes().isFailure) { - e.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckFailure(message) => - throw new AnalysisException(s"Cannot resolve '${e.sql}' " + - s"for CREATE TEMPORARY MACRO $macroName, due to data type mismatch: $message") - } - } else { - throw new AnalysisException(s"Cannot resolve '${e.sql}' " + - s"for CREATE TEMPORARY MACRO $macroName ") - } - } val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString val info = new ExpressionInfo(macroInfo, macroName) From b52698ffdb5b5c809f749b7793f276bb9b305a0c Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 00:54:50 +0800 Subject: [PATCH 18/26] add import --- .../scala/org/apache/spark/sql/execution/SparkSqlParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d26bd785970ef..b3dc50bc8ad9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** * Concrete parser for Spark SQL statements. From 3eacebc38ac5e50981bbad6415365c8abca6111f Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 01:33:09 +0800 Subject: [PATCH 19/26] treat macro as temp function like hive --- .../analysis/AlreadyExistException.scala | 3 -- .../catalyst/analysis/FunctionRegistry.scala | 35 ------------------- .../sql/catalyst/catalog/SessionCatalog.scala | 8 ++--- 3 files changed, 4 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index eecdcf6ffa781..57f7a80bedc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -44,6 +44,3 @@ class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[Tabl class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") - -class TempMacroAlreadyExistsException(func: String) - extends AnalysisException(s"Temp macro '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b0ea86f709cfb..a55558c56c377 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Modifier - -import scala.collection.mutable.HashSet import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -64,10 +61,6 @@ trait FunctionRegistry { /** Checks if a function with a given name exists. */ def functionExists(name: String): Boolean = lookupFunction(name).isDefined - def registerMacro(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit - - def dropMacro(name: String): Boolean - /** Clear all registered functions. */ def clear(): Unit @@ -80,8 +73,6 @@ class SimpleFunctionRegistry extends FunctionRegistry { protected val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) - val macros = new HashSet[String] - override def registerFunction( name: String, info: ExpressionInfo, @@ -114,26 +105,8 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } - override def registerMacro( - name: String, - info: ExpressionInfo, - builder: FunctionBuilder): Unit = synchronized { - functionBuilders.put(name, (info, builder)) - macros += name.toLowerCase() - } - - override def dropMacro(name: String): Boolean = synchronized { - if (macros.contains(name.toLowerCase)) { - macros -= name.toLowerCase - functionBuilders.remove(name).isDefined - } else { - false - } - } - override def clear(): Unit = synchronized { functionBuilders.clear() - macros.clear() } override def clone(): SimpleFunctionRegistry = synchronized { @@ -175,14 +148,6 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } - override def registerMacro(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit = { - throw new UnsupportedOperationException - } - - override def dropMacro(name: String): Boolean = { - throw new UnsupportedOperationException - } - override def clear(): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f2b9c8ab5f2cb..f2983724810d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1094,16 +1094,16 @@ class SessionCatalog( def createTempMacro( name: String, info: ExpressionInfo, - funcDefinition: FunctionBuilder): Unit = { + functionBuilder: FunctionBuilder): Unit = { if (functionRegistry.functionExists(name)) { - throw new TempMacroAlreadyExistsException(name) + throw new AnalysisException(s"Function $name already exists") } - functionRegistry.registerMacro(name, info, funcDefinition) + functionRegistry.registerFunction(name, info, functionBuilder) } /** Drop a temporary macro. */ def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { + if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { throw new NoSuchTempMacroException(name) } } From fce112147449278f0078321306f1a1fe34ab938b Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 01:36:01 +0800 Subject: [PATCH 20/26] add Modifier for FunctionRegistry. --- .../apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a55558c56c377..a4c7f7a8de223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} From eaff4e966e07dd0df36ff85952462d68cb9474f9 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 01:44:30 +0800 Subject: [PATCH 21/26] update comments. --- .../scala/org/apache/spark/sql/execution/command/macros.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 26c8cf3858586..3c5c22b62d2cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -72,7 +72,7 @@ case class CreateMacroCommand( s"expected number of columns: ${columns.size} for Macro $macroName") } macroFunction.transform { - // Skip to validate the input type because check it before. + // Skip to validate the input type because check it at runtime. case b: BoundReference => children(b.ordinal) } } From 97632a9a3dab1322929c9011005fe1422e1cd748 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 02:16:28 +0800 Subject: [PATCH 22/26] add dropMacro(). --- .../catalyst/expressions/ExpressionInfo.java | 20 ++++++++++++++++--- .../catalyst/analysis/FunctionRegistry.scala | 15 ++++++++++++++ .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../spark/sql/execution/command/macros.scala | 2 +- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index 4565ed44877a5..a0c7795651200 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -26,6 +26,7 @@ public class ExpressionInfo { private String name; private String extended; private String db; + private boolean macro; public String getClassName() { return className; @@ -47,19 +48,32 @@ public String getDb() { return db; } - public ExpressionInfo(String className, String db, String name, String usage, String extended) { + public boolean isMacro() { + return macro; + } + + public ExpressionInfo(String className, String db, String name, String usage, String extended, boolean macro) { this.className = className; this.db = db; this.name = name; this.usage = usage; this.extended = extended; + this.macro = macro; + } + + public ExpressionInfo(String className, String db, String name, String usage, String extended) { + this(className, db, name, usage, extended, false); } public ExpressionInfo(String className, String name) { - this(className, null, name, null, null); + this(className, null, name, null, null, false); + } + + public ExpressionInfo(String className, String name, boolean macro) { + this(className, null, name, null, null, macro); } public ExpressionInfo(String className, String db, String name) { - this(className, db, name, null, null); + this(className, db, name, null, null, false); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a4c7f7a8de223..201b064e2ce8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -60,6 +60,9 @@ trait FunctionRegistry { /** Drop a function and return whether the function existed. */ def dropFunction(name: String): Boolean + /** Drop a macro and return whether the macro existed. */ + def dropMacro(name: String): Boolean + /** Checks if a function with a given name exists. */ def functionExists(name: String): Boolean = lookupFunction(name).isDefined @@ -107,6 +110,14 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } + override def dropMacro(name: String): Boolean = synchronized { + if (functionBuilders.get(name).map(_._1).filter(_.isMacro).isDefined) { + functionBuilders.remove(name).isDefined + } else { + false + } + } + override def clear(): Unit = synchronized { functionBuilders.clear() } @@ -146,6 +157,10 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } + override def dropMacro(name: String): Boolean = { + throw new UnsupportedOperationException + } + override def dropFunction(name: String): Boolean = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f2983724810d9..1935cf5e2ae6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1103,7 +1103,7 @@ class SessionCatalog( /** Drop a temporary macro. */ def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { + if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { throw new NoSuchTempMacroException(name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 3c5c22b62d2cb..3a25a67c58798 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -65,7 +65,7 @@ case class CreateMacroCommand( } val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString - val info = new ExpressionInfo(macroInfo, macroName) + val info = new ExpressionInfo(macroInfo, macroName, true) val builder = (children: Seq[Expression]) => { if (children.size != columns.size) { throw new AnalysisException(s"Actual number of columns: ${children.size} != " + From 4ee32e95e61015fe608884d3096ac82a781dd767 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 28 May 2017 02:20:25 +0800 Subject: [PATCH 23/26] reformat code style --- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 1935cf5e2ae6b..c938b6acca586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1092,9 +1092,9 @@ class SessionCatalog( /** Create a temporary macro. */ def createTempMacro( - name: String, - info: ExpressionInfo, - functionBuilder: FunctionBuilder): Unit = { + name: String, + info: ExpressionInfo, + functionBuilder: FunctionBuilder): Unit = { if (functionRegistry.functionExists(name)) { throw new AnalysisException(s"Function $name already exists") } From b539e94eae58847c9da13a3cb94932b17ea2fc6e Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 30 May 2017 16:38:17 +0800 Subject: [PATCH 24/26] address some comments. --- .../catalyst/expressions/ExpressionInfo.java | 28 +++++--- .../catalyst/analysis/FunctionRegistry.scala | 71 +++++++++++++++---- .../sql/catalyst/catalog/SessionCatalog.scala | 15 ++-- .../expressions/complexTypeCreator.scala | 3 +- .../spark/sql/execution/command/macros.scala | 2 +- .../internal/BaseSessionStateBuilder.scala | 2 +- 6 files changed, 87 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index a0c7795651200..f24b741cd96ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -21,12 +21,16 @@ * Expression information, will be used to describe a expression. */ public class ExpressionInfo { + + public enum FunctionType { + BUILTIN, PERSISTENT, TEMPORARY; + } private String className; private String usage; private String name; private String extended; private String db; - private boolean macro; + private FunctionType functionType; public String getClassName() { return className; @@ -48,32 +52,36 @@ public String getDb() { return db; } - public boolean isMacro() { - return macro; + public FunctionType getFunctionType() { + return functionType; } - public ExpressionInfo(String className, String db, String name, String usage, String extended, boolean macro) { + public ExpressionInfo(String className, String db, String name, String usage, String extended, FunctionType functionType) { this.className = className; this.db = db; this.name = name; this.usage = usage; this.extended = extended; - this.macro = macro; + this.functionType = functionType; } public ExpressionInfo(String className, String db, String name, String usage, String extended) { - this(className, db, name, usage, extended, false); + this(className, db, name, usage, extended, FunctionType.TEMPORARY); } public ExpressionInfo(String className, String name) { - this(className, null, name, null, null, false); + this(className, null, name, null, null, FunctionType.TEMPORARY); } - public ExpressionInfo(String className, String name, boolean macro) { - this(className, null, name, null, null, macro); + public ExpressionInfo(String className, String name, FunctionType functionType) { + this(className, null, name, null, null, functionType); } public ExpressionInfo(String className, String db, String name) { - this(className, db, name, null, null, false); + this(className, db, name, null, null, FunctionType.TEMPORARY); + } + + public ExpressionInfo(String className, String db, String name, FunctionType functionType) { + this(className, db, name, null, null, functionType); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 201b064e2ce8d..1fe2efd7fff99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -26,6 +26,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo.FunctionType import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -60,9 +61,6 @@ trait FunctionRegistry { /** Drop a function and return whether the function existed. */ def dropFunction(name: String): Boolean - /** Drop a macro and return whether the macro existed. */ - def dropMacro(name: String): Boolean - /** Checks if a function with a given name exists. */ def functionExists(name: String): Boolean = lookupFunction(name).isDefined @@ -110,12 +108,55 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } - override def dropMacro(name: String): Boolean = synchronized { - if (functionBuilders.get(name).map(_._1).filter(_.isMacro).isDefined) { - functionBuilders.remove(name).isDefined + override def clear(): Unit = synchronized { + functionBuilders.clear() + } + + override def clone(): SimpleFunctionRegistry = synchronized { + val registry = new SimpleFunctionRegistry + functionBuilders.iterator.foreach { case (name, (info, builder)) => + registry.registerFunction(name, info, builder) + } + registry + } +} + +class SystemFunctionRegistry(builtin: SimpleFunctionRegistry) extends SimpleFunctionRegistry { + + override def registerFunction( + name: String, + info: ExpressionInfo, + builder: FunctionBuilder): Unit = synchronized { + if (info.getFunctionType.equals(FunctionType.BUILTIN)) { + builtin.registerFunction(name, info, builder) } else { - false + functionBuilders.put(name, (info, builder)) + } + } + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + val func = synchronized { + functionBuilders.get(name).map(_._2).orElse(builtin.lookupFunctionBuilder(name)).getOrElse { + throw new AnalysisException(s"undefined function $name") + } } + func(children) + } + + override def listFunction(): Seq[String] = synchronized { + (functionBuilders.iterator.map(_._1).toList ++ builtin.listFunction()).sorted + } + + override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized { + functionBuilders.get(name).map(_._1).orElse(builtin.lookupFunction(name)) + } + + override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized { + functionBuilders.get(name).map(_._2).orElse(builtin.lookupFunctionBuilder(name)) + } + + override def dropFunction(name: String): Boolean = synchronized { + functionBuilders.remove(name).isDefined } override def clear(): Unit = synchronized { @@ -123,7 +164,7 @@ class SimpleFunctionRegistry extends FunctionRegistry { } override def clone(): SimpleFunctionRegistry = synchronized { - val registry = new SimpleFunctionRegistry + val registry = new SystemFunctionRegistry(builtin.clone()) functionBuilders.iterator.foreach { case (name, (info, builder)) => registry.registerFunction(name, info, builder) } @@ -157,10 +198,6 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } - override def dropMacro(name: String): Boolean = { - throw new UnsupportedOperationException - } - override def dropFunction(name: String): Boolean = { throw new UnsupportedOperationException } @@ -471,6 +508,8 @@ object FunctionRegistry { fr } + val systemRegistry = new SystemFunctionRegistry(builtin) + val functionSet: Set[String] = builtin.listFunction().toSet /** See usage above. */ @@ -534,7 +573,8 @@ object FunctionRegistry { } val clazz = scala.reflect.classTag[Cast].runtimeClass val usage = "_FUNC_(expr) - Casts the value `expr` to the target data type `_FUNC_`." - (name, (new ExpressionInfo(clazz.getCanonicalName, null, name, usage, null), builder)) + (name, (new ExpressionInfo(clazz.getCanonicalName, null, name, usage, null, + FunctionType.BUILTIN), builder)) } /** @@ -544,9 +584,10 @@ object FunctionRegistry { val clazz = scala.reflect.classTag[T].runtimeClass val df = clazz.getAnnotation(classOf[ExpressionDescription]) if (df != null) { - new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended()) + new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended(), + ExpressionInfo.FunctionType.BUILTIN) } else { - new ExpressionInfo(clazz.getCanonicalName, name) + new ExpressionInfo(clazz.getCanonicalName, name, ExpressionInfo.FunctionType.BUILTIN) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c938b6acca586..21c0fe1c672d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo.FunctionType import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils @@ -1095,15 +1096,12 @@ class SessionCatalog( name: String, info: ExpressionInfo, functionBuilder: FunctionBuilder): Unit = { - if (functionRegistry.functionExists(name)) { - throw new AnalysisException(s"Function $name already exists") - } functionRegistry.registerFunction(name, info, functionBuilder) } /** Drop a temporary macro. */ def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropMacro(name) && !ignoreIfNotExists) { + if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { throw new NoSuchTempMacroException(name) } } @@ -1144,7 +1142,8 @@ class SessionCatalog( new ExpressionInfo( metadata.className, qualifiedName.database.orNull, - qualifiedName.identifier) + qualifiedName.identifier, + FunctionType.PERSISTENT) } else { failFunctionLookup(name.funcName) } @@ -1266,7 +1265,11 @@ class SessionCatalog( if (func.database.isDefined) { dropFunction(func, ignoreIfNotExists = false) } else { - dropTempFunction(func.funcName, ignoreIfNotExists = false) + val functionType = functionRegistry.lookupFunction(func.funcName).map(_.getFunctionType) + .getOrElse(FunctionType.TEMPORARY) + if (!functionType.equals(FunctionType.BUILTIN)) { + dropTempFunction(func.funcName, ignoreIfNotExists = false) + } } } clearTempTables() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b6675a84ece48..d2a35c5131cfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -260,7 +260,8 @@ object CreateStruct extends FunctionBuilder { null, "struct", "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", - "") + "", + ExpressionInfo.FunctionType.BUILTIN) ("struct", (info, this)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 3a25a67c58798..3c5c22b62d2cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -65,7 +65,7 @@ case class CreateMacroCommand( } val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString - val info = new ExpressionInfo(macroInfo, macroName, true) + val info = new ExpressionInfo(macroInfo, macroName) val builder = (children: Seq[Expression]) => { if (children.size != columns.size) { throw new AnalysisException(s"Actual number of columns: ${children.size} != " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2a801d87b12eb..4f43badbc6cd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -95,7 +95,7 @@ abstract class BaseSessionStateBuilder( * This either gets cloned from a pre-existing version or cloned from the built-in registry. */ protected lazy val functionRegistry: FunctionRegistry = { - parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone() + parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.systemRegistry).clone() } /** From 1563f12d78a9c32bf4bed69cb9f86a7d00eb18ef Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 30 May 2017 22:42:31 +0800 Subject: [PATCH 25/26] address some comments. --- .../catalyst/expressions/ExpressionInfo.java | 5 +- .../apache/spark/sql/AnalysisException.scala | 9 + .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../analysis/NoSuchItemException.scala | 3 - .../sql/catalyst/catalog/SessionCatalog.scala | 6 +- .../spark/sql/execution/SparkSqlParser.scala | 5 +- .../spark/sql/execution/command/macros.scala | 41 ++- .../test/resources/sql-tests/inputs/macro.sql | 56 +++ .../resources/sql-tests/results/macro.sql.out | 344 ++++++++++++++++++ .../sql/execution/command/DDLSuite.scala | 29 -- 10 files changed, 447 insertions(+), 54 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/macro.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/macro.sql.out diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index f24b741cd96ed..681ab8669f842 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -25,6 +25,7 @@ public class ExpressionInfo { public enum FunctionType { BUILTIN, PERSISTENT, TEMPORARY; } + private String className; private String usage; private String name; @@ -65,10 +66,6 @@ public ExpressionInfo(String className, String db, String name, String usage, St this.functionType = functionType; } - public ExpressionInfo(String className, String db, String name, String usage, String extended) { - this(className, db, name, usage, extended, FunctionType.TEMPORARY); - } - public ExpressionInfo(String className, String name) { this(className, null, name, null, null, FunctionType.TEMPORARY); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 50ee6cd4085ea..4b75924d87b68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -55,3 +55,12 @@ class AnalysisException protected[sql] ( s"$message;$lineAnnotation$positionAnnotation" } } + +object AnalysisException { + /** + * Create a no such temporary macro exception. + */ + def noSuchTempMacroException(func: String): AnalysisException = { + new AnalysisException(s"Temporary macro '$func' not found") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1fe2efd7fff99..0d87bb320e616 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -144,7 +144,7 @@ class SystemFunctionRegistry(builtin: SimpleFunctionRegistry) extends SimpleFunc } override def listFunction(): Seq[String] = synchronized { - (functionBuilders.iterator.map(_._1).toList ++ builtin.listFunction()).sorted + (functionBuilders.iterator.map(_._1).toList ++ builtin.listFunction()).distinct.sorted } override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized { @@ -160,6 +160,7 @@ class SystemFunctionRegistry(builtin: SimpleFunctionRegistry) extends SimpleFunc } override def clear(): Unit = synchronized { + builtin.clear() functionBuilders.clear() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index b4e84b7c25ab0..f5aae60431c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -52,6 +52,3 @@ class NoSuchPartitionsException(db: String, table: String, specs: Seq[TableParti class NoSuchTempFunctionException(func: String) extends AnalysisException(s"Temporary function '$func' not found") - -class NoSuchTempMacroException(func: String) - extends AnalysisException(s"Temporary macro '$func' not found") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 21c0fe1c672d2..ba27645572522 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1102,7 +1102,7 @@ class SessionCatalog( /** Drop a temporary macro. */ def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = { if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { - throw new NoSuchTempMacroException(name) + throw AnalysisException.noSuchTempMacroException(name) } } @@ -1265,8 +1265,8 @@ class SessionCatalog( if (func.database.isDefined) { dropFunction(func, ignoreIfNotExists = false) } else { - val functionType = functionRegistry.lookupFunction(func.funcName).map(_.getFunctionType) - .getOrElse(FunctionType.TEMPORARY) + val functionType = functionRegistry.lookupFunction(func.funcName) + .map(_.getFunctionType).getOrElse(FunctionType.TEMPORARY) if (!functionType.equals(FunctionType.BUILTIN)) { dropTempFunction(func.funcName, ignoreIfNotExists = false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b3dc50bc8ad9a..05e722bd88a7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -724,12 +724,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) { - val arguments = Option(ctx.colTypeList).map(visitColTypeList(_)) - .getOrElse(Seq.empty[StructField]) + val columns = createSchema(ctx.colTypeList) val e = expression(ctx.expression) CreateMacroCommand( ctx.macroName.getText, - MacroFunctionWrapper(arguments, e)) + MacroFunctionWrapper(columns, e)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 3c5c22b62d2cb..9c62a7783343f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -20,12 +20,14 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types.StructType /** * This class provides arguments and body expression of the macro function. */ -case class MacroFunctionWrapper(columns: Seq[StructField], macroFunction: Expression) +case class MacroFunctionWrapper(columns: StructType, macroFunction: Expression) + /** * The DDL command that creates a macro. @@ -41,16 +43,33 @@ case class CreateMacroCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val columns = funcWrapper.columns.map { col => - AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() } - val colToIndex: Map[String, Int] = columns.map(_.name).zipWithIndex.toMap + val columns = funcWrapper.columns + val columnAttrs = columns.toAttributes + def formatName: (String => String) = + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + (name: String) => name + } else { + (name: String) => name.toLowerCase + } + val colToIndex: Map[String, Int] = columnAttrs.map(_.name).map(formatName).zipWithIndex.toMap if (colToIndex.size != columns.size) { throw new AnalysisException(s"Cannot support duplicate colNames " + s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}") } + + try { + val plan = Project(Seq(Alias(funcWrapper.macroFunction, "m")()), LocalRelation(columnAttrs)) + val analyzed = sparkSession.sessionState.analyzer.execute(plan) + sparkSession.sessionState.analyzer.checkAnalysis(analyzed) + } catch { + case a: AnalysisException => + throw new AnalysisException(s"CREATE TEMPORARY MACRO $macroName " + + s"with exception: ${a.getMessage}") + } + val macroFunction = funcWrapper.macroFunction.transform { case u: UnresolvedAttribute => - val index = colToIndex.get(u.name).getOrElse( + val index = colToIndex.get(formatName(u.name)).getOrElse( throw new AnalysisException(s"Cannot find colName: ${u} " + s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}")) BoundReference(index, columns(index).dataType, columns(index).nullable) @@ -64,15 +83,15 @@ case class CreateMacroCommand( s"for CREATE TEMPORARY MACRO $macroName") } - val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString - val info = new ExpressionInfo(macroInfo, macroName) + val columnLength: Int = columns.length + val info = new ExpressionInfo(macroName, macroName) val builder = (children: Seq[Expression]) => { - if (children.size != columns.size) { + if (children.size != columnLength) { throw new AnalysisException(s"Actual number of columns: ${children.size} != " + - s"expected number of columns: ${columns.size} for Macro $macroName") + s"expected number of columns: ${columnLength} for Macro $macroName") } macroFunction.transform { - // Skip to validate the input type because check it at runtime. + // Skip to validate the input type because check it before. case b: BoundReference => children(b.ordinal) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/macro.sql b/sql/core/src/test/resources/sql-tests/inputs/macro.sql new file mode 100644 index 0000000000000..e6b08521ce525 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/macro.sql @@ -0,0 +1,56 @@ +CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x)); +SELECT SIGMOID(2); +DROP TEMPORARY MACRO SIGMOID; + +CREATE TEMPORARY MACRO FIXED_NUMBER() 1; +SELECT FIXED_NUMBER() + 1; +DROP TEMPORARY MACRO FIXED_NUMBER; + +CREATE TEMPORARY MACRO SIMPLE_ADD (x INT, y INT) x + y; +SELECT SIMPLE_ADD(1, 9); +DROP TEMPORARY MACRO SIMPLE_ADD; + +CREATE TEMPORARY MACRO flr(d bigint) FLOOR(d/10)*10; +SELECT flr(12); +DROP TEMPORARY MACRO flr; + +CREATE TEMPORARY MACRO STRING_LEN(x string) length(x); +CREATE TEMPORARY MACRO STRING_LEN_PLUS_ONE(x string) length(x)+1; +CREATE TEMPORARY MACRO STRING_LEN_PLUS_TWO(x string) length(x)+2; +create table macro_test (x string) using parquet;; +insert into table macro_test values ("bb"), ("a"), ("ccc"); +SELECT CONCAT(STRING_LEN(x), ":", STRING_LEN_PLUS_ONE(x), ":", STRING_LEN_PLUS_TWO(x)) a +FROM macro_test; +SELECT CONCAT(STRING_LEN(x), ":", STRING_LEN_PLUS_ONE(x), ":", STRING_LEN_PLUS_TWO(x)) a +FROM +macro_test +sort by a; +drop table macro_test; + +CREATE TABLE macro_testing(a int, b int, c int) using parquet;; +insert into table macro_testing values (1,2,3); +insert into table macro_testing values (4,5,6); +CREATE TEMPORARY MACRO math_square(x int) x*x; +CREATE TEMPORARY MACRO math_add(x int) x+x; +select math_square(a), math_square(b),factorial(a), factorial(b), math_add(a), math_add(b),int(c) +from macro_testing order by int(c); +drop table macro_testing; + +CREATE TEMPORARY MACRO max(x int, y int) x + y; +SELECT max(1, 2); +DROP TEMPORARY MACRO max; +SELECT max(2); + +CREATE TEMPORARY MACRO c() 3E9; +SELECT floor(c()/10); +DROP TEMPORARY MACRO c; + +CREATE TEMPORARY MACRO fixed_number() 42; +DROP TEMPORARY FUNCTION fixed_number; +DROP TEMPORARY MACRO IF EXISTS fixed_number; + +-- invalid queries +CREATE TEMPORARY MACRO simple_add_error(x int) x + y; +CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y; +CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2); +DROP TEMPORARY MACRO SOME_MACRO; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/macro.sql.out b/sql/core/src/test/resources/sql-tests/results/macro.sql.out new file mode 100644 index 0000000000000..cdc80cb4f84f8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/macro.sql.out @@ -0,0 +1,344 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 41 + + +-- !query 0 +CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x)) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT SIGMOID(2) +-- !query 1 schema +struct<(CAST(1.0 AS DOUBLE) / (CAST(1.0 AS DOUBLE) + EXP(CAST((- 2) AS DOUBLE)))):double> +-- !query 1 output +0.8807970779778823 + + +-- !query 2 +DROP TEMPORARY MACRO SIGMOID +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY MACRO FIXED_NUMBER() 1 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +SELECT FIXED_NUMBER() + 1 +-- !query 4 schema +struct<(1 + 1):int> +-- !query 4 output +2 + + +-- !query 5 +DROP TEMPORARY MACRO FIXED_NUMBER +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +CREATE TEMPORARY MACRO SIMPLE_ADD (x INT, y INT) x + y +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +SELECT SIMPLE_ADD(1, 9) +-- !query 7 schema +struct<(1 + 9):int> +-- !query 7 output +10 + + +-- !query 8 +DROP TEMPORARY MACRO SIMPLE_ADD +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +CREATE TEMPORARY MACRO flr(d bigint) FLOOR(d/10)*10 +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +SELECT flr(12) +-- !query 10 schema +struct<(FLOOR((CAST(12 AS DOUBLE) / CAST(10 AS DOUBLE))) * CAST(10 AS BIGINT)):bigint> +-- !query 10 output +10 + + +-- !query 11 +DROP TEMPORARY MACRO flr +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +CREATE TEMPORARY MACRO STRING_LEN(x string) length(x) +-- !query 12 schema +struct<> +-- !query 12 output + + + +-- !query 13 +CREATE TEMPORARY MACRO STRING_LEN_PLUS_ONE(x string) length(x)+1 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +CREATE TEMPORARY MACRO STRING_LEN_PLUS_TWO(x string) length(x)+2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +create table macro_test (x string) using parquet +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +insert into table macro_test values ("bb"), ("a"), ("ccc") +-- !query 16 schema +struct<> +-- !query 16 output + + + +-- !query 17 +SELECT CONCAT(STRING_LEN(x), ":", STRING_LEN_PLUS_ONE(x), ":", STRING_LEN_PLUS_TWO(x)) a +FROM macro_test +-- !query 17 schema +struct +-- !query 17 output +1:2:3 +2:3:4 +3:4:5 + + +-- !query 18 +SELECT CONCAT(STRING_LEN(x), ":", STRING_LEN_PLUS_ONE(x), ":", STRING_LEN_PLUS_TWO(x)) a +FROM +macro_test +sort by a +-- !query 18 schema +struct +-- !query 18 output +1:2:3 +2:3:4 +3:4:5 + + +-- !query 19 +drop table macro_test +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +CREATE TABLE macro_testing(a int, b int, c int) using parquet +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +insert into table macro_testing values (1,2,3) +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +insert into table macro_testing values (4,5,6) +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +CREATE TEMPORARY MACRO math_square(x int) x*x +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +CREATE TEMPORARY MACRO math_add(x int) x+x +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +select math_square(a), math_square(b),factorial(a), factorial(b), math_add(a), math_add(b),int(c) +from macro_testing order by int(c) +-- !query 25 schema +struct<(a * a):int,(b * b):int,factorial(a):bigint,factorial(b):bigint,(a + a):int,(b + b):int,c:int> +-- !query 25 output +1 4 1 2 2 4 3 +16 25 24 120 8 10 6 + + +-- !query 26 +drop table macro_testing +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +CREATE TEMPORARY MACRO max(x int, y int) x + y +-- !query 27 schema +struct<> +-- !query 27 output + + + +-- !query 28 +SELECT max(1, 2) +-- !query 28 schema +struct<(1 + 2):int> +-- !query 28 output +3 + + +-- !query 29 +DROP TEMPORARY MACRO max +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +SELECT max(2) +-- !query 30 schema +struct +-- !query 30 output +2 + + +-- !query 31 +CREATE TEMPORARY MACRO c() 3E9 +-- !query 31 schema +struct<> +-- !query 31 output + + + +-- !query 32 +SELECT floor(c()/10) +-- !query 32 schema +struct +-- !query 32 output +300000000 + + +-- !query 33 +DROP TEMPORARY MACRO c +-- !query 33 schema +struct<> +-- !query 33 output + + + +-- !query 34 +CREATE TEMPORARY MACRO fixed_number() 42 +-- !query 34 schema +struct<> +-- !query 34 output + + + +-- !query 35 +DROP TEMPORARY FUNCTION fixed_number +-- !query 35 schema +struct<> +-- !query 35 output + + + +-- !query 36 +DROP TEMPORARY MACRO IF EXISTS fixed_number +-- !query 36 schema +struct<> +-- !query 36 output + + + +-- !query 37 +CREATE TEMPORARY MACRO simple_add_error(x int) x + y +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +CREATE TEMPORARY MACRO simple_add_error with exception: cannot resolve '`y`' given input columns: [x]; line 1 pos 51; + + +-- !query 38 +CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y +-- !query 38 schema +struct<> +-- !query 38 output +org.apache.spark.sql.AnalysisException +Cannot support duplicate colNames for CREATE TEMPORARY MACRO simple_add_error, actual columns: StructField(x,IntegerType,true),StructField(x,IntegerType,true); + + +-- !query 39 +CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2) +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.AnalysisException +CREATE TEMPORARY MACRO simple_add_error with exception: cannot resolve '`c2`' given input columns: []; line 1 pos 64; + + +-- !query 40 +DROP TEMPORARY MACRO SOME_MACRO +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +Temporary macro 'SOME_MACRO' not found; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8d7b7836f7069..e4dd077715d0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1516,35 +1516,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ) } - test("create/drop temporary macro") { - intercept[AnalysisException] { - sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x + y") - } - intercept[AnalysisException] { - sql(s"CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y") - } - intercept[AnalysisException] { - sql(s"CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2 from t2) ") - } - sql("CREATE TEMPORARY MACRO fixed_number() 42") - checkAnswer(sql("SELECT fixed_number()"), Row(42)) - sql("CREATE TEMPORARY MACRO string_len_plus_two(x string) length(x) + 2") - checkAnswer(sql("SELECT string_len_plus_two('abc')"), Row(5)) - sql("CREATE TEMPORARY MACRO simple_add(x int, y int) x + y") - checkAnswer(sql("SELECT simple_add(1, 2)"), Row(3)) - intercept[AnalysisException] { - sql(s"SELECT simple_add(1)") - } - sql("DROP TEMPORARY MACRO fixed_number") - intercept[AnalysisException] { - sql(s"DROP TEMPORARY MACRO abs") - } - intercept[AnalysisException] { - sql("DROP TEMPORARY MACRO SOME_MACRO") - } - sql("DROP TEMPORARY MACRO IF EXISTS SOME_MACRO") - } - test("create a data source table without schema") { import testImplicits._ withTempPath { tempDir => From 4d8e843fb490845b8e5b55033ccac9bba93b7591 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Wed, 31 May 2017 02:01:01 +0800 Subject: [PATCH 26/26] update --- .../spark/sql/execution/command/macros.scala | 53 +++++----- .../test/resources/sql-tests/inputs/macro.sql | 11 +- .../resources/sql-tests/results/macro.sql.out | 100 +++++++++++++++--- 3 files changed, 124 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala index 9c62a7783343f..d3fbd94e39275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.command +import scala.collection.mutable + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types.StructType @@ -43,52 +44,50 @@ case class CreateMacroCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val columns = funcWrapper.columns - val columnAttrs = columns.toAttributes - def formatName: (String => String) = - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + val columns = funcWrapper.columns.map(_.name) + val columnAttrs = funcWrapper.columns.toAttributes + def formatName = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { (name: String) => name } else { (name: String) => name.toLowerCase } - val colToIndex: Map[String, Int] = columnAttrs.map(_.name).map(formatName).zipWithIndex.toMap + val colToIndex: Map[String, Int] = columns.map(formatName).zipWithIndex.toMap if (colToIndex.size != columns.size) { - throw new AnalysisException(s"Cannot support duplicate colNames " + - s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}") + throw new AnalysisException(s"Failed to CREATE TEMPORARY MACRO $macroName, because " + + s"at least one parameter name was used more than once : ${columns.mkString(",")}") } - try { + val resolvedMacroFunction = try { val plan = Project(Seq(Alias(funcWrapper.macroFunction, "m")()), LocalRelation(columnAttrs)) - val analyzed = sparkSession.sessionState.analyzer.execute(plan) + val analyzed @ Project(Seq(named), _) = sparkSession.sessionState.analyzer.execute(plan) sparkSession.sessionState.analyzer.checkAnalysis(analyzed) + named.children.head } catch { case a: AnalysisException => - throw new AnalysisException(s"CREATE TEMPORARY MACRO $macroName " + - s"with exception: ${a.getMessage}") + throw new AnalysisException(s"Failed to CREATE TEMPORARY MACRO $macroName, because of " + + s"exception: ${a.getMessage}") } - val macroFunction = funcWrapper.macroFunction.transform { - case u: UnresolvedAttribute => + val foundColumns: mutable.Set[String] = new mutable.HashSet() + val macroFunction = resolvedMacroFunction.transform { + case u: AttributeReference => val index = colToIndex.get(formatName(u.name)).getOrElse( - throw new AnalysisException(s"Cannot find colName: ${u} " + - s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}")) - BoundReference(index, columns(index).dataType, columns(index).nullable) - case u: UnresolvedFunction => - sparkSession.sessionState.catalog.lookupFunction(u.name, u.children) - case s: SubqueryExpression => - throw new AnalysisException(s"Cannot support Subquery: ${s} " + - s"for CREATE TEMPORARY MACRO $macroName") - case u: UnresolvedGenerator => - throw new AnalysisException(s"Cannot support Generator: ${u} " + - s"for CREATE TEMPORARY MACRO $macroName") + throw new AnalysisException(s"Failed to CREATE TEMPORARY MACRO $macroName, because " + + s"it cannot find colName: ${u.name}, actual columns: ${columns.mkString(",")}")) + foundColumns.add(formatName(u.name)) + BoundReference(index, u.dataType, u.nullable) + } + if (foundColumns.size != columns.size) { + throw new AnalysisException(s"Failed to CREATE TEMPORARY MACRO $macroName, because " + + s"expected columns ${foundColumns.mkString(",")} but found ${columns.mkString(",")}") } val columnLength: Int = columns.length val info = new ExpressionInfo(macroName, macroName) val builder = (children: Seq[Expression]) => { if (children.size != columnLength) { - throw new AnalysisException(s"Actual number of columns: ${children.size} != " + - s"expected number of columns: ${columnLength} for Macro $macroName") + throw new AnalysisException(s"Arguments length: ${children.size} != " + + s"expected number: ${columnLength} of arguments for Macro $macroName") } macroFunction.transform { // Skip to validate the input type because check it before. diff --git a/sql/core/src/test/resources/sql-tests/inputs/macro.sql b/sql/core/src/test/resources/sql-tests/inputs/macro.sql index e6b08521ce525..fbbe5b368c7e9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/macro.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/macro.sql @@ -49,8 +49,17 @@ CREATE TEMPORARY MACRO fixed_number() 42; DROP TEMPORARY FUNCTION fixed_number; DROP TEMPORARY MACRO IF EXISTS fixed_number; +CREATE TEMPORARY MACRO add_bigint_int(x bigint, y int) x + y; +SELECT add_bigint_int(1, 1.5); +DROP TEMPORARY MACRO add_bigint_int; + -- invalid queries CREATE TEMPORARY MACRO simple_add_error(x int) x + y; CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y; CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2); -DROP TEMPORARY MACRO SOME_MACRO; \ No newline at end of file +DROP TEMPORARY MACRO SOME_MACRO; +CREATE TEMPORARY MACRO macro_add(x int, y int, z int) x + y; +CREATE TEMPORARY MACRO macro_add(x int, x int) x + x; +CREATE TEMPORARY MACRO macro_add(x int, y int) x + y; +SELECT macro_add(1, 2, 3); +DROP TEMPORARY MACRO macro_add; diff --git a/sql/core/src/test/resources/sql-tests/results/macro.sql.out b/sql/core/src/test/resources/sql-tests/results/macro.sql.out index cdc80cb4f84f8..aec4b358b7dce 100644 --- a/sql/core/src/test/resources/sql-tests/results/macro.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/macro.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 41 +-- Number of queries: 49 -- !query 0 @@ -309,36 +309,112 @@ struct<> -- !query 37 -CREATE TEMPORARY MACRO simple_add_error(x int) x + y +CREATE TEMPORARY MACRO add_bigint_int(x bigint, y int) x + y -- !query 37 schema struct<> -- !query 37 output -org.apache.spark.sql.AnalysisException -CREATE TEMPORARY MACRO simple_add_error with exception: cannot resolve '`y`' given input columns: [x]; line 1 pos 51; + -- !query 38 -CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y +SELECT add_bigint_int(1, 1.5) -- !query 38 schema -struct<> +struct<(CAST(1 AS BIGINT) + CAST(1.5 AS BIGINT)):bigint> -- !query 38 output -org.apache.spark.sql.AnalysisException -Cannot support duplicate colNames for CREATE TEMPORARY MACRO simple_add_error, actual columns: StructField(x,IntegerType,true),StructField(x,IntegerType,true); +2 -- !query 39 -CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2) +DROP TEMPORARY MACRO add_bigint_int -- !query 39 schema struct<> -- !query 39 output -org.apache.spark.sql.AnalysisException -CREATE TEMPORARY MACRO simple_add_error with exception: cannot resolve '`c2`' given input columns: []; line 1 pos 64; + -- !query 40 -DROP TEMPORARY MACRO SOME_MACRO +multiply + +CREATE TEMPORARY MACRO simple_add_error(x int) x + y -- !query 40 schema struct<> -- !query 40 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'multiply' expecting {'(', 'SELECT', 'FROM', 'ADD', 'DESC', 'WITH', 'VALUES', 'CREATE', 'TABLE', 'INSERT', 'DELETE', 'DESCRIBE', 'EXPLAIN', 'SHOW', 'USE', 'DROP', 'ALTER', 'MAP', 'SET', 'RESET', 'START', 'COMMIT', 'ROLLBACK', 'REDUCE', 'REFRESH', 'CLEAR', 'CACHE', 'UNCACHE', 'DFS', 'TRUNCATE', 'ANALYZE', 'LIST', 'REVOKE', 'GRANT', 'LOCK', 'UNLOCK', 'MSCK', 'EXPORT', 'IMPORT', 'LOAD'}(line 1, pos 0) + +== SQL == +multiply +^^^ + +CREATE TEMPORARY MACRO simple_add_error(x int) x + y + + +-- !query 41 +CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y +-- !query 41 schema +struct<> +-- !query 41 output +org.apache.spark.sql.AnalysisException +Failed to CREATE TEMPORARY MACRO simple_add_error, because at least one parameter name was used more than once : x,x; + + +-- !query 42 +CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2) +-- !query 42 schema +struct<> +-- !query 42 output +org.apache.spark.sql.AnalysisException +Failed to CREATE TEMPORARY MACRO simple_add_error, because of exception: cannot resolve '`c2`' given input columns: []; line 1 pos 64; + + +-- !query 43 +DROP TEMPORARY MACRO SOME_MACRO +-- !query 43 schema +struct<> +-- !query 43 output org.apache.spark.sql.AnalysisException Temporary macro 'SOME_MACRO' not found; + + +-- !query 44 +CREATE TEMPORARY MACRO macro_add(x int, y int, z int) x + y +-- !query 44 schema +struct<> +-- !query 44 output +org.apache.spark.sql.AnalysisException +Failed to CREATE TEMPORARY MACRO macro_add, because expected columns Set(0, 1) but found x,y,z; + + +-- !query 45 +CREATE TEMPORARY MACRO macro_add(x int, x int) x + x +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +Failed to CREATE TEMPORARY MACRO macro_add, because at least one parameter name was used more than once : x,x; + + +-- !query 46 +CREATE TEMPORARY MACRO macro_add(x int, y int) x + y +-- !query 46 schema +struct<> +-- !query 46 output + + + +-- !query 47 +SELECT macro_add(1, 2, 3) +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +Arguments length: 3 != expected number: 2 of arguments for Macro macro_add; line 1 pos 7 + + +-- !query 48 +DROP TEMPORARY MACRO macro_add +-- !query 48 schema +struct<> +-- !query 48 output +