From c9a36e0fb6301b6807ea6ca9e3415e899f7a83ac Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 23 May 2018 00:08:22 +0200 Subject: [PATCH 01/17] [SPARK-24313][SQL] Support In subqueries which are valid in other RDBMS --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++++- .../sql/catalyst/analysis/TypeCoercion.scala | 38 ++++++---------- .../spark/sql/catalyst/dsl/package.scala | 5 ++- .../sql/catalyst/expressions/predicates.scala | 32 ++++++++------ .../sql/catalyst/optimizer/expressions.scala | 10 ++--- .../sql/catalyst/optimizer/subquery.scala | 20 +++------ .../sql/catalyst/parser/AstBuilder.scala | 9 +++- .../statsEstimation/FilterEstimation.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 7 +-- .../sql/catalyst/analysis/AnalysisSuite.scala | 10 ++--- .../analysis/ResolveSubquerySuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 12 ++--- .../catalog/ExternalCatalogSuite.scala | 4 +- .../expressions/CanonicalizeSuite.scala | 12 ++--- .../catalyst/expressions/PredicateSuite.scala | 44 +++++++++---------- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 29 ++++++------ .../PullupCorrelatedPredicatesSuite.scala | 2 +- .../parser/ExpressionParserSuite.scala | 2 +- .../FilterEstimationSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 7 ++- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileSourceStrategy.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/DataSourceStrategySuite.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/client/HiveClientSuite.scala | 12 ++--- 29 files changed, 168 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e3107f1c6f75..44ad30d07fe3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1347,11 +1347,17 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + case i @ In(values, Seq(l @ ListQuery(_, _, exprId, _))) + if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - In(value, Seq(expr)) + val subqueryOutputNum = expr.asInstanceOf[ListQuery].childOutputs.length + if (values.length != subqueryOutputNum) { + throw new AnalysisException(s"${i.sql} has ${values.length} values, but the " + + s"subquery has $subqueryOutputNum output values.") + } + In(values, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..c1242f537d029 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -406,15 +406,6 @@ object TypeCoercion { * Analysis Exception will be raised at the type checking phase. */ case class InConversion(conf: SQLConf) extends TypeCoercionRule { - private def flattenExpr(expr: Expression): Seq[Expression] = { - expr match { - // Multi columns in IN clause is represented as a CreateNamedStruct. - // flatten the named struct to get the list of expressions. - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -422,11 +413,9 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && flattenExpr(a).length == sub.output.length => - // LHS is the value expression of IN subquery. - val lhs = flattenExpr(a) - + // LHS is the value expressions of IN subquery. + case i @ In(lhs, Seq(ListQuery(sub, children, exprId, _))) + if !i.resolved && lhs.length == sub.output.length => // RHS is the subquery output. val rhs = sub.output @@ -442,27 +431,26 @@ object TypeCoercion { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val castedLhs = lhs.zip(commonTypes).map { + val newLhs = lhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } - // Before constructing the In expression, wrap the multi values in LHS - // in a CreatedNamedStruct. - val newLhs = castedLhs match { - case Seq(lhs) => lhs - case _ => CreateStruct(castedLhs) - } - val newSub = Project(castedRhs, sub) In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) } else { i } - case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { - case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) + case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) => + findWiderCommonType(i.value.dataType +: b.map(_.dataType)) match { + case Some(finalDataType: StructType) if i.values.length > 1 => + val newValues = a.zip(finalDataType.fields.map(_.dataType)).map { + case (expr, dataType) => Cast(expr, dataType) + } + In(newValues, b.map(Cast(_, finalDataType))) + case Some(finalDataType) => + In(a.map(Cast(_, finalDataType)), b.map(Cast(_, finalDataType))) case None => i } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index efb2eba655e15..a3ec2713fb7b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,7 +88,10 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) + def in(list: Expression*): Expression = expr match { + case c: CreateNamedStruct => In(c.valExprs, list) + case other => In(Seq(other), list) + } def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f54103c4fbfba..527201ad2174e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -161,33 +161,38 @@ case class Not(child: Expression) true """) // scalastyle:on line.size.limit -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + @transient lazy val value = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + override def checkInputDataTypes(): TypeCheckResult = { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != childOutputs.length) { + if (values.length != childOutputs.length) { TypeCheckResult.TypeCheckFailure( s""" |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. + |#columns in left hand side: ${values.length}. |#columns in right hand side: ${childOutputs.length}. |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. + |[${values.map(_.sql).mkString(", ")}]. |Right side columns: |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) } else { - val mismatchedColumns = valExprs.zip(childOutputs).flatMap { + val mismatchedColumns = values.zip(childOutputs).flatMap { case (l, r) if l.dataType != r.dataType => s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" case _ => None @@ -199,7 +204,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |Mismatched columns: |[${mismatchedColumns.mkString(", ")}] |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |[${values.map(_.dataType.catalogString).mkString(", ")}]. |Right side: |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } @@ -212,7 +217,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } - override def children: Seq[Expression] = value +: list + override def children: Seq[Expression] = values ++: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) @@ -307,9 +312,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def sql: String = { - val childrenSQL = children.map(_.sql) - val valueSQL = childrenSQL.head - val listSQL = childrenSQL.tail.mkString(", ") + val valueSQL = value.sql + val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..11a231d6e9c32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -212,18 +212,18 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 1. Converts the predicate to false when the list is empty and * the value is not nullable. * 2. Removes literal repetitions. - * 3. Replaces [[In (value, seq[Literal])]] with optimized version + * 3. Replaces [[In (values, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral - case expr @ In(v, list) if expr.inSetConvertible => + case i @ In(_, list) if list.isEmpty && !i.value.nullable => FalseLiteral + case expr @ In(_, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v, HashSet() ++ hSet) + InSet(expr.value, HashSet() ++ hSet) } else if (newList.size < list.size) { expr.copy(list = newList) } else { // newList.length == list.length @@ -493,7 +493,7 @@ object NullPropagation extends Rule[LogicalPlan] { } // If the value expression is NULL then transform the In expression to null literal. - case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index de89e17e51f1b..d2954d20ceda2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def getValueExpression(e: Expression): Seq[Expression] = { - e match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should @@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + case (p, In(values, Seq(ListQuery(sub, conditions, _, _)))) => + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(In(values, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case In(values, Seq(ListQuery(sub, conditions, _, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 383ebde3229d6..e7e382b8d45ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1086,6 +1086,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1094,9 +1099,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(In(getValueExpressions(e), Seq(ListQuery(plan(ctx.query))))) case SqlBaseParser.IN => - invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3eeefaedb18..a14123d0a6c12 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -164,7 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: Attribute, expList) + case In(Seq(ar: Attribute), expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..1030e27e4bd38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -521,7 +521,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(In(Seq(a), Seq(ListQuery(LocalRelation(b)))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -530,12 +530,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter( + Or(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..a45d396abf173 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -275,21 +275,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: different types in inlist but can be converted to a common type") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, - LocalRelation() - ) + val plan = Project( + Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + LocalRelation()) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 1bf8d76da04d8..03129b9e86234 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter(In(Seq(a), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..b65e65d671c22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1240,16 +1240,16 @@ class TypeCoercionSuite extends AnalysisTest { // InConversion val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, - In(UnresolvedAttribute("a"), Seq(Literal(1))), - In(UnresolvedAttribute("a"), Seq(Literal(1))) + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))), + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))) ) ruleTest(inConversion, - In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), - In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))) ) ruleTest(inConversion, - In(Literal("a"), Seq(Literal(1), Literal("b"))), - In(Cast(Literal("a"), StringType), + In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))), + In(Seq(Cast(Literal("a"), StringType)), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index b376108399c1c..51464332dbc80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac checkAnswer(tbl2, Seq.empty, Set(part1, part2)) checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 28e6940f3cca3..b78d23e3472e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(idAttr, Seq(Literal(1), Literal(2))) - val in2 = In(idAttr, Seq(Literal(2), Literal(1))) - val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2))) + val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1))) + val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ac76b17ef4761..0aec614748324 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -125,32 +125,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("basic IN predicate test") { - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), - true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), - null) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq.empty), null) + checkEvaluation(In(Seq(Literal(1)), Seq.empty), false) + checkEvaluation(In(Seq(Literal(1)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Seq(Literal(1)), + Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) + checkEvaluation(In(Seq(Literal(2)), + Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(3)), Seq(Literal(1), Literal(2))), false) checkEvaluation( - And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + And(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), In(Seq(Literal(2)), Seq(Literal(1), Literal(2)))), true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(ns, Seq(ns)), null) - checkEvaluation(In(Literal("a"), Seq(ns)), null) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + checkEvaluation(In(Seq(ns), Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(Seq(ns), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("a")), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("^Ba*n"), ns)), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^n"))), false) } @@ -187,7 +187,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(input(0), input.slice(1, 10)), expected) + checkEvaluation(In(Seq(input.head), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => @@ -243,12 +243,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22501: In should not generate codes beyond 64KB") { val N = 3000 val sets = (1 to N).map(i => Literal(i.toDouble)) - checkEvaluation(In(Literal(1.0D), sets), true) + checkEvaluation(In(Seq(Literal(1.0D)), sets), true) } test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() - In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + In(Seq(Literal(1.0D)), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 641c89873dcc4..09d8f2e4d1ad4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(Literal(1), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal(1)), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..0f38ebcd4a140 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -45,9 +45,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +56,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +69,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,7 +79,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .where(In(Seq(UnresolvedAttribute("a")), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -94,13 +94,15 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In(Seq(UnresolvedAttribute("a")), + Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In(Seq(UnresolvedAttribute("a")), + Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) @@ -109,7 +111,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal.create(null, NullType)), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -125,7 +127,8 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute)") { val originalQuery = testRelation - .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) + .where(In(Seq(Literal.create(null, StringType)), + Seq(Literal(1), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -141,7 +144,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(Literal.create(null, StringType), + .select(In(Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -156,7 +159,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -180,7 +183,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(Literal("a"), Nil)) + .where(In(Seq(Literal("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 169b8737d808b..02d03fb8b8d57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .where(In(Seq('a), Seq(ListQuery(correlatedSubquery)))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index cb8a1fecb80a7..4e4928fd4c298 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In('a, Seq(ListQuery(table("c").select('b))))) + In(Seq('a), Seq(ListQuery(table("c").select('b))))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 47bfa62569583..b84334eaf49bb 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -440,7 +440,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Filter(In(Seq(attrDate), Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = Some(3), min = Some(d20170103), max = Some(d20170105), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4eee3de5f7d4e..3893eae1b8a71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -793,7 +793,12 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + def isin(list: Any*): Column = withExpr { + expr match { + case c: CreateNamedStruct => In(c.valExprs, list.map(lit(_).expr)) + case other => In(Seq(other), list.map(lit(_).expr)) + } + } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..407509911d381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -219,7 +219,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(a: AttributeReference, list: Seq[Expression]) + case In(Seq(a: AttributeReference), list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7b129435c45db..cb0ff3c2bf7f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -475,7 +475,7 @@ object DataSourceStrategy { // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + case expressions.In(Seq(a: Attribute), list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, hSet.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index fe27b78bf3360..68e7bb07b794a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,7 +86,7 @@ object FileSourceStrategy extends Strategy with Logging { expr match { case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => getBucketSetFromValue(a, v) - case expressions.In(a: Attribute, list) + case expressions.In(Seq(a: Attribute), list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) case expressions.InSet(a: Attribute, hset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cc8cb3874c9b..77a7a5679b3bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2261,6 +2261,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24341: IN subqueries with struct fields") { + Seq((1, 1)).toDF("a", "b").createOrReplaceTempView("tab") + checkAnswer(sql("select 1 from range(1) where (1, 1) in (select a, b from tab)"), Row(1)) + + Seq((1, 1)).toDF("a", "b").createOrReplaceTempView("tab_a") + Seq((1, 1)).toDF("na", "nb").createOrReplaceTempView("tab_b") + intercept[AnalysisException] { + sql("select 1 from tab_a where (a, b) not in (select (na, nb) from tab_b)").collect() + } + + testData2.select(struct("a", "b").as("record")).createOrReplaceTempView("struct_tab") + checkAnswer( + sql("select count(*) from struct_tab where record in " + + "(select (na as a, nb as b) from tab_b)"), + Row(1)) + checkAnswer( + sql("select count(*) from struct_tab where record not in " + + "(select (na as a, nb as b) from tab_b)"), + Row(5)) + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index efc2f20a907f1..0610f51b721a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -459,7 +459,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), - Seq(In(attribute, Nil)), testRelation) + Seq(In(Seq(attribute), Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f20aded169e44..1177bfc7a8586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -51,7 +51,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) - testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca2..f7b27a2e8fff3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -667,7 +667,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + case In(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableLiterals(values)) if useAdvanced => Some(convertInToOr(name, values)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 55275f6b37945..97242df6247c3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -176,8 +176,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } @@ -188,8 +188,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } @@ -207,8 +207,8 @@ class HiveClientSuite(version: String) 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } From 65ff49af087f0efae50c20372e3af227727d9f21 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 29 Jun 2018 16:05:34 +0200 Subject: [PATCH 02/17] introduce InValues --- .../sql/catalyst/analysis/Analyzer.scala | 9 ++-- .../sql/catalyst/analysis/TypeCoercion.scala | 22 ++++----- .../spark/sql/catalyst/dsl/package.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 48 +++++++++++-------- .../sql/catalyst/optimizer/expressions.scala | 10 ++-- .../sql/catalyst/optimizer/subquery.scala | 12 ++--- .../sql/catalyst/parser/AstBuilder.scala | 6 +-- .../statsEstimation/FilterEstimation.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 7 +-- .../sql/catalyst/analysis/AnalysisSuite.scala | 8 ++-- .../analysis/ResolveSubquerySuite.scala | 5 +- .../catalyst/analysis/TypeCoercionSuite.scala | 12 ++--- .../catalog/ExternalCatalogSuite.scala | 4 +- .../expressions/CanonicalizeSuite.scala | 12 ++--- .../catalyst/expressions/PredicateSuite.scala | 46 +++++++++--------- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 26 +++++----- .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../parser/ExpressionParserSuite.scala | 2 +- .../FilterEstimationSuite.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 4 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 3 +- .../datasources/FileSourceStrategy.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- .../datasources/DataSourceStrategySuite.scala | 3 +- .../spark/sql/hive/client/HiveShim.scala | 4 +- .../sql/hive/client/HiveClientSuite.scala | 12 ++--- 28 files changed, 146 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 44ad30d07fe3b..5bba7797bb42d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1347,17 +1347,16 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case i @ In(values, Seq(l @ ListQuery(_, _, exprId, _))) - if values.forall(_.resolved) && !l.resolved => + case i @ In(value, Seq(l @ ListQuery(_, _, exprId, _))) if value.resolved && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) val subqueryOutputNum = expr.asInstanceOf[ListQuery].childOutputs.length - if (values.length != subqueryOutputNum) { - throw new AnalysisException(s"${i.sql} has ${values.length} values, but the " + + if (value.numValues != subqueryOutputNum) { + throw new AnalysisException(s"${i.sql} has ${value.numValues} values, but the " + s"subquery has $subqueryOutputNum output values.") } - In(values, Seq(expr)) + In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c1242f537d029..c75bf645dc357 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -415,42 +415,42 @@ object TypeCoercion { // in IN subquery. // LHS is the value expressions of IN subquery. case i @ In(lhs, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && lhs.length == sub.output.length => + if !i.resolved && lhs.numValues == sub.output.length => // RHS is the subquery output. val rhs = sub.output - val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => + val commonTypes = lhs.children.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an // IN subquery expression. - if (commonTypes.length == lhs.length) { + if (commonTypes.length == lhs.numValues) { val castedRhs = rhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val newLhs = lhs.zip(commonTypes).map { + val newLhsChildren = lhs.children.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } val newSub = Project(castedRhs, sub) - In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) + In(InValues(newLhsChildren), Seq(ListQuery(newSub, children, exprId, newSub.output))) } else { i } - case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) => - findWiderCommonType(i.value.dataType +: b.map(_.dataType)) match { - case Some(finalDataType: StructType) if i.values.length > 1 => - val newValues = a.zip(finalDataType.fields.map(_.dataType)).map { + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => + findWiderCommonType(a.dataType +: b.map(_.dataType)) match { + case Some(finalDataType: StructType) if a.numValues > 1 => + val newValues = a.children.zip(finalDataType.fields.map(_.dataType)).map { case (expr, dataType) => Cast(expr, dataType) } - In(newValues, b.map(Cast(_, finalDataType))) + In(InValues(newValues), b.map(Cast(_, finalDataType))) case Some(finalDataType) => - In(a.map(Cast(_, finalDataType)), b.map(Cast(_, finalDataType))) + In(InValues(a.children.map(Cast(_, finalDataType))), b.map(Cast(_, finalDataType))) case None => i } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a3ec2713fb7b4..e127331c44332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -89,8 +89,8 @@ package object dsl { def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) def in(list: Expression*): Expression = expr match { - case c: CreateNamedStruct => In(c.valExprs, list) - case other => In(Seq(other), list) + case c: CreateNamedStruct => In(InValues(c.valExprs), list) + case other => In(InValues(Seq(other)), list) } def like(other: Expression): Expression = Like(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 527201ad2174e..d1c5685047740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,6 +138,22 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +case class InValues(children: Seq[Expression]) extends Unevaluable { + @transient lazy val numValues: Int = children.length + @transient lazy val valueExpression: Expression = if (children.length > 1) { + CreateNamedStruct(children.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + children.head + } + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = valueExpression.dataType + override def sql: String = valueExpression.sql + override def toString: String = valueExpression.toString +} + /** * Evaluates to `true` if `list` contains `value`. @@ -161,38 +177,29 @@ case class Not(child: Expression) true """) // scalastyle:on line.size.limit -case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate { +case class In(value: InValues, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") - @transient lazy val value = if (values.length > 1) { - CreateNamedStruct(values.zipWithIndex.flatMap { - case (v: NamedExpression, _) => Seq(Literal(v.name), v) - case (v, idx) => Seq(Literal(s"_$idx"), v) - }) - } else { - values.head - } - override def checkInputDataTypes(): TypeCheckResult = { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => - if (values.length != childOutputs.length) { + if (value.numValues != childOutputs.length) { TypeCheckResult.TypeCheckFailure( s""" |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. - |#columns in left hand side: ${values.length}. + |#columns in left hand side: ${value.numValues}. |#columns in right hand side: ${childOutputs.length}. |Left side columns: - |[${values.map(_.sql).mkString(", ")}]. + |[${value.children.map(_.sql).mkString(", ")}]. |Right side columns: |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) } else { - val mismatchedColumns = values.zip(childOutputs).flatMap { + val mismatchedColumns = value.children.zip(childOutputs).flatMap { case (l, r) if l.dataType != r.dataType => s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" case _ => None @@ -204,7 +211,7 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate |Mismatched columns: |[${mismatchedColumns.mkString(", ")}] |Left side: - |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |[${value.children.map(_.dataType.catalogString).mkString(", ")}]. |Right side: |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } @@ -217,7 +224,7 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate } } - override def children: Seq[Expression] = values ++: list + override def children: Seq[Expression] = value +: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) @@ -227,7 +234,7 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { - val evaluatedValue = value.eval(input) + val evaluatedValue = value.valueExpression.eval(input) if (evaluatedValue == null) { null } else { @@ -250,7 +257,7 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaDataType = CodeGenerator.javaType(value.dataType) - val valueGen = value.genCode(ctx) + val valueGen = value.valueExpression.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: // -1 means no matches found and there is at least one value in the list evaluated to null @@ -312,8 +319,9 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate } override def sql: String = { - val valueSQL = value.sql - val listSQL = list.map(_.sql).mkString(", ") + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") s"($valueSQL IN ($listSQL))" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 11a231d6e9c32..226244c76a704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -212,18 +212,18 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 1. Converts the predicate to false when the list is empty and * the value is not nullable. * 2. Removes literal repetitions. - * 3. Replaces [[In (values, seq[Literal])]] with optimized version + * 3. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case i @ In(_, list) if list.isEmpty && !i.value.nullable => FalseLiteral - case expr @ In(_, list) if expr.inSetConvertible => + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(expr.value, HashSet() ++ hSet) + InSet(v.valueExpression, HashSet() ++ hSet) } else if (newList.size < list.size) { expr.copy(list = newList) } else { // newList.length == list.length @@ -493,7 +493,7 @@ object NullPropagation extends Rule[LogicalPlan] { } // If the value expression is NULL then transform the In expression to null literal. - case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) + case In(InValues(Seq(Literal(null, _))), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index d2954d20ceda2..3ee01705fcdf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -91,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(values, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = values.zip(sub.output).map(EqualTo.tupled) + case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => + val inConditions = value.children.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(values, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = values.zip(sub.output).map(EqualTo.tupled) + val inConditions = value.children.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -144,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(values, Seq(ListQuery(sub, conditions, _, _))) => + case In(value, Seq(ListQuery(sub, conditions, _, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = values.zip(sub.output).map(EqualTo.tupled) + val inConditions = value.children.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e7e382b8d45ff..19c152746a5e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1086,9 +1086,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } - def getValueExpressions(e: Expression): Seq[Expression] = e match { - case c: CreateNamedStruct => c.valExprs - case other => Seq(other) + def getValueExpressions(e: Expression): InValues = e match { + case c: CreateNamedStruct => InValues(c.valExprs) + case other => InValues(Seq(other)) } // Create the predicate. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index a14123d0a6c12..4a1bc32ae7798 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -164,7 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(Seq(ar: Attribute), expList) + case In(InValues(Seq(ar: Attribute)), expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 1030e27e4bd38..8fed9836fd611 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -521,7 +521,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(Seq(a), Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(In(InValues(Seq(a)), Seq(ListQuery(LocalRelation(b)))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -530,13 +530,14 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter( + Cast(Not(In(InValues(Seq(a)), Seq(ListQuery(LocalRelation(b))))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) val plan2 = Filter( - Or(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + Or(Not(In(InValues(Seq(a)), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a45d396abf173..c71bee47ffd1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -275,7 +275,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project( + Alias(In(InValues(Seq(Literal(null))), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) @@ -283,13 +284,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-8654: different types in inlist but can be converted to a common type") { val plan = Project( - Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + Alias(In(InValues(Seq(Literal(null))), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation()) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project( + Alias(In(InValues(Seq(Literal(null))), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 03129b9e86234..71de693909c98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{In, InValues, ListQuery, OuterReference} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(Seq(a), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter( + In(InValues(Seq(a)), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index b65e65d671c22..7046862960fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1240,16 +1240,16 @@ class TypeCoercionSuite extends AnalysisTest { // InConversion val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, - In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))), - In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))) + In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1))), + In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1))) ) ruleTest(inConversion, - In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))), - In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))) + In(InValues(Seq(Literal("test"))), Seq(UnresolvedAttribute("a"), Literal(1))), + In(InValues(Seq(Literal("test"))), Seq(UnresolvedAttribute("a"), Literal(1))) ) ruleTest(inConversion, - In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))), - In(Seq(Cast(Literal("a"), StringType)), + In(InValues(Seq(Literal("a"))), Seq(Literal(1), Literal("b"))), + In(InValues(Seq(Cast(Literal("a"), StringType))), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 51464332dbc80..a0f892bb4eb14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac checkAnswer(tbl2, Seq.empty, Set(part1, part2)) checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(In(InValues(Seq('a.int * 10)), Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In(InValues(Seq('a.int)), Seq(4)))), Set(part1, part2)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index b78d23e3472e8..4181c6b9aa568 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2))) - val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1))) - val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(InValues(Seq(idAttr)), Seq(Literal(1), Literal(2))) + val in2 = In(InValues(Seq(idAttr)), Seq(Literal(2), Literal(1))) + val in3 = In(InValues(Seq(idAttr)), Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(InValues(Seq(idAttr)), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(InValues(Seq(idAttr)), Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(InValues(Seq(idAttr)), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0aec614748324..f6a9bb6ba655b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -125,32 +125,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("basic IN predicate test") { - checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(Literal(1), + checkEvaluation(In(InValues(Seq(NonFoldableLiteral.create(null, IntegerType))), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), + checkEvaluation(In(InValues(Seq(NonFoldableLiteral.create(null, IntegerType))), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq.empty), null) - checkEvaluation(In(Seq(Literal(1)), Seq.empty), false) - checkEvaluation(In(Seq(Literal(1)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Seq(Literal(1)), + checkEvaluation(In(InValues(Seq(NonFoldableLiteral.create(null, IntegerType))), Seq.empty), + null) + checkEvaluation(In(InValues(Seq(Literal(1))), Seq.empty), false) + checkEvaluation( + In(InValues(Seq(Literal(1))), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(InValues(Seq(Literal(1))), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Seq(Literal(2)), + checkEvaluation(In(InValues(Seq(Literal(2))), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Seq(Literal(3)), Seq(Literal(1), Literal(2))), false) - checkEvaluation( - And(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), In(Seq(Literal(2)), Seq(Literal(1), - Literal(2)))), - true) + checkEvaluation(In(InValues(Seq(Literal(1))), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(InValues(Seq(Literal(2))), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(InValues(Seq(Literal(3))), Seq(Literal(1), Literal(2))), false) + checkEvaluation(And(In(InValues(Seq(Literal(1))), Seq(Literal(1), Literal(2))), + In(InValues(Seq(Literal(2))), Seq(Literal(1), Literal(2)))), true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(Seq(ns), Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(Seq(ns), Seq(ns)), null) - checkEvaluation(In(Seq(Literal("a")), Seq(ns)), null) - checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^n"))), false) + checkEvaluation(In(InValues(Seq(ns)), Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(InValues(Seq(ns)), Seq(ns)), null) + checkEvaluation(In(InValues(Seq(Literal("a"))), Seq(ns)), null) + checkEvaluation(In(InValues(Seq(Literal("^Ba*n"))), Seq(Literal("^Ba*n"), ns)), true) + checkEvaluation(In(InValues(Seq(Literal("^Ba*n"))), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(InValues(Seq(Literal("^Ba*n"))), Seq(Literal("aa"), Literal("^n"))), false) } @@ -187,7 +187,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(Seq(input.head), input.slice(1, 10)), expected) + checkEvaluation(In(InValues(Seq(input.head)), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => @@ -243,12 +243,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22501: In should not generate codes beyond 64KB") { val N = 3000 val sets = (1 to N).map(i => Literal(i.toDouble)) - checkEvaluation(In(Seq(Literal(1.0D)), sets), true) + checkEvaluation(In(InValues(Seq(Literal(1.0D))), sets), true) } test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() - In(Seq(Literal(1.0D)), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + In(InValues(Seq(Literal(1.0D))), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 09d8f2e4d1ad4..80e2bf069630e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(Seq(Literal(1)), Seq(Literal(1), Literal(2)))) + .where(In(InValues(Seq(Literal(1))), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 0f38ebcd4a140..d036ca5e8d63e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -45,9 +45,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), + .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(Seq(UnresolvedAttribute("b")), + .where(In(InValues(Seq(UnresolvedAttribute("b"))), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +56,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) - .where(In(Seq(UnresolvedAttribute("b")), + .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2)))) + .where(In(InValues(Seq(UnresolvedAttribute("b"))), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +69,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) + .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,7 +79,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), (1 to 11).map(Literal(_)))) + .where(In(InValues(Seq(UnresolvedAttribute("a"))), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -94,14 +94,14 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), + .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(Seq(UnresolvedAttribute("a")), + .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze @@ -111,7 +111,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(Seq(Literal.create(null, NullType)), Seq(Literal(1), Literal(2)))) + .where(In(InValues(Seq(Literal.create(null, NullType))), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -127,7 +127,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute)") { val originalQuery = testRelation - .where(In(Seq(Literal.create(null, StringType)), + .where(In(InValues(Seq(Literal.create(null, StringType))), Seq(Literal(1), UnresolvedAttribute("b")))) .analyze @@ -144,7 +144,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(Seq(Literal.create(null, StringType)), + .select(In(InValues(Seq(Literal.create(null, StringType))), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -159,7 +159,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -183,7 +183,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(Seq(Literal("a")), Nil)) + .where(In(InValues(Seq(Literal("a"))), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 02d03fb8b8d57..a34099a8057d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{In, InValues, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In(Seq('a), Seq(ListQuery(correlatedSubquery)))) + .where(In(InValues(Seq('a)), Seq(ListQuery(correlatedSubquery)))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 4e4928fd4c298..ce4c1d8c46584 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In(Seq('a), Seq(ListQuery(table("c").select('b))))) + In(InValues(Seq('a)), Seq(ListQuery(table("c").select('b))))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index b84334eaf49bb..6bcee2acb641b 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -440,7 +440,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(Seq(attrDate), Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Filter(In(InValues(Seq(attrDate)), + Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = Some(3), min = Some(d20170103), max = Some(d20170105), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3893eae1b8a71..41958f51ed2a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -795,8 +795,8 @@ class Column(val expr: Expression) extends Logging { @scala.annotation.varargs def isin(list: Any*): Column = withExpr { expr match { - case c: CreateNamedStruct => In(c.valExprs, list.map(lit(_).expr)) - case other => In(Seq(other), list.map(lit(_).expr)) + case c: CreateNamedStruct => In(InValues(c.valExprs), list.map(lit(_).expr)) + case other => In(InValues(Seq(other)), list.map(lit(_).expr)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 407509911d381..48ae523918a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -219,7 +219,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(Seq(a: AttributeReference), list: Seq[Expression]) + case In(InValues(Seq(a: AttributeReference)), list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index cb0ff3c2bf7f3..93b1309bc119a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -475,7 +475,8 @@ object DataSourceStrategy { // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(Seq(a: Attribute), list) if !list.exists(!_.isInstanceOf[Literal]) => + case expressions.In(InValues(Seq(a: Attribute)), list) + if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, hSet.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 68e7bb07b794a..0b883bcea1666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,7 +86,7 @@ object FileSourceStrategy extends Strategy with Logging { expr match { case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => getBucketSetFromValue(a, v) - case expressions.In(Seq(a: Attribute), list) + case expressions.In(InValues(Seq(a: Attribute)), list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) case expressions.InSet(a: Attribute, hset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0610f51b721a0..89333178dd4f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In, InValues} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} @@ -459,7 +459,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), - Seq(In(Seq(attribute), Nil)), testRelation) + Seq(In(InValues(Seq(attribute)), Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 1177bfc7a8586..f40a7c5c9bae0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -51,7 +51,8 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) - testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(InValues(Seq(attrInt)), Seq(1, 2, 3)), + Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index f7b27a2e8fff3..8cf63d4adbee5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -667,8 +667,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableLiterals(values)) - if useAdvanced => + case In(InValues(Seq(ExtractAttribute(NonVarcharAttribute(name)))), + ExtractableLiterals(values)) if useAdvanced => Some(convertInToOr(name, values)) case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 97242df6247c3..e0fb4c1a3dfa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -176,8 +176,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v.valueExpression, list.map(_.eval(EmptyRow)).toSet) }) } @@ -188,8 +188,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v.valueExpression, list.map(_.eval(EmptyRow)).toSet) }) } @@ -207,8 +207,8 @@ class HiveClientSuite(version: String) 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v.valueExpression, list.map(_.eval(EmptyRow)).toSet) }) } From 268307f52248d6408862cc76ccb54612ef9ef216 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 29 Jun 2018 19:32:59 +0200 Subject: [PATCH 03/17] add analyzer rule to add InValues --- .../sql/catalyst/analysis/Analyzer.scala | 22 ++++++++-- .../sql/catalyst/analysis/TypeCoercion.scala | 9 +++-- .../spark/sql/catalyst/dsl/package.scala | 5 +-- .../sql/catalyst/expressions/predicates.scala | 40 ++++++++++++++----- .../sql/catalyst/optimizer/expressions.scala | 4 +- .../sql/catalyst/parser/AstBuilder.scala | 9 +---- .../analysis/AnalysisErrorSuite.scala | 8 ++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 9 ++--- .../expressions/CanonicalizeSuite.scala | 12 +++--- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 29 ++++++-------- .../PullupCorrelatedPredicatesSuite.scala | 2 +- .../parser/ExpressionParserSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 7 +--- .../subq-input-typecheck.sql.out | 20 +--------- .../sql/hive/client/HiveClientSuite.scala | 12 +++--- 16 files changed, 96 insertions(+), 96 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5bba7797bb42d..71fb2cbcd4534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -145,6 +145,8 @@ class Analyzer( ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), + Batch("Resolve IN values", Once, + ResolveInValues), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -244,6 +246,20 @@ class Analyzer( } } + /** + * Substitutes In values with an instance of [[InValues]]. + */ + object ResolveInValues extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case q => q transformExpressions { + case In(value, list) if !value.isInstanceOf[InValues] => value match { + case c: CreateNamedStruct => In(InValues(c.valExprs), list) + case other => In(InValues(Seq(other)), list) + } + } + } + } + /** * Replaces [[UnresolvedAlias]]s with concrete aliases. */ @@ -1351,9 +1367,9 @@ class Analyzer( val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - val subqueryOutputNum = expr.asInstanceOf[ListQuery].childOutputs.length - if (value.numValues != subqueryOutputNum) { - throw new AnalysisException(s"${i.sql} has ${value.numValues} values, but the " + + val subqueryOutputNum = expr.plan.output.length + if (i.inValues.numValues != subqueryOutputNum) { + throw new AnalysisException(s"${i.sql} has ${i.inValues.numValues} values, but the " + s"subquery has $subqueryOutputNum output values.") } In(value, Seq(expr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c75bf645dc357..85e2e89d27c1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -413,9 +413,10 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - // LHS is the value expressions of IN subquery. - case i @ In(lhs, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && lhs.numValues == sub.output.length => + case i @ In(_, Seq(ListQuery(sub, children, exprId, _))) + if !i.resolved && i.inValues.numValues == sub.output.length => + // LHS is the value expressions of IN subquery. + val lhs = i.inValues // RHS is the subquery output. val rhs = sub.output @@ -444,7 +445,7 @@ object TypeCoercion { case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(a.dataType +: b.map(_.dataType)) match { - case Some(finalDataType: StructType) if a.numValues > 1 => + case Some(finalDataType: StructType) if i.inValues.numValues > 1 => val newValues = a.children.zip(finalDataType.fields.map(_.dataType)).map { case (expr, dataType) => Cast(expr, dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e127331c44332..efb2eba655e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,10 +88,7 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = expr match { - case c: CreateNamedStruct => In(InValues(c.valExprs), list) - case other => In(InValues(Seq(other)), list) - } + def in(list: Expression*): Expression = In(expr, list) def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d1c5685047740..0a2199a6a1287 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,7 +138,9 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } -case class InValues(children: Seq[Expression]) extends Unevaluable { +case class InValues(children: Seq[Expression]) extends Expression { + require(children.nonEmpty, "Value of IN expression cannot be empty") + @transient lazy val numValues: Int = children.length @transient lazy val valueExpression: Expression = if (children.length > 1) { CreateNamedStruct(children.zipWithIndex.flatMap { @@ -152,6 +154,12 @@ case class InValues(children: Seq[Expression]) extends Unevaluable { override def dataType: DataType = valueExpression.dataType override def sql: String = valueExpression.sql override def toString: String = valueExpression.toString + + override def eval(input: InternalRow): Any = + throw new RuntimeException("InValues cannot be evaluated.") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode) = + throw new RuntimeException("InValues cannot generate code.") } @@ -177,29 +185,38 @@ case class InValues(children: Seq[Expression]) extends Unevaluable { true """) // scalastyle:on line.size.limit -case class In(value: InValues, list: Seq[Expression]) extends Predicate { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + // During analysis we replace any Expression set as value with a InValues expression so we are + // sure it is an instance of InValues + @transient lazy val inValues = value.asInstanceOf[InValues] + override def checkInputDataTypes(): TypeCheckResult = { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => - if (value.numValues != childOutputs.length) { + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case inValues: InValues => inValues.children + case expr => Seq(expr) + } + if (valExprs.length != childOutputs.length) { TypeCheckResult.TypeCheckFailure( s""" |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. - |#columns in left hand side: ${value.numValues}. + |#columns in left hand side: ${valExprs.length}. |#columns in right hand side: ${childOutputs.length}. |Left side columns: - |[${value.children.map(_.sql).mkString(", ")}]. + |[${valExprs.map(_.sql).mkString(", ")}]. |Right side columns: |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) } else { - val mismatchedColumns = value.children.zip(childOutputs).flatMap { + val mismatchedColumns = valExprs.zip(childOutputs).flatMap { case (l, r) if l.dataType != r.dataType => s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" case _ => None @@ -211,7 +228,7 @@ case class In(value: InValues, list: Seq[Expression]) extends Predicate { |Mismatched columns: |[${mismatchedColumns.mkString(", ")}] |Left side: - |[${value.children.map(_.dataType.catalogString).mkString(", ")}]. + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. |Right side: |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } @@ -229,12 +246,15 @@ case class In(value: InValues, list: Seq[Expression]) extends Predicate { private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) + override def foldable: Boolean = value match { + case i: InValues => i.valueExpression.foldable && list.forall(_.foldable) + case _ => children.forall(_.foldable) + } override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { - val evaluatedValue = value.valueExpression.eval(input) + val evaluatedValue = inValues.valueExpression.eval(input) if (evaluatedValue == null) { null } else { @@ -257,7 +277,7 @@ case class In(value: InValues, list: Seq[Expression]) extends Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaDataType = CodeGenerator.javaType(value.dataType) - val valueGen = value.valueExpression.genCode(ctx) + val valueGen = inValues.valueExpression.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: // -1 means no matches found and there is at least one value in the list evaluated to null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 226244c76a704..52f8033e1db5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -219,11 +219,11 @@ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral - case expr @ In(v, list) if expr.inSetConvertible => + case expr @ In(_, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v.valueExpression, HashSet() ++ hSet) + InSet(expr.inValues.valueExpression, HashSet() ++ hSet) } else if (newList.size < list.size) { expr.copy(list = newList) } else { // newList.length == list.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 19c152746a5e8..383ebde3229d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1086,11 +1086,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } - def getValueExpressions(e: Expression): InValues = e match { - case c: CreateNamedStruct => InValues(c.valExprs) - case other => InValues(Seq(other)) - } - // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1099,9 +1094,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(getValueExpressions(e), Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) case SqlBaseParser.IN => - invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8fed9836fd611..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -521,7 +521,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(InValues(Seq(a)), Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -530,14 +530,12 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter( - Cast(Not(In(InValues(Seq(a)), Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter( - Or(Not(In(InValues(Seq(a)), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c71bee47ffd1c..42ca9373c918f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -275,23 +275,20 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project( - Alias(In(InValues(Seq(Literal(null))), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: different types in inlist but can be converted to a common type") { - val plan = Project( - Alias(In(InValues(Seq(Literal(null))), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation()) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project( - Alias(In(InValues(Seq(Literal(null))), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 4181c6b9aa568..28e6940f3cca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(InValues(Seq(idAttr)), Seq(Literal(1), Literal(2))) - val in2 = In(InValues(Seq(idAttr)), Seq(Literal(2), Literal(1))) - val in3 = In(InValues(Seq(idAttr)), Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(InValues(Seq(idAttr)), Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(InValues(Seq(idAttr)), Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(InValues(Seq(idAttr)), Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 80e2bf069630e..641c89873dcc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(InValues(Seq(Literal(1))), Seq(Literal(1), Literal(2)))) + .where(In(Literal(1), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index d036ca5e8d63e..478118ed709f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -45,9 +45,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(InValues(Seq(UnresolvedAttribute("b"))), + .where(In(UnresolvedAttribute("b"), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +56,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2)))) - .where(In(InValues(Seq(UnresolvedAttribute("b"))), + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) + .where(In(UnresolvedAttribute("b"), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +69,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,7 +79,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), (1 to 11).map(Literal(_)))) + .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -94,15 +94,13 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), - Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), - Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) @@ -111,7 +109,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(InValues(Seq(Literal.create(null, NullType))), Seq(Literal(1), Literal(2)))) + .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -127,8 +125,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute)") { val originalQuery = testRelation - .where(In(InValues(Seq(Literal.create(null, StringType))), - Seq(Literal(1), UnresolvedAttribute("b")))) + .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -144,7 +141,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(InValues(Seq(Literal.create(null, StringType))), + .select(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -159,7 +156,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -183,7 +180,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(InValues(Seq(Literal("a"))), Nil)) + .where(In(Literal("a"), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index a34099a8057d7..e2b96ece73fd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In(InValues(Seq('a)), Seq(ListQuery(correlatedSubquery)))) + .where(In('a, Seq(ListQuery(correlatedSubquery)))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index ce4c1d8c46584..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In(InValues(Seq('a)), Seq(ListQuery(table("c").select('b))))) + In('a, Seq(ListQuery(table("c").select('b))))) } test("like expressions") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 41958f51ed2a0..4eee3de5f7d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -793,12 +793,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { - expr match { - case c: CreateNamedStruct => In(InValues(c.valExprs), list.map(lit(_).expr)) - case other => In(InValues(Seq(other)), list.map(lit(_).expr)) - } - } + def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index 70aeb9373f3c7..ea7b529c82d5e 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -72,15 +72,7 @@ t1a IN (SELECT t2a, t2b struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: -The number of columns in the left hand side of an IN subquery does not match the -number of columns in the output of subquery. -#columns in left hand side: 1. -#columns in right hand side: 2. -Left side columns: -[t1.`t1a`]. -Right side columns: -[t2.`t2a`, t2.`t2b`].; +(t1.`t1a` IN (listquery())) has 1 values, but the subquery has 2 output values.; -- !query 6 @@ -93,12 +85,4 @@ WHERE struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: -The number of columns in the left hand side of an IN subquery does not match the -number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. -Left side columns: -[t1.`t1a`, t1.`t1b`]. -Right side columns: -[t2.`t2a`].; +(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery())) has 2 values, but the subquery has 1 output values.; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index e0fb4c1a3dfa5..fe74c44dd5691 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -176,8 +176,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v.valueExpression, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.inValues.valueExpression, list.map(_.eval(EmptyRow)).toSet) }) } @@ -188,8 +188,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v.valueExpression, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.inValues.valueExpression, list.map(_.eval(EmptyRow)).toSet) }) } @@ -207,8 +207,8 @@ class HiveClientSuite(version: String) 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v.valueExpression, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.inValues.valueExpression, list.map(_.eval(EmptyRow)).toSet) }) } From d3e39ed3f442958cfaaa1ef056cb72fedf0fce1c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 3 Jul 2018 23:44:01 +0200 Subject: [PATCH 04/17] fix ut failures --- .../sql/hive/client/HiveClientSuite.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index fe74c44dd5691..c9f8fcac831cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -74,6 +74,12 @@ class HiveClientSuite(version: String) } } + private def analyzeIn(expr: Expression): Expression = expr match { + case In(CreateNamedStruct(children), list) => In(InValues(children), list) + case In(v, list) => In(InValues(Seq(v)), list) + case other => other + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -156,7 +162,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - attr("ds").in(20170102, 20170103), + analyzeIn(attr("ds").in(20170102, 20170103)), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -164,7 +170,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { testMetastorePartitionFiltering( - attr("ds").cast(LongType).in(20170102L, 20170103L), + analyzeIn(attr("ds").cast(LongType).in(20170102L, 20170103L)), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -172,7 +178,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - attr("ds").in(20170102, 20170103), + analyzeIn(attr("ds").in(20170102, 20170103)), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -184,7 +190,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") { testMetastorePartitionFiltering( - attr("ds").cast(LongType).in(20170102L, 20170103L), + analyzeIn(attr("ds").cast(LongType).in(20170102L, 20170103L)), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -195,7 +201,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - attr("chunk").in("ab", "ba"), + analyzeIn(attr("chunk").in("ab", "ba")), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -203,7 +209,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - attr("chunk").in("ab", "ba"), + analyzeIn(attr("chunk").in("ab", "ba")), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -231,7 +237,7 @@ class HiveClientSuite(version: String) "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + testMetastorePartitionFiltering(analyzeIn(attr("chunk").in("ab", "ba")) && ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } From 60b57d2cde0ec1984a99c3db338df74b730623d0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Jul 2018 14:29:41 +0200 Subject: [PATCH 05/17] fix merging --- .../apache/spark/sql/catalyst/optimizer/expressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index a3a9eca1dccad..447346dea7007 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -227,12 +227,12 @@ object OptimizeIn extends Rule[LogicalPlan] { if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStructLike] + && !expr.inValues.valueExpression.isInstanceOf[CreateNamedStructLike] && !newList.head.isInstanceOf[CreateNamedStructLike]) { - EqualTo(v, newList.head) + EqualTo(expr.inValues.valueExpression, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v, HashSet() ++ hSet) + InSet(expr.inValues.valueExpression, HashSet() ++ hSet) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 From 22f77ae5fff52c4a9c0900c0246b34782cb76652 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Jul 2018 19:12:03 +0200 Subject: [PATCH 06/17] fix UT error --- .../scala/org/apache/spark/sql/hive/client/FiltersSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 2a4efd0cce6e0..d454c2d15ae95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -73,13 +73,13 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") filterTest("SPARK-24879 null literals should be ignored for IN constructs", - (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil, + (InValues(Seq(a("intcol", IntegerType))) in (Literal(1), Literal(null))) :: Nil, "(intcol = 1)") // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization // will be applied by Catalyst, this filter converter does not need to account for this. filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE", - (a("intcol", IntegerType) in Literal(null)) :: Nil, + (InValues(Seq(a("intcol", IntegerType))) in Literal(null)) :: Nil, "") filterTest("typecast null literals should not be pushed down in simple predicates", From 04128292e6d145ec608166b532c960cac72a500c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 25 Jul 2018 11:46:58 +0200 Subject: [PATCH 07/17] fix OptimizeIn merge --- .../apache/spark/sql/catalyst/optimizer/expressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 447346dea7007..742b1793a8973 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,11 +218,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty => + case expr @ In(_, list) if list.isEmpty => // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case expr @ In(v, list) if expr.inSetConvertible => + If(IsNotNull(expr.inValues.valueExpression), FalseLiteral, Literal(null, BooleanType)) + case expr @ In(_, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, From f9b753666c4a3bceeacb90716b11babd0fc4ed3a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Jul 2018 13:27:54 +0200 Subject: [PATCH 08/17] move tests --- .../inputs/subquery/in-subquery/in-basic.sql | 14 +++++ .../subquery/in-subquery/in-basic.sql.out | 62 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 21 ------- 3 files changed, 76 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql new file mode 100644 index 0000000000000..f4ffc20086386 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql @@ -0,0 +1,14 @@ +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1); +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2); +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2); + +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b); +-- Invalid query, see SPARK-24341 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b); + +-- Aliasing is needed as a workaround for SPARK-24443 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b); +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out new file mode 100644 index 0000000000000..6e26426a5e39a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -0,0 +1,62 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b) +-- !query 3 schema +struct<1:int> +-- !query 3 output + + + +-- !query 4 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +(named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())) has 2 values, but the subquery has 1 output values.; + + +-- !query 5 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b) +-- !query 5 schema +struct +-- !query 5 output +1 + + +-- !query 6 +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b) +-- !query 6 schema +struct +-- !query 6 output +3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2f109e00d7346..9cf8c47fa6cf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2320,27 +2320,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } - test("SPARK-24341: IN subqueries with struct fields") { - Seq((1, 1)).toDF("a", "b").createOrReplaceTempView("tab") - checkAnswer(sql("select 1 from range(1) where (1, 1) in (select a, b from tab)"), Row(1)) - - Seq((1, 1)).toDF("a", "b").createOrReplaceTempView("tab_a") - Seq((1, 1)).toDF("na", "nb").createOrReplaceTempView("tab_b") - intercept[AnalysisException] { - sql("select 1 from tab_a where (a, b) not in (select (na, nb) from tab_b)").collect() - } - - testData2.select(struct("a", "b").as("record")).createOrReplaceTempView("struct_tab") - checkAnswer( - sql("select count(*) from struct_tab where record in " + - "(select (na as a, nb as b) from tab_b)"), - Row(1)) - checkAnswer( - sql("select count(*) from struct_tab where record not in " + - "(select (na as a, nb as b) from tab_b)"), - Row(5)) - } - test("SPARK-24165: CaseWhen/If - nullability of nested types") { val rows = new java.util.ArrayList[Row]() rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) From f5fa2c4b99a810c25a02e6d32550135d429c70c2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Jul 2018 15:17:08 +0200 Subject: [PATCH 09/17] fix error message according to comment --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++++------ .../subquery/in-subquery/in-basic.sql.out | 10 ++++++++- .../subq-input-typecheck.sql.out | 20 ++++++++++++++++-- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74bf1c167c4f8..ecfafdaa57d77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1451,16 +1451,25 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case i @ In(value, Seq(l @ ListQuery(_, _, exprId, _))) if value.resolved && !l.resolved => + case In(value, Seq(l @ ListQuery(_, _, exprId, _))) if value.resolved && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - val subqueryOutputNum = expr.plan.output.length - if (i.inValues.numValues != subqueryOutputNum) { - throw new AnalysisException(s"${i.sql} has ${i.inValues.numValues} values, but the " + - s"subquery has $subqueryOutputNum output values.") + val subqueryOutput = expr.plan.output + val resolvedIn = In(value, Seq(expr)) + if (resolvedIn.inValues.numValues != subqueryOutput.length) { + throw new AnalysisException( + s"""Cannot analyze ${resolvedIn.sql}. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${resolvedIn.inValues.numValues}. + |#columns in right hand side: ${subqueryOutput.length}. + |Left side columns: + |[${resolvedIn.inValues.children.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${subqueryOutput.map(_.sql).mkString(", ")}].""".stripMargin) } - In(value, Seq(expr)) + resolvedIn } } diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out index 6e26426a5e39a..05dcd0f311cab 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -41,7 +41,15 @@ select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -(named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())) has 2 values, but the subquery has 1 output values.; +Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[tab_a.`a1`, tab_a.`b1`]. +Right side columns: +[`named_struct(a2, a2, b2, b2)`].; -- !query 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index acc7a91a0509c..e11d5a336615f 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -92,7 +92,15 @@ t1a IN (SELECT t2a, t2b struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -(t1.`t1a` IN (listquery())) has 1 values, but the subquery has 2 output values.; +Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 1. +#columns in right hand side: 2. +Left side columns: +[t1.`t1a`]. +Right side columns: +[t2.`t2a`, t2.`t2b`].; -- !query 8 @@ -105,7 +113,15 @@ WHERE struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery())) has 2 values, but the subquery has 1 output values.; +Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[t1.`t1a`, t1.`t1b`]. +Right side columns: +[t2.`t2a`].; -- !query 9 From 571b2733a229d2271472cf60ede2f9072d437256 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Jul 2018 16:48:04 +0200 Subject: [PATCH 10/17] revert to Seq[Expression] --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++------ .../sql/catalyst/analysis/TypeCoercion.scala | 25 ++++--- .../spark/sql/catalyst/dsl/package.scala | 5 +- .../sql/catalyst/expressions/predicates.scala | 68 ++++++------------- .../sql/catalyst/optimizer/expressions.scala | 12 ++-- .../sql/catalyst/optimizer/subquery.scala | 12 ++-- .../sql/catalyst/parser/AstBuilder.scala | 9 ++- .../statsEstimation/FilterEstimation.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 7 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 7 +- .../analysis/ResolveSubquerySuite.scala | 5 +- .../catalyst/analysis/TypeCoercionSuite.scala | 12 ++-- .../catalog/ExternalCatalogSuite.scala | 4 +- .../expressions/CanonicalizeSuite.scala | 12 ++-- .../catalyst/expressions/PredicateSuite.scala | 42 ++++++------ .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 32 +++++---- .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../parser/ExpressionParserSuite.scala | 2 +- .../FilterEstimationSuite.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 7 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileSourceStrategy.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- .../datasources/DataSourceStrategySuite.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../spark/sql/hive/client/FiltersSuite.scala | 4 +- .../sql/hive/client/HiveClientSuite.scala | 26 +++---- 29 files changed, 154 insertions(+), 189 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ecfafdaa57d77..52040c4853e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -148,8 +148,6 @@ class Analyzer( ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), - Batch("Resolve IN values", Once, - ResolveInValues), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -249,20 +247,6 @@ class Analyzer( } } - /** - * Substitutes In values with an instance of [[InValues]]. - */ - object ResolveInValues extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case q => q transformExpressions { - case In(value, list) if !value.isInstanceOf[InValues] => value match { - case c: CreateNamedStruct => In(InValues(c.valExprs), list) - case other => In(InValues(Seq(other)), list) - } - } - } - } - /** * Replaces [[UnresolvedAlias]]s with concrete aliases. */ @@ -1451,21 +1435,22 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(_, _, exprId, _))) if value.resolved && !l.resolved => + case In(values, Seq(l @ ListQuery(_, _, exprId, _))) + if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) val subqueryOutput = expr.plan.output - val resolvedIn = In(value, Seq(expr)) - if (resolvedIn.inValues.numValues != subqueryOutput.length) { + val resolvedIn = In(values, Seq(expr)) + if (values.length != subqueryOutput.length) { throw new AnalysisException( s"""Cannot analyze ${resolvedIn.sql}. |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. - |#columns in left hand side: ${resolvedIn.inValues.numValues}. + |#columns in left hand side: ${values.length}. |#columns in right hand side: ${subqueryOutput.length}. |Left side columns: - |[${resolvedIn.inValues.children.map(_.sql).mkString(", ")}]. + |[${values.map(_.sql).mkString(", ")}]. |Right side columns: |[${subqueryOutput.map(_.sql).mkString(", ")}].""".stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index bdbb4d4ed2639..dec1a94e4f866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -450,45 +450,44 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(_, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && i.inValues.numValues == sub.output.length => + case i @ In(lhs, Seq(ListQuery(sub, children, exprId, _))) + if !i.resolved && lhs.length == sub.output.length => // LHS is the value expressions of IN subquery. - val lhs = i.inValues // RHS is the subquery output. val rhs = sub.output - val commonTypes = lhs.children.zip(rhs).flatMap { case (l, r) => + val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an // IN subquery expression. - if (commonTypes.length == lhs.numValues) { + if (commonTypes.length == lhs.length) { val castedRhs = rhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val newLhsChildren = lhs.children.zip(commonTypes).map { + val newLhs = lhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } val newSub = Project(castedRhs, sub) - In(InValues(newLhsChildren), Seq(ListQuery(newSub, children, exprId, newSub.output))) + In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) } else { i } - case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(a.dataType +: b.map(_.dataType)) match { - case Some(finalDataType: StructType) if i.inValues.numValues > 1 => - val newValues = a.children.zip(finalDataType.fields.map(_.dataType)).map { + case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) => + findWiderCommonType(i.value.dataType +: b.map(_.dataType)) match { + case Some(finalDataType: StructType) if a.length > 1 => + val newValues = a.zip(finalDataType.fields.map(_.dataType)).map { case (expr, dataType) => Cast(expr, dataType) } - In(InValues(newValues), b.map(Cast(_, finalDataType))) + In(newValues, b.map(Cast(_, finalDataType))) case Some(finalDataType) => - In(InValues(a.children.map(Cast(_, finalDataType))), b.map(Cast(_, finalDataType))) + In(a.map(Cast(_, finalDataType)), b.map(Cast(_, finalDataType))) case None => i } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 89e8c998f740d..6a65d18c5cfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,7 +88,10 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) + def in(list: Expression*): Expression = expr match { + case c: CreateNamedStruct => In(c.valExprs, list) + case other => In(Seq(other), list) + } def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e2605f89329ec..a1b0292755564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,30 +138,6 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } -case class InValues(children: Seq[Expression]) extends Expression { - require(children.nonEmpty, "Value of IN expression cannot be empty") - - @transient lazy val numValues: Int = children.length - @transient lazy val valueExpression: Expression = if (children.length > 1) { - CreateNamedStruct(children.zipWithIndex.flatMap { - case (v: NamedExpression, _) => Seq(Literal(v.name), v) - case (v, idx) => Seq(Literal(s"_$idx"), v) - }) - } else { - children.head - } - override def nullable: Boolean = children.exists(_.nullable) - override def dataType: DataType = valueExpression.dataType - override def sql: String = valueExpression.sql - override def toString: String = valueExpression.toString - - override def eval(input: InternalRow): Any = - throw new RuntimeException("InValues cannot be evaluated.") - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode) = - throw new RuntimeException("InValues cannot generate code.") -} - /** * Evaluates to `true` if `list` contains `value`. @@ -185,13 +161,18 @@ case class InValues(children: Seq[Expression]) extends Expression { true """) // scalastyle:on line.size.limit -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") - // During analysis we replace any Expression set as value with a InValues expression so we are - // sure it is an instance of InValues - @transient lazy val inValues = value.asInstanceOf[InValues] + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } override def checkInputDataTypes(): TypeCheckResult = { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, @@ -199,24 +180,19 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case inValues: InValues => inValues.children - case expr => Seq(expr) - } - if (valExprs.length != childOutputs.length) { + if (values.length != childOutputs.length) { TypeCheckResult.TypeCheckFailure( s""" |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. + |#columns in left hand side: ${values.length}. |#columns in right hand side: ${childOutputs.length}. |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. + |[${values.map(_.sql).mkString(", ")}]. |Right side columns: |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) } else { - val mismatchedColumns = valExprs.zip(childOutputs).flatMap { + val mismatchedColumns = values.zip(childOutputs).flatMap { case (l, r) if l.dataType != r.dataType => Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") case _ => None @@ -228,7 +204,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |Mismatched columns: |[${mismatchedColumns.mkString(", ")}] |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |[${values.map(_.dataType.catalogString).mkString(", ")}]. |Right side: |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } @@ -241,20 +217,17 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } - override def children: Seq[Expression] = value +: list + override def children: Seq[Expression] = values ++: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = value match { - case i: InValues => i.valueExpression.foldable && list.forall(_.foldable) - case _ => children.forall(_.foldable) - } + override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { - val evaluatedValue = inValues.valueExpression.eval(input) + val evaluatedValue = value.eval(input) if (evaluatedValue == null) { null } else { @@ -277,7 +250,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaDataType = CodeGenerator.javaType(value.dataType) - val valueGen = inValues.valueExpression.genCode(ctx) + val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: // -1 means no matches found and there is at least one value in the list evaluated to null @@ -339,9 +312,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def sql: String = { - val childrenSQL = children.map(_.sql) - val valueSQL = childrenSQL.head - val listSQL = childrenSQL.tail.mkString(", ") + val valueSQL = value.sql + val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e0717e8ef8933..845078e30f6ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -212,7 +212,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 1. Converts the predicate to false when the list is empty and * the value is not nullable. * 2. Removes literal repetitions. - * 3. Replaces [[In (value, seq[Literal])]] with optimized version + * 3. Replaces [[In (values, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { @@ -221,18 +221,18 @@ object OptimizeIn extends Rule[LogicalPlan] { case expr @ In(_, list) if list.isEmpty => // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(expr.inValues.valueExpression), FalseLiteral, Literal(null, BooleanType)) + If(IsNotNull(expr.value), FalseLiteral, Literal(null, BooleanType)) case expr @ In(_, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !expr.inValues.valueExpression.isInstanceOf[CreateNamedStructLike] + && !expr.value.isInstanceOf[CreateNamedStructLike] && !newList.head.isInstanceOf[CreateNamedStructLike]) { - EqualTo(expr.inValues.valueExpression, newList.head) + EqualTo(expr.value, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(expr.inValues.valueExpression, HashSet() ++ hSet) + InSet(expr.value, HashSet() ++ hSet) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 @@ -504,7 +504,7 @@ object NullPropagation extends Rule[LogicalPlan] { } // If the value expression is NULL then transform the In expression to null literal. - case In(InValues(Seq(Literal(null, _))), _) => Literal.create(null, BooleanType) + case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 3ee01705fcdf3..d2954d20ceda2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -91,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = value.children.zip(sub.output).map(EqualTo.tupled) + case (p, In(values, Seq(ListQuery(sub, conditions, _, _)))) => + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(In(values, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = value.children.zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -144,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case In(values, Seq(ListQuery(sub, conditions, _, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = value.children.zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 49f578a24aaeb..05c27b2350c3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1104,6 +1104,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1112,9 +1117,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(In(getValueExpressions(e), Seq(ListQuery(plan(ctx.query))))) case SqlBaseParser.IN => - invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4a1bc32ae7798..a14123d0a6c12 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -164,7 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(InValues(Seq(ar: Attribute)), expList) + case In(Seq(ar: Attribute), expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index f4cfed4a91594..2adaecc625595 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -528,7 +528,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(In(Seq(a), Seq(ListQuery(LocalRelation(b)))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -537,12 +537,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter( + Or(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index d23ba000b6e8a..72b428acc277f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -276,20 +276,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: different types in inlist but can be converted to a common type") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + val plan = Project( + Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation()) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 71de693909c98..03129b9e86234 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, InValues, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,8 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter( - In(InValues(Seq(a)), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter(In(Seq(a), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index bac750f2b5f6d..957a95aaf9082 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1361,16 +1361,16 @@ class TypeCoercionSuite extends AnalysisTest { // InConversion val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, - In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1))), - In(InValues(Seq(UnresolvedAttribute("a"))), Seq(Literal(1))) + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))), + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))) ) ruleTest(inConversion, - In(InValues(Seq(Literal("test"))), Seq(UnresolvedAttribute("a"), Literal(1))), - In(InValues(Seq(Literal("test"))), Seq(UnresolvedAttribute("a"), Literal(1))) + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))) ) ruleTest(inConversion, - In(InValues(Seq(Literal("a"))), Seq(Literal(1), Literal("b"))), - In(InValues(Seq(Cast(Literal("a"), StringType))), + In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))), + In(Seq(Cast(Literal("a"), StringType)), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index a0f892bb4eb14..51464332dbc80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac checkAnswer(tbl2, Seq.empty, Set(part1, part2)) checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In(InValues(Seq('a.int * 10)), Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In(InValues(Seq('a.int)), Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 28e6940f3cca3..b78d23e3472e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(idAttr, Seq(Literal(1), Literal(2))) - val in2 = In(idAttr, Seq(Literal(2), Literal(1))) - val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2))) + val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1))) + val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index f6a9bb6ba655b..6ae4e7ccac129 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -125,32 +125,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("basic IN predicate test") { - checkEvaluation(In(InValues(Seq(NonFoldableLiteral.create(null, IntegerType))), Seq(Literal(1), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(InValues(Seq(NonFoldableLiteral.create(null, IntegerType))), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(InValues(Seq(NonFoldableLiteral.create(null, IntegerType))), Seq.empty), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq.empty), null) - checkEvaluation(In(InValues(Seq(Literal(1))), Seq.empty), false) + checkEvaluation(In(Seq(Literal(1)), Seq.empty), false) checkEvaluation( - In(InValues(Seq(Literal(1))), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(InValues(Seq(Literal(1))), + In(Seq(Literal(1)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(InValues(Seq(Literal(2))), + checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(InValues(Seq(Literal(1))), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(InValues(Seq(Literal(2))), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(InValues(Seq(Literal(3))), Seq(Literal(1), Literal(2))), false) - checkEvaluation(And(In(InValues(Seq(Literal(1))), Seq(Literal(1), Literal(2))), - In(InValues(Seq(Literal(2))), Seq(Literal(1), Literal(2)))), true) + checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(3)), Seq(Literal(1), Literal(2))), false) + checkEvaluation(And(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), + In(Seq(Literal(2)), Seq(Literal(1), Literal(2)))), true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(InValues(Seq(ns)), Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(InValues(Seq(ns)), Seq(ns)), null) - checkEvaluation(In(InValues(Seq(Literal("a"))), Seq(ns)), null) - checkEvaluation(In(InValues(Seq(Literal("^Ba*n"))), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(InValues(Seq(Literal("^Ba*n"))), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(InValues(Seq(Literal("^Ba*n"))), Seq(Literal("aa"), Literal("^n"))), false) + checkEvaluation(In(Seq(ns), Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(Seq(ns), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("a")), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("^Ba*n"), ns)), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^n"))), false) } @@ -187,7 +187,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(InValues(Seq(input.head)), input.slice(1, 10)), expected) + checkEvaluation(In(Seq(input.head), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => @@ -243,12 +243,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22501: In should not generate codes beyond 64KB") { val N = 3000 val sets = (1 to N).map(i => Literal(i.toDouble)) - checkEvaluation(In(InValues(Seq(Literal(1.0D))), sets), true) + checkEvaluation(In(Seq(Literal(1.0D)), sets), true) } test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() - In(InValues(Seq(Literal(1.0D))), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + In(Seq(Literal(1.0D)), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 641c89873dcc4..09d8f2e4d1ad4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(Literal(1), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal(1)), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 86522a6a54ed5..eacf9746e8607 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -45,9 +45,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +56,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +69,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,7 +79,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .where(In(Seq(UnresolvedAttribute("a")), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -94,13 +94,15 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where( + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where( + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) @@ -109,7 +111,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal.create(null, NullType)), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -125,7 +127,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute)") { val originalQuery = testRelation - .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) + .where(In(Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -141,7 +143,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(Literal.create(null, StringType), + .select(In(Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -156,7 +158,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -179,7 +181,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: one element in list gets transformed to EqualTo.") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1)))) .analyze val optimized = Optimize.execute(originalQuery) @@ -195,7 +197,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(Literal("a"), Nil)) + .where(In(Seq(Literal("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) @@ -211,7 +213,7 @@ class OptimizeInSuite extends PlanTest { "when value is nullable") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Nil)) + .where(In(Seq(UnresolvedAttribute("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index e2b96ece73fd3..02d03fb8b8d57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, InValues, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .where(In(Seq('a), Seq(ListQuery(correlatedSubquery)))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4d422d8506fc..af2d0c2eabb7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In('a, Seq(ListQuery(table("c").select('b))))) + In(Seq('a), Seq(ListQuery(table("c").select('b))))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 6bcee2acb641b..b84334eaf49bb 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -440,8 +440,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(InValues(Seq(attrDate)), - Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Filter(In(Seq(attrDate), Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = Some(3), min = Some(d20170103), max = Some(d20170105), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4eee3de5f7d4e..3893eae1b8a71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -793,7 +793,12 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + def isin(list: Any*): Column = withExpr { + expr match { + case c: CreateNamedStruct => In(c.valExprs, list.map(lit(_).expr)) + case other => In(Seq(other), list.map(lit(_).expr)) + } + } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 1f8922d32e02b..e0e817d98ef9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -219,7 +219,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(InValues(Seq(a: AttributeReference)), list: Seq[Expression]) + case In(Seq(a: AttributeReference), list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 93b1309bc119a..92dc1714618eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -475,7 +475,7 @@ object DataSourceStrategy { // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(InValues(Seq(a: Attribute)), list) + case expressions.In(Seq(a: Attribute), list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 0b883bcea1666..68e7bb07b794a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,7 +86,7 @@ object FileSourceStrategy extends Strategy with Logging { expr match { case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => getBucketSetFromValue(a, v) - case expressions.In(InValues(Seq(a: Attribute)), list) + case expressions.In(Seq(a: Attribute), list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) case expressions.InSet(a: Attribute, hset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 89333178dd4f3..0610f51b721a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In, InValues} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} @@ -459,7 +459,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), - Seq(In(InValues(Seq(attribute)), Nil)), testRelation) + Seq(In(Seq(attribute), Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f40a7c5c9bae0..50f0b54790b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -51,7 +51,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) - testTranslateFilter(In(InValues(Seq(attrInt)), Seq(1, 2, 3)), + testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8bcfabad78cb9..9df7a9d86c411 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -685,7 +685,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(InValues(Seq(ExtractAttribute(NonVarcharAttribute(name)))), + case In(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableLiterals(values)) if useAdvanced => Some(convertInToOr(name, values)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index d454c2d15ae95..2a4efd0cce6e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -73,13 +73,13 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") filterTest("SPARK-24879 null literals should be ignored for IN constructs", - (InValues(Seq(a("intcol", IntegerType))) in (Literal(1), Literal(null))) :: Nil, + (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil, "(intcol = 1)") // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization // will be applied by Catalyst, this filter converter does not need to account for this. filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE", - (InValues(Seq(a("intcol", IntegerType))) in Literal(null)) :: Nil, + (a("intcol", IntegerType) in Literal(null)) :: Nil, "") filterTest("typecast null literals should not be pushed down in simple predicates", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 91b4578a75cb4..b66beb88a329c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -74,12 +74,6 @@ class HiveClientSuite(version: String) } } - private def analyzeIn(expr: Expression): Expression = expr match { - case In(CreateNamedStruct(children), list) => In(InValues(children), list) - case In(v, list) => In(InValues(Seq(v)), list) - case other => other - } - override def beforeAll() { super.beforeAll() client = init(true) @@ -178,7 +172,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - analyzeIn(attr("ds").in(20170102, 20170103)), + attr("ds").in(20170102, 20170103), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -186,7 +180,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { testMetastorePartitionFiltering( - analyzeIn(attr("ds").cast(LongType).in(20170102L, 20170103L)), + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -194,30 +188,30 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - analyzeIn(attr("ds").in(20170102, 20170103)), + attr("ds").in(20170102, 20170103), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.inValues.valueExpression, list.map(_.eval(EmptyRow)).toSet) + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") { testMetastorePartitionFiltering( - analyzeIn(attr("ds").cast(LongType).in(20170102L, 20170103L)), + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.inValues.valueExpression, list.map(_.eval(EmptyRow)).toSet) + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - analyzeIn(attr("chunk").in("ab", "ba")), + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -225,12 +219,12 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - analyzeIn(attr("chunk").in("ab", "ba")), + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.inValues.valueExpression, list.map(_.eval(EmptyRow)).toSet) + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } @@ -253,7 +247,7 @@ class HiveClientSuite(version: String) "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering(analyzeIn(attr("chunk").in("ab", "ba")) && + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } From 3af5b78bdd403aa87bb25819296c537c0cbd5260 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 2 Aug 2018 14:35:44 +0200 Subject: [PATCH 11/17] introduce InSubquery --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../sql/catalyst/analysis/TypeCoercion.scala | 16 +-- .../spark/sql/catalyst/dsl/package.scala | 9 +- .../sql/catalyst/expressions/predicates.scala | 109 ++++++++++-------- .../sql/catalyst/optimizer/expressions.scala | 17 +-- .../sql/catalyst/optimizer/subquery.scala | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../statsEstimation/FilterEstimation.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 6 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 10 +- .../analysis/ResolveSubquerySuite.scala | 5 +- .../catalyst/analysis/TypeCoercionSuite.scala | 12 +- .../catalog/ExternalCatalogSuite.scala | 4 +- .../expressions/CanonicalizeSuite.scala | 12 +- .../catalyst/expressions/PredicateSuite.scala | 42 +++---- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 30 ++--- .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../parser/ExpressionParserSuite.scala | 2 +- .../FilterEstimationSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 7 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 3 +- .../datasources/FileSourceStrategy.scala | 2 +- .../subquery/in-subquery/in-basic.sql.out | 8 +- .../subq-input-typecheck.sql.out | 16 +-- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/DataSourceStrategySuite.scala | 3 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/client/HiveClientSuite.scala | 12 +- 30 files changed, 187 insertions(+), 176 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3a3d00fd4ddf3..fc442c7fb5383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1422,24 +1422,24 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(values, Seq(l @ ListQuery(_, _, exprId, _))) + case InSubquery(values, l @ ListQuery(_, _, exprId, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) val subqueryOutput = expr.plan.output - val resolvedIn = In(values, Seq(expr)) + val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) if (values.length != subqueryOutput.length) { throw new AnalysisException( s"""Cannot analyze ${resolvedIn.sql}. |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. - |#columns in left hand side: ${values.length}. - |#columns in right hand side: ${subqueryOutput.length}. + |#columns in left hand side: ${values.length} + |#columns in right hand side: ${subqueryOutput.length} |Left side columns: - |[${values.map(_.sql).mkString(", ")}]. + |[${values.map(_.sql).mkString(", ")}] |Right side columns: - |[${subqueryOutput.map(_.sql).mkString(", ")}].""".stripMargin) + |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) } resolvedIn } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index da809b6dd1bcc..648aa9ee8fa0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -456,7 +456,7 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(lhs, Seq(ListQuery(sub, children, exprId, _))) + case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) if !i.resolved && lhs.length == sub.output.length => // LHS is the value expressions of IN subquery. // RHS is the subquery output. @@ -480,20 +480,14 @@ object TypeCoercion { } val newSub = Project(castedRhs, sub) - In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) + InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) } else { i } - case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) => - findWiderCommonType(i.value.dataType +: b.map(_.dataType)) match { - case Some(finalDataType: StructType) if a.length > 1 => - val newValues = a.zip(finalDataType.fields.map(_.dataType)).map { - case (expr, dataType) => Cast(expr, dataType) - } - In(newValues, b.map(Cast(_, finalDataType))) - case Some(finalDataType) => - In(a.map(Cast(_, finalDataType)), b.map(Cast(_, finalDataType))) + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => + findWiderCommonType(i.children.map(_.dataType)) match { + case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 6a65d18c5cfd8..eb7907e5abe35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,9 +88,12 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = expr match { - case c: CreateNamedStruct => In(c.valExprs, list) - case other => In(Seq(other), list) + def in(list: Expression*): Expression = list match { + case Seq(l: ListQuery) => expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, l) + case other => InSubquery(Seq(other), l) + } + case _ => In(expr, list) } def like(other: Expression): Expression = Like(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a1b0292755564..149bd79278a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,6 +138,66 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override def children: Seq[Expression] = values :+ query + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ($query)" + override def sql: String = s"(${value.sql} IN (${query.sql}))" +} + /** * Evaluates to `true` if `list` contains `value`. @@ -161,63 +221,22 @@ case class Not(child: Expression) true """) // scalastyle:on line.size.limit -case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") - @transient lazy val value: Expression = if (values.length > 1) { - CreateNamedStruct(values.zipWithIndex.flatMap { - case (v: NamedExpression, _) => Seq(Literal(v.name), v) - case (v, idx) => Seq(Literal(s"_$idx"), v) - }) - } else { - values.head - } - override def checkInputDataTypes(): TypeCheckResult = { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { - list match { - case ListQuery(_, _, _, childOutputs) :: Nil => - if (values.length != childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${values.length}. - |#columns in right hand side: ${childOutputs.length}. - |Left side columns: - |[${values.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = values.zip(childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${values.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - case _ => - TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") - } + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } - override def children: Seq[Expression] = values ++: list + override def children: Seq[Expression] = value +: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 845078e30f6ca..251cadda02c2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -212,27 +212,27 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 1. Converts the predicate to false when the list is empty and * the value is not nullable. * 2. Removes literal repetitions. - * 3. Replaces [[In (values, seq[Literal])]] with optimized version + * 3. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case expr @ In(_, list) if list.isEmpty => + case In(v, list) if list.isEmpty => // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(expr.value), FalseLiteral, Literal(null, BooleanType)) - case expr @ In(_, list) if expr.inSetConvertible => + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) + case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !expr.value.isInstanceOf[CreateNamedStructLike] + && !v.isInstanceOf[CreateNamedStructLike] && !newList.head.isInstanceOf[CreateNamedStructLike]) { - EqualTo(expr.value, newList.head) + EqualTo(v, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(expr.value, HashSet() ++ hSet) + InSet(v, HashSet() ++ hSet) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 @@ -504,7 +504,8 @@ object NullPropagation extends Rule[LogicalPlan] { } // If the value expression is NULL then transform the In expression to null literal. - case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) + case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index d2954d20ceda2..e9b7a8b76e683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -91,12 +91,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(values, Seq(ListQuery(sub, conditions, _, _)))) => + case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(values, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -144,7 +144,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(values, Seq(ListQuery(sub, conditions, _, _))) => + case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index eecb8746c8f8b..db504505c0ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1117,9 +1117,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(getValueExpressions(e), Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => - invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index a14123d0a6c12..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -164,7 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(Seq(ar: Attribute), expList) + case In(ar: Attribute, expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 2adaecc625595..b9ae3d7c6d168 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -528,7 +528,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(Seq(a), Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -537,13 +537,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) val plan2 = Filter( - Or(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 9967a9656af31..9fb50a5e565e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -276,21 +276,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: different types in inlist but can be converted to a common type") { - val plan = Project( - Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, - LocalRelation()) + val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + LocalRelation() + ) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 03129b9e86234..74a8590b5eefe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(Seq(a), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter( + InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 957a95aaf9082..4161f09c63190 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1361,16 +1361,16 @@ class TypeCoercionSuite extends AnalysisTest { // InConversion val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, - In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))), - In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))) + In(UnresolvedAttribute("a"), Seq(Literal(1))), + In(UnresolvedAttribute("a"), Seq(Literal(1))) ) ruleTest(inConversion, - In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))), - In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))) + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) ) ruleTest(inConversion, - In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))), - In(Seq(Cast(Literal("a"), StringType)), + In(Literal("a"), Seq(Literal(1), Literal("b"))), + In(Cast(Literal("a"), StringType), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 51464332dbc80..b376108399c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac checkAnswer(tbl2, Seq.empty, Set(part1, part2)) checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index b78d23e3472e8..28e6940f3cca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2))) - val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1))) - val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6ae4e7ccac129..ee215d3903629 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -125,32 +125,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("basic IN predicate test") { - checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(Literal(1), + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq.empty), + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkEvaluation(In(Seq(Literal(1)), Seq.empty), false) + checkEvaluation(In(Literal(1), Seq.empty), false) checkEvaluation( - In(Seq(Literal(1)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Seq(Literal(1)), + In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Seq(Literal(2)), + checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Seq(Literal(3)), Seq(Literal(1), Literal(2))), false) - checkEvaluation(And(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), - In(Seq(Literal(2)), Seq(Literal(1), Literal(2)))), true) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(And(In(Literal(1), Seq(Literal(1), Literal(2))), + In(Literal(2), Seq(Literal(1), Literal(2)))), true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(Seq(ns), Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(Seq(ns), Seq(ns)), null) - checkEvaluation(In(Seq(Literal("a")), Seq(ns)), null) - checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^n"))), false) + checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(ns, Seq(ns)), null) + checkEvaluation(In(Literal("a"), Seq(ns)), null) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } @@ -187,7 +187,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(Seq(input.head), input.slice(1, 10)), expected) + checkEvaluation(In(input.head, input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => @@ -243,12 +243,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22501: In should not generate codes beyond 64KB") { val N = 3000 val sets = (1 to N).map(i => Literal(i.toDouble)) - checkEvaluation(In(Seq(Literal(1.0D)), sets), true) + checkEvaluation(In(Literal(1.0D), sets), true) } test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() - In(Seq(Literal(1.0D)), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 09d8f2e4d1ad4..641c89873dcc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(Seq(Literal(1)), Seq(Literal(1), Literal(2)))) + .where(In(Literal(1), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index eacf9746e8607..230326d7df241 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -45,9 +45,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(Seq(UnresolvedAttribute("b")), + .where(In(UnresolvedAttribute("b"), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +56,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) - .where(In(Seq(UnresolvedAttribute("b")), + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) + .where(In(UnresolvedAttribute("b"), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +69,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,7 +79,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), (1 to 11).map(Literal(_)))) + .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -95,14 +95,14 @@ class OptimizeInSuite extends PlanTest { val originalQuery = testRelation .where( - In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where( - In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) @@ -111,7 +111,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(Seq(Literal.create(null, NullType)), Seq(Literal(1), Literal(2)))) + .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -127,7 +127,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute)") { val originalQuery = testRelation - .where(In(Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b")))) + .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -143,7 +143,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(Seq(Literal.create(null, StringType)), + .select(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -158,7 +158,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -181,7 +181,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: one element in list gets transformed to EqualTo.") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) .analyze val optimized = Optimize.execute(originalQuery) @@ -197,7 +197,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(Seq(Literal("a")), Nil)) + .where(In(Literal("a"), Nil)) .analyze val optimized = Optimize.execute(originalQuery) @@ -213,7 +213,7 @@ class OptimizeInSuite extends PlanTest { "when value is nullable") { val originalQuery = testRelation - .where(In(Seq(UnresolvedAttribute("a")), Nil)) + .where(In(UnresolvedAttribute("a"), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 02d03fb8b8d57..8a5a55146726e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In(Seq('a), Seq(ListQuery(correlatedSubquery)))) + .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index af2d0c2eabb7e..8f92d4c90dc49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In(Seq('a), Seq(ListQuery(table("c").select('b))))) + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index b84334eaf49bb..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -440,7 +440,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(Seq(attrDate), Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = Some(3), min = Some(d20170103), max = Some(d20170105), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3893eae1b8a71..4eee3de5f7d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -793,12 +793,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { - expr match { - case c: CreateNamedStruct => In(c.valExprs, list.map(lit(_).expr)) - case other => In(Seq(other), list.map(lit(_).expr)) - } - } + def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 414e35270f94a..6012aba1acbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -235,7 +235,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(Seq(a: AttributeReference), list: Seq[Expression]) + case In(a: AttributeReference, list: Seq[Expression]) if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index bf58e775a117c..f2a49423d4ec7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -475,8 +475,7 @@ object DataSourceStrategy { // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(Seq(a: Attribute), list) - if list.forall(_.isInstanceOf[Literal]) => + case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, hSet.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 68e7bb07b794a..fe27b78bf3360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,7 +86,7 @@ object FileSourceStrategy extends Strategy with Logging { expr match { case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => getBucketSetFromValue(a, v) - case expressions.In(Seq(a: Attribute), list) + case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) case expressions.InSet(a: Attribute, hset) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out index 05dcd0f311cab..088db55d66406 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -44,12 +44,12 @@ org.apache.spark.sql.AnalysisException Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. +#columns in left hand side: 2 +#columns in right hand side: 1 Left side columns: -[tab_a.`a1`, tab_a.`b1`]. +[tab_a.`a1`, tab_a.`b1`] Right side columns: -[`named_struct(a2, a2, b2, b2)`].; +[`named_struct(a2, a2, b2, b2)`]; -- !query 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index e11d5a336615f..c52e5706deeee 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -95,12 +95,12 @@ org.apache.spark.sql.AnalysisException Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 1. -#columns in right hand side: 2. +#columns in left hand side: 1 +#columns in right hand side: 2 Left side columns: -[t1.`t1a`]. +[t1.`t1a`] Right side columns: -[t2.`t2a`, t2.`t2b`].; +[t2.`t2a`, t2.`t2b`]; -- !query 8 @@ -116,12 +116,12 @@ org.apache.spark.sql.AnalysisException Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. +#columns in left hand side: 2 +#columns in right hand side: 1 Left side columns: -[t1.`t1a`, t1.`t1b`]. +[t1.`t1a`, t1.`t1b`] Right side columns: -[t2.`t2a`].; +[t2.`t2a`]; -- !query 9 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0610f51b721a0..efc2f20a907f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -459,7 +459,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), - Seq(In(Seq(attribute), Nil)), testRelation) + Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 50f0b54790b4c..f20aded169e44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -51,8 +51,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) - testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), - Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 9df7a9d86c411..66c4f48ad19a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -685,7 +685,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(Seq(ExtractAttribute(NonVarcharAttribute(name))), + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) if useAdvanced => Some(convertInToOr(name, values)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index b66beb88a329c..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -192,8 +192,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } @@ -204,8 +204,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } @@ -223,8 +223,8 @@ class HiveClientSuite(version: String) 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } From 0f00a06a1853cb13d1d156bafcb85973c92e2b8e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 2 Aug 2018 14:40:28 +0200 Subject: [PATCH 12/17] simplify diff --- .../sql/catalyst/expressions/PredicateSuite.scala | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ee215d3903629..5d8996f016b2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -129,15 +129,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(2))), null) checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), - null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation( - In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), - Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Literal(2), - Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + true) + checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) From 53e3d961a0cde6d6ab6b4c8b86b9134b9532f776 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 2 Aug 2018 14:43:42 +0200 Subject: [PATCH 13/17] remove unneeded changes --- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 8 +++++--- .../spark/sql/catalyst/optimizer/OptimizeInSuite.scala | 6 ++---- .../sql/execution/datasources/DataSourceStrategy.scala | 2 +- .../scala/org/apache/spark/sql/hive/client/HiveShim.scala | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 5d8996f016b2d..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -139,8 +139,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation(And(In(Literal(1), Seq(Literal(1), Literal(2))), - In(Literal(2), Seq(Literal(1), Literal(2)))), true) + checkEvaluation( + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + Literal(2)))), + true) val ns = NonFoldableLiteral.create(null, StringType) checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) @@ -185,7 +187,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(input.head, input.slice(1, 10)), expected) + checkEvaluation(In(input(0), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 230326d7df241..86522a6a54ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -94,15 +94,13 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where( - In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where( - In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f2a49423d4ec7..e1b049b6ceaba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -475,7 +475,7 @@ object DataSourceStrategy { // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) => + case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, hSet.toArray.map(toScala))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 66c4f48ad19a6..bc9d4cd7f4181 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -685,8 +685,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(ExtractAttribute(NonVarcharAttribute(name)), - ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) From 45a91fc4b252967cf99c88331f51a702edadbaa2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 3 Aug 2018 13:21:51 +0200 Subject: [PATCH 14/17] fix test error --- .../org/apache/spark/sql/catalyst/expressions/subquery.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 6acc87a3e7367..fc1caed84e272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper { def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { case _: Exists | Not(_: Exists) => false - case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false + case _: InSubquery | Not(_: InSubquery) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { - case In(_, Seq(_: ListQuery)) => true + case _: InSubquery => true case _ => false }.isDefined }.isDefined From cb3467be92c1f7c8ed313ff1b37a00f82d59eda6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 3 Aug 2018 18:35:42 +0200 Subject: [PATCH 15/17] remove ListQuery --- .../sql/catalyst/analysis/Analyzer.scala | 13 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../catalyst/analysis/timeZoneAnalysis.scala | 4 +- .../spark/sql/catalyst/dsl/package.scala | 11 +- .../sql/catalyst/expressions/predicates.scala | 63 +-------- .../sql/catalyst/expressions/subquery.scala | 126 +++++++++++++----- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 10 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 6 +- .../analysis/ResolveSubquerySuite.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 4 +- .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../optimizer/RewriteSubquerySuite.scala | 3 +- .../parser/ExpressionParserSuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 4 +- 16 files changed, 128 insertions(+), 134 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc442c7fb5383..bd0da3029fc07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1422,16 +1422,15 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case InSubquery(values, l @ ListQuery(_, _, exprId, _)) - if values.forall(_.resolved) && !l.resolved => - val expr = resolveSubQuery(l, plans)((plan, exprs) => { - ListQuery(plan, exprs, exprId, plan.output) + case i @ InSubquery(values, _, _, exprId, _) + if values.forall(_.resolved) && !i.resolved => + val expr = resolveSubQuery(i, plans)((plan, exprs) => { + InSubquery(values, plan, exprs, exprId, plan.output) }) val subqueryOutput = expr.plan.output - val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) if (values.length != subqueryOutput.length) { throw new AnalysisException( - s"""Cannot analyze ${resolvedIn.sql}. + s"""Cannot analyze ${expr.sql}. |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. |#columns in left hand side: ${values.length} @@ -1441,7 +1440,7 @@ class Analyzer( |Right side columns: |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) } - resolvedIn + expr } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 648aa9ee8fa0e..41e2cede4645d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -456,7 +456,7 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) + case i @ InSubquery(lhs, sub, children, exprId, _) if !i.resolved && lhs.length == sub.output.length => // LHS is the value expressions of IN subquery. // RHS is the subquery output. @@ -480,7 +480,7 @@ object TypeCoercion { } val newSub = Project(castedRhs, sub) - InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) + InSubquery(newLhs, newSub, children, exprId, newSub.output) } else { i } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index a27aa845bf0ae..cb9d1c67e5ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, InSubquery, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -34,7 +34,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { // the types between the value expression and list query expression of IN expression. // We need to subject the subquery plan through ResolveTimeZone again to setup timezone // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(apply(e.plan)) + case e: InSubquery => e.withNewPlan(apply(e.plan)) } override def apply(plan: LogicalPlan): LogicalPlan = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index eb7907e5abe35..69f1ee8e4538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,12 +88,11 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = list match { - case Seq(l: ListQuery) => expr match { - case c: CreateNamedStruct => InSubquery(c.valExprs, l) - case other => InSubquery(Seq(other), l) - } - case _ => In(expr, list) + def in(list: Expression*): Expression = In(expr, list) + + def in(plan: LogicalPlan): Expression = expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, plan) + case other => InSubquery(Seq(other), plan) } def like(other: Expression): Expression = Like(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 149bd79278a54..dadfc7735c388 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -104,8 +104,7 @@ trait PredicateHelper { protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { // Non-deterministic expressions are not allowed as join conditions. case e if !e.deterministic => false - case _: ListQuery | _: Exists => - // A ListQuery defines the query which we want to search in an IN subquery expression. + case _: InSubquery | _: Exists => // Currently the only way to evaluate an IN subquery is to convert it to a // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. // It cannot be evaluated as part of a Join operator. @@ -138,66 +137,6 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } -/** - * Evaluates to `true` if `values` are returned in `query`'s result set. - */ -case class InSubquery(values: Seq[Expression], query: ListQuery) - extends Predicate with Unevaluable { - - @transient lazy val value: Expression = if (values.length > 1) { - CreateNamedStruct(values.zipWithIndex.flatMap { - case (v: NamedExpression, _) => Seq(Literal(v.name), v) - case (v, idx) => Seq(Literal(s"_$idx"), v) - }) - } else { - values.head - } - - - override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, - ignoreNullability = true) - if (mismatchOpt) { - if (values.length != query.childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${values.length}. - |#columns in right hand side: ${query.childOutputs.length}. - |Left side columns: - |[${values.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = values.zip(query.childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${values.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") - } - } - - override def children: Seq[Expression] = values :+ query - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"$value IN ($query)" - override def sql: String = s"(${value.sql} IN (${query.sql}))" -} - /** * Evaluates to `true` if `list` contains `value`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index fc1caed84e272..5eb7c158a7b27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -73,7 +75,7 @@ object SubqueryExpression { */ def hasInOrExistsSubquery(e: Expression): Boolean = { e.find { - case _: ListQuery | _: Exists => true + case _: InSubquery | _: Exists => true case _ => false }.isDefined } @@ -270,65 +272,121 @@ object ScalarSubquery { } /** - * A [[ListQuery]] expression defines the query which we want to search in an IN subquery - * expression. It should and can only be used in conjunction with an IN expression. + * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. * * For example (SQL): * {{{ * SELECT * * FROM a - * WHERE a.id IN (SELECT id - * FROM b) + * WHERE EXISTS (SELECT * + * FROM b + * WHERE b.id = a.id) * }}} */ -case class ListQuery( +case class Exists( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, - exprId: ExprId = NamedExpression.newExprId, - childOutputs: Seq[Attribute] = Seq.empty) - extends SubqueryExpression(plan, children, exprId) with Unevaluable { - override def dataType: DataType = if (childOutputs.length > 1) { - childOutputs.toStructType - } else { - childOutputs.head.dataType - } - override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { override def nullable: Boolean = false - override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def toString: String = s"list#${exprId.id} $conditionString" + override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) + override def toString: String = s"exists#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { - ListQuery( + Exists( plan.canonicalized, children.map(_.canonicalized), - ExprId(0), - childOutputs.map(_.canonicalized.asInstanceOf[Attribute])) + ExprId(0)) } } /** - * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * A [[InSubquery]] expression defines a IN expression where the values are searched in the output + * of a subquery. * * For example (SQL): * {{{ * SELECT * * FROM a - * WHERE EXISTS (SELECT * - * FROM b - * WHERE b.id = a.id) + * WHERE a.id IN (SELECT id + * FROM b) * }}} */ -case class Exists( +case class InSubquery(values: Seq[Expression], plan: LogicalPlan, - children: Seq[Expression] = Seq.empty, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { - override def nullable: Boolean = false - override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id} $conditionString" + conditions: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId, + queryOutputs: Seq[Attribute] = Seq.empty) + extends SubqueryExpression(plan, conditions, exprId) with Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + @transient lazy val queryResultDataType = if (queryOutputs.length > 1) { + queryOutputs.toStructType + } else { + queryOutputs.head.dataType + } + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(queryResultDataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != queryOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${queryOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${queryOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(queryOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${queryOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override lazy val resolved: Boolean = childrenResolved && plan.resolved && queryOutputs.nonEmpty + override def children: Seq[Expression] = values ++ conditions + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN (list#${exprId.id} $conditionString)" + override def sql: String = + s"(${value.sql} IN (listquery(${conditions.map(_.sql).mkString(", ")})))" + override def withNewPlan(plan: LogicalPlan): InSubquery = copy(plan = plan) + override lazy val canonicalized: Expression = { - Exists( + InSubquery( + values.map(_.canonicalized), plan.canonicalized, - children.map(_.canonicalized), - ExprId(0)) + conditions.map(_.canonicalized), + ExprId(0), + queryOutputs.map(_.canonicalized.asInstanceOf[Attribute])) } + + override def prettyName: String = "in" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 251cadda02c2c..139d61dd359ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -505,7 +505,7 @@ object NullPropagation extends Rule[LogicalPlan] { // If the value expression is NULL then transform the In expression to null literal. case In(Literal(null, _), _) => Literal.create(null, BooleanType) - case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _, _, _, _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index e9b7a8b76e683..17e07af743333 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -91,12 +91,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => + case (p, InSubquery(values, sub, conditions, _, _)) => val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => + case (p, Not(InSubquery(values, sub, conditions, _, _))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -144,7 +144,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case InSubquery(values, ListQuery(sub, conditions, _, _)) => + case InSubquery(values, sub, conditions, _, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) @@ -256,9 +256,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) Exists(newPlan, newCond, exprId) - case ListQuery(sub, _, exprId, childOutputs) => + case InSubquery(values, sub, _, exprId, childOutputs) => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) - ListQuery(newPlan, newCond, exprId, childOutputs) + InSubquery(values, newPlan, newCond, exprId, childOutputs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index db504505c0ffc..b71454a6c6969 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1117,7 +1117,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) + invertIfNotDefined(InSubquery(getValueExpressions(e), plan(ctx.query))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index b9ae3d7c6d168..40282f4b3c500 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -528,7 +528,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), + Seq(a, Alias(InSubquery(Seq(a), LocalRelation(b)), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -537,13 +537,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), LocalRelation(b))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) val plan2 = Filter( - Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) + Or(Not(InSubquery(Seq(a), LocalRelation(b))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 74a8590b5eefe..38660ab1d4323 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} +import org.apache.spark.sql.catalyst.expressions.InSubquery import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -34,7 +34,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter( - InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) + InSubquery(Seq(a), Project(Seq(UnresolvedAttribute("a")), t2)), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 82a10254d846d..a01defeba11cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -912,12 +912,12 @@ class FilterPushdownSuite extends PlanTest { val queryPlan = x .join(z) .where(("x.b".attr === "z.b".attr) && - ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) + ("x.a".attr > 1 || "z.c".attr.in(w.select("w.d".attr)))) .analyze val expectedPlan = x .join(z, Inner, Some("x.b".attr === "z.b".attr)) - .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) + .where("x.a".attr > 1 || "z.c".attr.in(w.select("w.d".attr))) .analyze val optimized = Optimize.execute(queryPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 8a5a55146726e..e269f67d47e4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} +import org.apache.spark.sql.catalyst.expressions.InSubquery import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) + .where(InSubquery(Seq('a), correlatedSubquery)) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 6b3739c372c3a..b24b02fe5f670 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.ListQuery import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -41,7 +40,7 @@ class RewriteSubquerySuite extends PlanTest { val relation = LocalRelation('a.int, 'b.int) val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int) - val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a) + val query = relation.where('a.in(relInSubquery.select('x))).select('a) val optimized = Optimize.execute(query.analyze) val correctAnswer = relation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 8f92d4c90dc49..f285b62bc076f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - InSubquery(Seq('a), ListQuery(table("c").select('b)))) + InSubquery(Seq('a), table("c").select('b))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6241d5cbb1d25..ae3e81a408dbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -52,8 +52,8 @@ trait PlanTestBase extends PredicateHelper { self: Suite => s.copy(exprId = ExprId(0)) case e: Exists => e.copy(exprId = ExprId(0)) - case l: ListQuery => - l.copy(exprId = ExprId(0)) + case i: InSubquery => + i.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => From a6114a655305f318230bf1bbd25394e952793a94 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 3 Aug 2018 23:37:30 +0200 Subject: [PATCH 16/17] Revert "remove ListQuery" This reverts commit cb3467be92c1f7c8ed313ff1b37a00f82d59eda6. --- .../sql/catalyst/analysis/Analyzer.scala | 13 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../catalyst/analysis/timeZoneAnalysis.scala | 4 +- .../spark/sql/catalyst/dsl/package.scala | 11 +- .../sql/catalyst/expressions/predicates.scala | 63 ++++++++- .../sql/catalyst/expressions/subquery.scala | 126 +++++------------- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 10 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 6 +- .../analysis/ResolveSubquerySuite.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 4 +- .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../optimizer/RewriteSubquerySuite.scala | 3 +- .../parser/ExpressionParserSuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 4 +- 16 files changed, 134 insertions(+), 128 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bd0da3029fc07..fc442c7fb5383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1422,15 +1422,16 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case i @ InSubquery(values, _, _, exprId, _) - if values.forall(_.resolved) && !i.resolved => - val expr = resolveSubQuery(i, plans)((plan, exprs) => { - InSubquery(values, plan, exprs, exprId, plan.output) + case InSubquery(values, l @ ListQuery(_, _, exprId, _)) + if values.forall(_.resolved) && !l.resolved => + val expr = resolveSubQuery(l, plans)((plan, exprs) => { + ListQuery(plan, exprs, exprId, plan.output) }) val subqueryOutput = expr.plan.output + val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) if (values.length != subqueryOutput.length) { throw new AnalysisException( - s"""Cannot analyze ${expr.sql}. + s"""Cannot analyze ${resolvedIn.sql}. |The number of columns in the left hand side of an IN subquery does not match the |number of columns in the output of subquery. |#columns in left hand side: ${values.length} @@ -1440,7 +1441,7 @@ class Analyzer( |Right side columns: |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) } - expr + resolvedIn } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 41e2cede4645d..648aa9ee8fa0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -456,7 +456,7 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ InSubquery(lhs, sub, children, exprId, _) + case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) if !i.resolved && lhs.length == sub.output.length => // LHS is the value expressions of IN subquery. // RHS is the subquery output. @@ -480,7 +480,7 @@ object TypeCoercion { } val newSub = Project(castedRhs, sub) - InSubquery(newLhs, newSub, children, exprId, newSub.output) + InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) } else { i } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index cb9d1c67e5ec9..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, InSubquery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -34,7 +34,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { // the types between the value expression and list query expression of IN expression. // We need to subject the subquery plan through ResolveTimeZone again to setup timezone // information for time zone aware expressions. - case e: InSubquery => e.withNewPlan(apply(e.plan)) + case e: ListQuery => e.withNewPlan(apply(e.plan)) } override def apply(plan: LogicalPlan): LogicalPlan = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 69f1ee8e4538f..eb7907e5abe35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,11 +88,12 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) - - def in(plan: LogicalPlan): Expression = expr match { - case c: CreateNamedStruct => InSubquery(c.valExprs, plan) - case other => InSubquery(Seq(other), plan) + def in(list: Expression*): Expression = list match { + case Seq(l: ListQuery) => expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, l) + case other => InSubquery(Seq(other), l) + } + case _ => In(expr, list) } def like(other: Expression): Expression = Like(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index dadfc7735c388..149bd79278a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -104,7 +104,8 @@ trait PredicateHelper { protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { // Non-deterministic expressions are not allowed as join conditions. case e if !e.deterministic => false - case _: InSubquery | _: Exists => + case _: ListQuery | _: Exists => + // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. // It cannot be evaluated as part of a Join operator. @@ -137,6 +138,66 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override def children: Seq[Expression] = values :+ query + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ($query)" + override def sql: String = s"(${value.sql} IN (${query.sql}))" +} + /** * Evaluates to `true` if `list` contains `value`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 5eb7c158a7b27..fc1caed84e272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -75,7 +73,7 @@ object SubqueryExpression { */ def hasInOrExistsSubquery(e: Expression): Boolean = { e.find { - case _: InSubquery | _: Exists => true + case _: ListQuery | _: Exists => true case _ => false }.isDefined } @@ -272,121 +270,65 @@ object ScalarSubquery { } /** - * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * A [[ListQuery]] expression defines the query which we want to search in an IN subquery + * expression. It should and can only be used in conjunction with an IN expression. * * For example (SQL): * {{{ * SELECT * * FROM a - * WHERE EXISTS (SELECT * - * FROM b - * WHERE b.id = a.id) + * WHERE a.id IN (SELECT id + * FROM b) * }}} */ -case class Exists( +case class ListQuery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { + exprId: ExprId = NamedExpression.newExprId, + childOutputs: Seq[Attribute] = Seq.empty) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = if (childOutputs.length > 1) { + childOutputs.toStructType + } else { + childOutputs.head.dataType + } + override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty override def nullable: Boolean = false - override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id} $conditionString" + override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) + override def toString: String = s"list#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { - Exists( + ListQuery( plan.canonicalized, children.map(_.canonicalized), - ExprId(0)) + ExprId(0), + childOutputs.map(_.canonicalized.asInstanceOf[Attribute])) } } /** - * A [[InSubquery]] expression defines a IN expression where the values are searched in the output - * of a subquery. + * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. * * For example (SQL): * {{{ * SELECT * * FROM a - * WHERE a.id IN (SELECT id - * FROM b) + * WHERE EXISTS (SELECT * + * FROM b + * WHERE b.id = a.id) * }}} */ -case class InSubquery(values: Seq[Expression], +case class Exists( plan: LogicalPlan, - conditions: Seq[Expression] = Seq.empty, - exprId: ExprId = NamedExpression.newExprId, - queryOutputs: Seq[Attribute] = Seq.empty) - extends SubqueryExpression(plan, conditions, exprId) with Predicate with Unevaluable { - - @transient lazy val value: Expression = if (values.length > 1) { - CreateNamedStruct(values.zipWithIndex.flatMap { - case (v: NamedExpression, _) => Seq(Literal(v.name), v) - case (v, idx) => Seq(Literal(s"_$idx"), v) - }) - } else { - values.head - } - - @transient lazy val queryResultDataType = if (queryOutputs.length > 1) { - queryOutputs.toStructType - } else { - queryOutputs.head.dataType - } - - override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = !DataType.equalsStructurally(queryResultDataType, value.dataType, - ignoreNullability = true) - if (mismatchOpt) { - if (values.length != queryOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${values.length}. - |#columns in right hand side: ${queryOutputs.length}. - |Left side columns: - |[${values.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${queryOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = values.zip(queryOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${values.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${queryOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") - } - } - - override lazy val resolved: Boolean = childrenResolved && plan.resolved && queryOutputs.nonEmpty - override def children: Seq[Expression] = values ++ conditions - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"$value IN (list#${exprId.id} $conditionString)" - override def sql: String = - s"(${value.sql} IN (listquery(${conditions.map(_.sql).mkString(", ")})))" - override def withNewPlan(plan: LogicalPlan): InSubquery = copy(plan = plan) - + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) + override def toString: String = s"exists#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { - InSubquery( - values.map(_.canonicalized), + Exists( plan.canonicalized, - conditions.map(_.canonicalized), - ExprId(0), - queryOutputs.map(_.canonicalized.asInstanceOf[Attribute])) + children.map(_.canonicalized), + ExprId(0)) } - - override def prettyName: String = "in" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 139d61dd359ba..251cadda02c2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -505,7 +505,7 @@ object NullPropagation extends Rule[LogicalPlan] { // If the value expression is NULL then transform the In expression to null literal. case In(Literal(null, _), _) => Literal.create(null, BooleanType) - case InSubquery(Seq(Literal(null, _)), _, _, _, _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 17e07af743333..e9b7a8b76e683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -91,12 +91,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, InSubquery(values, sub, conditions, _, _)) => + case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(InSubquery(values, sub, conditions, _, _))) => + case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -144,7 +144,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case InSubquery(values, sub, conditions, _, _) => + case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) @@ -256,9 +256,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) Exists(newPlan, newCond, exprId) - case InSubquery(values, sub, _, exprId, childOutputs) => + case ListQuery(sub, _, exprId, childOutputs) => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) - InSubquery(values, newPlan, newCond, exprId, childOutputs) + ListQuery(newPlan, newCond, exprId, childOutputs) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b71454a6c6969..db504505c0ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1117,7 +1117,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(InSubquery(getValueExpressions(e), plan(ctx.query))) + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 40282f4b3c500..b9ae3d7c6d168 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -528,7 +528,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(InSubquery(Seq(a), LocalRelation(b)), "c")()), + Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -537,13 +537,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(InSubquery(Seq(a), LocalRelation(b))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) val plan2 = Filter( - Or(Not(InSubquery(Seq(a), LocalRelation(b))), c), LocalRelation(a, c)) + Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 38660ab1d4323..74a8590b5eefe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.InSubquery +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -34,7 +34,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter( - InSubquery(Seq(a), Project(Seq(UnresolvedAttribute("a")), t2)), t1) + InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index a01defeba11cd..82a10254d846d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -912,12 +912,12 @@ class FilterPushdownSuite extends PlanTest { val queryPlan = x .join(z) .where(("x.b".attr === "z.b".attr) && - ("x.a".attr > 1 || "z.c".attr.in(w.select("w.d".attr)))) + ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) .analyze val expectedPlan = x .join(z, Inner, Some("x.b".attr === "z.b".attr)) - .where("x.a".attr > 1 || "z.c".attr.in(w.select("w.d".attr))) + .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) .analyze val optimized = Optimize.execute(queryPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index e269f67d47e4a..8a5a55146726e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.InSubquery +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(InSubquery(Seq('a), correlatedSubquery)) + .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index b24b02fe5f670..6b3739c372c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.ListQuery import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -40,7 +41,7 @@ class RewriteSubquerySuite extends PlanTest { val relation = LocalRelation('a.int, 'b.int) val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int) - val query = relation.where('a.in(relInSubquery.select('x))).select('a) + val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a) val optimized = Optimize.execute(query.analyze) val correctAnswer = relation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index f285b62bc076f..8f92d4c90dc49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,7 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - InSubquery(Seq('a), table("c").select('b))) + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index ae3e81a408dbf..6241d5cbb1d25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -52,8 +52,8 @@ trait PlanTestBase extends PredicateHelper { self: Suite => s.copy(exprId = ExprId(0)) case e: Exists => e.copy(exprId = ExprId(0)) - case i: InSubquery => - i.copy(exprId = ExprId(0)) + case l: ListQuery => + l.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => From eb1dfb7e0873b8479ea54d223b7fde3dcefa4834 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 6 Aug 2018 10:11:56 +0200 Subject: [PATCH 17/17] address comment --- .../sql/catalyst/expressions/Canonicalize.scala | 2 -- .../sql/catalyst/optimizer/OptimizeInSuite.scala | 15 +++++++++++++++ .../catalyst/parser/ExpressionParserSuite.scala | 12 ++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 7541f527a52a8..fe6db8b344d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -87,8 +87,6 @@ object Canonicalize { case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) // order the list in the In operator - // In subqueries contain only one element of type ListQuery. So checking that the length > 1 - // we are not reordering In subqueries. case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) case _ => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 86522a6a54ed5..a36083b847043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") { + val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 8f92d4c90dc49..cdd734ab3df37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -155,6 +155,18 @@ class ExpressionParserSuite extends PlanTest { assertEqual( "a in (select b from c)", InSubquery(Seq('a), ListQuery(table("c").select('b)))) + + assertEqual( + "(a, b, c) in (select d, e, f from g)", + InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f)))) + + assertEqual( + "(a, b) in (select c from d)", + InSubquery(Seq('a, 'b), ListQuery(table("d").select('c)))) + + assertEqual( + "(a) in (select b from c)", + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") {