diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 50f32e81d997..2e2858ab34ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -153,6 +153,7 @@ object ExternalCatalogUtils { val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) + boundPredicate.initialize(0) inputPartitions.filter { p => boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7bf10f199f1c..b0f90a095661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,6 +36,13 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c21a496..717841136f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,6 +23,24 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +/** + * A pattern that matches all filter operators if condition is deterministic + * or child is LeafNode of filter + */ +object FilterOperation extends PredicateHelper { + type ReturnType = (Expression, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case Filter(condition, child) if condition.deterministic => + Some((condition, child)) + + case Filter(condition, child: LeafNode) => + Some((condition, child)) + + case _ => None + } +} + /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -60,7 +78,7 @@ object PhysicalOperation extends PredicateHelper { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) - case Filter(condition, child) if condition.deterministic => + case FilterOperation(condition, child) => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e..243cf75b4cae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -175,6 +175,7 @@ abstract class PartitioningAwareFileIndex( val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) }) + boundPredicate.initialize(0) val selected = partitions.filter { case PartitionPath(values, _) => boundPredicate.eval(values) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5eb34e587e95..8e3c8fe69026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2029,4 +2029,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-21707: nondeterministic expressions correctly for filter predicates") { + withTempPath { path => + val p = path.getAbsolutePath + Seq(1 -> "a").toDF("a", "b").write.partitionBy("a").parquet(p) + val df = spark.read.parquet(p) + checkAnswer(df.filter(rand(10) <= 1.0).select($"a"), Row(1)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ba0ca666b5c1..d78f29490c54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -570,4 +570,29 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { df1.groupBy("j").agg(max("k"))) } } + + test("SPARK-21707 read user need fields when the condition of filter is nondeterministic") { + withTable("bucketed_table") { + df1.write + .format("parquet") + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + val table = spark.table("bucketed_table").select($"i", $"j", $"k") + assert(table.queryExecution.sparkPlan.inputSet.toSeq.length == 3) + + // the condition of filter is nondeterministic and no fields. + val table1 = spark.table("bucketed_table").where(rand(10) <= 0.5).select($"i") + assert(table1.queryExecution.sparkPlan.inputSet.toSeq.length == 1) + assert(table1.queryExecution.sparkPlan.inputSet != table.queryExecution.sparkPlan.inputSet) + + // the condition of filter is nondeterministic and one fields. + val table2 = spark.table("bucketed_table") + .where(rand(10) <= 0.5 && $"j" > 1) + .select($"i") + assert(table2.queryExecution.sparkPlan.inputSet.toSeq.length == 2) + assert(table2.queryExecution.sparkPlan.inputSet != table.queryExecution.sparkPlan.inputSet) + } + } }