From 6fec9fc480a2b4a9c3a1bbbde0c08b54204c5545 Mon Sep 17 00:00:00 2001 From: chenzhx Date: Wed, 13 Jul 2022 10:43:20 +0800 Subject: [PATCH] [SPARK-38901][SQL] DS V2 supports push down misc functions --- .../expressions/GeneralScalarExpression.java | 36 ++++++++++ .../util/V2ExpressionSQLBuilder.java | 6 ++ .../catalyst/util/V2ExpressionBuilder.scala | 6 ++ .../org/apache/spark/sql/jdbc/H2Dialect.scala | 19 ++++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 71 ++++++++++++++++--- 5 files changed, 129 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 53c511a87f691..9ef0d481bc9c6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -364,6 +364,42 @@ *
  • Since version: 3.4.0
  • * * + *
  • Name: AES_ENCRYPT + * + *
  • + *
  • Name: AES_DECRYPT + * + *
  • + *
  • Name: SHA1 + * + *
  • + *
  • Name: SHA2 + * + *
  • + *
  • Name: MD5 + * + *
  • + *
  • Name: CRC32 + * + *
  • * * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 541b88a5027d1..3a78a946e36a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -149,6 +149,12 @@ public String build(Expression expr) { case "DATE_ADD": case "DATE_DIFF": case "TRUNC": + case "AES_ENCRYPT": + case "AES_DECRYPT": + case "SHA1": + case "SHA2": + case "MD5": + case "CRC32": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 414155537290a..a029c002d0d09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -257,6 +257,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Extract("WEEK", v)) case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) + case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children) + case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children) + case Crc32(child) => generateExpressionWithName("CRC32", Seq(child)) + case Md5(child) => generateExpressionWithName("MD5", Seq(child)) + case Sha1(child) => generateExpressionWithName("SHA1", Seq(child)) + case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children) // TODO supports other expressions case ApplyFunctionExpression(function, children) => val childrenExpressions = children.flatMap(generateExpression(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 4200ba91fb1b6..737e3de10a925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -50,7 +50,7 @@ private[sql] object H2Dialect extends JdbcDialect { Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN", "TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN", - "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM") + "PI", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM", "MD5", "SHA1", "SHA2") override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) @@ -235,5 +235,22 @@ private[sql] object H2Dialect extends JdbcDialect { } s"EXTRACT($newField FROM $source)" } + + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + if (isSupportedFunction(funcName)) { + funcName match { + case "MD5" => + "RAWTOHEX(HASH('MD5', " + inputs.mkString(",") + "))" + case "SHA1" => + "RAWTOHEX(HASH('SHA-1', " + inputs.mkString(",") + "))" + case "SHA2" => + "RAWTOHEX(HASH('SHA-" + inputs(1) + "'," + inputs(0) + "))" + case _ => super.visitSQLFunction(funcName, inputs) + } + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support function: $funcName"); + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3b226d606430c..d5255fa1c59f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -45,6 +45,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val tempDir = Utils.createTempDir() val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" + val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++ + Array.fill(15)(0.toByte) val testH2Dialect = new JdbcDialect { override def canHandle(url: String): Boolean = H2Dialect.canHandle(url) @@ -178,6 +180,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"datetime\" VALUES " + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() + + conn.prepareStatement("CREATE TABLE \"test\".\"binary1\" (name TEXT(32),b BINARY(20))") + .executeUpdate() + val stmt = conn.prepareStatement("INSERT INTO \"test\".\"binary1\" VALUES (?, ?)") + stmt.setString(1, "jen") + stmt.setBytes(2, testBytes) + stmt.executeUpdate() } H2Dialect.registerFunction("my_avg", IntegralAverage) H2Dialect.registerFunction("my_strlen", StrLen(CharLength)) @@ -860,7 +869,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df2) checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1]", - "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1") + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1") checkAnswer(df2, Seq(Row(2, "david", 10000.00))) } @@ -1190,6 +1199,52 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df8, Seq(Row("alex"))) } + test("scan with filter push-down with misc functions") { + val df1 = sql("SELECT name FROM h2.test.binary1 WHERE " + + "md5(b) = '4371fe0aa613bcb081543a37d241adcb'") + checkFiltersRemoved(df1) + val expectedPlanFragment1 = "PushedFilters: [B IS NOT NULL, " + + "MD5(B) = '4371fe0aa613bcb081543a37d241adcb']" + checkPushedInfo(df1, expectedPlanFragment1) + checkAnswer(df1, Seq(Row("jen"))) + + val df2 = sql("SELECT name FROM h2.test.binary1 WHERE " + + "sha1(b) = 'cf355e86e8666f9300ef12e996acd5c629e0b0a1'") + checkFiltersRemoved(df2) + val expectedPlanFragment2 = "PushedFilters: [B IS NOT NULL, " + + "SHA1(B) = 'cf355e86e8666f9300ef12e996acd5c629e0b0a1']," + checkPushedInfo(df2, expectedPlanFragment2) + checkAnswer(df2, Seq(Row("jen"))) + + val df3 = sql("SELECT name FROM h2.test.binary1 WHERE " + + "sha2(b, 256) = '911732d10153f859dec04627df38b19290ec707ff9f83910d061421fdc476109'") + checkFiltersRemoved(df3) + val expectedPlanFragment3 = "PushedFilters: [B IS NOT NULL, (SHA2(B, 256)) = " + + "'911732d10153f859dec04627df38b19290ec707ff9f83910d061421fdc476109']" + checkPushedInfo(df3, expectedPlanFragment3) + checkAnswer(df3, Seq(Row("jen"))) + + val df4 = sql("SELECT * FROM h2.test.employee WHERE crc32(name) = '142689369'") + checkFiltersRemoved(df4, false) + val expectedPlanFragment4 = "PushedFilters: [NAME IS NOT NULL], " + checkPushedInfo(df4, expectedPlanFragment4) + checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true))) + + val df5 = sql("SELECT name FROM h2.test.employee WHERE " + + "aes_encrypt(cast(null as string), name) is null") + checkFiltersRemoved(df5, false) + val expectedPlanFragment5 = "PushedFilters: [], " + checkPushedInfo(df5, expectedPlanFragment5) + checkAnswer(df5, Seq(Row("amy"), Row("cathy"), Row("alex"), Row("david"), Row("jen"))) + + val df6 = sql("SELECT name FROM h2.test.employee WHERE " + + "aes_decrypt(cast(null as binary), name) is null") + checkFiltersRemoved(df6, false) + val expectedPlanFragment6 = "PushedFilters: [], " + checkPushedInfo(df6, expectedPlanFragment6) + checkAnswer(df6, Seq(Row("amy"), Row("cathy"), Row("alex"), Row("david"), Row("jen"))) + } + test("scan with filter push-down with UDF") { JdbcDialects.unregisterDialect(H2Dialect) try { @@ -1269,7 +1324,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Seq(Row("test", "people", false), Row("test", "empty_table", false), Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false), - Row("test", "datetime", false))) + Row("test", "datetime", false), Row("test", "binary1", false))) } test("SQL API: create table as select") { @@ -1819,12 +1874,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, - """ - |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS), - |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)], - |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], - |PushedGroupByExpressions: [DEPT], - |""".stripMargin.replaceAll("\n", " ")) + """ + |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS), + |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df, Seq(Row(10000d, 10000d, 20000d, 20000d), Row(2500d, 2500d, 5000d, 5000d), Row(0d, 0d, null, null))) }