Skip to content
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ Below is a list of all the keywords in Spark SQL.
|RECOVER|non-reserved|non-reserved|non-reserved|
|REDUCE|non-reserved|non-reserved|non-reserved|
|REFERENCES|reserved|non-reserved|reserved|
|RECURSIVE|reserved|non-reserved|reserved|
|REFRESH|non-reserved|non-reserved|non-reserved|
|RENAME|non-reserved|non-reserved|non-reserved|
|REPAIR|non-reserved|non-reserved|non-reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ describeColName
;

ctes
: WITH namedQuery (',' namedQuery)*
: WITH RECURSIVE? namedQuery (',' namedQuery)*
;

namedQuery
Expand Down Expand Up @@ -1386,6 +1386,7 @@ nonReserved
| RECORDREADER
| RECORDWRITER
| RECOVER
| RECURSIVE
| REDUCE
| REFERENCES
| REFRESH
Expand Down Expand Up @@ -1643,6 +1644,7 @@ RANGE: 'RANGE';
RECORDREADER: 'RECORDREADER';
RECORDWRITER: 'RECORDWRITER';
RECOVER: 'RECOVER';
RECURSIVE: 'RECURSIVE';
REDUCE: 'REDUCE';
REFERENCES: 'REFERENCES';
REFRESH: 'REFRESH';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class Analyzer(
ResolveRelations ::
ResolveTables ::
ResolveReferences ::
ResolveRecursiveReferences ::
ResolveCreateNamedStruct ::
ResolveDeserializer ::
ResolveNewInstance ::
Expand Down Expand Up @@ -1657,6 +1658,23 @@ class Analyzer(
}
}

/**
* This rule resolve [[RecursiveReference]]s when the anchor term of the corresponding
* [[RecursiveRelation]] is resolved (ie. we know the output of the recursive relation).
*/
object ResolveRecursiveReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case rr @ RecursiveRelation(cteName, anchorTerm, recursiveTerm)
if anchorTerm.resolved && !recursiveTerm.resolved =>

val newRecursiveTerm = recursiveTerm.transform {
case UnresolvedRecursiveReference(name, accumulated) if name == cteName =>
RecursiveReference(name, anchorTerm.output.map(_.newInstance()), accumulated)
}
rr.copy(recursiveTerm = newRecursiveTerm)
}
}

/**
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Except, LogicalPlan, RecursiveRelation, SubqueryAlias, Union, With}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, With}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -59,7 +60,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
startOfQuery: Boolean = true): Unit = {
val resolver = SQLConf.get.resolver
plan match {
case With(child, relations) =>
case With(child, relations, _) =>
val newNames = mutable.ArrayBuffer.empty[String]
newNames ++= outerCTERelationNames
relations.foreach {
Expand All @@ -86,7 +87,7 @@ object CTESubstitution extends Rule[LogicalPlan] {

private def legacyTraverseAndSubstituteCTE(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
case With(child, relations) =>
case With(child, relations, _) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = true)
substituteCTE(child, resolvedCTERelations)
}
Expand Down Expand Up @@ -135,20 +136,22 @@ object CTESubstitution extends Rule[LogicalPlan] {
*/
private def traverseAndSubstituteCTE(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
case With(child: LogicalPlan, relations) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = false)
case With(child: LogicalPlan, relations, allowRecursion) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = false, allowRecursion)
substituteCTE(child, resolvedCTERelations)

case other =>
other.transformExpressions {
case e: SubqueryExpression => e.withNewPlan(traverseAndSubstituteCTE(e.plan))
case e: SubqueryExpression =>
e.withNewPlan(traverseAndSubstituteCTE(e.plan))
}
}
}

private def resolveCTERelations(
relations: Seq[(String, SubqueryAlias)],
isLegacy: Boolean): Seq[(String, LogicalPlan)] = {
isLegacy: Boolean,
allowRecursion: Boolean = false): Seq[(String, LogicalPlan)] = {
val resolvedCTERelations = new mutable.ArrayBuffer[(String, LogicalPlan)](relations.size)
for ((name, relation) <- relations) {
val innerCTEResolved = if (isLegacy) {
Expand All @@ -161,8 +164,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
// substitute CTE defined in `relation` first.
traverseAndSubstituteCTE(relation)
}
val recursionHandled = if (allowRecursion) {
handleRecursion(innerCTEResolved, name)
} else {
innerCTEResolved
}
// CTE definition can reference a previous one
resolvedCTERelations += (name -> substituteCTE(innerCTEResolved, resolvedCTERelations))
resolvedCTERelations += (name -> substituteCTE(recursionHandled, resolvedCTERelations))
}
resolvedCTERelations
}
Expand All @@ -172,12 +180,122 @@ object CTESubstitution extends Rule[LogicalPlan] {
cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan =
plan resolveOperatorsUp {
case u @ UnresolvedRelation(Seq(table)) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map(_._2).getOrElse(u)
cteRelations.find(r => SQLConf.get.resolver(r._1, table)).map(_._2).getOrElse(u)

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.plan, cteRelations))
}
}

