From b3c0720596d17eeda7b9a4a0ab2090d617a7086d Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 3 Aug 2022 11:28:45 +0800 Subject: [PATCH 1/2] [SPARK-39961][SQL] DS V2 push-down translate Cast if the cast is safe --- .../catalyst/util/V2ExpressionBuilder.scala | 3 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 45 ++++++++----------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 414155537290a..b04bcc6966ab8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -88,7 +88,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case Cast(child, dataType, _, true) => + case Cast(child, dataType, _, ansiEnabled) + if Cast.canUpCast(child.dataType, dataType) || ansiEnabled => generateExpression(child).map(v => new V2Cast(v, dataType)) case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3b226d606430c..da4f9175cd56d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -1109,7 +1109,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " + "CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]" } else { - "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL]," + "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%']" } checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true))) @@ -1199,18 +1199,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],") checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - val df2 = sql( - """ - |SELECT * - |FROM h2.test.people - |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 """.stripMargin) - checkFiltersRemoved(df2) - checkPushedInfo(df2, - "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") - checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) - } + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) JdbcDialects.registerDialect(H2Dialect) @@ -2262,24 +2260,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: partial push-down AVG with overflow") { - def createDataFrame: DataFrame = spark.read - .option("partitionColumn", "id") - .option("lowerBound", "0") - .option("upperBound", "2") - .option("numPartitions", "2") - .table("h2.test.item") - .agg(avg($"PRICE").as("avg")) - Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { - val df = createDataFrame + val df = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) checkAggregateRemoved(df, false) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(PRICE), SUM(PRICE)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [COUNT(PRICE), SUM(PRICE)]") if (ansiEnabled) { val e = intercept[SparkException] { df.collect() From feb2a131a14b91359ea9d95cdbed548b48d3aabf Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Aug 2022 08:33:27 +0800 Subject: [PATCH 2/2] Update code --- .../apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index b04bcc6966ab8..d451c73b39d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -89,7 +89,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { None } case Cast(child, dataType, _, ansiEnabled) - if Cast.canUpCast(child.dataType, dataType) || ansiEnabled => + if ansiEnabled || Cast.canUpCast(child.dataType, dataType) => generateExpression(child).map(v => new V2Cast(v, dataType)) case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children)