diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 124cb001b5cf..5dfc64d7b6c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -62,18 +62,15 @@ private[sql] object H2Dialect extends JdbcDialect { assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && !f.isDistinct => assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && !f.isDistinct => assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => + Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && !f.isDistinct => assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.children().head}, ${f.children().last})") + Some(s"CORR(${f.children().head}, ${f.children().last})") case _ => None } ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 108348fbcd3a..0a713bdb76ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -1652,23 +1652,37 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { - val df = sql("SELECT COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + + val df1 = sql("SELECT COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") - checkFiltersRemoved(df) - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + checkAnswer(df1, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + + val df2 = sql("SELECT COVAR_POP(DISTINCT bonus, bonus), COVAR_SAMP(DISTINCT bonus, bonus)" + + " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") + checkAnswer(df2, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } test("scan with aggregate push-down: CORR with filter and group by") { - val df = sql("SELECT CORR(bonus, bonus) FROM h2.test.employee WHERE dept > 0" + + val df1 = sql("SELECT CORR(bonus, bonus) FROM h2.test.employee WHERE dept > 0" + " GROUP BY DePt") - checkFiltersRemoved(df) - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [CORR(BONUS, BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) + checkAnswer(df1, Seq(Row(1d), Row(1d), Row(null))) + + val df2 = sql("SELECT CORR(DISTINCT bonus, bonus) FROM h2.test.employee WHERE dept > 0" + + " GROUP BY DePt") + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") + checkAnswer(df2, Seq(Row(1d), Row(1d), Row(null))) } test("scan with aggregate push-down: aggregate over alias push down") {