From fdb8d2a5215247b93b777dda2ea0ab6c1dad0efd Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 13 Feb 2017 15:25:39 -0800 Subject: [PATCH 1/9] [SPARK-18874] First phase: Deferring the correlated predicate pull up to Optimizer phase --- .../sql/catalyst/analysis/Analyzer.scala | 290 +++++++++++------- .../sql/catalyst/analysis/CheckAnalysis.scala | 27 +- .../sql/catalyst/analysis/TypeCoercion.scala | 79 ++++- .../sql/catalyst/expressions/predicates.scala | 37 ++- .../sql/catalyst/expressions/subquery.scala | 250 ++++++++++----- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/optimizer/subquery.scala | 161 +++++++++- .../analysis/AnalysisErrorSuite.scala | 11 +- .../analysis/ResolveSubquerySuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 2 - .../apache/spark/sql/execution/subquery.scala | 3 - .../invalid-correlation.sql.out | 4 +- .../org/apache/spark/sql/SubquerySuite.scala | 7 +- 13 files changed, 642 insertions(+), 235 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 93666f14958e..5b26f5ea5b00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,7 +21,7 @@ import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ @@ -162,6 +162,8 @@ class Analyzer( FixNullability), Batch("ResolveTimeZone", Once, ResolveTimeZone), + Batch("Subquery", Once, + UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -710,13 +712,85 @@ class Analyzer( } transformUp { case other => other transformExpressions { case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } newRight } } + private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + } + + /** + * The outer plan may have been de-duplicated and the function below updates the + * outer references to refer to the de-duplicated attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- 'Project [*] + * : +- Filter exists#271 [c1#250] + * : : +- Project [1 AS 1#295] + * : : +- Filter (outer(c1#250) = c1#263) + * : : +- SubqueryAlias t2 + * : : +- Relation[c1#263,c2#264] parquet + * : +- SubqueryAlias t1 + * : +- Relation[c1#250,c2#251] parquet + * +- 'Project [*] + * +- Filter exists#272 [c1#250] + * : +- Project [1 AS 1#298] + * : +- Filter (outer(c1#250) = c1#263) + * : +- SubqueryAlias t2 + * : +- Relation[c1#263,c2#264] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#250,c2#251] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#250, c2#251] + * : +- Filter exists#271 [c1#250] + * : : +- Project [1 AS 1#295] + * : : +- Filter (outer(c1#250) = c1#263) + * : : +- SubqueryAlias t2 + * : : +- Relation[c1#263,c2#264] parquet + * : +- SubqueryAlias t1 + * : +- Relation[c1#250,c2#251] parquet + * +- Project [c1#299, c2#300] + * +- Filter exists#272 [c1#299] + * : +- Project [1 AS 1#298] + * : +- Filter (outer(c1#299) = c1#263) ==> Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#263,c2#264] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#299,c2#300] parquet => Outer plan's attributes are de-duplicated. + */ + private def dedupOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(dedupAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -1132,31 +1206,24 @@ class Analyzer( } /** - * Pull out all (outer) correlated predicates from a given subquery. This method removes the - * correlated predicates from subquery [[Filter]]s and adds the references of these predicates - * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to - * be able to evaluate the predicates at the top level. - * - * This method returns the rewritten subquery and correlated predicates. + * Validates to make sure the outer references appearing inside the subquery + * are legal. This function also returns the list of expressions + * that contain outer references. These outer references would be kept as children + * of subquery expressions by the caller of this function. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { + val outerReferences = scala.collection.mutable.ArrayBuffer.empty[Seq[Expression]] // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (p.collectFirst(predicateMap).nonEmpty) { + if (p.find(SubExprUtils.getOuterReferences(_).nonEmpty).nonEmpty) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - // Helper function for locating outer references. - def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined - } - // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { - if (p.expressions.exists(containsOuter)) { + if (p.expressions.exists(SubExprUtils.containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + s"clauses:\n$p") @@ -1194,19 +1261,10 @@ class Analyzer( } } - /** Determine which correlated predicate references are missing from this plan. */ - def missingReferences(p: LogicalPlan): AttributeSet = { - val localPredicateReferences = p.collect(predicateMap) - .flatten - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - localPredicateReferences -- p.outputSet - } - var foundNonEqualCorrelatedPred : Boolean = false - // Simplify the predicates before pulling them out. + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. val transformed = BooleanSimplification(sub) transformUp { // Whitelist operators allowed in a correlated subquery @@ -1255,37 +1313,22 @@ class Analyzer( // The other operator is Join. Filter can be anywhere in a correlated subquery. case f @ Filter(cond, child) => // Find all predicates with an outer reference. - val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + val (correlated, local) = + splitConjunctivePredicates(cond).partition(SubExprUtils.containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { case _: EqualTo | _: EqualNullSafe => false case _ => true } - - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> xs - newFilter - case xs => - predicateMap += child -> xs - child - } + outerReferences += SubExprUtils.getOuterReferences(correlated) + f // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. case p @ Project(expressions, child) => failOnOuterReference(p) - - val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) - } else { - p - } + p // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains @@ -1295,13 +1338,7 @@ class Analyzer( case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - - val referencesToAdd = missingReferences(a) - if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) - } else { - a - } + a // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => @@ -1352,52 +1389,16 @@ class Analyzer( failOnOuterReferenceInSubTree(p) p } - (transformed, predicateMap.values.flatten.toSeq) + outerReferences.flatten } /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. - */ - private def rewriteSubQuery( - sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - - // Make sure the inner and the outer query attributes do not collide. - val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = basePlan.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = basePlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, basePlan) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (basePlan, baseConditions) - } - // Remove outer references from the correlated predicates. We wait with extracting - // these until collisions between the inner and outer query attributes have been - // solved. - val conditions = deDuplicatedConditions.map(_.transform { - case OuterReference(ref) => ref - }) - (plan, conditions) - } - - /** - * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * - * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + * Outer references from the correlated predicates are updated as children of + * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, @@ -1420,7 +1421,7 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: Pull out the predicates if the plan is resolved. + // Step 2: pull the outer references and record them as children of SubqueryExpression if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. @@ -1428,34 +1429,38 @@ class Analyzer( failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } - // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + // Validate the outer reference and record the outer references as children of + // subquery expression. + f.tupled(current, checkAndGetOuterReferences(current)) + } else { e.withNewPlan(current) } } /** - * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS - * expressions into PredicateSubquery expression once the are resolved. + * Resolves the subquery. Apart of resolving the subquery and outer references (if any) + * in the subquery plan, the children of subquery expression are updated to record the + * outer references. This is needed to make sure + * (1) The column(s) referred from the outer query are not pruned from the plan during + * optimization. + * (2) Any aggregate expression(s) that reference outer attributes are pushed down to + * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) - case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + case e @ Exists(sub, _, exprId) if !sub.resolved => + resolveSubQuery(e, plans)(Exists(_, _, exprId)) + case In(e, Seq(l @ ListQuery(sub, _, exprId))) if e.resolved && !sub.resolved => // Get the left hand side expressions. val expressions = e match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => - // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) - PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) - } + val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + In(e, Seq(expr)) } } @@ -2353,6 +2358,11 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(ResolveTimeZone.apply(e.plan)) } } } @@ -2533,3 +2543,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * The aggregate expressions from subquery referencing outer query block are pushed + * down to the outer query block for evaluation. This rule below updates such outer references + * as AttributeReference referring attributes from the parent/outer query block. + * + * For example (SQL): + * {{{ + * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) + * }}} + * Plan before the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < min(outer(b#227))) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + * Plan after the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < outer(min(b#227)#249)) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + */ +object UpdateOuterReferences extends Rule[LogicalPlan] { + private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } + + private def updateOuterReferenceInSubquery( + plan: LogicalPlan, + refExprs: Seq[Expression]): LogicalPlan = { + plan transformAllExpressions { case e => + val outerAlias = + refExprs.find(stripAlias(_).semanticEquals(SubExprUtils.stripOuterReference(e))) + outerAlias match { + case Some(a: Alias) => OuterReference(a.toAttribute) + case _ => e + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case f @ Filter(_, a: Aggregate) if f.resolved => + f transformExpressions { + case s: SubqueryExpression if s.children.nonEmpty => + // Collect the aliases from output of aggregate. + val outerAliases = a.aggregateExpressions collect { case a: Alias => a } + // Update the subquery plan to record the OuterReference to point to outer query plan. + s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d32fbeb4e91e..2481ad18721b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -133,6 +133,8 @@ trait CheckAnalysis extends PredicateHelper { if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") + } + else if (conditions.nonEmpty) { } else if (conditions.nonEmpty) { // Collect the columns from the subquery for further checking. var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) @@ -152,6 +154,11 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = + SubExprUtils.getCorrelatedPredicates(query) + .flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates @@ -167,17 +174,7 @@ trait CheckAnalysis extends PredicateHelper { // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => - // SPARK-18814: Map any aliases to their AttributeReference children - // for the checking in the Aggregate operators below this Project. - subqueryColumns = subqueryColumns.map { - xs => p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => - child - }.getOrElse(xs) - } - - cleanQuery(p.child) + case p: Project => cleanQuery(p.child) case child => child } @@ -213,8 +210,10 @@ trait CheckAnalysis extends PredicateHelper { case Filter(condition, _) => splitConjunctivePredicates(condition).foreach { - case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => + case _: Exists | Not(_: Exists) => + case In(_, Seq(_: ListQuery)) => + case Not(In(_, Seq(_: ListQuery))) => + case e if SubExprUtils.hasNullAwarePredicateWithinNot(e) => failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + s" conditions: $e") case e => @@ -306,7 +305,7 @@ trait CheckAnalysis extends PredicateHelper { s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") } - case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => + case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") case _: Union | _: SetOperation if operator.children.length > 1 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2c00957bd6af..4ab2e65c7cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -108,6 +108,26 @@ object TypeCoercion { case _ => None } + /** + * This function determines the target type of a comparison operator when one operand + * is a String and the other is not. It also handles when one op is a Date and the + * other is a Timestamp by making the target type to be String. Currently this is used + * to coerce types between LHS and RHS of the IN expression. + */ + val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + case (StringType, DateType) => Some(StringType) + case (DateType, StringType) => Some(StringType) + case (StringType, TimestampType) => Some(StringType) + case (TimestampType, StringType) => Some(StringType) + case (TimestampType, DateType) => Some(StringType) + case (DateType, TimestampType) => Some(StringType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l, r) => findTightestCommonType(l, r) + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -365,17 +385,66 @@ object TypeCoercion { } /** - * Convert the value and in list expressions to the common operator type - * by looking at all the argument types and finding the closest one that - * all the arguments can be cast to. When no common operator type is found - * the original expression will be returned and an Analysis Exception will - * be raised at type checking phase. + * Handles type coercion for both IN expression with subquery and IN + * expressions without subquery. + * 1. In the first case, find the common type by comparing the left hand side + * expression types against corresponding right hand side expression derived + * from the subquery expression's plan output. Inject appropriate casts in the + * LHS and RHS side of IN expression. + * + * 2. In the second case, convert the value and in list expressions to the + * common operator type by looking at all the argument types and finding + * the closest one that all the arguments can be cast to. When no common + * operator type is found the original expression will be returned and an + * Analysis Exception will be raised at the type checking phase. */ object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Handle type casting required between value expression and subquery output + // in IN subquery. + case i @ In(a, Seq(ListQuery(sub, children, exprId))) if !i.resolved => + // lhs is the value expression of IN subquery. + val lhs = a match { + // Multi columns in IN clause is represented as a CreateNamedStruct. + // flatten the named struct to get the list of expressions. + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + + // rhs is the subquery output. + val rhs = sub.output + require(lhs.length == rhs.length) + + val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => + findCommonTypeForBinaryComparison(l.dataType, r.dataType) + } + + if (commonTypes.length == lhs.length) { + val castedRhs = rhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + val castedLhs = lhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Cast(e, dt) + case (e, _) => e + } + + // Before constructing the In expression, wrap the multi values in lhs + // in a CreatedNamedStruct. + val newLhs = a match { + case cns: CreateNamedStruct => + val nameValue = cns.nameExprs.zip(castedLhs).flatMap(pair => Seq(pair._1, pair._2)) + CreateNamedStruct(nameValue) + case _ => castedLhs.head + } + In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + } else { + i + } + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac56ff13fa5b..0ab62c8d701c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -123,19 +123,36 @@ case class Not(child: Expression) */ @ExpressionDescription( usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.") -case class In(value: Expression, list: Seq[Expression]) extends Predicate - with ImplicitCastInputTypes { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") - - override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) - override def checkInputDataTypes(): TypeCheckResult = { - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be same type") - } else { - TypeCheckResult.TypeCheckSuccess + list match { + case ListQuery(sub, _, _) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + val isTypeMismatched = valExprs.zip(sub.output).exists { + case (l, r) => l.dataType != r.dataType + } + if (isTypeMismatched) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the LHS of an IN subquery + |[${valExprs.map(_.dataType).mkString(", ")}] + |is not compatible with the data type of the output of the subquery + |[${sub.output.map(_.dataType).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } + case _ => + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e2e7d98e3345..423c95d07c9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.types._ /** @@ -40,19 +42,179 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { /** * A base interface for expressions that contain a [[LogicalPlan]]. */ -abstract class SubqueryExpression extends PlanExpression[LogicalPlan] { +abstract class SubqueryExpression( + plan: LogicalPlan, + children: Seq[Expression], + exprId: ExprId) extends PlanExpression[LogicalPlan] { + + override lazy val resolved: Boolean = childrenResolved && plan.resolved + override lazy val references: AttributeSet = + if (plan.resolved) super.references -- plan.outputSet else super.references override def withNewPlan(plan: LogicalPlan): SubqueryExpression + override def semanticEquals(o: Expression): Boolean = o match { + case p: SubqueryExpression => + this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } } object SubqueryExpression { + /** + * Returns true when an expression contains an IN or EXISTS subquery and false otherwise. + */ + def hasInOrExistsSubquery(e: Expression): Boolean = { + e.find { + case _: ListQuery | _: Exists => true + case _ => false + }.isDefined + } + + /** + * Returns true when an expression contains a subquery that has outer reference(s). The outer + * reference attributes are kept as children of subquery expression by + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] + */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { - case e: SubqueryExpression if e.children.nonEmpty => true + case s: SubqueryExpression if s.children.nonEmpty => true case _ => false }.isDefined } } +object SubExprUtils extends PredicateHelper { + /** + * Returns true when an expression contains correlated predicates i.e outer references and + * returns false otherwise. + */ + def containsOuter(e: Expression): Boolean = { + e.find(_.isInstanceOf[OuterReference]).isDefined + } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { + e.find{ x => + x.isInstanceOf[Not] && e.find { + case In(_, Seq(_: ListQuery)) => true + case _ => false + }.isDefined + }.isDefined + } + + /** + * Returns an expression after removing the OuterReference shell. + */ + def stripOuterReference(e: Expression): Expression = { + e.transform { + case OuterReference(r) => r + } + } + + /** + * Returns the list of expressions after removing the OuterReference shell from each of + * the expression. + */ + def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + + /** + * Returns the logical plan after removing the OuterReference shell from all the expressions + * of the input logical plan. + */ + def stripOuterReferences(p: LogicalPlan): LogicalPlan = { + p.transformAllExpressions { + case OuterReference(a) => a + } + } + + /** + * Given a list of expressions, returns the expressions which have outer references. Aggregate + * expressions are treated in a special way. If the children of aggregate expression contains an + * outer reference, then the entire aggregate expression is marked as an outer reference. + * Example (SQL): + * {{{ + * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b)) + * }}} + * In the above case, we want to mark the entire min(b) as an outer reference + * OuterReference(min(b)) instead of min(OuterReference(b)). + * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of + * outer references and local references under an aggregate expression. + * For example (SQL): + * {{{ + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a) + max(p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + sq.c) > 1)) + * }}} + * The code below needs to change when we support the above cases. + */ + def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + val outerExpressions = scala.collection.mutable.ArrayBuffer.empty[Expression] + conditions foreach { expr => + expr transformDown { + case a: AggregateExpression if containsOuter(a) => + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e + } + } + outerExpressions + } + + /** + * Returns all the expressions that have outer references from a logical plan. Currently only + * Filter operator can host outer references. + */ + def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { + val conditions = BooleanSimplification(plan) collect { case Filter(cond, _) => cond } + getOuterReferences(conditions) + } + + /** + * Returns the correlated predicates from a logical plan. The OuterReference wrapper + * is removed before returning the predicate to the caller. + */ + def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { + val correlatedPredicates = scala.collection.mutable.ArrayBuffer.empty[Seq[Expression]] + val conditions = BooleanSimplification(plan) collect { case Filter(cond, _) => cond } + + // Collect all the expressions that have outer references. + conditions foreach { e => + val (corr, _) = splitConjunctivePredicates(e).partition(containsOuter) + val correlated = stripOuterReferences(corr) + correlated match { + case Nil => + case xs => + correlatedPredicates += xs + } + } + correlatedPredicates.flatten + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -63,14 +225,8 @@ case class ScalarSubquery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved: Boolean = childrenResolved && plan.resolved - override lazy val references: AttributeSet = { - if (plan.resolved) super.references -- plan.outputSet - else super.references - } + extends SubqueryExpression(plan, children, exprId) with Unevaluable { override def dataType: DataType = plan.schema.fields.head.dataType - override def foldable: Boolean = false override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" @@ -79,59 +235,12 @@ case class ScalarSubquery( object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { e.find { - case e: ScalarSubquery if e.children.nonEmpty => true + case s: ScalarSubquery if s.children.nonEmpty => true case _ => false }.isDefined } } -/** - * A predicate subquery checks the existence of a value in a sub-query. We currently only allow - * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will - * be rewritten into a left semi/anti join during analysis. - */ -case class PredicateSubquery( - plan: LogicalPlan, - children: Seq[Expression] = Seq.empty, - nullAware: Boolean = false, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = childrenResolved && plan.resolved - override lazy val references: AttributeSet = super.references -- plan.outputSet - override def nullable: Boolean = nullAware - override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan) - override def semanticEquals(o: Expression): Boolean = o match { - case p: PredicateSubquery => - plan.sameResult(p.plan) && nullAware == p.nullAware && - children.length == p.children.length && - children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) - case _ => false - } - override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" -} - -object PredicateSubquery { - def hasPredicateSubquery(e: Expression): Boolean = { - e.find { - case _: PredicateSubquery | _: ListQuery | _: Exists => true - case _ => false - }.isDefined - } - - /** - * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could - * turn the null-aware predicate into not-null-aware predicate. - */ - def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { - e.find{ x => - x.isInstanceOf[Not] && e.find { - case p: PredicateSubquery => p.nullAware - case _ => false - }.isDefined - }.isDefined - } -} - /** * A [[ListQuery]] expression defines the query which we want to search in an IN subquery * expression. It should and can only be used in conjunction with an IN expression. @@ -144,18 +253,20 @@ object PredicateSubquery { * FROM b) * }}} */ -case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty - override def dataType: DataType = ArrayType(NullType) +case class ListQuery( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = plan.schema.fields.head.dataType override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def toString: String = s"list#${exprId.id}" + override def toString: String = s"list#${exprId.id} $conditionString" } /** * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * * For example (SQL): * {{{ * SELECT * @@ -165,11 +276,12 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr * WHERE b.id = a.id) * }}} */ -case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty +case class Exists( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id}" + override def toString: String = s"exists#${exprId.id} $conditionString" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index caafa1c134cd..e9dbded3d4d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,6 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Pullup Correlated Expressions", Once, + PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, @@ -885,7 +887,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { val attributes = plan.outputSet val matched = condition.find { - case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty + case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty case _ => false } matched.isEmpty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index fb7ce6aecea5..f3820f298092 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -41,10 +41,17 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + private def getValueExpression(e: Expression): Seq[Expression] = { + e match { + case cns : CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = - splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { @@ -54,20 +61,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, PredicateSubquery(sub, conditions, _, _)) => + case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, false, _))) => + case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, true, _))) => + case (p, In(e, Seq(l @ ListQuery(sub, conditions, _)))) => + val inConditions = getValueExpression(e).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, sub, LeftSemi, joinCond) + case (p, Not(In(e, Seq(l @ ListQuery(sub, conditions, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + val inConditions = getValueExpression(e).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: // (a1,b1,...) = (a2,b2,...) @@ -83,11 +95,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } /** - * Given a predicate expression and an input plan, it rewrites - * any embedded existential sub-query into an existential join. - * It returns the rewritten expression together with the updated plan. - * Currently, it does not support null-aware joins. Embedded NOT IN predicates - * are blocked in the Analyzer. + * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query + * into an existential join. It returns the rewritten expression together with the updated plan. + * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in + * the Analyzer. */ private def rewriteExistentialExpr( exprs: Seq[Expression], @@ -95,17 +106,139 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { var newPlan = plan val newExprs = exprs.map { e => e transformUp { - case PredicateSubquery(sub, conditions, nullAware, _) => - // TODO: support null-aware join + case Exists(sub, conditions, exprId) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) + newPlan = Join(newPlan, sub, + ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - } + case In(e, Seq(l@ ListQuery(sub, conditions, exprId))) => + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val inConditions = getValueExpression(e).zip(sub.output).map(EqualTo.tupled) + newPlan = Join(newPlan, sub, + ExistenceJoin(exists), (inConditions ++ conditions).reduceLeftOption(And)) + exists + } } (newExprs.reduceOption(And), newPlan) } } + /** + * Pull out all (outer) correlated predicates from a given subquery. This method removes the + * correlated predicates from subquery [[Filter]]s and adds the references of these predicates + * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to + * be able to evaluate the predicates at the top level. + * + * TODO: Look to merge this rule with RewritePredicateSubquery. + */ +object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper { + /** + * Returns the correlated predicates and a updated plan that removes the outer references. + */ + private def pullOutCorrelatedPredicates( + sub: LogicalPlan, + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + + /** Determine which correlated predicate references are missing from this plan. */ + def missingReferences(p: LogicalPlan): AttributeSet = { + val localPredicateReferences = p.collect(predicateMap) + .flatten + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + localPredicateReferences -- p.outputSet + } + + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { + case f @ Filter(cond, child) => + val (correlated, local) = + splitConjunctivePredicates(cond).partition(SubExprUtils.containsOuter) + + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> xs + newFilter + case xs => + predicateMap += child -> xs + child + } + case p @ Project(expressions, child) => + val referencesToAdd = missingReferences(p) + if (referencesToAdd.nonEmpty) { + Project(expressions ++ referencesToAdd, child) + } else { + p + } + case a @ Aggregate(grouping, expressions, child) => + val referencesToAdd = missingReferences(a) + if (referencesToAdd.nonEmpty) { + Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + } else { + a + } + case p => + p + } + + // Make sure the inner and the outer query attributes do not collide. + // In case of a collision, change the subquery plan's output to use + // different attribute by creating alias(s). + val baseConditions = predicateMap.values.flatten.toSeq + val (newplan: LogicalPlan, newcond: Seq[Expression]) = if (outer.nonEmpty) { + val outputSet = outer.map(_.outputSet).reduce(_ ++ _) + val duplicates = transformed.outputSet.intersect(outputSet) + val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = transformed.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, transformed) + val aliasedConditions = baseConditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (transformed, baseConditions) + } + (plan, SubExprUtils.stripOuterReferences(deDuplicatedConditions)) + } else { + (transformed, SubExprUtils.stripOuterReferences(baseConditions)) + } + (newplan, newcond) + } + + private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = { + plan transformExpressions { + case s @ ScalarSubquery(sub, cond, exprId) if s.children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ScalarSubquery(newPlan, newCond, exprId) + case e @ Exists(sub, cond, exprId) if e.children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + Exists(newPlan, newCond, exprId) + case l @ ListQuery(sub, cond, exprId) => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ListQuery(newPlan, newCond, exprId) + } + } + + /** + * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case f @ Filter(_, a: Aggregate) => + rewriteSubQueries(f, Seq(a, a.child)) + // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. + case q: UnaryNode => + rewriteSubQueries(q, q.children) + } +} /** * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c5e877d12811..d2ebca5a83dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) val plan4 = Filter( Exists( Limit(1, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) @@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 4aafb2b83fb6..55693121431a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1) + val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.ResolveSubquery(expr) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e9b7a0c6ad67..5eb31413ad70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { e.copy(exprId = ExprId(0)) case l: ListQuery => l.copy(exprId = ExprId(0)) - case p: PredicateSubquery => - p.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 730ca27f82ba..58be2d1da281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -144,9 +144,6 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { ScalarSubquery( SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) => - val executedPlan = new QueryExecution(sparkSession, query).executedPlan - InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) } } } diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 50ae01e181bc..f7bbb35aad6c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -46,7 +46,7 @@ and t2b = (select max(avg) struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; +grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.; -- !query 4 @@ -63,4 +63,4 @@ where t1a in (select min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)]; +resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 25dbecb5894e..6f1cd49c08ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -622,7 +622,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { checkAnswer( - sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + sql( + """ + |select l.b, (select (r.c + count(*)) is null + |from r + |where l.a = r.c group by r.c) from l + """.stripMargin), Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } From 5f62a2cce095268f4cdeb607b3bcfb2d6782d3fd Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 28 Feb 2017 17:45:33 -0800 Subject: [PATCH 2/9] Review comments --- .../sql/catalyst/analysis/Analyzer.scala | 79 ++++++------------- .../sql/catalyst/analysis/CheckAnalysis.scala | 18 ++--- .../sql/catalyst/expressions/subquery.scala | 23 ++++-- 3 files changed, 48 insertions(+), 72 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5b26f5ea5b00..24e6cf93eb13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -732,51 +732,38 @@ class Analyzer( * For example (SQL): * {{{ * SELECT * FROM t1 - * WHERE EXISTS (SELECT 1 - * FROM t2 - * WHERE t1.c1 = t2.c1) * INTERSECT * SELECT * FROM t1 * WHERE EXISTS (SELECT 1 * FROM t2 * WHERE t1.c1 = t2.c1) * }}} - * Plan before resolveReference rule. + * Plan before resolveReference rule. * 'Intersect - * :- 'Project [*] - * : +- Filter exists#271 [c1#250] - * : : +- Project [1 AS 1#295] - * : : +- Filter (outer(c1#250) = c1#263) - * : : +- SubqueryAlias t2 - * : : +- Relation[c1#263,c2#264] parquet - * : +- SubqueryAlias t1 - * : +- Relation[c1#250,c2#251] parquet + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet * +- 'Project [*] - * +- Filter exists#272 [c1#250] - * : +- Project [1 AS 1#298] - * : +- Filter (outer(c1#250) = c1#263) - * : +- SubqueryAlias t2 - * : +- Relation[c1#263,c2#264] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#250,c2#251] parquet + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet * Plan after the resolveReference rule. * Intersect - * :- Project [c1#250, c2#251] - * : +- Filter exists#271 [c1#250] - * : : +- Project [1 AS 1#295] - * : : +- Filter (outer(c1#250) = c1#263) - * : : +- SubqueryAlias t2 - * : : +- Relation[c1#263,c2#264] parquet - * : +- SubqueryAlias t1 - * : +- Relation[c1#250,c2#251] parquet - * +- Project [c1#299, c2#300] - * +- Filter exists#272 [c1#299] - * : +- Project [1 AS 1#298] - * : +- Filter (outer(c1#299) = c1#263) ==> Updated + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated * : +- SubqueryAlias t2 - * : +- Relation[c1#263,c2#264] parquet + * : +- Relation[c1#251,c2#252] parquet * +- SubqueryAlias t1 - * +- Relation[c1#299,c2#300] parquet => Outer plan's attributes are de-duplicated. + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. */ private def dedupOuterReferencesInSubquery( plan: LogicalPlan, @@ -1212,7 +1199,7 @@ class Analyzer( * of subquery expressions by the caller of this function. */ private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { - val outerReferences = scala.collection.mutable.ArrayBuffer.empty[Seq[Expression]] + val outerReferences = ArrayBuffer.empty[Expression] // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { @@ -1265,7 +1252,7 @@ class Analyzer( // Simplify the predicates before validating any unsupported correlation patterns // in the plan. - val transformed = BooleanSimplification(sub) transformUp { + BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: @@ -1288,25 +1275,18 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias case p: BroadcastHint => - p case p: Distinct => - p case p: LeafNode => - p case p: Repartition => - p case p: SubqueryAlias => - p // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. case p: Sort => failOnOuterReference(p) - p case p: RepartitionByExpression => failOnOuterReference(p) - p // Category 3: // Filter is one of the two operators allowed to host correlated expressions. @@ -1321,14 +1301,12 @@ class Analyzer( case _: EqualTo | _: EqualNullSafe => false case _ => true } - outerReferences += SubExprUtils.getOuterReferences(correlated) - f + outerReferences ++= SubExprUtils.getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. case p @ Project(expressions, child) => failOnOuterReference(p) - p // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains @@ -1338,7 +1316,6 @@ class Analyzer( case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - a // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => @@ -1369,7 +1346,6 @@ class Analyzer( case _ => failOnOuterReferenceInSubTree(j) } - j // Generator with join=true, i.e., expressed with // LATERAL VIEW [OUTER], similar to inner join, @@ -1379,7 +1355,6 @@ class Analyzer( // Generator with join=false is treated as Category 4. case p @ Generate(generator, true, _, _, _, _) => failOnOuterReference(p) - p // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only @@ -1387,9 +1362,8 @@ class Analyzer( // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - p } - outerReferences.flatten + outerReferences } /** @@ -1431,8 +1405,7 @@ class Analyzer( } // Validate the outer reference and record the outer references as children of // subquery expression. - f.tupled(current, checkAndGetOuterReferences(current)) - + f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } @@ -2362,7 +2335,7 @@ class Analyzer( // the types between the value expression and list query expression of IN expression. // We need to subject the subquery plan through ResolveTimeZone again to setup timezone // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(ResolveTimeZone.apply(e.plan)) + case e: ListQuery => e.withNewPlan(apply(e.plan)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2481ad18721b..de127ede8595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -208,16 +208,9 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case Filter(condition, _) => - splitConjunctivePredicates(condition).foreach { - case _: Exists | Not(_: Exists) => - case In(_, Seq(_: ListQuery)) => - case Not(In(_, Seq(_: ListQuery))) => - case e if SubExprUtils.hasNullAwarePredicateWithinNot(e) => - failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - s" conditions: $e") - case e => - } + case Filter(condition, _) if SubExprUtils.hasNullAwarePredicateWithinNot(condition) => + failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + + s" conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( @@ -306,7 +299,10 @@ trait CheckAnalysis extends PredicateHelper { } case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => - failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + p match { + case _: Filter => // Ok + case other => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + } case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 423c95d07c9e..3a20ea8f4e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -97,13 +99,18 @@ object SubExprUtils extends PredicateHelper { * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could * turn the null-aware predicate into not-null-aware predicate. */ - def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { - e.find{ x => - x.isInstanceOf[Not] && e.find { - case In(_, Seq(_: ListQuery)) => true - case _ => false + def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { + splitConjunctivePredicates(condition).exists { + case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => + false + case e => e.find { x => + x.isInstanceOf[Not] && e.find { + case In(_, Seq(_: ListQuery)) => true + case _ => false + }.isDefined }.isDefined - }.isDefined + } + } /** @@ -169,7 +176,7 @@ object SubExprUtils extends PredicateHelper { * The code below needs to change when we support the above cases. */ def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { - val outerExpressions = scala.collection.mutable.ArrayBuffer.empty[Expression] + val outerExpressions = ArrayBuffer.empty[Expression] conditions foreach { expr => expr transformDown { case a: AggregateExpression if containsOuter(a) => @@ -198,7 +205,7 @@ object SubExprUtils extends PredicateHelper { * is removed before returning the predicate to the caller. */ def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { - val correlatedPredicates = scala.collection.mutable.ArrayBuffer.empty[Seq[Expression]] + val correlatedPredicates = ArrayBuffer.empty[Seq[Expression]] val conditions = BooleanSimplification(plan) collect { case Filter(cond, _) => cond } // Collect all the expressions that have outer references. From 8a8a7afcf62d7826e56bc16287546cc1be82d687 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 1 Mar 2017 13:48:18 -0800 Subject: [PATCH 3/9] review comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/expressions/subquery.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 24e6cf93eb13..5981470fb3ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1203,7 +1203,7 @@ class Analyzer( // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (p.find(SubExprUtils.getOuterReferences(_).nonEmpty).nonEmpty) { + if (SubExprUtils.hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 3a20ea8f4e4f..567e40650c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.types._ @@ -138,6 +137,16 @@ object SubExprUtils extends PredicateHelper { } } + /** + * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. + */ + def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find { + case f: Filter => containsOuter(f.condition) + case other => false + }.isDefined + } + /** * Given a list of expressions, returns the expressions which have outer references. Aggregate * expressions are treated in a special way. If the children of aggregate expression contains an @@ -196,7 +205,7 @@ object SubExprUtils extends PredicateHelper { * Filter operator can host outer references. */ def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { - val conditions = BooleanSimplification(plan) collect { case Filter(cond, _) => cond } + val conditions = plan.collect { case Filter(cond, _) => cond } getOuterReferences(conditions) } @@ -206,7 +215,7 @@ object SubExprUtils extends PredicateHelper { */ def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { val correlatedPredicates = ArrayBuffer.empty[Seq[Expression]] - val conditions = BooleanSimplification(plan) collect { case Filter(cond, _) => cond } + val conditions = plan.collect { case Filter(cond, _) => cond } // Collect all the expressions that have outer references. conditions foreach { e => From f0d2e7f752574d7ffb717b9b53c41231a09268cd Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 2 Mar 2017 01:18:40 -0800 Subject: [PATCH 4/9] Review comments --- .../sql/catalyst/analysis/Analyzer.scala | 23 +++++------ .../sql/catalyst/analysis/CheckAnalysis.scala | 13 +++---- .../sql/catalyst/analysis/TypeCoercion.scala | 16 +++++--- .../sql/catalyst/expressions/subquery.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 38 +++++++++---------- 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5981470fb3ae..ccd721c88556 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} @@ -1203,14 +1204,14 @@ class Analyzer( // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (SubExprUtils.hasOuterReferences(p)) { + if (hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { - if (p.expressions.exists(SubExprUtils.containsOuter)) { + if (p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + s"clauses:\n$p") @@ -1274,11 +1275,7 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case p: BroadcastHint => - case p: Distinct => - case p: LeafNode => - case p: Repartition => - case p: SubqueryAlias => + case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. @@ -1294,14 +1291,14 @@ class Analyzer( case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = - splitConjunctivePredicates(cond).partition(SubExprUtils.containsOuter) + splitConjunctivePredicates(cond).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { case _: EqualTo | _: EqualNullSafe => false case _ => true } - outerReferences ++= SubExprUtils.getOuterReferences(correlated) + outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. @@ -1426,14 +1423,14 @@ class Analyzer( resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(e, Seq(l @ ListQuery(sub, _, exprId))) if e.resolved && !sub.resolved => + case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. - val expressions = e match { + val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) - In(e, Seq(expr)) + In(value, Seq(expr)) } } @@ -2559,7 +2556,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { refExprs: Seq[Expression]): LogicalPlan = { plan transformAllExpressions { case e => val outerAlias = - refExprs.find(stripAlias(_).semanticEquals(SubExprUtils.stripOuterReference(e))) + refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { case Some(a: Alias) => OuterReference(a.toAttribute) case _ => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index de127ede8595..d2ed35a66e68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -155,10 +156,8 @@ trait CheckAnalysis extends PredicateHelper { // are not part of the correlated columns. val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) // Collect the local references from the correlated predicate in the subquery. - val subqueryColumns = - SubExprUtils.getCorrelatedPredicates(query) - .flatMap(_.references) - .filterNot(conditions.flatMap(_.references).contains) + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates @@ -208,9 +207,9 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case Filter(condition, _) if SubExprUtils.hasNullAwarePredicateWithinNot(condition) => + case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - s" conditions: $condition") + " conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( @@ -301,7 +300,7 @@ trait CheckAnalysis extends PredicateHelper { case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => p match { case _: Filter => // Ok - case other => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") } case _: Union | _: SetOperation if operator.children.length > 1 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 4ab2e65c7cc7..879a53ce23f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -387,8 +387,8 @@ object TypeCoercion { /** * Handles type coercion for both IN expression with subquery and IN * expressions without subquery. - * 1. In the first case, find the common type by comparing the left hand side - * expression types against corresponding right hand side expression derived + * 1. In the first case, find the common type by comparing the left hand side (LHS) + * expression types against corresponding right hand side (RHS) expression derived * from the subquery expression's plan output. Inject appropriate casts in the * LHS and RHS side of IN expression. * @@ -406,7 +406,7 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. case i @ In(a, Seq(ListQuery(sub, children, exprId))) if !i.resolved => - // lhs is the value expression of IN subquery. + // LHS is the value expression of IN subquery. val lhs = a match { // Multi columns in IN clause is represented as a CreateNamedStruct. // flatten the named struct to get the list of expressions. @@ -414,7 +414,7 @@ object TypeCoercion { case expr => Seq(expr) } - // rhs is the subquery output. + // RHS is the subquery output. val rhs = sub.output require(lhs.length == rhs.length) @@ -422,6 +422,8 @@ object TypeCoercion { findCommonTypeForBinaryComparison(l.dataType, r.dataType) } + // The number of columns/expressions must match between LHS and RHS of an + // IN subquery expression. if (commonTypes.length == lhs.length) { val castedRhs = rhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() @@ -432,11 +434,13 @@ object TypeCoercion { case (e, _) => e } - // Before constructing the In expression, wrap the multi values in lhs + // Before constructing the In expression, wrap the multi values in LHS // in a CreatedNamedStruct. val newLhs = a match { case cns: CreateNamedStruct => - val nameValue = cns.nameExprs.zip(castedLhs).flatMap(pair => Seq(pair._1, pair._2)) + val nameValue = cns.nameExprs.zip(castedLhs).flatMap { + case (name, value) => Seq(name, value) + } CreateNamedStruct(nameValue) case _ => castedLhs.head } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 567e40650c87..da725b0e0631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -79,7 +79,7 @@ object SubqueryExpression { */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { - case s: SubqueryExpression if s.children.nonEmpty => true + case s: SubqueryExpression => s.children.nonEmpty case _ => false }.isDefined } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index f3820f298092..ba3fd1d5f802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -67,18 +68,18 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, In(e, Seq(l @ ListQuery(sub, conditions, _)))) => - val inConditions = getValueExpression(e).zip(sub.output).map(EqualTo.tupled) + case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(In(e, Seq(l @ ListQuery(sub, conditions, _))))) => + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = getValueExpression(e).zip(sub.output).map(EqualTo.tupled) + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -106,16 +107,15 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { var newPlan = plan val newExprs = exprs.map { e => e transformUp { - case Exists(sub, conditions, exprId) => + case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - newPlan = Join(newPlan, sub, - ExistenceJoin(exists), conditions.reduceLeftOption(And)) + newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - case In(e, Seq(l@ ListQuery(sub, conditions, exprId))) => + case In(value, Seq(ListQuery(sub, conditions, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = getValueExpression(e).zip(sub.output).map(EqualTo.tupled) - newPlan = Join(newPlan, sub, - ExistenceJoin(exists), (inConditions ++ conditions).reduceLeftOption(And)) + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions) exists } } @@ -154,7 +154,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper val transformed = BooleanSimplification(sub) transformUp { case f @ Filter(cond, child) => val (correlated, local) = - splitConjunctivePredicates(cond).partition(SubExprUtils.containsOuter) + splitConjunctivePredicates(cond).partition(containsOuter) // Rewrite the filter without the correlated predicates if any. correlated match { @@ -189,7 +189,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper // In case of a collision, change the subquery plan's output to use // different attribute by creating alias(s). val baseConditions = predicateMap.values.flatten.toSeq - val (newplan: LogicalPlan, newcond: Seq[Expression]) = if (outer.nonEmpty) { + val (newPlan, newCond) = if (outer.nonEmpty) { val outputSet = outer.map(_.outputSet).reduce(_ ++ _) val duplicates = transformed.outputSet.intersect(outputSet) val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { @@ -207,22 +207,22 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } else { (transformed, baseConditions) } - (plan, SubExprUtils.stripOuterReferences(deDuplicatedConditions)) + (plan, stripOuterReferences(deDuplicatedConditions)) } else { - (transformed, SubExprUtils.stripOuterReferences(baseConditions)) + (transformed, stripOuterReferences(baseConditions)) } - (newplan, newcond) + (newPlan, newCond) } private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { - case s @ ScalarSubquery(sub, cond, exprId) if s.children.nonEmpty => + case ScalarSubquery(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) ScalarSubquery(newPlan, newCond, exprId) - case e @ Exists(sub, cond, exprId) if e.children.nonEmpty => + case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) Exists(newPlan, newCond, exprId) - case l @ ListQuery(sub, cond, exprId) => + case ListQuery(sub, _, exprId) => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) ListQuery(newPlan, newCond, exprId) } From 55842fa6c0a333c990caa132ea5002ea2e8829f3 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 2 Mar 2017 01:32:27 -0800 Subject: [PATCH 5/9] Review comment --- .../org/apache/spark/sql/catalyst/expressions/subquery.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index da725b0e0631..2bcf8375cb2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -251,7 +251,7 @@ case class ScalarSubquery( object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { e.find { - case s: ScalarSubquery if s.children.nonEmpty => true + case s: ScalarSubquery => s.children.nonEmpty case _ => false }.isDefined } From c677ed845a9c0b0cdf234c5cb106112335493443 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 2 Mar 2017 13:40:21 -0800 Subject: [PATCH 6/9] code review --- .../sql/catalyst/analysis/Analyzer.scala | 24 +++++++++---------- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 ++-- .../sql/catalyst/expressions/subquery.scala | 13 ++++------ 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ccd721c88556..a67cd9c96a5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1280,18 +1280,17 @@ class Analyzer( // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. - case p: Sort => - failOnOuterReference(p) - case p: RepartitionByExpression => - failOnOuterReference(p) + case s: Sort => + failOnOuterReference(s) + case r: RepartitionByExpression => + failOnOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f @ Filter(cond, child) => + case f: Filter => // Find all predicates with an outer reference. - val (correlated, local) = - splitConjunctivePredicates(cond).partition(containsOuter) + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { @@ -1302,7 +1301,7 @@ class Analyzer( // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. - case p @ Project(expressions, child) => + case p: Project => failOnOuterReference(p) // Aggregate cannot host any correlated expressions @@ -1310,7 +1309,7 @@ class Analyzer( // only equality correlated predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. - case a @ Aggregate(grouping, expressions, child) => + case a: Aggregate => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) @@ -1350,8 +1349,8 @@ class Analyzer( // but must not host any outer references. // Note: // Generator with join=false is treated as Category 4. - case p @ Generate(generator, true, _, _, _, _) => - failOnOuterReference(p) + case g: Generate if g.join => + failOnOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only @@ -1392,7 +1391,8 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: pull the outer references and record them as children of SubqueryExpression + // Step 2: If the subquery plan is fully resolved, pull the outer references and record + // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d2ed35a66e68..0420f4ae03d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -208,8 +208,8 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${f.condition.dataType.simpleString} is not a boolean.") case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => - failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - " conditions: $condition") + failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + + s"conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 2bcf8375cb2c..af660429980b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -115,11 +115,7 @@ object SubExprUtils extends PredicateHelper { /** * Returns an expression after removing the OuterReference shell. */ - def stripOuterReference(e: Expression): Expression = { - e.transform { - case OuterReference(r) => r - } - } + def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } /** * Returns the list of expressions after removing the OuterReference shell from each of @@ -219,10 +215,9 @@ object SubExprUtils extends PredicateHelper { // Collect all the expressions that have outer references. conditions foreach { e => - val (corr, _) = splitConjunctivePredicates(e).partition(containsOuter) - val correlated = stripOuterReferences(corr) - correlated match { - case Nil => + val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter) + stripOuterReferences(correlated) match { + case Nil => // no-op case xs => correlatedPredicates += xs } From 00c890ef9572ae195b171ad4f724e2157640e3fd Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 4 Mar 2017 16:47:14 -0800 Subject: [PATCH 7/9] Review comments --- .../sql/catalyst/analysis/TypeCoercion.scala | 56 +++++++------------ 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 879a53ce23f9..70e5d034c595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -111,10 +111,12 @@ object TypeCoercion { /** * This function determines the target type of a comparison operator when one operand * is a String and the other is not. It also handles when one op is a Date and the - * other is a Timestamp by making the target type to be String. Currently this is used - * to coerce types between LHS and RHS of the IN expression. + * other is a Timestamp by making the target type to be String. */ val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + // We should cast all relative timestamp/date/string comparison into string comparisons + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true case (StringType, DateType) => Some(StringType) case (DateType, StringType) => Some(StringType) case (StringType, TimestampType) => Some(StringType) @@ -125,7 +127,7 @@ object TypeCoercion { case (NullType, StringType) => Some(StringType) case (l: StringType, r: AtomicType) if r != StringType => Some(r) case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) - case (l, r) => findTightestCommonType(l, r) + case (l, r) => None } /** @@ -325,6 +327,14 @@ object TypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -341,37 +351,10 @@ object TypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisons - // This behaves as a user would expect because timestamp strings sort lexicographically. - // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true - case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - - // Comparisons between dates and timestamps. - case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - - // Checking NullType - case p @ BinaryComparison(left @ StringType(), right @ NullType()) => - p.makeCopy(Array(left, Literal.create(null, StringType))) - case p @ BinaryComparison(left @ NullType(), right @ StringType()) => - p.makeCopy(Array(Literal.create(null, StringType), right)) - - // When compare string with atomic type, case string to that type. - case p @ BinaryComparison(left @ StringType(), right @ AtomicType()) - if right.dataType != StringType => - p.makeCopy(Array(Cast(left, right.dataType), right)) - case p @ BinaryComparison(left @ AtomicType(), right @ StringType()) - if left.dataType != StringType => - p.makeCopy(Array(left, Cast(right, left.dataType))) + case p @ BinaryComparison(left, right) + if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) @@ -419,7 +402,10 @@ object TypeCoercion { require(lhs.length == rhs.length) val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) + findCommonTypeForBinaryComparison(l.dataType, r.dataType) match { + case d @ Some(_) => d + case _ => findTightestCommonType(l.dataType, r.dataType) + } } // The number of columns/expressions must match between LHS and RHS of an From 27cb36a2ddfd13100348f8294155806b8f72fec3 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 13 Mar 2017 01:51:15 -0700 Subject: [PATCH 8/9] rebase --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0420f4ae03d7..da0c6b098f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -136,10 +136,6 @@ trait CheckAnalysis extends PredicateHelper { s"Scalar subquery must return only one column, but got ${query.output.size}") } else if (conditions.nonEmpty) { - } else if (conditions.nonEmpty) { - // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) - def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. From 19cdbb040ccf2e74e1271ca33e6842607c1e0760 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 13 Mar 2017 22:27:05 -0700 Subject: [PATCH 9/9] Code review --- .../sql/catalyst/analysis/Analyzer.scala | 4 ++ .../sql/catalyst/analysis/TypeCoercion.scala | 37 +++++++++---------- .../sql/catalyst/expressions/predicates.scala | 22 +++++++---- .../sql/catalyst/expressions/subquery.scala | 13 ++----- 4 files changed, 41 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a67cd9c96a5f..a3764d8c843d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1297,6 +1297,10 @@ class Analyzer( case _: EqualTo | _: EqualNullSafe => false case _ => true } + // The aggregate expressions are treated in a special way by getOuterReferences. If the + // aggregate expression contains only outer reference attributes then the entire aggregate + // expression is isolated as an OuterReference. + // i.e min(OuterReference(b)) => OuterReference(min(b)) outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 70e5d034c595..768897dc0713 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -382,30 +382,32 @@ object TypeCoercion { * Analysis Exception will be raised at the type checking phase. */ object InConversion extends Rule[LogicalPlan] { + private def flattenExpr(expr: Expression): Seq[Expression] = { + expr match { + // Multi columns in IN clause is represented as a CreateNamedStruct. + // flatten the named struct to get the list of expressions. + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId))) if !i.resolved => + case i @ In(a, Seq(ListQuery(sub, children, exprId))) + if !i.resolved && flattenExpr(a).length == sub.output.length => // LHS is the value expression of IN subquery. - val lhs = a match { - // Multi columns in IN clause is represented as a CreateNamedStruct. - // flatten the named struct to get the list of expressions. - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } + val lhs = flattenExpr(a) // RHS is the subquery output. val rhs = sub.output - require(lhs.length == rhs.length) val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) match { - case d @ Some(_) => d - case _ => findTightestCommonType(l.dataType, r.dataType) - } + findCommonTypeForBinaryComparison(l.dataType, r.dataType) + .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -422,14 +424,11 @@ object TypeCoercion { // Before constructing the In expression, wrap the multi values in LHS // in a CreatedNamedStruct. - val newLhs = a match { - case cns: CreateNamedStruct => - val nameValue = cns.nameExprs.zip(castedLhs).flatMap { - case (name, value) => Seq(name, value) - } - CreateNamedStruct(nameValue) - case _ => castedLhs.head + val newLhs = castedLhs match { + case Seq(lhs) => lhs + case _ => CreateStruct(castedLhs) } + In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) } else { i diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 0ab62c8d701c..e5d1a1e2996c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -133,16 +133,24 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case cns: CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - val isTypeMismatched = valExprs.zip(sub.output).exists { - case (l, r) => l.dataType != r.dataType + + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None } - if (isTypeMismatched) { + + if (mismatchedColumns.nonEmpty) { TypeCheckResult.TypeCheckFailure( s""" - |The data type of one or more elements in the LHS of an IN subquery - |[${valExprs.map(_.dataType).mkString(", ")}] - |is not compatible with the data type of the output of the subquery - |[${sub.output.map(_.dataType).mkString(", ")}]. + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. """.stripMargin) } else { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index af660429980b..ad11700fa28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -184,7 +184,7 @@ object SubExprUtils extends PredicateHelper { val outerExpressions = ArrayBuffer.empty[Expression] conditions foreach { expr => expr transformDown { - case a: AggregateExpression if containsOuter(a) => + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => val newExpr = stripOuterReference(a) outerExpressions += newExpr newExpr @@ -210,19 +210,14 @@ object SubExprUtils extends PredicateHelper { * is removed before returning the predicate to the caller. */ def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { - val correlatedPredicates = ArrayBuffer.empty[Seq[Expression]] val conditions = plan.collect { case Filter(cond, _) => cond } - - // Collect all the expressions that have outer references. - conditions foreach { e => + conditions.flatMap { e => val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter) stripOuterReferences(correlated) match { - case Nil => // no-op - case xs => - correlatedPredicates += xs + case Nil => None + case xs => xs } } - correlatedPredicates.flatten } }