diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 4c4ec000d0930..a337bd6d98b25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -66,13 +66,21 @@ trait ConstraintHelper { val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) predicates.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = predicates - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceConstraints(predicates - eq, r, l) - case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceConstraints(predicates - eq, l, r) + val candidates = predicates - eq + inferredConstraints ++= replaceConstraints(candidates, l, r) + inferredConstraints ++= replaceConstraints(candidates, r, l) + case eq @ EqualTo(l @ Cast(la: Attribute, _, tz), r: Attribute) => + val candidates = predicates - eq + inferredConstraints ++= replaceConstraints(candidates, la, Cast(r, la.dataType, tz)) + inferredConstraints ++= replaceConstraints(candidates, r, l) + case eq @ EqualTo(l: Attribute, r @ Cast(ra: Attribute, _, tz)) => + val candidates = predicates - eq + inferredConstraints ++= replaceConstraints(candidates, l, r) + inferredConstraints ++= replaceConstraints(candidates, ra, Cast(l, ra.dataType, tz)) + case eq @ EqualTo(Cast(la: Attribute, _, ltz), Cast(ra: Attribute, _, rtz)) => + val candidates = predicates - eq + inferredConstraints ++= replaceConstraints(candidates, la, Cast(ra, la.dataType, ltz)) + inferredConstraints ++= replaceConstraints(candidates, ra, Cast(la, ra.dataType, rtz)) case _ => // No inference } inferredConstraints -- constraints @@ -83,6 +91,8 @@ trait ConstraintHelper { source: Expression, destination: Expression): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination + }).map(_ transform { + case Cast(cast @ Cast(e, _, _), dt, _) if cast == destination && dt == e.dataType => e }) /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 79bd573f1d84a..ef24bbd08529f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, ShortType} class InferFiltersFromConstraintsSuite extends PlanTest { @@ -36,6 +36,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, + SimplifyCasts, BooleanSimplification, PruneFilters) :: Nil } @@ -265,7 +266,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } - test("Constraints should be inferred from cast equality constraint(filter higher data type)") { + test("Constraints inferred from cast equality constraint(filter higher data type)") { val testRelation1 = LocalRelation('a.int) val testRelation2 = LocalRelation('b.long) val originalLeft = testRelation1.subquery('left) @@ -275,45 +276,66 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), - Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => - testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) - } - - Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr === "left.a".attr.cast(LongType)), + Some("left.a".attr === "right.b".attr.cast(IntegerType)), Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => - testConstraintsAfterJoin( - originalLeft, - originalRight, - testRelation1.where(IsNotNull('a)).subquery('left), - right, - Inner, - condition) + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) } } - test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") { + test("Constraints inferred from cast equality constraint(filter lower data type)") { val testRelation1 = LocalRelation('a.int) val testRelation2 = LocalRelation('b.long) val originalLeft = testRelation1.where('a === 1).subquery('left) val originalRight = testRelation2.subquery('right) val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left) - val right = testRelation2.where(IsNotNull('b)).subquery('right) + val right = testRelation2.where(IsNotNull('b) && 'b.cast(IntegerType) === 1).subquery('right) Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), - Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => + Some("right.b".attr === "left.a".attr.cast(LongType)), + Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) } + } - Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), - Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => - testConstraintsAfterJoin( - originalLeft, - originalRight, - left, - testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right), - Inner, - condition) + test("Constraints inferred from cast equality constraint(filter decimal type)") { + val testRelation1 = LocalRelation('a.decimal(18, 0)) + val testRelation2 = LocalRelation('b.long) + val originalLeft = testRelation1.where('a === 1).subquery('left) + val originalRight = testRelation2.subquery('right) + + val decimalOne = Literal(1).cast(DecimalType(1, 0)).cast(DecimalType(18, 0)) + + val left = testRelation1.where(IsNotNull('a) && 'a === decimalOne).subquery('left) + val right = testRelation2.where(IsNotNull('b) + && 'b.cast(DecimalType(18, 0)) === decimalOne).subquery('right) + + Seq(Some("left.a".attr.cast(DecimalType(20, 0)) === "right.b".attr.cast(DecimalType(20, 0))), + Some("right.b".attr.cast(DecimalType(20, 0)) === "left.a".attr.cast(DecimalType(20, 0))), + Some("right.b".attr.cast(IntegerType) === "left.a".attr.cast(IntegerType)), + Some("left.a".attr === "right.b".attr.cast(DecimalType(18, 0))), + Some("left.a".attr.cast(LongType) === "right.b".attr)).foreach { condition => + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + } + + test("Constraints inferred from cast equality constraint(filter bigint type)") { + val testRelation1 = LocalRelation('a.decimal(18, 0)) + val testRelation2 = LocalRelation('b.long) + val originalLeft = testRelation1.subquery('left) + val originalRight = testRelation2.where('b === 1L).subquery('right) + + val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left) + val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) + + Seq(Some("left.a".attr.cast(DecimalType(20, 0)) === "right.b".attr.cast(DecimalType(20, 0))), + Some("right.b".attr.cast(DecimalType(20, 0)) === "left.a".attr.cast(DecimalType(20, 0))), + Some("right.b".attr.cast(IntegerType) === "left.a".attr.cast(IntegerType)), + Some("left.a".attr === "right.b".attr.cast(DecimalType(18, 0))), + Some("left.a".attr.cast(LongType) === "right.b".attr)).foreach { condition => + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) } } }