diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..cf17f59599968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,15 +218,24 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 + // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, + // TODO: we exclude them in this rule. + && !v.isInstanceOf[CreateNamedStructLike] + && !newList.head.isInstanceOf[CreateNamedStructLike]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..86522a6a54ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } }