Skip to content

Commit 20fc87a

Browse files
committed
Improvement a special case for non-deterministic filters in optimizer
1 parent b35660d commit 20fc87a

File tree

6 files changed

+62
-1
lines changed

6 files changed

+62
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ object ExternalCatalogUtils {
153153
val index = partitionSchema.indexWhere(_.name == att.name)
154154
BoundReference(index, partitionSchema(index).dataType, nullable = true)
155155
})
156+
boundPredicate.initialize(0)
156157

157158
inputPartitions.filter { p =>
158159
boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ object InterpretedPredicate {
3636

3737
case class InterpretedPredicate(expression: Expression) extends BasePredicate {
3838
override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean]
39+
40+
override def initialize(partitionIndex: Int): Unit = {
41+
expression.foreach {
42+
case n: Nondeterministic => n.initialize(partitionIndex)
43+
case _ =>
44+
}
45+
}
3946
}
4047

4148
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525

26+
/**
27+
* A pattern that matches all filter operators if condition is deterministic
28+
* or child is LeafNode of filter
29+
*/
30+
object FilterOperation extends PredicateHelper {
31+
type ReturnType = (Expression, LogicalPlan)
32+
33+
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
34+
case Filter(condition, child) if condition.deterministic =>
35+
Some((condition, child))
36+
37+
case Filter(condition, child: LeafNode) =>
38+
Some((condition, child))
39+
40+
case _ => None
41+
}
42+
}
43+
2644
/**
2745
* A pattern that matches any number of project or filter operations on top of another relational
2846
* operator. All filter operators are collected and their conditions are broken up and returned
@@ -60,7 +78,7 @@ object PhysicalOperation extends PredicateHelper {
6078
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
6179
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
6280

63-
case Filter(condition, child) if condition.deterministic =>
81+
case FilterOperation(condition, child) =>
6482
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
6583
val substitutedCondition = substitute(aliases)(condition)
6684
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ abstract class PartitioningAwareFileIndex(
175175
val index = partitionColumns.indexWhere(a.name == _.name)
176176
BoundReference(index, partitionColumns(index).dataType, nullable = true)
177177
})
178+
boundPredicate.initialize(0)
178179

179180
val selected = partitions.filter {
180181
case PartitionPath(values, _) => boundPredicate.eval(values)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,4 +2029,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
20292029
testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
20302030
Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
20312031
}
2032+
2033+
test("SPARK-21707: nondeterministic expressions correctly for filter predicates") {
2034+
withTempPath { path =>
2035+
val p = path.getAbsolutePath
2036+
Seq(1 -> "a").toDF("a", "b").write.partitionBy("a").parquet(p)
2037+
val df = spark.read.parquet(p)
2038+
checkAnswer(df.filter(rand(10) <= 1.0).select($"a"), Row(1))
2039+
}
2040+
}
20322041
}

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,4 +570,29 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
570570
df1.groupBy("j").agg(max("k")))
571571
}
572572
}
573+
574+
test("SPARK-21707 read user need fields when the condition of filter is nondeterministic") {
575+
withTable("bucketed_table") {
576+
df1.write
577+
.format("parquet")
578+
.partitionBy("i")
579+
.bucketBy(8, "j", "k")
580+
.saveAsTable("bucketed_table")
581+
582+
val table = spark.table("bucketed_table").select($"i", $"j", $"k")
583+
assert(table.queryExecution.sparkPlan.inputSet.toSeq.length == 3)
584+
585+
// the condition of filter is nondeterministic and no fields.
586+
val table1 = spark.table("bucketed_table").where(rand(10) <= 0.5).select($"i")
587+
assert(table1.queryExecution.sparkPlan.inputSet.toSeq.length == 1)
588+
assert(table1.queryExecution.sparkPlan.inputSet != table.queryExecution.sparkPlan.inputSet)
589+
590+
// the condition of filter is nondeterministic and one fields.
591+
val table2 = spark.table("bucketed_table")
592+
.where(rand(10) <= 0.5 && $"j" > 1)
593+
.select($"i")
594+
assert(table2.queryExecution.sparkPlan.inputSet.toSeq.length == 2)
595+
assert(table2.queryExecution.sparkPlan.inputSet != table.queryExecution.sparkPlan.inputSet)
596+
}
597+
}
573598
}

0 commit comments

Comments
 (0)