Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ object ExternalCatalogUtils {
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
})
boundPredicate.initialize(0)

inputPartitions.filter { p =>
boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ object InterpretedPredicate {

case class InterpretedPredicate(expression: Expression) extends BasePredicate {
override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean]

override def initialize(partitionIndex: Int): Unit = {
expression.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._

/**
* A pattern that matches all filter operators if condition is deterministic
* or child is LeafNode of filter
*/
object FilterOperation extends PredicateHelper {
type ReturnType = (Expression, LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case Filter(condition, child) if condition.deterministic =>
Some((condition, child))

case Filter(condition, child: LeafNode) =>
Some((condition, child))

case _ => None
}
}

/**
* A pattern that matches any number of project or filter operations on top of another relational
* operator. All filter operators are collected and their conditions are broken up and returned
Expand Down Expand Up @@ -60,7 +78,7 @@ object PhysicalOperation extends PredicateHelper {
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))

case Filter(condition, child) if condition.deterministic =>
case FilterOperation(condition, child) =>
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ abstract class PartitioningAwareFileIndex(
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
})
boundPredicate.initialize(0)

val selected = partitions.filter {
case PartitionPath(values, _) => boundPredicate.eval(values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2029,4 +2029,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
}

test("SPARK-21707: nondeterministic expressions correctly for filter predicates") {
withTempPath { path =>
val p = path.getAbsolutePath
Seq(1 -> "a").toDF("a", "b").write.partitionBy("a").parquet(p)
val df = spark.read.parquet(p)
checkAnswer(df.filter(rand(10) <= 1.0).select($"a"), Row(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test can pass on current master.

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -570,4 +570,29 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
df1.groupBy("j").agg(max("k")))
}
}

test("SPARK-21707 read user need fields when the condition of filter is nondeterministic") {
withTable("bucketed_table") {
df1.write
.format("parquet")
.partitionBy("i")
.bucketBy(8, "j", "k")
.saveAsTable("bucketed_table")

val table = spark.table("bucketed_table").select($"i", $"j", $"k")
assert(table.queryExecution.sparkPlan.inputSet.toSeq.length == 3)

// the condition of filter is nondeterministic and no fields.
val table1 = spark.table("bucketed_table").where(rand(10) <= 0.5).select($"i")
assert(table1.queryExecution.sparkPlan.inputSet.toSeq.length == 1)
assert(table1.queryExecution.sparkPlan.inputSet != table.queryExecution.sparkPlan.inputSet)

// the condition of filter is nondeterministic and one fields.
val table2 = spark.table("bucketed_table")
.where(rand(10) <= 0.5 && $"j" > 1)
.select($"i")
assert(table2.queryExecution.sparkPlan.inputSet.toSeq.length == 2)
assert(table2.queryExecution.sparkPlan.inputSet != table.queryExecution.sparkPlan.inputSet)
}
}
}