@@ -23,7 +23,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression =>
2323import org .apache .spark .sql .connector .expressions .aggregate .{AggregateFunc , Avg , Count , CountStar , GeneralAggregateFunc , Max , Min , Sum , UserDefinedAggregateFunc }
2424import org .apache .spark .sql .connector .expressions .filter .{AlwaysFalse , AlwaysTrue , And => V2And , Not => V2Not , Or => V2Or , Predicate => V2Predicate }
2525import org .apache .spark .sql .execution .datasources .PushableExpression
26- import org .apache .spark .sql .types .{BooleanType , IntegerType , StringType }
26+ import org .apache .spark .sql .types .{BooleanType , DataType , IntegerType , StringType }
2727
2828/**
2929 * The builder to generate V2 expressions from catalyst expressions.
@@ -96,45 +96,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
9696 generateExpression(child).map(v => new V2Cast (v, dataType))
9797 case AggregateExpression (aggregateFunction, Complete , isDistinct, None , _) =>
9898 generateAggregateFunc(aggregateFunction, isDistinct)
99- case Abs (child , true ) => generateExpressionWithName(" ABS" , Seq (child) )
100- case Coalesce (children) => generateExpressionWithName(" COALESCE" , children )
101- case Greatest (children) => generateExpressionWithName(" GREATEST" , children )
102- case Least (children) => generateExpressionWithName(" LEAST" , children )
103- case Rand (child , hideSeed) =>
99+ case Abs (_ , true ) => generateExpressionWithName(" ABS" , expr, isPredicate )
100+ case _ : Coalesce => generateExpressionWithName(" COALESCE" , expr, isPredicate )
101+ case _ : Greatest => generateExpressionWithName(" GREATEST" , expr, isPredicate )
102+ case _ : Least => generateExpressionWithName(" LEAST" , expr, isPredicate )
103+ case Rand (_ , hideSeed) =>
104104 if (hideSeed) {
105105 Some (new GeneralScalarExpression (" RAND" , Array .empty[V2Expression ]))
106106 } else {
107- generateExpressionWithName(" RAND" , Seq (child) )
107+ generateExpressionWithName(" RAND" , expr, isPredicate )
108108 }
109- case log : Logarithm => generateExpressionWithName(" LOG" , log.children )
110- case Log10 (child) => generateExpressionWithName(" LOG10" , Seq (child) )
111- case Log2 (child) => generateExpressionWithName(" LOG2" , Seq (child) )
112- case Log (child) => generateExpressionWithName(" LN" , Seq (child) )
113- case Exp (child) => generateExpressionWithName(" EXP" , Seq (child) )
114- case pow : Pow => generateExpressionWithName(" POWER" , pow.children )
115- case Sqrt (child) => generateExpressionWithName(" SQRT" , Seq (child) )
116- case Floor (child) => generateExpressionWithName(" FLOOR" , Seq (child) )
117- case Ceil (child) => generateExpressionWithName(" CEIL" , Seq (child) )
118- case round : Round => generateExpressionWithName(" ROUND" , round.children )
119- case Sin (child) => generateExpressionWithName(" SIN" , Seq (child) )
120- case Sinh (child) => generateExpressionWithName(" SINH" , Seq (child) )
121- case Cos (child) => generateExpressionWithName(" COS" , Seq (child) )
122- case Cosh (child) => generateExpressionWithName(" COSH" , Seq (child) )
123- case Tan (child) => generateExpressionWithName(" TAN" , Seq (child) )
124- case Tanh (child) => generateExpressionWithName(" TANH" , Seq (child) )
125- case Cot (child) => generateExpressionWithName(" COT" , Seq (child) )
126- case Asin (child) => generateExpressionWithName(" ASIN" , Seq (child) )
127- case Asinh (child) => generateExpressionWithName(" ASINH" , Seq (child) )
128- case Acos (child) => generateExpressionWithName(" ACOS" , Seq (child) )
129- case Acosh (child) => generateExpressionWithName(" ACOSH" , Seq (child) )
130- case Atan (child) => generateExpressionWithName(" ATAN" , Seq (child) )
131- case Atanh (child) => generateExpressionWithName(" ATANH" , Seq (child) )
132- case atan2 : Atan2 => generateExpressionWithName(" ATAN2" , atan2.children )
133- case Cbrt (child) => generateExpressionWithName(" CBRT" , Seq (child) )
134- case ToDegrees (child) => generateExpressionWithName(" DEGREES" , Seq (child) )
135- case ToRadians (child) => generateExpressionWithName(" RADIANS" , Seq (child) )
136- case Signum (child) => generateExpressionWithName(" SIGN" , Seq (child) )
137- case wb : WidthBucket => generateExpressionWithName(" WIDTH_BUCKET" , wb.children )
109+ case _ : Logarithm => generateExpressionWithName(" LOG" , expr, isPredicate )
110+ case _ : Log10 => generateExpressionWithName(" LOG10" , expr, isPredicate )
111+ case _ : Log2 => generateExpressionWithName(" LOG2" , expr, isPredicate )
112+ case _ : Log => generateExpressionWithName(" LN" , expr, isPredicate )
113+ case _ : Exp => generateExpressionWithName(" EXP" , expr, isPredicate )
114+ case _ : Pow => generateExpressionWithName(" POWER" , expr, isPredicate )
115+ case _ : Sqrt => generateExpressionWithName(" SQRT" , expr, isPredicate )
116+ case _ : Floor => generateExpressionWithName(" FLOOR" , expr, isPredicate )
117+ case _ : Ceil => generateExpressionWithName(" CEIL" , expr, isPredicate )
118+ case _ : Round => generateExpressionWithName(" ROUND" , expr, isPredicate )
119+ case _ : Sin => generateExpressionWithName(" SIN" , expr, isPredicate )
120+ case _ : Sinh => generateExpressionWithName(" SINH" , expr, isPredicate )
121+ case _ : Cos => generateExpressionWithName(" COS" , expr, isPredicate )
122+ case _ : Cosh => generateExpressionWithName(" COSH" , expr, isPredicate )
123+ case _ : Tan => generateExpressionWithName(" TAN" , expr, isPredicate )
124+ case _ : Tanh => generateExpressionWithName(" TANH" , expr, isPredicate )
125+ case _ : Cot => generateExpressionWithName(" COT" , expr, isPredicate )
126+ case _ : Asin => generateExpressionWithName(" ASIN" , expr, isPredicate )
127+ case _ : Asinh => generateExpressionWithName(" ASINH" , expr, isPredicate )
128+ case _ : Acos => generateExpressionWithName(" ACOS" , expr, isPredicate )
129+ case _ : Acosh => generateExpressionWithName(" ACOSH" , expr, isPredicate )
130+ case _ : Atan => generateExpressionWithName(" ATAN" , expr, isPredicate )
131+ case _ : Atanh => generateExpressionWithName(" ATANH" , expr, isPredicate )
132+ case _ : Atan2 => generateExpressionWithName(" ATAN2" , expr, isPredicate )
133+ case _ : Cbrt => generateExpressionWithName(" CBRT" , expr, isPredicate )
134+ case _ : ToDegrees => generateExpressionWithName(" DEGREES" , expr, isPredicate )
135+ case _ : ToRadians => generateExpressionWithName(" RADIANS" , expr, isPredicate )
136+ case _ : Signum => generateExpressionWithName(" SIGN" , expr, isPredicate )
137+ case _ : WidthBucket => generateExpressionWithName(" WIDTH_BUCKET" , expr, isPredicate )
138138 case and : And =>
139139 // AND expects predicate
140140 val l = generateExpression(and.left, true )
@@ -185,57 +185,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
185185 assert(v.isInstanceOf [V2Predicate ])
186186 new V2Not (v.asInstanceOf [V2Predicate ])
187187 }
188- case UnaryMinus (child , true ) => generateExpressionWithName(" -" , Seq (child) )
189- case BitwiseNot (child) => generateExpressionWithName(" ~" , Seq (child) )
190- case CaseWhen (branches, elseValue) =>
188+ case UnaryMinus (_ , true ) => generateExpressionWithName(" -" , expr, isPredicate )
189+ case _ : BitwiseNot => generateExpressionWithName(" ~" , expr, isPredicate )
190+ case caseWhen @ CaseWhen (branches, elseValue) =>
191191 val conditions = branches.map(_._1).flatMap(generateExpression(_, true ))
192- val values = branches.map(_._2).flatMap(generateExpression(_, true ))
193- if (conditions.length == branches.length && values.length == branches.length) {
192+ val values = branches.map(_._2).flatMap(generateExpression(_))
193+ val elseExprOpt = elseValue.flatMap(generateExpression(_))
194+ if (conditions.length == branches.length && values.length == branches.length &&
195+ elseExprOpt.size == elseValue.size) {
194196 val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
195197 Seq [V2Expression ](c, v)
196198 }
197- if (elseValue.isDefined) {
198- elseValue.flatMap(generateExpression(_)).map { v =>
199- val children = (branchExpressions :+ v).toArray[V2Expression ]
200- // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue]
201- new V2Predicate (" CASE_WHEN" , children)
202- }
199+ val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression ]
200+ // The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)]
201+ if (isPredicate && caseWhen.dataType.isInstanceOf [BooleanType ]) {
202+ Some (new V2Predicate (" CASE_WHEN" , children))
203203 } else {
204- // The children looks like [condition1, value1, ..., conditionN, valueN]
205- Some (new V2Predicate (" CASE_WHEN" , branchExpressions.toArray[V2Expression ]))
204+ Some (new GeneralScalarExpression (" CASE_WHEN" , children))
206205 }
207206 } else {
208207 None
209208 }
210- case iff : If => generateExpressionWithName(" CASE_WHEN" , iff.children )
209+ case _ : If => generateExpressionWithName(" CASE_WHEN" , expr, isPredicate )
211210 case substring : Substring =>
212211 val children = if (substring.len == Literal (Integer .MAX_VALUE )) {
213212 Seq (substring.str, substring.pos)
214213 } else {
215214 substring.children
216215 }
217- generateExpressionWithName (" SUBSTRING" , children)
218- case Upper (child) => generateExpressionWithName(" UPPER" , Seq (child) )
219- case Lower (child) => generateExpressionWithName(" LOWER" , Seq (child) )
216+ generateExpressionWithNameByChildren (" SUBSTRING" , children, substring.dataType, isPredicate )
217+ case _ : Upper => generateExpressionWithName(" UPPER" , expr, isPredicate )
218+ case _ : Lower => generateExpressionWithName(" LOWER" , expr, isPredicate )
220219 case BitLength (child) if child.dataType.isInstanceOf [StringType ] =>
221- generateExpressionWithName(" BIT_LENGTH" , Seq (child) )
220+ generateExpressionWithName(" BIT_LENGTH" , expr, isPredicate )
222221 case Length (child) if child.dataType.isInstanceOf [StringType ] =>
223- generateExpressionWithName(" CHAR_LENGTH" , Seq (child) )
224- case concat : Concat => generateExpressionWithName(" CONCAT" , concat.children )
225- case translate : StringTranslate => generateExpressionWithName(" TRANSLATE" , translate.children )
226- case trim : StringTrim => generateExpressionWithName(" TRIM" , trim.children )
227- case trim : StringTrimLeft => generateExpressionWithName(" LTRIM" , trim.children )
228- case trim : StringTrimRight => generateExpressionWithName(" RTRIM" , trim.children )
222+ generateExpressionWithName(" CHAR_LENGTH" , expr, isPredicate )
223+ case _ : Concat => generateExpressionWithName(" CONCAT" , expr, isPredicate )
224+ case _ : StringTranslate => generateExpressionWithName(" TRANSLATE" , expr, isPredicate )
225+ case _ : StringTrim => generateExpressionWithName(" TRIM" , expr, isPredicate )
226+ case _ : StringTrimLeft => generateExpressionWithName(" LTRIM" , expr, isPredicate )
227+ case _ : StringTrimRight => generateExpressionWithName(" RTRIM" , expr, isPredicate )
229228 case overlay : Overlay =>
230229 val children = if (overlay.len == Literal (- 1 )) {
231230 Seq (overlay.input, overlay.replace, overlay.pos)
232231 } else {
233232 overlay.children
234233 }
235- generateExpressionWithName (" OVERLAY" , children)
236- case date : DateAdd => generateExpressionWithName(" DATE_ADD" , date.children )
237- case date : DateDiff => generateExpressionWithName(" DATE_DIFF" , date.children )
238- case date : TruncDate => generateExpressionWithName(" TRUNC" , date.children )
234+ generateExpressionWithNameByChildren (" OVERLAY" , children, overlay.dataType, isPredicate )
235+ case _ : DateAdd => generateExpressionWithName(" DATE_ADD" , expr, isPredicate )
236+ case _ : DateDiff => generateExpressionWithName(" DATE_DIFF" , expr, isPredicate )
237+ case _ : TruncDate => generateExpressionWithName(" TRUNC" , expr, isPredicate )
239238 case Second (child, _) =>
240239 generateExpression(child).map(v => new V2Extract (" SECOND" , v))
241240 case Minute (child, _) =>
@@ -268,12 +267,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
268267 generateExpression(child).map(v => new V2Extract (" WEEK" , v))
269268 case YearOfWeek (child) =>
270269 generateExpression(child).map(v => new V2Extract (" YEAR_OF_WEEK" , v))
271- case encrypt : AesEncrypt => generateExpressionWithName(" AES_ENCRYPT" , encrypt.children )
272- case decrypt : AesDecrypt => generateExpressionWithName(" AES_DECRYPT" , decrypt.children )
273- case Crc32 (child) => generateExpressionWithName(" CRC32" , Seq (child) )
274- case Md5 (child) => generateExpressionWithName(" MD5" , Seq (child) )
275- case Sha1 (child) => generateExpressionWithName(" SHA1" , Seq (child) )
276- case sha2 : Sha2 => generateExpressionWithName(" SHA2" , sha2.children )
270+ case _ : AesEncrypt => generateExpressionWithName(" AES_ENCRYPT" , expr, isPredicate )
271+ case _ : AesDecrypt => generateExpressionWithName(" AES_DECRYPT" , expr, isPredicate )
272+ case _ : Crc32 => generateExpressionWithName(" CRC32" , expr, isPredicate )
273+ case _ : Md5 => generateExpressionWithName(" MD5" , expr, isPredicate )
274+ case _ : Sha1 => generateExpressionWithName(" SHA1" , expr, isPredicate )
275+ case _ : Sha2 => generateExpressionWithName(" SHA2" , expr, isPredicate )
277276 // TODO supports other expressions
278277 case ApplyFunctionExpression (function, children) =>
279278 val childrenExpressions = children.flatMap(generateExpression(_))
@@ -345,10 +344,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
345344 }
346345
347346 private def generateExpressionWithName (
348- v2ExpressionName : String , children : Seq [Expression ]): Option [V2Expression ] = {
347+ v2ExpressionName : String ,
348+ expr : Expression ,
349+ isPredicate : Boolean ): Option [V2Expression ] = {
350+ generateExpressionWithNameByChildren(
351+ v2ExpressionName, expr.children, expr.dataType, isPredicate)
352+ }
353+
354+ private def generateExpressionWithNameByChildren (
355+ v2ExpressionName : String ,
356+ children : Seq [Expression ],
357+ dataType : DataType ,
358+ isPredicate : Boolean ): Option [V2Expression ] = {
349359 val childrenExpressions = children.flatMap(generateExpression(_))
350360 if (childrenExpressions.length == children.length) {
351- Some (new GeneralScalarExpression (v2ExpressionName, childrenExpressions.toArray[V2Expression ]))
361+ if (isPredicate && dataType.isInstanceOf [BooleanType ]) {
362+ Some (new V2Predicate (v2ExpressionName, childrenExpressions.toArray[V2Expression ]))
363+ } else {
364+ Some (new GeneralScalarExpression (
365+ v2ExpressionName, childrenExpressions.toArray[V2Expression ]))
366+ }
352367 } else {
353368 None
354369 }
0 commit comments