diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72ac80e0a0a1..d4d45eb9fa10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -493,7 +493,7 @@ object TypeCoercion { i } - case i @ In(a, b) if b.exists(_.dataType != a.dataType) => + case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 176ea823b1fc..70785423f513 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -94,7 +94,10 @@ package object dsl { case c: CreateNamedStruct => InSubquery(c.valExprs, l) case other => InSubquery(Seq(other), l) } - case _ => In(expr, list) + case _ => expr match { + case c: CreateNamedStruct => In(c.valExprs, list) + case other => In(Seq(other), list) + } } def like(other: Expression): Expression = Like(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index fe6db8b344d3..527e016d5571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -87,7 +87,7 @@ object Canonicalize { case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) // order the list in the In operator - case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case In(values, list) if list.length > 1 => In(values, list.sortBy(_.hashCode())) case _ => e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2125340f38ee..7a4672731b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,9 +22,11 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,13 +140,12 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } -/** - * Evaluates to `true` if `values` are returned in `query`'s result set. - */ -case class InSubquery(values: Seq[Expression], query: ListQuery) - extends Predicate with Unevaluable { +trait InBase extends Predicate { + def values: Seq[Expression] + + @transient protected lazy val isMultiValued = values.length > 1 - @transient private lazy val value: Expression = if (values.length > 1) { + @transient lazy val value: Expression = if (isMultiValued) { CreateNamedStruct(values.zipWithIndex.flatMap { case (v: NamedExpression, _) => Seq(Literal(v.name), v) case (v, idx) => Seq(Literal(s"_$idx"), v) @@ -153,6 +154,28 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) values.head } + @transient lazy val checkNullGenCode: (ExprCode) => Block = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { + e => code"${e.isNull} || ${e.value}.anyNull()" + } else { + e => code"${e.isNull}" + } + } + + @transient lazy val checkNullEval: (Any) => Boolean = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { + input => input == null || input.asInstanceOf[InternalRow].anyNull + } else { + input => input == null + } + } +} + +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends InBase with Unevaluable { override def checkInputDataTypes(): TypeCheckResult = { if (values.length != query.childOutputs.length) { @@ -202,7 +225,12 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.", + usage = """ + expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr1` equals to any exprN. Otherwise, if + `expr` is a single value and it is null or any exprN is null or `expr` contains multiple + values and spark.sql.legacy.inOperator.falseForNullField is false and any of the exprN or + fields of the exprN is null it returns null, else it returns false. + """, arguments = """ Arguments: * expr1, expr2, expr3, ... - the arguments must be same type. @@ -219,7 +247,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) true """) // scalastyle:on line.size.limit -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase { require(list != null, "list should not be null") @@ -234,24 +262,29 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } - override def children: Seq[Expression] = value +: list + override def children: Seq[Expression] = values ++ list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) - override def nullable: Boolean = children.exists(_.nullable) + override def nullable: Boolean = if (isMultiValued && !SQLConf.get.inFalseForNullField) { + children.exists(_.nullable) || + list.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable)) + } else { + value.nullable || list.exists(_.nullable) + } override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) - if (evaluatedValue == null) { + if (checkNullEval(evaluatedValue)) { null } else { var hasNull = false list.foreach { e => val v = e.eval(input) - if (v == null) { + if (checkNullEval(v)) { hasNull = true } else if (ordering.equiv(v, evaluatedValue)) { return true @@ -283,7 +316,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val listCode = listGen.map(x => s""" |${x.code} - |if (${x.isNull}) { + |if (${checkNullGenCode(x)}) { | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; @@ -316,7 +349,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; - |if (!${valueGen.isNull}) { + |if (!(${checkNullGenCode(valueGen)})) { | $tmpResult = $NOT_MATCHED; | $javaDataType $valueArg = ${valueGen.value}; | do { @@ -339,37 +372,57 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { +case class InSet(values: Seq[Expression], hset: Set[Any]) extends InBase { require(hset != null, "hset could not be null") - override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" - @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + override def children: Seq[Expression] = values - override def nullable: Boolean = child.nullable || hasNull + @transient private[this] lazy val hasNull: Boolean = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { + hset.exists(checkNullEval) + } else { + hset.contains(null) + } + } - protected override def nullSafeEval(value: Any): Any = { - if (set.contains(value)) { - true - } else if (hasNull) { + override def nullable: Boolean = { + val isValueNullable = if (isMultiValued && !SQLConf.get.inFalseForNullField) { + values.exists(_.nullable) + } else { + value.nullable + } + isValueNullable || hasNull + } + + override def eval(input: InternalRow): Any = { + val inputValue = value.eval(input) + if (checkNullEval(inputValue)) { null } else { - false + if (set.contains(inputValue)) { + true + } else if (hasNull) { + null + } else { + false + } } } - @transient lazy val set: Set[Any] = child.dataType match { + @transient lazy val set: Set[Any] = value.dataType match { case _: AtomicType => hset case _: NullType => hset case _ => // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows - TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + TreeSet.empty(TypeUtils.getInterpretedOrdering(value.dataType)) ++ hset } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setTerm = ctx.addReferenceObj("set", set) - val childGen = child.genCode(ctx) + val childGen = value.genCode(ctx) val setIsNull = if (hasNull) { s"${ev.isNull} = !${ev.value};" } else { @@ -378,7 +431,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = code""" |${childGen.code} - |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${checkNullGenCode(childGen)}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); @@ -388,7 +441,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } override def sql: String = { - val valueSQL = child.sql + val valueSQL = value.sql val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") s"($valueSQL IN ($listSQL))" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index f8037588fa71..77c4de2e08f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -212,27 +212,33 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 1. Converts the predicate to false when the list is empty and * the value is not nullable. * 2. Removes literal repetitions. - * 3. Replaces [[In (value, seq[Literal])]] with optimized version - * [[InSet (value, HashSet[Literal])]] which is much faster. + * 3. Replaces [[In (values, seq[Literal])]] with optimized version + * [[InSet (values, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty => - // When v is not nullable, the following expression will be optimized + case i @ In(values, list) if list.isEmpty => + // When values are not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case expr @ In(v, list) if expr.inSetConvertible => + val isNotNull = if (SQLConf.get.inFalseForNullField) { + IsNotNull(i.value) + } else { + values.map(IsNotNull).reduce(And) + } + If(isNotNull, FalseLiteral, Literal(null, BooleanType)) + case expr @ In(values, list) if expr.inSetConvertible => + // if we have more than one element in the values, we have to skip this optimization val newList = ExpressionSet(list).toSeq if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStructLike] + && !expr.value.isInstanceOf[CreateNamedStructLike] && !newList.head.isInstanceOf[CreateNamedStructLike]) { - EqualTo(v, newList.head) + EqualTo(expr.value, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v, HashSet() ++ hSet) + InSet(values, HashSet() ++ hSet) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 @@ -527,7 +533,7 @@ object NullPropagation extends Rule[LogicalPlan] { } // If the value expression is NULL then transform the In expression to null literal. - case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ba0b72e747fc..7d313c11c43f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1120,7 +1120,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.IN if ctx.query != null => invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => - invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3eeefaedb1..f302952574ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -164,15 +164,14 @@ case class FilterEstimation(plan: Filter) extends Logging { case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: Attribute, expList) - if expList.forall(e => e.isInstanceOf[Literal]) => + case In(Seq(ar: Attribute), expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. // Here we convert In into InSet anyway, because they share the same processing logic. val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ar: Attribute, set) => + case InSet(Seq(ar: Attribute), set) => evaluateInSet(ar, set, update) // In current stage, we don't have advanced statistics such as sketches or histograms. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b699707d8523..ca8afe078ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1561,6 +1561,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_IN_FALSE_FOR_NULL_FIELD = + buildConf("spark.sql.legacy.inOperator.falseForNullField") + .internal() + .doc("When set to true, the IN operator returns false when comparing multiple values " + + "containing a null. When set to false (default), it returns null, instead. This is " + + "important especially when using NOT IN as in the second case, it filters out the rows " + + "when a null is present in a field; while in the first one, those rows are returned.") + .booleanConf + .createWithDefault(false) + val LEGACY_INTEGRALDIVIDE_RETURN_LONG = buildConf("spark.sql.legacy.integralDivide.returnBigint") .doc("If it is set to true, the div operator returns always a bigint. This behavior was " + "inherited from Hive. Otherwise, the return type is the data type of the operands.") @@ -1978,6 +1988,8 @@ class SQLConf extends Serializable with Logging { def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + def inFalseForNullField: Boolean = getConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD) + def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f9facbb71a4e..24684ff57f84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -280,21 +280,22 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: different types in inlist but can be converted to a common type") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + val plan = Project( + Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0eba1c537d67..048798d8aef6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1432,16 +1432,16 @@ class TypeCoercionSuite extends AnalysisTest { // InConversion val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, - In(UnresolvedAttribute("a"), Seq(Literal(1))), - In(UnresolvedAttribute("a"), Seq(Literal(1))) + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))), + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))) ) ruleTest(inConversion, - In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), - In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))) ) ruleTest(inConversion, - In(Literal("a"), Seq(Literal(1), Literal("b"))), - In(Cast(Literal("a"), StringType), + In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))), + In(Seq(Cast(Literal("a"), StringType)), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index b376108399c1..51464332dbc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac checkAnswer(tbl2, Seq.empty, Set(part1, part2)) checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 28e6940f3cca..b78d23e3472e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(idAttr, Seq(Literal(1), Literal(2))) - val in2 = In(idAttr, Seq(Literal(2), Literal(1))) - val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2))) + val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1))) + val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ac76b17ef476..09e91d36d90b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -22,11 +22,12 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -125,32 +126,34 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("basic IN predicate test") { - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq.empty), null) + checkEvaluation(In(Seq(Literal(1)), Seq.empty), false) + checkEvaluation(In(Seq(Literal(1)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation( + In(Seq(Literal(1)), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkEvaluation( + In(Seq(Literal(2)), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(3)), Seq(Literal(1), Literal(2))), false) checkEvaluation( - And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + And(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), In(Seq(Literal(2)), Seq(Literal(1), Literal(2)))), true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(ns, Seq(ns)), null) - checkEvaluation(In(Literal("a"), Seq(ns)), null) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + checkEvaluation(In(Seq(ns), Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(Seq(ns), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("a")), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("^Ba*n"), ns)), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^n"))), false) } @@ -178,7 +181,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } val input = inputData.map(NonFoldableLiteral.create(_, dataType)) - val expected = if (inputData(0) == null) { + val expected = if (inputData.head == null) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { true @@ -187,7 +190,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(input(0), input.slice(1, 10)), expected) + checkEvaluation(In(Seq(input(0)), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => @@ -212,9 +215,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } // Struct types: + val atomicAndComplexTypes = atomicTypes ++ atomicTypes.map { t => + StructType(StructField("f1", t) :: StructField("f2", t) :: Nil) + } for ( - colOneType <- atomicTypes; - colTwoType <- atomicTypes; + colOneType <- atomicAndComplexTypes; + colTwoType <- atomicAndComplexTypes; nullable <- Seq(true, false)) { val structType = StructType( StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) @@ -243,12 +249,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22501: In should not generate codes beyond 64KB") { val N = 3000 val sets = (1 to N).map(i => Literal(i.toDouble)) - checkEvaluation(In(Literal(1.0D), sets), true) + checkEvaluation(In(Seq(Literal(1.0D)), sets), true) } test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() - In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + In(Seq(Literal(1.0D)), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } @@ -259,13 +265,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val two = Literal(2) val three = Literal(3) val nl = Literal(null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + checkEvaluation(InSet(Seq(one), hS), true) + checkEvaluation(InSet(Seq(two), hS), true) + checkEvaluation(InSet(Seq(two), nS), true) + checkEvaluation(InSet(Seq(three), hS), false) + checkEvaluation(InSet(Seq(three), nS), null) + checkEvaluation(InSet(Seq(nl), hS), null) + checkEvaluation(InSet(Seq(nl), nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) @@ -289,7 +295,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) + checkEvaluation(InSet(Seq(input(0)), inputData.slice(1, 10).toSet), expected) } } @@ -439,7 +445,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: InSet should not use global variables") { val ctx = new CodegenContext - InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) + InSet(Seq(Literal(1)), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 641c89873dcc..09d8f2e4d1ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(Literal(1), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal(1)), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index a36083b84704..65cf2db6e684 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -24,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD import org.apache.spark.sql.types._ @@ -45,9 +47,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +58,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +71,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,13 +81,13 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .where(In(Seq(UnresolvedAttribute("a")), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) + .where(InSet(Seq(UnresolvedAttribute("a")), (1 to 11).toSet)) .analyze comparePlans(optimized, correctAnswer) @@ -94,13 +96,15 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In( + Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In( + Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) @@ -109,7 +113,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal.create(null, NullType)), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -136,11 +140,55 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: (notNull, NULL) IN (expr1, ..., exprN) transformed to Filter(null)") { + Seq(("true", false), ("false", null)).foreach { case (legacyInFalseForNull, expected) => + withSQLConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD.key -> legacyInFalseForNull) { + val originalQuery = + testRelation + .where(In(Seq(Literal.create(null, IntegerType), Literal(1)), + Seq(Literal.create( + InternalRow.fromSeq(Seq(1, 2)), + StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType))))))) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(expected, BooleanType)) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + } + + test("OptimizedIn test: (notNull, NULL) IN () transformed to IsNotNull filters") { + val legacyExpected = If( + Literal(true, BooleanType), Literal(false, BooleanType), Literal(null, BooleanType)) + val newExpected = If( + Literal(false, BooleanType), Literal(false, BooleanType), Literal(null, BooleanType)) + Seq(("true", legacyExpected), ("false", newExpected)).foreach { + case (legacyInFalseForNull, expected) => + withSQLConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD.key -> legacyInFalseForNull) { + val originalQuery = + testRelation + .where(In( + Seq(Literal.create(null, IntegerType), UnresolvedAttribute("b")), Seq.empty)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(expected) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = testRelation - .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) + .where(In( + Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -156,7 +204,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(Literal.create(null, StringType), + .select(In(Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -171,7 +219,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -194,7 +242,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: one element in list gets transformed to EqualTo.") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1)))) .analyze val optimized = Optimize.execute(originalQuery) @@ -210,7 +258,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(Literal("a"), Nil)) + .where(In(Seq(Literal("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) @@ -226,7 +274,7 @@ class OptimizeInSuite extends PlanTest { "when value is nullable") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Nil)) + .where(In(Seq(UnresolvedAttribute("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 47bfa6256958..06d36716f022 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -351,7 +351,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IN (3, 4, 5)") { validateEstimatedStats( - Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Filter(InSet(Seq(attrInt), Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) @@ -359,7 +359,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with all zeros") { validateEstimatedStats( - Filter(InSet(attrString, Set(3, 4, 5)), + Filter(InSet(Seq(attrString), Set(3, 4, 5)), StatsTestPlan(Seq(attrString), 0, AttributeMap(Seq(attrString -> ColumnStat(distinctCount = Some(0), min = None, max = None, @@ -370,7 +370,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with string") { validateEstimatedStats( - Filter(InSet(attrString, Set("A0")), + Filter(InSet(Seq(attrString), Set("A0")), StatsTestPlan(Seq(attrString), 10, AttributeMap(Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, @@ -382,14 +382,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Filter(Not(InSet(Seq(attrInt), Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( - Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Filter(InSet(Seq(attrBool), Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) @@ -440,7 +440,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Filter(In(Seq(attrDate), Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = Some(3), min = Some(d20170103), max = Some(d20170105), @@ -509,7 +509,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Filter(InSet(Seq(attrInt), Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ae27690f2e5b..2e2f28fcd5be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -793,7 +793,13 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + def isin(list: Any*): Column = withExpr { + val inValues = expr match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + In(inValues, list.map(lit(_).expr)) + } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 196d057c2de1..efb772c6e123 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -235,7 +235,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(a: AttributeReference, list: Seq[Expression]) + case In(Seq(a: AttributeReference), list: Seq[Expression]) if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c6000442fae7..13727b6cec6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -468,14 +468,14 @@ object DataSourceStrategy { case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - case expressions.InSet(a: Attribute, set) => + case expressions.InSet(Seq(a: Attribute), set) => val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + case expressions.In(Seq(a: Attribute), list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, hSet.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index fe27b78bf336..4812fa97464c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,10 +86,10 @@ object FileSourceStrategy extends Strategy with Logging { expr match { case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => getBucketSetFromValue(a, v) - case expressions.In(a: Attribute, list) + case expressions.In(Seq(a: Attribute), list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) - case expressions.InSet(a: Attribute, hset) + case expressions.InSet(Seq(a: Attribute), hset) if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql index 8eea84f4f527..9e01c950b9ef 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -14,8 +14,27 @@ CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES -- Case 1 (not possible to write a literal with no rows, so we ignore it.) -- (subquery is empty -> row is returned) --- Cases 2, 3 and 4 are currently broken, so I have commented them out here. --- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. +set spark.sql.legacy.inOperator.falseForNullField=false; + + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))); + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); -- Case 5 -- (one null column with no match -> row is returned) @@ -37,3 +56,26 @@ SELECT * FROM m WHERE b = 5.0 -- Matches (4, 5.0) AND (a, b) NOT IN ((2, 3.0)); + + +set spark.sql.legacy.inOperator.falseForNullField=true; + + -- Case 2 (old behavior) + -- (subquery contains a row with null in all columns -> rows returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))); + + -- Case 3 (old behavior) + -- (probe-side columns are all null -> row returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); + + -- Case 4 (old behavior) + -- (one column null, other column matches a row in the subquery result -> row returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out index a16e98af9a41..ef6815096aaf 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 12 -- !query 0 @@ -16,39 +16,133 @@ struct<> -- !query 1 +set spark.sql.legacy.inOperator.falseForNullField=false +-- !query 1 schema +struct +-- !query 1 output +spark.sql.legacy.inOperator.falseForNullField false + + +-- !query 2 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 -- Case 5 -- (one null column with no match -> row is returned) SELECT * FROM m WHERE b = 1.0 -- Matches (null, 1.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 1 schema +-- !query 5 schema struct --- !query 1 output -NULL 1 +-- !query 5 output --- !query 2 + +-- !query 6 -- Case 6 -- (no null columns with match -> row not returned) SELECT * FROM m WHERE b = 3.0 -- Matches (2, 3.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 2 schema +-- !query 6 schema struct --- !query 2 output +-- !query 6 output --- !query 3 +-- !query 7 -- Case 7 -- (no null columns with no match -> row is returned) SELECT * FROM m WHERE b = 5.0 -- Matches (4, 5.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 3 schema +-- !query 7 schema struct --- !query 3 output +-- !query 7 output +4 5 + + +-- !query 8 +set spark.sql.legacy.inOperator.falseForNullField=true +-- !query 8 schema +struct +-- !query 8 output +spark.sql.legacy.inOperator.falseForNullField true + + +-- !query 9 +-- Case 2 (old behavior) + -- (subquery contains a row with null in all columns -> rows returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))) +-- !query 9 schema +struct +-- !query 9 output +2 3 4 5 +NULL 1 + + +-- !query 10 +-- Case 3 (old behavior) + -- (probe-side columns are all null -> row returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 10 schema +struct +-- !query 10 output +NULL NULL + + +-- !query 11 +-- Case 4 (old behavior) + -- (one column null, other column matches a row in the subquery result -> row returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 11 schema +struct +-- !query 11 output +NULL 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index efc2f20a907f..0610f51b721a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -459,7 +459,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), - Seq(In(attribute, Nil)), testRelation) + Seq(In(Seq(attribute), Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f20aded169e4..1c9e84eefe47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -49,9 +49,10 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) - testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter( + InSet(Seq(attrInt), Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) - testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a2bc651bb2bd..90ead5de5a0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -171,7 +171,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { df) // Case 4: InSet - val inSetExpr = expressions.InSet($"j".expr, + val inSetExpr = expressions.InSet(Seq($"j".expr), Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr)) checkPrunedAnswers( bucketSpec, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bc9d4cd7f418..daa242949e8b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -685,11 +685,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + case In(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableLiterals(values)) if useAdvanced => Some(convertInToOr(name, values)) - case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + case InSet(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableValues(values)) if useAdvanced => Some(convertInToOr(name, values))