diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 8f8d8010e26b..b827878f82fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -194,7 +194,7 @@ case class InMemoryTableScanExec( } // Returned filter predicate should return false iff it is impossible for the input expression - // to evaluate to `true' based on statistics collected about this partition batch. + // to evaluate to `true` based on statistics collected about this partition batch. @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => @@ -237,6 +237,34 @@ case class InMemoryTableScanExec( if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) + + // This is an example to explain how it works, imagine that the id column stored as follows: + // __________________________________________ + // | Partition ID | lowerBound | upperBound | + // |--------------|------------|------------| + // | p1 | '1' | '9' | + // | p2 | '10' | '19' | + // | p3 | '20' | '29' | + // | p4 | '30' | '39' | + // | p5 | '40' | '49' | + // |______________|____________|____________| + // + // A filter: df.filter($"id".startsWith("2")). + // In this case it substr lowerBound and upperBound: + // ________________________________________________________________________________________ + // | Partition ID | lowerBound.substr(0, Length("2")) | upperBound.substr(0, Length("2")) | + // |--------------|-----------------------------------|-----------------------------------| + // | p1 | '1' | '9' | + // | p2 | '1' | '1' | + // | p3 | '2' | '2' | + // | p4 | '3' | '3' | + // | p5 | '4' | '4' | + // |______________|___________________________________|___________________________________| + // + // We can see that we only need to read p1 and p3. + case StartsWith(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound.substr(0, Length(l)) <= l && + l <= statsFor(a).upperBound.substr(0, Length(l)) } lazy val partitionFilters: Seq[Expression] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index af493e93b519..b3a5c687f775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -170,6 +170,15 @@ class PartitionBatchPruningSuite } } + // Support `StartsWith` predicate + checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s like '18%'", 1, 1)( + 180 to 189 + ) + checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s like '%'", 5, 11)( + 100 to 200 + ) + checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE '18%' like s", 5, 11)(Seq()) + // With disable IN_MEMORY_PARTITION_PRUNING option test("disable IN_MEMORY_PARTITION_PRUNING") { spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, false)