From 149aa35f1451d0cca62f3157b5d79d463d134af3 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Thu, 14 Jul 2022 14:23:54 -0700 Subject: [PATCH] Literal values should be on the right side of the data source filter --- .../datasources/v2/DataSourceV2Strategy.scala | 24 ++++++- .../v2/DataSourceV2StrategySuite.scala | 67 ++++++++++++++++++- 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 16c6b331d109..036b6e4862d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDelete, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} @@ -502,11 +502,31 @@ private[sql] object DataSourceV2Strategy { private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = { predicate match { - case PushablePredicate(expr) => Some(expr) + case PushablePredicate(expr) => + if (expr.children().length == 2) { + expr.children()(0) match { + case LiteralValue(_, _) => + Some(new Predicate(flipComparisonFilterName(expr.name()), + Array(expr.children()(1), expr.children()(0)))) + case _ => Some(expr) + } + } else { + Some(expr) + } case _ => None } } + private def flipComparisonFilterName(filterName: String): String = { + filterName match { + case ">" => "<" + case "<" => ">" + case ">=" => "<=" + case "<=" => ">=" + case _ => filterName + } + } + /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 66dc65cf6813..f5563122e7fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -18,14 +18,77 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType} class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { + val attrInts = Seq( + $"cint".int, + $"c.int".int, + GetStructField($"a".struct(StructType( + StructField("cstr", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), + GetStructField($"a".struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), + GetStructField($"a.b".struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: + StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), + GetStructField($"a.b".struct(StructType( + StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), + GetStructField(GetStructField($"a".struct(StructType( + StructField("cstr1", StringType, nullable = true) :: + StructField("b", StructType(StructField("cint", IntegerType, nullable = true) :: + StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) + ).zip(Seq( + "cint", + "`c.int`", // single level field that contains `dot` in name + "a.cint", // two level nested field + "a.`c.int`", // two level nested field, and nested level contains `dot` + "`a.b`.cint", // two level nested field, and top level contains `dot` + "`a.b`.`c.int`", // two level nested field, and both levels contain `dot` + "a.b.cint" // three level nested field + )) + + test("translate simple expression") { attrInts + .foreach { case (attrInt, intColName) => + testTranslateFilter(EqualTo(attrInt, 1), + Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(EqualTo(1, attrInt), + Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(EqualNullSafe(attrInt, 1), + Some(new Predicate("<=>", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(EqualNullSafe(1, attrInt), + Some(new Predicate("<=>", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(GreaterThan(attrInt, 1), + Some(new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(GreaterThan(1, attrInt), + Some(new Predicate("<", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(LessThan(attrInt, 1), + Some(new Predicate("<", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(LessThan(1, attrInt), + Some(new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), + Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), + Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + + testTranslateFilter(LessThanOrEqual(attrInt, 1), + Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + testTranslateFilter(LessThanOrEqual(1, attrInt), + Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))))) + } + } + test("SPARK-36644: Push down boolean column filter") { testTranslateFilter($"col".boolean, Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType)))))