Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1422,11 +1422,26 @@ class Analyzer(
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
case InSubquery(values, l @ ListQuery(_, _, exprId, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
})
In(value, Seq(expr))
val subqueryOutput = expr.plan.output
val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery])
if (values.length != subqueryOutput.length) {
throw new AnalysisException(
s"""Cannot analyze ${resolvedIn.sql}.
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${values.length}
|#columns in right hand side: ${subqueryOutput.length}
|Left side columns:
|[${values.map(_.sql).mkString(", ")}]
|Right side columns:
|[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin)
}
resolvedIn
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,27 +449,16 @@ object TypeCoercion {
* Analysis Exception will be raised at the type checking phase.
*/
case class InConversion(conf: SQLConf) extends TypeCoercionRule {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
// flatten the named struct to get the list of expressions.
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)

case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _))
if !i.resolved && lhs.length == sub.output.length =>
// LHS is the value expressions of IN subquery.
// RHS is the subquery output.
val rhs = sub.output

Expand All @@ -485,20 +474,13 @@ object TypeCoercion {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
val castedLhs = lhs.zip(commonTypes).map {
val newLhs = lhs.zip(commonTypes).map {
case (e, dt) if e.dataType != dt => Cast(e, dt)
case (e, _) => e
}

// Before constructing the In expression, wrap the multi values in LHS
// in a CreatedNamedStruct.
val newLhs = castedLhs match {
case Seq(lhs) => lhs
case _ => CreateStruct(castedLhs)
}

val newSub = Project(castedRhs, sub)
In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output))
} else {
i
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ package object dsl {
def <=> (other: Expression): Predicate = EqualNullSafe(expr, other)
def =!= (other: Expression): Predicate = Not(EqualTo(expr, other))

def in(list: Expression*): Expression = In(expr, list)
def in(list: Expression*): Expression = list match {
case Seq(l: ListQuery) => expr match {
case c: CreateNamedStruct => InSubquery(c.valExprs, l)
case other => InSubquery(Seq(other), l)
}
case _ => In(expr, list)
}

def like(other: Expression): Expression = Like(expr, other)
def rlike(other: Expression): Expression = RLike(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ object Canonicalize {
case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r)

// order the list in the In operator
// In subqueries contain only one element of type ListQuery. So checking that the length > 1
// we are not reordering In subqueries.
case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode()))

case _ => e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,66 @@ case class Not(child: Expression)
override def sql: String = s"(NOT ${child.sql})"
}

/**
* Evaluates to `true` if `values` are returned in `query`'s result set.
*/
case class InSubquery(values: Seq[Expression], query: ListQuery)
extends Predicate with Unevaluable {

@transient lazy val value: Expression = if (values.length > 1) {
CreateNamedStruct(values.zipWithIndex.flatMap {
case (v: NamedExpression, _) => Seq(Literal(v.name), v)
case (v, idx) => Seq(Literal(s"_$idx"), v)
})
} else {
values.head
}


override def checkInputDataTypes(): TypeCheckResult = {
val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType,
ignoreNullability = true)
if (mismatchOpt) {
if (values.length != query.childOutputs.length) {
TypeCheckResult.TypeCheckFailure(
s"""
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${values.length}.
|#columns in right hand side: ${query.childOutputs.length}.
|Left side columns:
|[${values.map(_.sql).mkString(", ")}].
|Right side columns:
|[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
} else {
val mismatchedColumns = values.zip(query.childOutputs).flatMap {
case (l, r) if l.dataType != r.dataType =>
Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})")
case _ => None
}
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
|is not compatible with the data type of the output of the subquery
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${values.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
}
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}

override def children: Seq[Expression] = values :+ query
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"$value IN ($query)"
override def sql: String = s"(${value.sql} IN (${query.sql}))"
}


/**
* Evaluates to `true` if `list` contains `value`.
Expand Down Expand Up @@ -169,44 +229,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType,
ignoreNullability = true))
if (mismatchOpt.isDefined) {
list match {
case ListQuery(_, _, _, childOutputs) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
if (valExprs.length != childOutputs.length) {
TypeCheckResult.TypeCheckFailure(
s"""
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${valExprs.length}.
|#columns in right hand side: ${childOutputs.length}.
|Left side columns:
|[${valExprs.map(_.sql).mkString(", ")}].
|Right side columns:
|[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
} else {
val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
case (l, r) if l.dataType != r.dataType =>
Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})")
case _ => None
}
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
|is not compatible with the data type of the output of the subquery
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
}
case _ =>
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}")
}
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}")
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
Expand Down Expand Up @@ -307,9 +331,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}

override def sql: String = {
val childrenSQL = children.map(_.sql)
val valueSQL = childrenSQL.head
val listSQL = childrenSQL.tail.mkString(", ")
val valueSQL = value.sql
val listSQL = list.map(_.sql).mkString(", ")
s"($valueSQL IN ($listSQL))"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper {
def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = {
splitConjunctivePredicates(condition).exists {
case _: Exists | Not(_: Exists) => false
case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false
case _: InSubquery | Not(_: InSubquery) => false
case e => e.find { x =>
x.isInstanceOf[Not] && e.find {
case In(_, Seq(_: ListQuery)) => true
case _: InSubquery => true
case _ => false
}.isDefined
}.isDefined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ object NullPropagation extends Rule[LogicalPlan] {

// If the value expression is NULL then transform the In expression to null literal.
case In(Literal(null, _), _) => Literal.create(null, BooleanType)
case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Please double check all the cases of IN in all the optimizer rules. We are afraid this new expression might introduce a regression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case in OptimizeInSuite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. I checked it again and I am pretty sure no regression is introduced. We don't have many optimizer rules using In and all the others were and are applied only to In with a list of literals. I am adding this and the other tests. Thanks.


// Non-leaf NullIntolerant expressions will return null, if at least one of its children is
// a null literal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand All @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
private def getValueExpression(e: Expression): Seq[Expression] = {
e match {
case cns : CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
// the produced join then becomes unresolved and break structural integrity. We should
Expand Down Expand Up @@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
Expand Down Expand Up @@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
newPlan = dedupJoin(
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
exists
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
// Deduplicate conflicting attributes if any.
newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case not => Not(e)
}

def getValueExpressions(e: Expression): Seq[Expression] = e match {
case c: CreateNamedStruct => c.valExprs
case other => Seq(other)
}

// Create the predicate.
ctx.kind.getType match {
case SqlBaseParser.BETWEEN =>
Expand All @@ -1112,7 +1117,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
GreaterThanOrEqual(e, expression(ctx.lower)),
LessThanOrEqual(e, expression(ctx.upper))))
case SqlBaseParser.IN if ctx.query != null =>
invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query)))))
invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
case SqlBaseParser.LIKE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val plan = Project(
Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()),
Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()),
LocalRelation(a))
assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil)
}
Expand All @@ -537,12 +537,13 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", BooleanType)()
val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType),
val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType),
LocalRelation(a))
assertAnalysisError(plan1,
"Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)

val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c))
val plan2 = Filter(
Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c))
assertAnalysisError(plan2,
"Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project}

/**
Expand All @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest {
val t2 = LocalRelation(b)

test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
val expr = Filter(
InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1)
val m = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr))
}.getMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") {
val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a")))
val originalQuery =
testRelation
.where(InSubquery(Seq(Literal.create(null, NullType)), subquery))
.analyze

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where(Literal.create(null, BooleanType))
.analyze
comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: Inset optimization disabled as " +
"list expression contains attribute)") {
val originalQuery =
Expand Down
Loading