From 0903f018c3adbc9ac0dcd79a4397279b42cc8549 Mon Sep 17 00:00:00 2001
From: "biaobiao.sun" <1319027852@qq.com>
Date: Sat, 6 Aug 2022 17:24:32 +0800
Subject: [PATCH] [SPARK-39929][SQL] DS V2 supports push down string
functions(non ANSI)
---
.../expressions/GeneralScalarExpression.java | 18 ++++++++++++++++++
.../connector/util/V2ExpressionSQLBuilder.java | 3 +++
.../catalyst/util/V2ExpressionBuilder.scala | 7 ++++++-
.../org/apache/spark/sql/jdbc/H2Dialect.scala | 3 ++-
.../apache/spark/sql/jdbc/JDBCV2Suite.scala | 16 ++++++++++++++++
5 files changed, 45 insertions(+), 2 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index 9ef0d481bc9c6..8339e341b8ebd 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -340,6 +340,24 @@
*
Since version: 3.4.0
*
*
+ * Name: BIT_LENGTH
+ *
+ * - SQL semantic:
BIT_LENGTH(src)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: CHAR_LENGTH
+ *
+ * - SQL semantic:
CHAR_LENGTH(src)
+ * - Since version: 3.4.0
+ *
+ *
+ * Name: CONCAT
+ *
+ * - SQL semantic:
CONCAT(col1, col2, ..., colN)
+ * - Since version: 3.4.0
+ *
+ *
* Name: OVERLAY
*
* - SQL semantic:
OVERLAY(string, replace, position[, length])
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 3a78a946e36a8..4fa132ccfd1c6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -155,6 +155,9 @@ public String build(Expression expr) {
case "SHA2":
case "MD5":
case "CRC32":
+ case "BIT_LENGTH":
+ case "CHAR_LENGTH":
+ case "CONCAT":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "CASE_WHEN": {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 89115bf7ab51a..18101ec3df8ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression =>
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
-import org.apache.spark.sql.types.{BooleanType, IntegerType}
+import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
/**
* The builder to generate V2 expressions from catalyst expressions.
@@ -217,6 +217,11 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
generateExpressionWithName("SUBSTRING", children)
case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
+ case BitLength(child) if child.dataType.isInstanceOf[StringType] =>
+ generateExpressionWithName("BIT_LENGTH", Seq(child))
+ case Length(child) if child.dataType.isInstanceOf[StringType] =>
+ generateExpressionWithName("CHAR_LENGTH", Seq(child))
+ case concat: Concat => generateExpressionWithName("CONCAT", concat.children)
case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children)
case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 737e3de10a925..7665bb91c6ee4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -50,7 +50,8 @@ private[sql] object H2Dialect extends JdbcDialect {
Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP",
"POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN",
"TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN",
- "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM", "MD5", "SHA1", "SHA2")
+ "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM", "MD5", "SHA1", "SHA2",
+ "BIT_LENGTH", "CHAR_LENGTH", "CONCAT")
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 02dff0973fe12..37ef58ea04d7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -1449,6 +1449,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
"PushedFilters: [NAME IS NOT NULL]"
checkPushedInfo(df5, expectedPlanFragment5)
checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ val df6 = sql("SELECT * FROM h2.test.employee WHERE bit_length(name) = 40")
+ checkFiltersRemoved(df6)
+ checkPushedInfo(df6, "[NAME IS NOT NULL, BIT_LENGTH(NAME) = 40]")
+ checkAnswer(df6, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true)))
+
+ val df7 = sql("SELECT * FROM h2.test.employee WHERE char_length(name) = 5")
+ checkFiltersRemoved(df7)
+ checkPushedInfo(df7, "[NAME IS NOT NULL, CHAR_LENGTH(NAME) = 5]")
+ checkAnswer(df6, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true)))
+
+ val df8 = sql("SELECT * FROM h2.test.employee WHERE " +
+ "concat(name, ',' , cast(salary as string)) = 'cathy,9000.00'")
+ checkFiltersRemoved(df8)
+ checkPushedInfo(df8, "[(CONCAT(NAME, ',', CAST(SALARY AS string))) = 'cathy,9000.00']")
+ checkAnswer(df8, Seq(Row(1, "cathy", 9000, 1200, false)))
}
test("scan with aggregate push-down: MAX AVG with filter and group by") {