diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 414155537290a..d451c73b39d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -88,7 +88,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case Cast(child, dataType, _, true) => + case Cast(child, dataType, _, ansiEnabled) + if ansiEnabled || Cast.canUpCast(child.dataType, dataType) => generateExpression(child).map(v => new V2Cast(v, dataType)) case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3b226d606430c..da4f9175cd56d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -1109,7 +1109,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " + "CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]" } else { - "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL]," + "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%']" } checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true))) @@ -1199,18 +1199,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],") checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - val df2 = sql( - """ - |SELECT * - |FROM h2.test.people - |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 """.stripMargin) - checkFiltersRemoved(df2) - checkPushedInfo(df2, - "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") - checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) - } + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) } finally { JdbcDialects.unregisterDialect(testH2Dialect) JdbcDialects.registerDialect(H2Dialect) @@ -2262,24 +2260,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: partial push-down AVG with overflow") { - def createDataFrame: DataFrame = spark.read - .option("partitionColumn", "id") - .option("lowerBound", "0") - .option("upperBound", "2") - .option("numPartitions", "2") - .table("h2.test.item") - .agg(avg($"PRICE").as("avg")) - Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { - val df = createDataFrame + val df = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) checkAggregateRemoved(df, false) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(PRICE), SUM(PRICE)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [COUNT(PRICE), SUM(PRICE)]") if (ansiEnabled) { val e = intercept[SparkException] { df.collect()