Skip to content

Commit 2e2b1ae

Browse files
committed
[SPARK-39784][SQL] Put Literal values on the right side of the data source filter after translating Catalyst Expression to data source filter
### What changes were proposed in this pull request? Even though the literal value could be on both sides of the filter, e.g. both `a > 1` and `1 < a` are valid, after translating Catalyst Expression to data source filter, we want the literal value on the right side so it's easier for the data source to handle these filters. We do this kind of normalization for V1 Filter. We should have the same behavior for V2 Filter. Before this PR, for the filters that have literal values on the right side, e.g. `1 > a`, we keep it as is. After this PR, we will normalize it to `a < 1` so the data source doesn't need to check each of the filters (and do the flip). ### Why are the changes needed? I think we should follow V1 Filter's behavior, normalize the filters during catalyst Expression to DS Filter translation time to make the literal values on the right side, so later on, data source doesn't need to check every single filter to figure out if it needs to flip the sides. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #37197 from huaxingao/flip. Authored-by: huaxingao <[email protected]> Signed-off-by: huaxingao <[email protected]>
1 parent 5e6aab4 commit 2e2b1ae

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
233233
val r = generateExpression(b.right)
234234
if (l.isDefined && r.isDefined) {
235235
b match {
236+
case _: Predicate if isBinaryComparisonOperator(b.sqlOperator) &&
237+
l.get.isInstanceOf[LiteralValue[_]] && r.get.isInstanceOf[FieldReference] =>
238+
Some(new V2Predicate(flipComparisonOperatorName(b.sqlOperator),
239+
Array[V2Expression](r.get, l.get)))
236240
case _: Predicate =>
237241
Some(new V2Predicate(b.sqlOperator, Array[V2Expression](l.get, r.get)))
238242
case _ =>
@@ -408,6 +412,23 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
408412
}
409413
case _ => None
410414
}
415+
416+
private def isBinaryComparisonOperator(operatorName: String): Boolean = {
417+
operatorName match {
418+
case ">" | "<" | ">=" | "<=" | "=" | "<=>" => true
419+
case _ => false
420+
}
421+
}
422+
423+
private def flipComparisonOperatorName(operatorName: String): String = {
424+
operatorName match {
425+
case ">" => "<"
426+
case "<" => ">"
427+
case ">=" => "<="
428+
case "<=" => ">="
429+
case _ => operatorName
430+
}
431+
}
411432
}
412433

413434
object ColumnOrField {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,77 @@
1818
package org.apache.spark.sql.execution.datasources.v2
1919

2020
import org.apache.spark.sql.catalyst.dsl.expressions._
21-
import org.apache.spark.sql.catalyst.expressions.Expression
21+
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.PlanTest
2323
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
2424
import org.apache.spark.sql.connector.expressions.filter.Predicate
2525
import org.apache.spark.sql.test.SharedSparkSession
26-
import org.apache.spark.sql.types.BooleanType
26+
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType}
2727

2828
class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
29+
val attrInts = Seq(
30+
$"cint".int,
31+
$"c.int".int,
32+
GetStructField($"a".struct(StructType(
33+
StructField("cstr", StringType, nullable = true) ::
34+
StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None),
35+
GetStructField($"a".struct(StructType(
36+
StructField("c.int", IntegerType, nullable = true) ::
37+
StructField("cstr", StringType, nullable = true) :: Nil)), 0, None),
38+
GetStructField($"a.b".struct(StructType(
39+
StructField("cstr1", StringType, nullable = true) ::
40+
StructField("cstr2", StringType, nullable = true) ::
41+
StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None),
42+
GetStructField($"a.b".struct(StructType(
43+
StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None),
44+
GetStructField(GetStructField($"a".struct(StructType(
45+
StructField("cstr1", StringType, nullable = true) ::
46+
StructField("b", StructType(StructField("cint", IntegerType, nullable = true) ::
47+
StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None)
48+
).zip(Seq(
49+
"cint",
50+
"`c.int`", // single level field that contains `dot` in name
51+
"a.cint", // two level nested field
52+
"a.`c.int`", // two level nested field, and nested level contains `dot`
53+
"`a.b`.cint", // two level nested field, and top level contains `dot`
54+
"`a.b`.`c.int`", // two level nested field, and both levels contain `dot`
55+
"a.b.cint" // three level nested field
56+
))
57+
58+
test("SPARK-39784: translate binary expression") { attrInts
59+
.foreach { case (attrInt, intColName) =>
60+
testTranslateFilter(EqualTo(attrInt, 1),
61+
Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
62+
testTranslateFilter(EqualTo(1, attrInt),
63+
Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
64+
65+
testTranslateFilter(EqualNullSafe(attrInt, 1),
66+
Some(new Predicate("<=>", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
67+
testTranslateFilter(EqualNullSafe(1, attrInt),
68+
Some(new Predicate("<=>", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
69+
70+
testTranslateFilter(GreaterThan(attrInt, 1),
71+
Some(new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
72+
testTranslateFilter(GreaterThan(1, attrInt),
73+
Some(new Predicate("<", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
74+
75+
testTranslateFilter(LessThan(attrInt, 1),
76+
Some(new Predicate("<", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
77+
testTranslateFilter(LessThan(1, attrInt),
78+
Some(new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
79+
80+
testTranslateFilter(GreaterThanOrEqual(attrInt, 1),
81+
Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
82+
testTranslateFilter(GreaterThanOrEqual(1, attrInt),
83+
Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
84+
85+
testTranslateFilter(LessThanOrEqual(attrInt, 1),
86+
Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
87+
testTranslateFilter(LessThanOrEqual(1, attrInt),
88+
Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
89+
}
90+
}
91+
2992
test("SPARK-36644: Push down boolean column filter") {
3093
testTranslateFilter($"col".boolean,
3194
Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType)))))

0 commit comments

Comments
 (0)