/**
* If recursion is allowed, recursion handling starts with inserting unresolved self-references
* ([[UnresolvedRecursiveReference]]) to places where a reference to the CTE definition itself is
* found.
* If there is a self-reference then we need to check if structure of the query satisfies the SQL
* recursion rules and insert a [[RecursiveRelation]] finally.
*/
private def handleRecursion(plan: LogicalPlan, cteName: String) = {
// check if there is any reference to the CTE and if there is then treat the CTE as recursive
val (recursiveReferencesPlan, recursiveReferenceCount) =
insertRecursiveReferences(plan, cteName)
if (recursiveReferenceCount > 0) {
// if there is a reference then the CTE needs to follow one of these structures
recursiveReferencesPlan match {
case SubqueryAlias(_, u: Union) =>
insertRecursiveRelation(cteName, Seq.empty, false, u)
case SubqueryAlias(_, Distinct(u: Union)) =>
insertRecursiveRelation(cteName, Seq.empty, true, u)
case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(columnNames, u: Union)) =>
insertRecursiveRelation(cteName, columnNames, false, u)
case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(columnNames, Distinct(u: Union))) =>
insertRecursiveRelation(cteName, columnNames, true, u)
case _ =>
throw new AnalysisException(s"Recursive query $cteName should contain UNION or UNION " +
"ALL statements only. This error can also be caused by ORDER BY or LIMIT keywords " +
"used on result of UNION or UNION ALL.")
}
} else {
plan
}
}

/**
* If we encounter a relation that matches the recursive CTE then the relation is replaced to an
* [[UnresolvedRecursiveReference]]. The replacement process also checks possible references in
* subqueries and reports them as errors.
*/
private def insertRecursiveReferences(plan: LogicalPlan, cteName: String): (LogicalPlan, Int) = {
val resolver = SQLConf.get.resolver

var recursiveReferenceCount = 0
val newPlan = plan resolveOperators {
case UnresolvedRelation(Seq(table)) if (resolver(cteName, table)) =>
recursiveReferenceCount += 1
UnresolvedRecursiveReference(cteName, false)

case other =>
other.subqueries.foreach(checkAndTraverse(_, {
case UnresolvedRelation(Seq(table)) if resolver(cteName, table) =>
throw new AnalysisException(s"Recursive query $cteName should not contain recursive " +
"references in its subquery.")
case _ => true
}))
other
}

(newPlan, recursiveReferenceCount)
}

private def insertRecursiveRelation(
cteName: String,
columnNames: Seq[String],
distinct: Boolean,
union: Union) = {
if (union.children.size != 2) {
throw new AnalysisException(s"Recursive query ${cteName} should contain one anchor term " +
"and one recursive term connected with UNION or UNION ALL.")
}

val anchorTerm :: recursiveTerm :: Nil = union.children

// The anchor term shouldn't contain a recursive reference that matches the name of the CTE,
// except if it is nested under an other RecursiveRelation with the same name.
checkAndTraverse(anchorTerm, {
case UnresolvedRecursiveReference(name, _) if name == cteName =>
throw new AnalysisException(s"Recursive query $cteName should not contain recursive " +
"references in its anchor (first) term.")
case RecursiveRelation(name, _, _) if name == cteName => false
case _ => true
})

// The anchor term has a special role, its output column are aliased if required.
val aliasedAnchorTerm = SubqueryAlias(cteName,
if (columnNames.nonEmpty) {
UnresolvedSubqueryColumnAliases(columnNames, anchorTerm)
} else {
anchorTerm
}
)

// If UNION combinator is used between the terms we extend the anchor with a DISTINCT and the
// recursive term with an EXCEPT clause and a reference to the so far accumulated result.
if (distinct) {
RecursiveRelation(cteName, Distinct(aliasedAnchorTerm),
Except(recursiveTerm, UnresolvedRecursiveReference(cteName, true), false))
} else {
RecursiveRelation(cteName, aliasedAnchorTerm, recursiveTerm)
}
}

