Skip to content

Commit bf764a3

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-22384][SQL][FOLLOWUP] Refine partition pruning when attribute is wrapped in Cast
## What changes were proposed in this pull request? As mentioned in #21586 , `Cast.mayTruncate` is not 100% safe, string to boolean is allowed. Since changing `Cast.mayTruncate` also changes the behavior of Dataset, here I propose to add a new `Cast.canSafeCast` for partition pruning. ## How was this patch tested? new test cases Author: Wenchen Fan <[email protected]> Closes #21712 from cloud-fan/safeCast.
1 parent ca8243f commit bf764a3

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,26 @@ object Cast {
134134
toPrecedence > 0 && fromPrecedence > toPrecedence
135135
}
136136

137+
/**
138+
* Returns true iff we can safely cast the `from` type to `to` type without any truncating or
139+
* precision lose, e.g. int -> long, date -> timestamp.
140+
*/
141+
def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match {
142+
case _ if from == to => true
143+
case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true
144+
case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true
145+
case (from, to) if legalNumericPrecedence(from, to) => true
146+
case (DateType, TimestampType) => true
147+
case (_, StringType) => true
148+
case _ => false
149+
}
150+
151+
private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = {
152+
val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
153+
val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
154+
fromPrecedence >= 0 && fromPrecedence < toPrecedence
155+
}
156+
137157
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
138158
case (NullType, _) => true
139159
case (_, _) if from == to => false

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException
4545
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType}
4646
import org.apache.spark.sql.catalyst.expressions._
4747
import org.apache.spark.sql.internal.SQLConf
48-
import org.apache.spark.sql.types.{IntegralType, StringType}
48+
import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType}
4949
import org.apache.spark.unsafe.types.UTF8String
5050
import org.apache.spark.util.Utils
5151

@@ -660,7 +660,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
660660
def unapply(expr: Expression): Option[Attribute] = {
661661
expr match {
662662
case attr: Attribute => Some(attr)
663-
case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child)
663+
case Cast(child @ AtomicType(), dt: AtomicType, _)
664+
if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child)
664665
case _ => None
665666
}
666667
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll
2424
import org.apache.spark.sql.catalyst.catalog._
2525
import org.apache.spark.sql.catalyst.dsl.expressions._
2626
import org.apache.spark.sql.catalyst.expressions._
27-
import org.apache.spark.sql.types.LongType
27+
import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType}
2828

2929
// TODO: Refactor this to `HivePartitionFilteringSuite`
3030
class HiveClientSuite(version: String)
@@ -122,6 +122,22 @@ class HiveClientSuite(version: String)
122122
"aa" :: Nil)
123123
}
124124

125+
test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") {
126+
testMetastorePartitionFiltering(
127+
attr("chunk").cast(IntegerType) === 1,
128+
20170101 to 20170103,
129+
0 to 23,
130+
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
131+
}
132+
133+
test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") {
134+
testMetastorePartitionFiltering(
135+
attr("chunk").cast(BooleanType) === true,
136+
20170101 to 20170103,
137+
0 to 23,
138+
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
139+
}
140+
125141
test("getPartitionsByFilter: 20170101=ds") {
126142
testMetastorePartitionFiltering(
127143
Literal(20170101) === attr("ds"),
@@ -138,7 +154,7 @@ class HiveClientSuite(version: String)
138154
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
139155
}
140156

141-
test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") {
157+
test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") {
142158
testMetastorePartitionFiltering(
143159
attr("ds").cast(LongType) === 20170101L && attr("h") === 10,
144160
20170101 to 20170101,

0 commit comments

Comments
 (0)