diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 50f32e81d997..2e2858ab34ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -153,6 +153,7 @@ object ExternalCatalogUtils { val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) + boundPredicate.initialize(0) inputPartitions.filter { p => boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7bf10f199f1c..b0f90a095661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,6 +36,13 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e..243cf75b4cae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -175,6 +175,7 @@ abstract class PartitioningAwareFileIndex( val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) }) + boundPredicate.initialize(0) val selected = partitions.filter { case PartitionPath(values, _) => boundPredicate.eval(values) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5eb34e587e95..b7b000c4d063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2029,4 +2029,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-21746: nondeterministic expressions correctly for filter predicates") { + withTempPath { path => + val p = path.getAbsolutePath + Seq(1 -> "a").toDF("a", "b").write.partitionBy("a").parquet(p) + val df = spark.read.parquet(p) + checkAnswer(df.filter(rand(10) <= 1.0).select($"a"), Row(1)) + } + } }