Skip to content

Commit e439e29

Browse files
wForgetcloud-fan
authored andcommitted
[SPARK-47463][SQL][3.5] Use V2Predicate to wrap expression with return type of boolean
Backports #45589 to 3.5 ### What changes were proposed in this pull request? Use V2Predicate to wrap If expr when building v2 expressions. ### Why are the changes needed? The `PushFoldableIntoBranches` optimizer may fold predicate into (if / case) branches and `V2ExpressionBuilder` wraps `If` as `GeneralScalarExpression`, which causes the assertion in `PushablePredicate.unapply` to fail. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #46074 from wForget/SPARK-47463_3.5. Authored-by: Zhen Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7aea21e commit e439e29

File tree

2 files changed

+97
-72
lines changed

2 files changed

+97
-72
lines changed

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression =>
2323
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
2424
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
2525
import org.apache.spark.sql.execution.datasources.PushableExpression
26-
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
26+
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}
2727

2828
/**
2929
* The builder to generate V2 expressions from catalyst expressions.
@@ -96,45 +96,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
9696
generateExpression(child).map(v => new V2Cast(v, dataType))
9797
case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) =>
9898
generateAggregateFunc(aggregateFunction, isDistinct)
99-
case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
100-
case Coalesce(children) => generateExpressionWithName("COALESCE", children)
101-
case Greatest(children) => generateExpressionWithName("GREATEST", children)
102-
case Least(children) => generateExpressionWithName("LEAST", children)
103-
case Rand(child, hideSeed) =>
99+
case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate)
100+
case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate)
101+
case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate)
102+
case _: Least => generateExpressionWithName("LEAST", expr, isPredicate)
103+
case Rand(_, hideSeed) =>
104104
if (hideSeed) {
105105
Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
106106
} else {
107-
generateExpressionWithName("RAND", Seq(child))
107+
generateExpressionWithName("RAND", expr, isPredicate)
108108
}
109-
case log: Logarithm => generateExpressionWithName("LOG", log.children)
110-
case Log10(child) => generateExpressionWithName("LOG10", Seq(child))
111-
case Log2(child) => generateExpressionWithName("LOG2", Seq(child))
112-
case Log(child) => generateExpressionWithName("LN", Seq(child))
113-
case Exp(child) => generateExpressionWithName("EXP", Seq(child))
114-
case pow: Pow => generateExpressionWithName("POWER", pow.children)
115-
case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child))
116-
case Floor(child) => generateExpressionWithName("FLOOR", Seq(child))
117-
case Ceil(child) => generateExpressionWithName("CEIL", Seq(child))
118-
case round: Round => generateExpressionWithName("ROUND", round.children)
119-
case Sin(child) => generateExpressionWithName("SIN", Seq(child))
120-
case Sinh(child) => generateExpressionWithName("SINH", Seq(child))
121-
case Cos(child) => generateExpressionWithName("COS", Seq(child))
122-
case Cosh(child) => generateExpressionWithName("COSH", Seq(child))
123-
case Tan(child) => generateExpressionWithName("TAN", Seq(child))
124-
case Tanh(child) => generateExpressionWithName("TANH", Seq(child))
125-
case Cot(child) => generateExpressionWithName("COT", Seq(child))
126-
case Asin(child) => generateExpressionWithName("ASIN", Seq(child))
127-
case Asinh(child) => generateExpressionWithName("ASINH", Seq(child))
128-
case Acos(child) => generateExpressionWithName("ACOS", Seq(child))
129-
case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child))
130-
case Atan(child) => generateExpressionWithName("ATAN", Seq(child))
131-
case Atanh(child) => generateExpressionWithName("ATANH", Seq(child))
132-
case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children)
133-
case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child))
134-
case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child))
135-
case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child))
136-
case Signum(child) => generateExpressionWithName("SIGN", Seq(child))
137-
case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children)
109+
case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate)
110+
case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate)
111+
case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate)
112+
case _: Log => generateExpressionWithName("LN", expr, isPredicate)
113+
case _: Exp => generateExpressionWithName("EXP", expr, isPredicate)
114+
case _: Pow => generateExpressionWithName("POWER", expr, isPredicate)
115+
case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate)
116+
case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate)
117+
case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate)
118+
case _: Round => generateExpressionWithName("ROUND", expr, isPredicate)
119+
case _: Sin => generateExpressionWithName("SIN", expr, isPredicate)
120+
case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate)
121+
case _: Cos => generateExpressionWithName("COS", expr, isPredicate)
122+
case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate)
123+
case _: Tan => generateExpressionWithName("TAN", expr, isPredicate)
124+
case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate)
125+
case _: Cot => generateExpressionWithName("COT", expr, isPredicate)
126+
case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate)
127+
case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate)
128+
case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate)
129+
case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate)
130+
case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate)
131+
case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate)
132+
case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate)
133+
case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate)
134+
case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate)
135+
case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate)
136+
case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate)
137+
case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate)
138138
case and: And =>
139139
// AND expects predicate
140140
val l = generateExpression(and.left, true)
@@ -185,57 +185,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
185185
assert(v.isInstanceOf[V2Predicate])
186186
new V2Not(v.asInstanceOf[V2Predicate])
187187
}
188-
case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child))
189-
case BitwiseNot(child) => generateExpressionWithName("~", Seq(child))
190-
case CaseWhen(branches, elseValue) =>
188+
case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate)
189+
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
190+
case caseWhen @ CaseWhen(branches, elseValue) =>
191191
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
192-
val values = branches.map(_._2).flatMap(generateExpression(_, true))
193-
if (conditions.length == branches.length && values.length == branches.length) {
192+
val values = branches.map(_._2).flatMap(generateExpression(_))
193+
val elseExprOpt = elseValue.flatMap(generateExpression(_))
194+
if (conditions.length == branches.length && values.length == branches.length &&
195+
elseExprOpt.size == elseValue.size) {
194196
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
195197
Seq[V2Expression](c, v)
196198
}
197-
if (elseValue.isDefined) {
198-
elseValue.flatMap(generateExpression(_)).map { v =>
199-
val children = (branchExpressions :+ v).toArray[V2Expression]
200-
// The children looks like [condition1, value1, ..., conditionN, valueN, elseValue]
201-
new V2Predicate("CASE_WHEN", children)
202-
}
199+
val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression]
200+
// The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)]
201+
if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) {
202+
Some(new V2Predicate("CASE_WHEN", children))
203203
} else {
204-
// The children looks like [condition1, value1, ..., conditionN, valueN]
205-
Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression]))
204+
Some(new GeneralScalarExpression("CASE_WHEN", children))
206205
}
207206
} else {
208207
None
209208
}
210-
case iff: If => generateExpressionWithName("CASE_WHEN", iff.children)
209+
case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate)
211210
case substring: Substring =>
212211
val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
213212
Seq(substring.str, substring.pos)
214213
} else {
215214
substring.children
216215
}
217-
generateExpressionWithName("SUBSTRING", children)
218-
case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
219-
case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
216+
generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate)
217+
case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate)
218+
case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate)
220219
case BitLength(child) if child.dataType.isInstanceOf[StringType] =>
221-
generateExpressionWithName("BIT_LENGTH", Seq(child))
220+
generateExpressionWithName("BIT_LENGTH", expr, isPredicate)
222221
case Length(child) if child.dataType.isInstanceOf[StringType] =>
223-
generateExpressionWithName("CHAR_LENGTH", Seq(child))
224-
case concat: Concat => generateExpressionWithName("CONCAT", concat.children)
225-
case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children)
226-
case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
227-
case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children)
228-
case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children)
222+
generateExpressionWithName("CHAR_LENGTH", expr, isPredicate)
223+
case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate)
224+
case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate)
225+
case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate)
226+
case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate)
227+
case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate)
229228
case overlay: Overlay =>
230229
val children = if (overlay.len == Literal(-1)) {
231230
Seq(overlay.input, overlay.replace, overlay.pos)
232231
} else {
233232
overlay.children
234233
}
235-
generateExpressionWithName("OVERLAY", children)
236-
case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children)
237-
case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children)
238-
case date: TruncDate => generateExpressionWithName("TRUNC", date.children)
234+
generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate)
235+
case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate)
236+
case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate)
237+
case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate)
239238
case Second(child, _) =>
240239
generateExpression(child).map(v => new V2Extract("SECOND", v))
241240
case Minute(child, _) =>
@@ -268,12 +267,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
268267
generateExpression(child).map(v => new V2Extract("WEEK", v))
269268
case YearOfWeek(child) =>
270269
generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v))
271-
case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children)
272-
case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children)
273-
case Crc32(child) => generateExpressionWithName("CRC32", Seq(child))
274-
case Md5(child) => generateExpressionWithName("MD5", Seq(child))
275-
case Sha1(child) => generateExpressionWithName("SHA1", Seq(child))
276-
case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children)
270+
case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate)
271+
case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate)
272+
case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate)
273+
case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate)
274+
case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate)
275+
case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate)
277276
// TODO supports other expressions
278277
case ApplyFunctionExpression(function, children) =>
279278
val childrenExpressions = children.flatMap(generateExpression(_))
@@ -345,10 +344,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
345344
}
346345

347346
private def generateExpressionWithName(
348-
v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = {
347+
v2ExpressionName: String,
348+
expr: Expression,
349+
isPredicate: Boolean): Option[V2Expression] = {
350+
generateExpressionWithNameByChildren(
351+
v2ExpressionName, expr.children, expr.dataType, isPredicate)
352+
}
353+
354+
private def generateExpressionWithNameByChildren(
355+
v2ExpressionName: String,
356+
children: Seq[Expression],
357+
dataType: DataType,
358+
isPredicate: Boolean): Option[V2Expression] = {
349359
val childrenExpressions = children.flatMap(generateExpression(_))
350360
if (childrenExpressions.length == children.length) {
351-
Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
361+
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
362+
Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
363+
} else {
364+
Some(new GeneralScalarExpression(
365+
v2ExpressionName, childrenExpressions.toArray[V2Expression]))
366+
}
352367
} else {
353368
None
354369
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
626626
}
627627
}
628628
}
629+
630+
test("SPARK-47463: Pushed down v2 filter with if expression") {
631+
withTempView("t1") {
632+
spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load()
633+
.createTempView("t1")
634+
val df = sql("SELECT * FROM t1 WHERE if(i = 1, i, 0) > 0")
635+
val result = df.collect()
636+
assert(result.length == 1)
637+
}
638+
}
629639
}
630640

631641

0 commit comments

Comments
 (0)