From 54ee21ad903827baf1117356f692370225c8662a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 7 Aug 2018 17:52:17 +0200 Subject: [PATCH 01/10] [SPARK-24395][SQL] IN operator should return NULL when comparing struct with NULL fields --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/expressions/predicates.scala | 29 ++++- .../apache/spark/sql/internal/SQLConf.scala | 12 ++ ...not-in-unit-tests-multi-column-literal.sql | 44 +++++++- ...in-unit-tests-multi-column-literal.sql.out | 106 ++++++++++++++++-- 5 files changed, 175 insertions(+), 17 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a1e019cbec4d..7dbbbaf90113 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1876,6 +1876,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - In version 2.3 and earlier, the IN operator returns `false` when comparing structs with null fields; since 2.4, by default Spark returns `null` in this scenario in compliance to other RDBMS behavior (therefore NOT IN filters out the rows). The previous behavior can be restored switching `spark.sql.legacy.inOperator.falseForNullField` to `true`. - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuaration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 149bd79278a5..2c21342c3abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,9 +22,11 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -240,20 +242,25 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) - override def nullable: Boolean = children.exists(_.nullable) + override def nullable: Boolean = value.dataType match { + case _: StructType if !SQLConf.get.inFalseForNullField => + children.exists(_.nullable) || + children.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable)) + case _ => children.exists(_.nullable) + } override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) - if (evaluatedValue == null) { + if (checkNullEval(evaluatedValue)) { null } else { var hasNull = false list.foreach { e => val v = e.eval(input) - if (v == null) { + if (checkNullEval(v)) { hasNull = true } else if (ordering.equiv(v, evaluatedValue)) { return true @@ -267,6 +274,18 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } + @transient lazy val checkNullGenCode: (ExprCode) => Block = value.dataType match { + case _: StructType if !SQLConf.get.inFalseForNullField => + e => code"${e.isNull} || ${e.value}.anyNull()" + case _ => e => code"${e.isNull}" + } + + @transient lazy val checkNullEval: (Any) => Boolean = value.dataType match { + case _: StructType if !SQLConf.get.inFalseForNullField => + input => input == null || input.asInstanceOf[InternalRow].anyNull + case _ => input => input == null + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) @@ -285,7 +304,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val listCode = listGen.map(x => s""" |${x.code} - |if (${x.isNull}) { + |if (${checkNullGenCode(x)}) { | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; @@ -318,7 +337,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; - |if (!${valueGen.isNull}) { + |if (!(${checkNullGenCode(valueGen)})) { | $tmpResult = $NOT_MATCHED; | $javaDataType $valueArg = ${valueGen.value}; | do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 67c3abb80c2c..672028286dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1476,6 +1476,16 @@ object SQLConf { "are performed before any UNION, EXCEPT and MINUS operations.") .booleanConf .createWithDefault(false) + + val LEGACY_IN_FALSE_FOR_NULL_FIELD = + buildConf("spark.sql.legacy.inOperator.falseForNullField") + .internal() + .doc("When set to true, the IN operator returns false when comparing literal structs " + + "containing a null field. When set to false (default), it returns null, instead. This is " + + "important especially when using NOT IN as in the second case, it filters out the rows " + + "when a null is present in a filed; while in the first one, those rows are returned.") + .booleanConf + .createWithDefault(false) } /** @@ -1873,6 +1883,8 @@ class SQLConf extends Serializable with Logging { def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + def inFalseForNullField: Boolean = getConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql index 8eea84f4f527..bf2fbcaff00a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -14,8 +14,25 @@ CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES -- Case 1 (not possible to write a literal with no rows, so we ignore it.) -- (subquery is empty -> row is returned) --- Cases 2, 3 and 4 are currently broken, so I have commented them out here. --- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))); + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); -- Case 5 -- (one null column with no match -> row is returned) @@ -37,3 +54,26 @@ SELECT * FROM m WHERE b = 5.0 -- Matches (4, 5.0) AND (a, b) NOT IN ((2, 3.0)); + + +set spark.sql.legacy.inOperator.falseForNullField=true; + + -- Case 2 (old behavior) + -- (subquery contains a row with null in all columns -> rows returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))); + + -- Case 3 (old behavior) + -- (probe-side columns are all null -> row returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); + + -- Case 4 (old behavior) + -- (one column null, other column matches a row in the subquery result -> row returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out index a16e98af9a41..67e5dcd0a0b4 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 11 -- !query 0 @@ -16,39 +16,125 @@ struct<> -- !query 1 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))) +-- !query 1 schema +struct +-- !query 1 output + + + +-- !query 2 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 -- Case 5 -- (one null column with no match -> row is returned) SELECT * FROM m WHERE b = 1.0 -- Matches (null, 1.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 1 schema +-- !query 4 schema struct --- !query 1 output -NULL 1 +-- !query 4 output --- !query 2 + +-- !query 5 -- Case 6 -- (no null columns with match -> row not returned) SELECT * FROM m WHERE b = 3.0 -- Matches (2, 3.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 2 schema +-- !query 5 schema struct --- !query 2 output +-- !query 5 output --- !query 3 +-- !query 6 -- Case 7 -- (no null columns with no match -> row is returned) SELECT * FROM m WHERE b = 5.0 -- Matches (4, 5.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 3 schema +-- !query 6 schema struct --- !query 3 output +-- !query 6 output 4 5 + + +-- !query 7 +set spark.sql.legacy.inOperator.falseForNullField=true +-- !query 7 schema +struct +-- !query 7 output +spark.sql.legacy.inOperator.falseForNullField true + + +-- !query 8 +-- Case 2 (old behavior) + -- (subquery contains a row with null in all columns -> rows returned) +SELECT * +FROM m +WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))) +-- !query 8 schema +struct +-- !query 8 output +2 3 +4 5 +NULL 1 + + +-- !query 9 +-- Case 3 (old behavior) + -- (probe-side columns are all null -> row returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 9 schema +struct +-- !query 9 output +NULL NULL + + +-- !query 10 +-- Case 4 (old behavior) + -- (one column null, other column matches a row in the subquery result -> row returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) +-- !query 10 schema +struct +-- !query 10 output +NULL 1 From 5974aeafcaf2500b8c02c122eb8d71911f8de3fd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 8 Aug 2018 12:51:45 +0200 Subject: [PATCH 02/10] fix ut --- .../catalyst/expressions/PredicateSuite.scala | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ac76b17ef476..1bf45ce94f12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -22,11 +22,12 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -155,7 +156,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("IN with different types") { - def testWithRandomDataGeneration(dataType: DataType, nullable: Boolean): Unit = { + def testWithRandomDataGeneration(dataType: DataType, + nullable: Boolean, + legacyNullHandling: Boolean = false): Unit = { + def isNull(e: Any): Boolean = { + if (!legacyNullHandling && dataType.isInstanceOf[StructType]) { + e == null || e.asInstanceOf[Row].anyNull + } else { + e == null + } + } val maybeDataGen = RandomDataGenerator.forType(dataType, nullable = nullable) // Actually we won't pass in unsupported data types, this is a safety check. val dataGen = maybeDataGen.getOrElse( @@ -178,11 +188,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } val input = inputData.map(NonFoldableLiteral.create(_, dataType)) - val expected = if (inputData(0) == null) { + val expected = if (isNull(inputData.head)) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { true - } else if (inputData.slice(1, 10).contains(null)) { + } else if (inputData.slice(1, 10).exists(isNull)) { null } else { false @@ -218,7 +228,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { nullable <- Seq(true, false)) { val structType = StructType( StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) - testWithRandomDataGeneration(structType, nullable) + if (nullable) { + Seq("true", "false").foreach { legacyNullHandling => + withSQLConf((SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD.key, legacyNullHandling)) { + testWithRandomDataGeneration(structType, nullable, legacyNullHandling.toBoolean) + } + } + } } // Map types: not supported From 9b2884298784484f7e42f3ffc913fff270b8c641 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 9 Aug 2018 09:59:18 +0200 Subject: [PATCH 03/10] add tests for nested struct --- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bf45ce94f12..4b2f5e1254e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -222,9 +222,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } // Struct types: + val atomicAndComplexTypes = atomicTypes ++ atomicTypes.map { t => + StructType(StructField("f1", t) :: StructField("f2", t) :: Nil) + } for ( - colOneType <- atomicTypes; - colTwoType <- atomicTypes; + colOneType <- atomicAndComplexTypes; + colTwoType <- atomicAndComplexTypes; nullable <- Seq(true, false)) { val structType = StructType( StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) From 655eaa4a439171beb341f5b1371e36e311594adf Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 9 Aug 2018 10:18:41 +0200 Subject: [PATCH 04/10] update IN doc --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2c21342c3abe..df5d34e41756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -206,7 +206,11 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.", + usage = """ + expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN. Otherwise, if + spark.sql.legacy.inOperator.falseForNullField is false and any of the elements or fields of + the elements is null it returns null, else it returns false. + """, arguments = """ Arguments: * expr1, expr2, expr3, ... - the arguments must be same type. From cf6b3b005dcd333d1853c0a54de7a03ce6acbd34 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 8 Oct 2018 12:41:51 +0200 Subject: [PATCH 05/10] change default according to comment --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5018c23e81ba..1e51906517d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1569,7 +1569,7 @@ object SQLConf { "important especially when using NOT IN as in the second case, it filters out the rows " + "when a null is present in a filed; while in the first one, those rows are returned.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val LEGACY_INTEGRALDIVIDE_RETURN_LONG = buildConf("spark.sql.legacy.integralDivide.returnBigint") .doc("If it is set to true, the div operator returns always a bigint. This behavior was " + From 59ff46bb1daf49c03343ba7074d0d9798623a89b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 9 Oct 2018 10:38:15 +0200 Subject: [PATCH 06/10] fix for new default value --- docs/sql-programming-guide.md | 1 - .../apache/spark/sql/internal/SQLConf.scala | 4 +- ...not-in-unit-tests-multi-column-literal.sql | 2 + ...in-unit-tests-multi-column-literal.sql.out | 68 +++++++++++-------- 4 files changed, 42 insertions(+), 33 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ed73c5dfce91..a1d7b1108bf7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1892,7 +1892,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - - In version 2.3 and earlier, the IN operator returns `false` when comparing structs with null fields; since 2.4, by default Spark returns `null` in this scenario in compliance to other RDBMS behavior (therefore NOT IN filters out the rows). The previous behavior can be restored switching `spark.sql.legacy.inOperator.falseForNullField` to `true`. - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1e51906517d4..171fa8334c1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1564,8 +1564,8 @@ object SQLConf { val LEGACY_IN_FALSE_FOR_NULL_FIELD = buildConf("spark.sql.legacy.inOperator.falseForNullField") .internal() - .doc("When set to true, the IN operator returns false when comparing literal structs " + - "containing a null field. When set to false (default), it returns null, instead. This is " + + .doc("When set to true (default), the IN operator returns false when comparing literal " + + "structs containing a null field. When set to false, it returns null, instead. This is " + "important especially when using NOT IN as in the second case, it filters out the rows " + "when a null is present in a filed; while in the first one, those rows are returned.") .booleanConf diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql index bf2fbcaff00a..9e01c950b9ef 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -14,6 +14,8 @@ CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES -- Case 1 (not possible to write a literal with no rows, so we ignore it.) -- (subquery is empty -> row is returned) +set spark.sql.legacy.inOperator.falseForNullField=false; + -- Case 2 -- (subquery contains a row with null in all columns -> row not returned) SELECT * diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out index 67e5dcd0a0b4..ef6815096aaf 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 12 -- !query 0 @@ -16,125 +16,133 @@ struct<> -- !query 1 +set spark.sql.legacy.inOperator.falseForNullField=false +-- !query 1 schema +struct +-- !query 1 output +spark.sql.legacy.inOperator.falseForNullField false + + +-- !query 2 -- Case 2 -- (subquery contains a row with null in all columns -> row not returned) SELECT * FROM m WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))) --- !query 1 schema +-- !query 2 schema struct --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 -- Case 3 -- (probe-side columns are all null -> row not returned) SELECT * FROM m WHERE a IS NULL AND b IS NULL -- Matches only (null, null) AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) --- !query 2 schema +-- !query 3 schema struct --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 -- Case 4 -- (one column null, other column matches a row in the subquery result -> row not returned) SELECT * FROM m WHERE b = 1.0 -- Matches (null, 1.0) AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output --- !query 4 +-- !query 5 -- Case 5 -- (one null column with no match -> row is returned) SELECT * FROM m WHERE b = 1.0 -- Matches (null, 1.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output --- !query 5 +-- !query 6 -- Case 6 -- (no null columns with match -> row not returned) SELECT * FROM m WHERE b = 3.0 -- Matches (2, 3.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output --- !query 6 +-- !query 7 -- Case 7 -- (no null columns with no match -> row is returned) SELECT * FROM m WHERE b = 5.0 -- Matches (4, 5.0) AND (a, b) NOT IN ((2, 3.0)) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 4 5 --- !query 7 +-- !query 8 set spark.sql.legacy.inOperator.falseForNullField=true --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output spark.sql.legacy.inOperator.falseForNullField true --- !query 8 +-- !query 9 -- Case 2 (old behavior) -- (subquery contains a row with null in all columns -> rows returned) SELECT * FROM m WHERE (a, b) NOT IN ((CAST (null AS INT), CAST (null AS DECIMAL(2, 1)))) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 2 3 4 5 NULL 1 --- !query 9 +-- !query 10 -- Case 3 (old behavior) -- (probe-side columns are all null -> row returned) SELECT * FROM m WHERE a IS NULL AND b IS NULL -- Matches only (null, null) AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output NULL NULL --- !query 10 +-- !query 11 -- Case 4 (old behavior) -- (one column null, other column matches a row in the subquery result -> row returned) SELECT * FROM m WHERE b = 1.0 -- Matches (null, 1.0) AND (a, b) NOT IN ((0, 1.0), (2, 3.0), (4, CAST(null AS DECIMAL(2, 1)))) --- !query 10 schema +-- !query 11 schema struct --- !query 10 output +-- !query 11 output NULL 1 From 321b9a8b9011eacaca0a0c5bd15eef9d0b2277ec Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 25 Oct 2018 18:17:59 +0200 Subject: [PATCH 07/10] differentiate structs and different values for IN with literals --- .../sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 5 +- .../catalyst/expressions/Canonicalize.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 47 ++++++++------ .../sql/catalyst/optimizer/expressions.scala | 16 ++--- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../statsEstimation/FilterEstimation.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 7 +- .../catalyst/analysis/TypeCoercionSuite.scala | 12 ++-- .../catalog/ExternalCatalogSuite.scala | 4 +- .../expressions/CanonicalizeSuite.scala | 12 ++-- .../catalyst/expressions/PredicateSuite.scala | 65 ++++++++----------- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 33 +++++----- .../FilterEstimationSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 8 ++- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileSourceStrategy.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/DataSourceStrategySuite.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/client/HiveClientSuite.scala | 12 ++-- 23 files changed, 127 insertions(+), 119 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72ac80e0a0a1..d4d45eb9fa10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -493,7 +493,7 @@ object TypeCoercion { i } - case i @ In(a, b) if b.exists(_.dataType != a.dataType) => + case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 176ea823b1fc..70785423f513 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -94,7 +94,10 @@ package object dsl { case c: CreateNamedStruct => InSubquery(c.valExprs, l) case other => InSubquery(Seq(other), l) } - case _ => In(expr, list) + case _ => expr match { + case c: CreateNamedStruct => In(c.valExprs, list) + case other => In(Seq(other), list) + } } def like(other: Expression): Expression = Like(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index fe6db8b344d3..527e016d5571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -87,7 +87,7 @@ object Canonicalize { case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) // order the list in the In operator - case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case In(values, list) if list.length > 1 => In(values, list.sortBy(_.hashCode())) case _ => e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 483a7248ba4a..dbaee61ab60a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -140,13 +140,12 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } -/** - * Evaluates to `true` if `values` are returned in `query`'s result set. - */ -case class InSubquery(values: Seq[Expression], query: ListQuery) - extends Predicate with Unevaluable { +abstract class InBase extends Predicate { + def values: Seq[Expression] - @transient private lazy val value: Expression = if (values.length > 1) { + @transient protected lazy val isMultiValued = values.length > 1 + + @transient lazy val value: Expression = if (isMultiValued) { CreateNamedStruct(values.zipWithIndex.flatMap { case (v: NamedExpression, _) => Seq(Literal(v.name), v) case (v, idx) => Seq(Literal(s"_$idx"), v) @@ -154,7 +153,13 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) } else { values.head } +} +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends InBase with Unevaluable { override def checkInputDataTypes(): TypeCheckResult = { if (values.length != query.childOutputs.length) { @@ -225,7 +230,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) true """) // scalastyle:on line.size.limit -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase { require(list != null, "list should not be null") @@ -240,15 +245,15 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } - override def children: Seq[Expression] = value +: list + override def children: Seq[Expression] = values ++ list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) - override def nullable: Boolean = value.dataType match { - case _: StructType if !SQLConf.get.inFalseForNullField => - children.exists(_.nullable) || - children.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable)) - case _ => children.exists(_.nullable) + override def nullable: Boolean = if (isMultiValued && !SQLConf.get.inFalseForNullField) { + children.exists(_.nullable) || + children.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable)) + } else { + children.exists(_.nullable) } override def foldable: Boolean = children.forall(_.foldable) @@ -276,16 +281,20 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } - @transient lazy val checkNullGenCode: (ExprCode) => Block = value.dataType match { - case _: StructType if !SQLConf.get.inFalseForNullField => + @transient lazy val checkNullGenCode: (ExprCode) => Block = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { e => code"${e.isNull} || ${e.value}.anyNull()" - case _ => e => code"${e.isNull}" + } else { + e => code"${e.isNull}" + } } - @transient lazy val checkNullEval: (Any) => Boolean = value.dataType match { - case _: StructType if !SQLConf.get.inFalseForNullField => + @transient lazy val checkNullEval: (Any) => Boolean = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { input => input == null || input.asInstanceOf[InternalRow].anyNull - case _ => input => input == null + } else { + input => input == null + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index f8037588fa71..e0eddc5d2bce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -212,27 +212,27 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 1. Converts the predicate to false when the list is empty and * the value is not nullable. * 2. Removes literal repetitions. - * 3. Replaces [[In (value, seq[Literal])]] with optimized version + * 3. Replaces [[In (values, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty => + case i @ In(_, list) if list.isEmpty => // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case expr @ In(v, list) if expr.inSetConvertible => + If(IsNotNull(i.value), FalseLiteral, Literal(null, BooleanType)) + case expr @ In(_, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStructLike] + && !expr.value.isInstanceOf[CreateNamedStructLike] && !newList.head.isInstanceOf[CreateNamedStructLike]) { - EqualTo(v, newList.head) + EqualTo(expr.value, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v, HashSet() ++ hSet) + InSet(expr.value, HashSet() ++ hSet) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 @@ -527,7 +527,7 @@ object NullPropagation extends Rule[LogicalPlan] { } // If the value expression is NULL then transform the In expression to null literal. - case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ba0b72e747fc..7d313c11c43f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1120,7 +1120,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.IN if ctx.query != null => invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => - invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3eeefaedb1..b651eaa0fef7 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -164,8 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: Attribute, expList) - if expList.forall(e => e.isInstanceOf[Literal]) => + case In(Seq(ar: Attribute), expList) if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. // Here we convert In into InSet anyway, because they share the same processing logic. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f9facbb71a4e..24684ff57f84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -280,21 +280,22 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: different types in inlist but can be converted to a common type") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + val plan = Project( + Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation() ) assertAnalysisSuccess(plan) } test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, + val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0eba1c537d67..048798d8aef6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1432,16 +1432,16 @@ class TypeCoercionSuite extends AnalysisTest { // InConversion val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, - In(UnresolvedAttribute("a"), Seq(Literal(1))), - In(UnresolvedAttribute("a"), Seq(Literal(1))) + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))), + In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))) ) ruleTest(inConversion, - In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), - In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))) ) ruleTest(inConversion, - In(Literal("a"), Seq(Literal(1), Literal("b"))), - In(Cast(Literal("a"), StringType), + In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))), + In(Seq(Cast(Literal("a"), StringType)), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index b376108399c1..51464332dbc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac checkAnswer(tbl2, Seq.empty, Set(part1, part2)) checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 28e6940f3cca..b78d23e3472e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -27,9 +27,9 @@ class CanonicalizeSuite extends SparkFunSuite { val range = Range(1, 1, 1, 1) val idAttr = range.output.head - val in1 = In(idAttr, Seq(Literal(1), Literal(2))) - val in2 = In(idAttr, Seq(Literal(2), Literal(1))) - val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2))) + val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1))) + val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3))) assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) @@ -37,11 +37,11 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(in1).sameResult(range.where(in2))) assert(!range.where(in1).sameResult(range.where(in3))) - val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(2), Literal(1))))) - val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))), CreateArray(Seq(Literal(1), Literal(2))))) - val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))), CreateArray(Seq(Literal(3), Literal(1))))) assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 4b2f5e1254e2..5f7c5ff6d565 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -126,46 +126,39 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("basic IN predicate test") { - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkEvaluation(In(Seq(NonFoldableLiteral.create(null, IntegerType)), Seq.empty), null) + checkEvaluation(In(Seq(Literal(1)), Seq.empty), false) + checkEvaluation(In(Seq(Literal(1)), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation( + In(Seq(Literal(1)), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkEvaluation( + In(Seq(Literal(2)), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(2)), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Seq(Literal(3)), Seq(Literal(1), Literal(2))), false) checkEvaluation( - And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + And(In(Seq(Literal(1)), Seq(Literal(1), Literal(2))), In(Seq(Literal(2)), Seq(Literal(1), Literal(2)))), true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(ns, Seq(ns)), null) - checkEvaluation(In(Literal("a"), Seq(ns)), null) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + checkEvaluation(In(Seq(ns), Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(Seq(ns), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("a")), Seq(ns)), null) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("^Ba*n"), ns)), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Seq(Literal("^Ba*n")), Seq(Literal("aa"), Literal("^n"))), false) } test("IN with different types") { - def testWithRandomDataGeneration(dataType: DataType, - nullable: Boolean, - legacyNullHandling: Boolean = false): Unit = { - def isNull(e: Any): Boolean = { - if (!legacyNullHandling && dataType.isInstanceOf[StructType]) { - e == null || e.asInstanceOf[Row].anyNull - } else { - e == null - } - } + def testWithRandomDataGeneration(dataType: DataType, nullable: Boolean): Unit = { val maybeDataGen = RandomDataGenerator.forType(dataType, nullable = nullable) // Actually we won't pass in unsupported data types, this is a safety check. val dataGen = maybeDataGen.getOrElse( @@ -188,16 +181,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } val input = inputData.map(NonFoldableLiteral.create(_, dataType)) - val expected = if (isNull(inputData.head)) { + val expected = if (inputData.head == null) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { true - } else if (inputData.slice(1, 10).exists(isNull)) { + } else if (inputData.slice(1, 10).contains(null)) { null } else { false } - checkEvaluation(In(input(0), input.slice(1, 10)), expected) + checkEvaluation(In(Seq(input(0)), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => @@ -231,13 +224,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { nullable <- Seq(true, false)) { val structType = StructType( StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) - if (nullable) { - Seq("true", "false").foreach { legacyNullHandling => - withSQLConf((SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD.key, legacyNullHandling)) { - testWithRandomDataGeneration(structType, nullable, legacyNullHandling.toBoolean) - } - } - } + testWithRandomDataGeneration(structType, nullable) } // Map types: not supported @@ -262,12 +249,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22501: In should not generate codes beyond 64KB") { val N = 3000 val sets = (1 to N).map(i => Literal(i.toDouble)) - checkEvaluation(In(Literal(1.0D), sets), true) + checkEvaluation(In(Seq(Literal(1.0D)), sets), true) } test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() - In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + In(Seq(Literal(1.0D)), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 641c89873dcc..09d8f2e4d1ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -251,7 +251,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select('a) - .where(In(Literal(1), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal(1)), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index a36083b84704..11e79edb9928 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -45,9 +45,9 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -56,8 +56,8 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) - .where(In(UnresolvedAttribute("b"), + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("b")), Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"), Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0), Rand(0), Rand(0)))) @@ -69,7 +69,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -79,7 +79,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .where(In(Seq(UnresolvedAttribute("a")), (1 to 11).map(Literal(_)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -94,13 +94,15 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In( + Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) + .where(In( + Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) @@ -109,7 +111,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation - .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) + .where(In(Seq(Literal.create(null, NullType)), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -140,7 +142,8 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute)") { val originalQuery = testRelation - .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) + .where(In( + Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -156,7 +159,7 @@ class OptimizeInSuite extends PlanTest { "list expression contains attribute - select)") { val originalQuery = testRelation - .select(In(Literal.create(null, StringType), + .select(In(Seq(Literal.create(null, StringType)), Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -171,7 +174,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: Setting the threshold for turning Set into InSet.") { val plan = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1), Literal(2), Literal(3)))) .analyze withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { @@ -194,7 +197,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: one element in list gets transformed to EqualTo.") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .where(In(Seq(UnresolvedAttribute("a")), Seq(Literal(1)))) .analyze val optimized = Optimize.execute(originalQuery) @@ -210,7 +213,7 @@ class OptimizeInSuite extends PlanTest { "when value is not nullable") { val originalQuery = testRelation - .where(In(Literal("a"), Nil)) + .where(In(Seq(Literal("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) @@ -226,7 +229,7 @@ class OptimizeInSuite extends PlanTest { "when value is nullable") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Nil)) + .where(In(Seq(UnresolvedAttribute("a")), Nil)) .analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 47bfa6256958..b84334eaf49b 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -440,7 +440,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Filter(In(Seq(attrDate), Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = Some(3), min = Some(d20170103), max = Some(d20170105), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ae27690f2e5b..2e2f28fcd5be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -793,7 +793,13 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + def isin(list: Any*): Column = withExpr { + val inValues = expr match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + In(inValues, list.map(lit(_).expr)) + } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 196d057c2de1..efb772c6e123 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -235,7 +235,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(a: AttributeReference, list: Seq[Expression]) + case In(Seq(a: AttributeReference), list: Seq[Expression]) if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c6000442fae7..38ad32678037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -475,7 +475,7 @@ object DataSourceStrategy { // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + case expressions.In(Seq(a: Attribute), list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, hSet.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index fe27b78bf336..68e7bb07b794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,7 +86,7 @@ object FileSourceStrategy extends Strategy with Logging { expr match { case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => getBucketSetFromValue(a, v) - case expressions.In(a: Attribute, list) + case expressions.In(Seq(a: Attribute), list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) case expressions.InSet(a: Attribute, hset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index efc2f20a907f..0610f51b721a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -459,7 +459,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), - Seq(In(attribute, Nil)), testRelation) + Seq(In(Seq(attribute), Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f20aded169e4..1177bfc7a858 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -51,7 +51,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) - testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bc9d4cd7f418..b335a92a231c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -685,7 +685,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def convert(expr: Expression): Option[String] = expr match { - case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + case In(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableLiterals(values)) if useAdvanced => Some(convertInToOr(name, values)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 7a325bf26b4c..42eb67bc1dd4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -192,8 +192,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } @@ -204,8 +204,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } @@ -223,8 +223,8 @@ class HiveClientSuite(version: String) 20170101 to 20170103, 0 to 4, "ab" :: "ba" :: Nil, { - case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(_, list) if expr.inSetConvertible => + InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) }) } From 1933e9d487eeab3ab9fd9af8d3164d181d0cd5ea Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Oct 2018 12:47:16 +0200 Subject: [PATCH 08/10] migrate also InSet to new management --- .../sql/catalyst/expressions/predicates.scala | 88 ++++++++++++------- .../sql/catalyst/optimizer/expressions.scala | 19 ++-- .../statsEstimation/FilterEstimation.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 6 +- .../catalyst/expressions/PredicateSuite.scala | 18 ++-- .../catalyst/optimizer/OptimizeInSuite.scala | 47 +++++++++- .../FilterEstimationSuite.scala | 12 +-- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileSourceStrategy.scala | 2 +- .../datasources/DataSourceStrategySuite.scala | 3 +- .../spark/sql/sources/BucketedReadSuite.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/client/HiveClientSuite.scala | 12 +-- 13 files changed, 144 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index dbaee61ab60a..62e8cd0b5c5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -140,7 +140,7 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } -abstract class InBase extends Predicate { +trait InBase extends Predicate { def values: Seq[Expression] @transient protected lazy val isMultiValued = values.length > 1 @@ -153,6 +153,22 @@ abstract class InBase extends Predicate { } else { values.head } + + @transient lazy val checkNullGenCode: (ExprCode) => Block = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { + e => code"${e.isNull} || ${e.value}.anyNull()" + } else { + e => code"${e.isNull}" + } + } + + @transient lazy val checkNullEval: (Any) => Boolean = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { + input => input == null || input.asInstanceOf[InternalRow].anyNull + } else { + input => input == null + } + } } /** @@ -251,9 +267,9 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase { override def nullable: Boolean = if (isMultiValued && !SQLConf.get.inFalseForNullField) { children.exists(_.nullable) || - children.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable)) + list.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable)) } else { - children.exists(_.nullable) + value.nullable || list.exists(_.nullable) } override def foldable: Boolean = children.forall(_.foldable) @@ -281,22 +297,6 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase { } } - @transient lazy val checkNullGenCode: (ExprCode) => Block = { - if (isMultiValued && !SQLConf.get.inFalseForNullField) { - e => code"${e.isNull} || ${e.value}.anyNull()" - } else { - e => code"${e.isNull}" - } - } - - @transient lazy val checkNullEval: (Any) => Boolean = { - if (isMultiValued && !SQLConf.get.inFalseForNullField) { - input => input == null || input.asInstanceOf[InternalRow].anyNull - } else { - input => input == null - } - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) @@ -371,37 +371,57 @@ case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { +case class InSet(values: Seq[Expression], hset: Set[Any]) extends InBase { require(hset != null, "hset could not be null") - override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" - @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + override def children: Seq[Expression] = values - override def nullable: Boolean = child.nullable || hasNull + @transient private[this] lazy val hasNull: Boolean = { + if (isMultiValued && !SQLConf.get.inFalseForNullField) { + hset.exists(checkNullEval) + } else { + hset.contains(null) + } + } - protected override def nullSafeEval(value: Any): Any = { - if (set.contains(value)) { - true - } else if (hasNull) { - null + override def nullable: Boolean = { + val isValueNullable = if (isMultiValued && !SQLConf.get.inFalseForNullField) { + values.exists(_.nullable) } else { - false + value.nullable } + isValueNullable || hasNull } - @transient lazy val set: Set[Any] = child.dataType match { + override def eval(input: InternalRow): Any = { + val inputValue = value.eval(input) + if (checkNullEval(inputValue)) { + if (set.contains(inputValue)) { + true + } else if (hasNull) { + null + } else { + false + } + } else { + null + } + } + + @transient lazy val set: Set[Any] = value.dataType match { case _: AtomicType => hset case _: NullType => hset case _ => // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows - TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + TreeSet.empty(TypeUtils.getInterpretedOrdering(value.dataType)) ++ hset } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setTerm = ctx.addReferenceObj("set", set) - val childGen = child.genCode(ctx) + val childGen = value.genCode(ctx) val setIsNull = if (hasNull) { s"${ev.isNull} = !${ev.value};" } else { @@ -410,7 +430,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = code""" |${childGen.code} - |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${checkNullGenCode(childGen)}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); @@ -420,7 +440,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } override def sql: String = { - val valueSQL = child.sql + val valueSQL = value.sql val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") s"($valueSQL IN ($listSQL))" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e0eddc5d2bce..13c2118914cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -213,16 +213,23 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * the value is not nullable. * 2. Removes literal repetitions. * 3. Replaces [[In (values, seq[Literal])]] with optimized version - * [[InSet (value, HashSet[Literal])]] which is much faster. + * [[InSet (values, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case i @ In(_, list) if list.isEmpty => - // When v is not nullable, the following expression will be optimized + case i @ In(values, list) if list.isEmpty => + // When values are not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(i.value), FalseLiteral, Literal(null, BooleanType)) - case expr @ In(_, list) if expr.inSetConvertible => + val isNotNull = if (SQLConf.get.inFalseForNullField) { + IsNotNull(i.value) + } else { + val valuesNotNull: Seq[Expression] = values.map(IsNotNull) + valuesNotNull.tail.foldLeft(valuesNotNull.head)(And) + } + If(isNotNull, FalseLiteral, Literal(null, BooleanType)) + case expr @ In(values, list) if expr.inSetConvertible => + // if we have more than one element in the values, we have to skip this optimization val newList = ExpressionSet(list).toSeq if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, @@ -232,7 +239,7 @@ object OptimizeIn extends Rule[LogicalPlan] { EqualTo(expr.value, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(expr.value, HashSet() ++ hSet) + InSet(values, HashSet() ++ hSet) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index b651eaa0fef7..f302952574ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -171,7 +171,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ar: Attribute, set) => + case InSet(Seq(ar: Attribute), set) => evaluateInSet(ar, set, update) // In current stage, we don't have advanced statistics such as sketches or histograms. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 171fa8334c1c..91f28c71b9a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1564,10 +1564,10 @@ object SQLConf { val LEGACY_IN_FALSE_FOR_NULL_FIELD = buildConf("spark.sql.legacy.inOperator.falseForNullField") .internal() - .doc("When set to true (default), the IN operator returns false when comparing literal " + - "structs containing a null field. When set to false, it returns null, instead. This is " + + .doc("When set to true (default), the IN operator returns false when comparing multiple " + + "values containing a null. When set to false, it returns null, instead. This is " + "important especially when using NOT IN as in the second case, it filters out the rows " + - "when a null is present in a filed; while in the first one, those rows are returned.") + "when a null is present in a field; while in the first one, those rows are returned.") .booleanConf .createWithDefault(true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 5f7c5ff6d565..09e91d36d90b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -265,13 +265,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val two = Literal(2) val three = Literal(3) val nl = Literal(null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + checkEvaluation(InSet(Seq(one), hS), true) + checkEvaluation(InSet(Seq(two), hS), true) + checkEvaluation(InSet(Seq(two), nS), true) + checkEvaluation(InSet(Seq(three), hS), false) + checkEvaluation(InSet(Seq(three), nS), null) + checkEvaluation(InSet(Seq(nl), hS), null) + checkEvaluation(InSet(Seq(nl), nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) @@ -295,7 +295,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) + checkEvaluation(InSet(Seq(input(0)), inputData.slice(1, 10).toSet), expected) } } @@ -445,7 +445,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: InSet should not use global variables") { val ctx = new CodegenContext - InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) + InSet(Seq(Literal(1)), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 11e79edb9928..65cf2db6e684 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -24,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD import org.apache.spark.sql.types._ @@ -85,7 +87,7 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) + .where(InSet(Seq(UnresolvedAttribute("a")), (1 to 11).toSet)) .analyze comparePlans(optimized, correctAnswer) @@ -138,6 +140,49 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: (notNull, NULL) IN (expr1, ..., exprN) transformed to Filter(null)") { + Seq(("true", false), ("false", null)).foreach { case (legacyInFalseForNull, expected) => + withSQLConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD.key -> legacyInFalseForNull) { + val originalQuery = + testRelation + .where(In(Seq(Literal.create(null, IntegerType), Literal(1)), + Seq(Literal.create( + InternalRow.fromSeq(Seq(1, 2)), + StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType))))))) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(expected, BooleanType)) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + } + + test("OptimizedIn test: (notNull, NULL) IN () transformed to IsNotNull filters") { + val legacyExpected = If( + Literal(true, BooleanType), Literal(false, BooleanType), Literal(null, BooleanType)) + val newExpected = If( + Literal(false, BooleanType), Literal(false, BooleanType), Literal(null, BooleanType)) + Seq(("true", legacyExpected), ("false", newExpected)).foreach { + case (legacyInFalseForNull, expected) => + withSQLConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD.key -> legacyInFalseForNull) { + val originalQuery = + testRelation + .where(In( + Seq(Literal.create(null, IntegerType), UnresolvedAttribute("b")), Seq.empty)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(expected) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index b84334eaf49b..06d36716f022 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -351,7 +351,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IN (3, 4, 5)") { validateEstimatedStats( - Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Filter(InSet(Seq(attrInt), Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) @@ -359,7 +359,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with all zeros") { validateEstimatedStats( - Filter(InSet(attrString, Set(3, 4, 5)), + Filter(InSet(Seq(attrString), Set(3, 4, 5)), StatsTestPlan(Seq(attrString), 0, AttributeMap(Seq(attrString -> ColumnStat(distinctCount = Some(0), min = None, max = None, @@ -370,7 +370,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with string") { validateEstimatedStats( - Filter(InSet(attrString, Set("A0")), + Filter(InSet(Seq(attrString), Set("A0")), StatsTestPlan(Seq(attrString), 10, AttributeMap(Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, @@ -382,14 +382,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Filter(Not(InSet(Seq(attrInt), Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( - Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Filter(InSet(Seq(attrBool), Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) @@ -509,7 +509,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Filter(InSet(Seq(attrInt), Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 38ad32678037..13727b6cec6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -468,7 +468,7 @@ object DataSourceStrategy { case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - case expressions.InSet(a: Attribute, set) => + case expressions.InSet(Seq(a: Attribute), set) => val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) Some(sources.In(a.name, set.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 68e7bb07b794..4812fa97464c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -89,7 +89,7 @@ object FileSourceStrategy extends Strategy with Logging { case expressions.In(Seq(a: Attribute), list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) - case expressions.InSet(a: Attribute, hset) + case expressions.InSet(Seq(a: Attribute), hset) if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 1177bfc7a858..1c9e84eefe47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -49,7 +49,8 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) - testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter( + InSet(Seq(attrInt), Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) testTranslateFilter(In(Seq(attrInt), Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a2bc651bb2bd..90ead5de5a0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -171,7 +171,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { df) // Case 4: InSet - val inSetExpr = expressions.InSet($"j".expr, + val inSetExpr = expressions.InSet(Seq($"j".expr), Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr)) checkPrunedAnswers( bucketSpec, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index b335a92a231c..daa242949e8b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -689,7 +689,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { if useAdvanced => Some(convertInToOr(name, values)) - case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + case InSet(Seq(ExtractAttribute(NonVarcharAttribute(name))), ExtractableValues(values)) if useAdvanced => Some(convertInToOr(name, values)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 42eb67bc1dd4..7a325bf26b4c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -192,8 +192,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } @@ -204,8 +204,8 @@ class HiveClientSuite(version: String) 20170102 to 20170103, 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } @@ -223,8 +223,8 @@ class HiveClientSuite(version: String) 20170101 to 20170103, 0 to 4, "ab" :: "ba" :: Nil, { - case expr @ In(_, list) if expr.inSetConvertible => - InSet(expr.value, list.map(_.eval(EmptyRow)).toSet) + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } From 809e80c944ec0e9635cd4944421587c4045c7457 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Oct 2018 15:28:50 +0200 Subject: [PATCH 09/10] fix --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 62e8cd0b5c5c..1b20128f72ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -399,6 +399,8 @@ case class InSet(values: Seq[Expression], hset: Set[Any]) extends InBase { override def eval(input: InternalRow): Any = { val inputValue = value.eval(input) if (checkNullEval(inputValue)) { + null + } else { if (set.contains(inputValue)) { true } else if (hasNull) { @@ -406,8 +408,6 @@ case class InSet(values: Seq[Expression], hset: Set[Any]) extends InBase { } else { false } - } else { - null } } From 389e6de3802be36b5d3b28d7090e26aa5d086905 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 31 Oct 2018 15:49:34 +0100 Subject: [PATCH 10/10] address comments --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 7 ++++--- .../apache/spark/sql/catalyst/optimizer/expressions.scala | 3 +-- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1b20128f72ba..7a4672731b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -226,9 +226,10 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN. Otherwise, if - spark.sql.legacy.inOperator.falseForNullField is false and any of the elements or fields of - the elements is null it returns null, else it returns false. + expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr1` equals to any exprN. Otherwise, if + `expr` is a single value and it is null or any exprN is null or `expr` contains multiple + values and spark.sql.legacy.inOperator.falseForNullField is false and any of the exprN or + fields of the exprN is null it returns null, else it returns false. """, arguments = """ Arguments: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 13c2118914cf..77c4de2e08f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -224,8 +224,7 @@ object OptimizeIn extends Rule[LogicalPlan] { val isNotNull = if (SQLConf.get.inFalseForNullField) { IsNotNull(i.value) } else { - val valuesNotNull: Seq[Expression] = values.map(IsNotNull) - valuesNotNull.tail.foldLeft(valuesNotNull.head)(And) + values.map(IsNotNull).reduce(And) } If(isNotNull, FalseLiteral, Literal(null, BooleanType)) case expr @ In(values, list) if expr.inSetConvertible => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 91f28c71b9a8..ca8afe078ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1564,12 +1564,12 @@ object SQLConf { val LEGACY_IN_FALSE_FOR_NULL_FIELD = buildConf("spark.sql.legacy.inOperator.falseForNullField") .internal() - .doc("When set to true (default), the IN operator returns false when comparing multiple " + - "values containing a null. When set to false, it returns null, instead. This is " + + .doc("When set to true, the IN operator returns false when comparing multiple values " + + "containing a null. When set to false (default), it returns null, instead. This is " + "important especially when using NOT IN as in the second case, it filters out the rows " + "when a null is present in a field; while in the first one, those rows are returned.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val LEGACY_INTEGRALDIVIDE_RETURN_LONG = buildConf("spark.sql.legacy.integralDivide.returnBigint") .doc("If it is set to true, the div operator returns always a bigint. This behavior was " +