diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index c162c1a0c5789..10f2e2416bb64 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -205,6 +205,8 @@ atomExpression | whenExpression | (functionName LPAREN) => function | tableOrColumn + | (LPAREN KW_SELECT) => subQueryExpression + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression) | LPAREN! expression RPAREN! ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 8099751900a42..b16025a17e2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -667,6 +667,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) } + case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) => + ScalarSubquery(nodeToPlan(subquery)) /* Stars (*) */ case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 26c3d286b19fa..04e56a8fdaf7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,6 +80,7 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveSubquery :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalJoin :: @@ -120,7 +121,14 @@ class Analyzer( withAlias.getOrElse(relation) } substituted.getOrElse(u) + case other => + // This can't be done in ResolveSubquery because that does not know the CTE. + other transformExpressions { + case e: SubqueryExpression => + e.withNewPlan(substituteCTE(e.query, cteRelations)) + } } + } } @@ -693,6 +701,30 @@ class Analyzer( } } + /** + * This rule resolve subqueries inside expressions. + * + * Note: CTE are handled in CTESubstitution. + */ + object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { + + private def hasSubquery(e: Expression): Boolean = { + e.find(_.isInstanceOf[SubqueryExpression]).isDefined + } + + private def hasSubquery(q: LogicalPlan): Boolean = { + q.expressions.exists(hasSubquery) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case q: LogicalPlan if q.childrenResolved && hasSubquery(q) => + q transformExpressions { + case e: SubqueryExpression if !e.query.resolved => + e.withNewPlan(execute(e.query)) + } + } + } + /** * Turns projections that contain aggregate expressions into aggregations. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala new file mode 100644 index 0000000000000..a8f5e1f63d4c7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} +import org.apache.spark.sql.types.DataType + +/** + * An interface for subquery that is used in expressions. + */ +abstract class SubqueryExpression extends LeafExpression { + + /** + * The logical plan of the query. + */ + def query: LogicalPlan + + /** + * Either a logical plan or a physical plan. The generated tree string (explain output) uses this + * field to explain the subquery. + */ + def plan: QueryPlan[_] + + /** + * Updates the query with new logical plan. + */ + def withNewPlan(plan: LogicalPlan): SubqueryExpression +} + +/** + * A subquery that will return only one row and one column. + * + * This will be converted into [[execution.ScalarSubquery]] during physical planning. + * + * Note: `exprId` is used to have unique name in explain string output. + */ +case class ScalarSubquery( + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression with Unevaluable { + + override def plan: LogicalPlan = Subquery(toString, query) + + override lazy val resolved: Boolean = query.resolved + + override def dataType: DataType = query.schema.fields.head.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (query.schema.length != 1) { + TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + + query.schema.length.toString) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId) + + override def toString: String = s"subquery#${exprId.id}" + + // TODO: support sql() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 902e18081bddf..f1f438075164e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -88,7 +88,19 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), - ConvertToLocalRelation) :: Nil + ConvertToLocalRelation) :: + Batch("Subquery", Once, + OptimizeSubqueries) :: Nil + } + + /** + * Optimize all the subqueries inside expression. + */ + object OptimizeSubqueries extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case subquery: SubqueryExpression => + subquery.withNewPlan(Optimizer.this.execute(subquery.query)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 05f5bdbfc0769..86bd33f526209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Subquery import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -226,4 +227,9 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" override def simpleString: String = statePrefix + super.simpleString + + override def treeChildren: Seq[PlanType] = { + val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e}) + children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType]) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 30df2a84f62c4..e46ce1cee7c6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -448,6 +448,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } + /** + * All the nodes that will be used to generate tree string. + */ + protected def treeChildren: Seq[BaseType] = children + /** * Appends the string represent of this node and its children to the given StringBuilder. * @@ -470,9 +475,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(simpleString) builder.append("\n") - if (children.nonEmpty) { - children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + if (treeChildren.nonEmpty) { + treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) } builder diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 8d7d6b5bf52ea..ed7121831ac29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.BooleanType import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { @@ -201,4 +202,10 @@ class CatalystQlSuite extends PlanTest { parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + "from windowData") } + + test("subquery") { + parser.parsePlan("select (select max(b) from s) ss from t") + parser.parsePlan("select * from t where a = (select b from s)") + parser.parsePlan("select * from t group by g having a > (select b from s)") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index e0cec09742eba..ca6dcd8bdfb84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -113,6 +113,17 @@ class AnalysisErrorSuite extends AnalysisTest { val dateLit = Literal.create(null, DateType) + errorTest( + "scalar subquery with 2 columns", + testRelation.select( + (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), + "Scalar subquery must return only one column, but got 2" :: Nil) + + errorTest( + "scalar subquery with no column", + testRelation.select(ScalarSubquery(LocalRelation()).as('a)), + "Scalar subquery must return only one column, but got 0" :: Nil) + errorTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d58b99655c1eb..55325c1662e2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -884,6 +884,7 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( + Batch("Subquery", Once, PlanSubqueries(self)), Batch("Add exchange", Once, EnsureRequirements(self)), Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index c72b8dc70708f..872ccde883060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import org.apache.spark.Logging import org.apache.spark.rdd.{RDD, RDDOperationScope} @@ -31,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ThreadUtils /** * The base class for physical operators. @@ -112,16 +115,58 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ final def execute(): RDD[InternalRow] = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() + waitForSubqueries() doExecute() } } + // All the subqueries and their Future of results. + @transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]() + + /** + * Collects all the subqueries and create a Future to take the first two rows of them. + */ + protected def prepareSubqueries(): Unit = { + val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) + allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => + val futureResult = Future { + // We only need the first row, try to take two rows so we can throw an exception if there + // are more than one rows returned. + e.executedPlan.executeTake(2) + }(SparkPlan.subqueryExecutionContext) + queryResults += e -> futureResult + } + } + + /** + * Waits for all the subqueries to finish and updates the results. + */ + protected def waitForSubqueries(): Unit = { + // fill in the result of subqueries + queryResults.foreach { + case (e, futureResult) => + val rows = Await.result(futureResult, Duration.Inf) + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column") + e.updateResult(rows(0).get(0, e.dataType)) + } else { + // There is no rows returned, the result should be null. + e.updateResult(null) + } + } + queryResults.clear() + } + /** * Prepare a SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { if (prepareCalled.compareAndSet(false, true)) { doPrepare() + prepareSubqueries() children.foreach(_.prepare()) } } @@ -231,6 +276,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } +object SparkPlan { + private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) +} + private[sql] trait LeafNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index f35efb5b24b1f..116013f307782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -73,9 +73,10 @@ trait CodegenSupport extends SparkPlan { /** * Returns Java source code to process the rows from upstream. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent ctx.freshNamePrefix = variablePrefix + waitForSubqueries() doProduce(ctx) } @@ -101,7 +102,7 @@ trait CodegenSupport extends SparkPlan { /** * Consume the columns generated from current SparkPlan, call it's parent. */ - def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { if (input != null) { assert(input.length == output.length) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4b82d5563460b..55bddd196ec46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -343,3 +343,18 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl protected override def doExecute(): RDD[InternalRow] = child.execute() } + +/** + * A plan as subquery. + * + * This is used to generate tree string for SparkScalarSubquery. + */ +case class Subquery(name: String, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala new file mode 100644 index 0000000000000..9c645c78e8732 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType + +/** + * A subquery that will return only one row and one column. + * + * This is the physical copy of ScalarSubquery to be used inside SparkPlan. + */ +case class ScalarSubquery( + @transient executedPlan: SparkPlan, + exprId: ExprId) + extends SubqueryExpression { + + override def query: LogicalPlan = throw new UnsupportedOperationException + override def withNewPlan(plan: LogicalPlan): SubqueryExpression = { + throw new UnsupportedOperationException + } + override def plan: SparkPlan = Subquery(simpleString, executedPlan) + + override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def nullable: Boolean = true + override def toString: String = s"subquery#${exprId.id}" + + // the first column in first row from `query`. + private var result: Any = null + + def updateResult(v: Any): Unit = { + result = v + } + + override def eval(input: InternalRow): Any = result + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + Literal.create(result, dataType).genCode(ctx, ev) + } +} + +/** + * Convert the subquery from logical plan into executed plan. + */ +private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = { + plan.transformAllExpressions { + case subquery: expressions.ScalarSubquery => + val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() + val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) + ScalarSubquery(executedPlan, subquery.exprId) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f665a1c87bd78..d2ddf0af4f041 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,18 +21,13 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.CatalystQl -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.parser.ParserConf -import org.apache.spark.sql.execution.{aggregate, SparkQl} +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ - class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala new file mode 100644 index 0000000000000..e851eb02f01b3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class SubquerySuite extends QueryTest with SharedSQLContext { + + test("simple uncorrelated scalar subquery") { + assertResult(Array(Row(1))) { + sql("select (select 1 as b) as b").collect() + } + + assertResult(Array(Row(1))) { + sql("with t2 as (select 1 as b, 2 as c) " + + "select a from (select 1 as a union all select 2 as a) t " + + "where a = (select max(b) from t2) ").collect() + } + + assertResult(Array(Row(3))) { + sql("select (select (select 1) + 1) + 1").collect() + } + + // more than one columns + val error = intercept[AnalysisException] { + sql("select (select 1, 2) as b").collect() + } + assert(error.message contains "Scalar subquery must return only one column, but got 2") + + // more than one rows + val error2 = intercept[RuntimeException] { + sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() + } + assert(error2.getMessage contains + "more than one row returned by a subquery used as an expression") + + // string type + assertResult(Array(Row("s"))) { + sql("select (select 's' as s) as b").collect() + } + + // zero rows + assertResult(Array(Row(null))) { + sql("select (select 's' as s limit 0) as b").collect() + } + } + + test("uncorrelated scalar subquery on testData") { + // initialize test Data + testData + + assertResult(Array(Row(5))) { + sql("select (select key from testData where key > 3 limit 1) + 1").collect() + } + + assertResult(Array(Row(-100))) { + sql("select -(select max(key) from testData)").collect() + } + + assertResult(Array(Row(null))) { + sql("select (select value from testData limit 0)").collect() + } + + assertResult(Array(Row("99"))) { + sql("select (select min(value) from testData" + + " where key = (select max(key) from testData) - 1)").collect() + } + } +}