/**
* Taverses the plan including subqueries and run the check while it returns true.
*/
private def checkAndTraverse(plan: LogicalPlan, check: LogicalPlan => Boolean): Unit = {
if (check(plan)) {
plan.children.foreach(checkAndTraverse(_, check))
plan.subqueries.foreach(checkAndTraverse(_, check))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ trait CheckAnalysis extends PredicateHelper {
case _ => // Analysis successful!
}
}
checkRecursion(plan)
checkCollectedMetrics(plan)
extendedCheckRules.foreach(_(plan))
plan.foreachUp {
Expand All @@ -665,6 +666,62 @@ trait CheckAnalysis extends PredicateHelper {
plan.setAnalyzed()
}

/**
* Recursion according to SQL standard comes with several limitations due to the fact that only
* those operations are allowed where the new set of rows can be computed from the result of the
* previous iteration. This implies that a recursive reference can't be used in some kinds of
* joins and aggregations.
* A further constraint is that a recursive term can contain one recursive reference only (except
* for using it on different sides of a UNION).
*
* This rule checks that these restrictions are not violated and returns the original plan.
*/
private def checkRecursion(
plan: LogicalPlan,
allowedRecursiveReferencesAndCounts: mutable.Map[String, Int] = mutable.Map.empty): Unit = {
plan match {
case RecursiveRelation(name, anchorTerm, recursiveTerm) =>
if (allowedRecursiveReferencesAndCounts.contains(name)) {
throw new AnalysisException(s"Recursive CTE definition $name is already in use.")
}
checkRecursion(anchorTerm, allowedRecursiveReferencesAndCounts)
checkRecursion(recursiveTerm, allowedRecursiveReferencesAndCounts += name -> 0)
allowedRecursiveReferencesAndCounts -= name
case RecursiveReference(name, _, false, _, _, _) =>
if (!allowedRecursiveReferencesAndCounts.contains(name)) {
throw new AnalysisException(s"Recursive reference $name cannot be used here. This can " +
"be caused by using it on inner side of an outer join, using it with aggregate in a " +
"subquery or using it multiple times in a recursive term (except for using it on " +
"different sides of an UNION ALL).")
}
if (allowedRecursiveReferencesAndCounts(name) > 0) {
throw new AnalysisException(s"Recursive reference $name cannot be used multiple times " +
"in a recursive term.")
}

allowedRecursiveReferencesAndCounts +=
name -> (allowedRecursiveReferencesAndCounts(name) + 1)
case Join(left, right, Inner, _, _) =>
checkRecursion(left, allowedRecursiveReferencesAndCounts)
checkRecursion(right, allowedRecursiveReferencesAndCounts)
case Join(left, right, LeftOuter, _, _) =>
checkRecursion(left, allowedRecursiveReferencesAndCounts)
checkRecursion(right, mutable.Map.empty)
case Join(left, right, RightOuter, _, _) =>
checkRecursion(left, mutable.Map.empty)
checkRecursion(right, allowedRecursiveReferencesAndCounts)
case Join(left, right, _, _, _) =>
checkRecursion(left, mutable.Map.empty)
checkRecursion(right, mutable.Map.empty)
case Aggregate(_, _, child) => checkRecursion(child, mutable.Map.empty)
case Union(children) =>
children.foreach(checkRecursion(_,
mutable.Map(allowedRecursiveReferencesAndCounts.keys.map(name => name -> 0).toSeq: _*)))
case o =>
o.children.foreach(checkRecursion(_, allowedRecursiveReferencesAndCounts))
}
}

/**
* Validates subquery expressions in the plan. Upon failure, returns an user facing error.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,9 @@ case class UnresolvedHaving(
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}

case class UnresolvedRecursiveReference(cteName: String, accumulated: Boolean) extends LeafNode {
override def output: Seq[Attribute] = Nil

override lazy val resolved = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,9 @@ object ColumnPruning extends Rule[LogicalPlan] {

case NestedColumnAliasing(p) => p

// Don't prune columns of RecursiveTable
case p @ Project(_, _: RecursiveRelation) => p

// for all other logical plans that inherits the output from it's children
// Project over project is handled by the first case, skip it here.
case p @ Project(_, child) if !child.isInstanceOf[Project] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.",
ctx)
}
With(plan, ctes)
With(plan, ctes, ctx.RECURSIVE() != null)
}

/**
Expand Down
Loading