diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDelete.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDelete.java index 80aa57ca1877..af08eeb28059 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDelete.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDelete.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.catalog; import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.v2.FilterV2; /** * A mix-in interface for {@link Table} delete support. Data sources can implement this @@ -41,5 +41,5 @@ public interface SupportsDelete { * @param filters filter expressions, used to select rows to delete when all expressions match * @throws IllegalArgumentException If the delete is rejected due to required effort */ - void deleteWhere(Filter[] filters); + void deleteWhere(FilterV2[] filters); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownFilters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownFilters.java index bee9e5508ca6..ba9605948416 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownFilters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownFilters.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.v2.FilterV2; /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to @@ -33,10 +33,10 @@ public interface SupportsPushDownFilters extends ScanBuilder { * Rows should be returned from the data source if and only if all of the filters match. That is, * filters must be interpreted as ANDed together. */ - Filter[] pushFilters(Filter[] filters); + FilterV2[] pushFilters(FilterV2[] filters); /** - * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(FilterV2[])}. * * There are 3 kinds of filters: * 1. pushable filters which don't need to be evaluated again after scanning. @@ -45,8 +45,8 @@ public interface SupportsPushDownFilters extends ScanBuilder { * 3. non-pushable filters. * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. * - * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * It's possible that there is no filters in the query and {@link #pushFilters(FilterV2[])} * is never called, empty array should be returned for this case. */ - Filter[] pushedFilters(); + FilterV2[] pushedFilters(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java index 6063a155ee20..82775605b9f6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.connector.write; -import org.apache.spark.sql.sources.AlwaysTrue$; -import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.v2.AlwaysTrue$; +import org.apache.spark.sql.sources.v2.FilterV2; /** * Write builder trait for tables that support overwrite by filter. @@ -36,10 +36,10 @@ public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { * @param filters filters used to match data to overwrite * @return this write builder for method chaining */ - WriteBuilder overwrite(Filter[] filters); + WriteBuilder overwrite(FilterV2[] filters); @Override default WriteBuilder truncate() { - return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); + return overwrite(new FilterV2[] { AlwaysTrue$.MODULE$ }); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/FilterV2.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/FilterV2.java new file mode 100644 index 000000000000..be8fff6f09d1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/FilterV2.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; + +@Experimental +public abstract class FilterV2 { + /** + * Returns list of columns that are referenced by this filter. + */ + public abstract NamedReference[] references(); + + protected NamedReference[] findReferences(Object valve) { + if (valve instanceof FilterV2) { + return ((FilterV2) valve).references(); + } else { + return new NamedReference[0]; + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index a1ab55a7185c..6e96313f1de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.sources.v2.FilterV2 //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -31,7 +34,7 @@ import org.apache.spark.annotation.{Evolving, Stable} @Stable abstract class Filter { /** - * List of columns that are referenced by this filter. + * List of top-level columns that are referenced by this filter. * @since 2.1.0 */ def references: Array[String] @@ -40,6 +43,17 @@ abstract class Filter { case f: Filter => f.references case _ => Array.empty } + + private[sql] def toV2: FilterV2 + + private[sql] def attToRef(attribute: String): NamedReference = { + FieldReference(Seq(attribute)) + } + + private[sql] def toV2Value(value: Any): Any = value match { + case f: Filter => f.toV2 + case _ => value + } } /** @@ -51,6 +65,8 @@ abstract class Filter { @Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + override private[sql] def toV2 = v2.EqualTo(attToRef(attribute), toV2Value(value)) } /** @@ -63,6 +79,8 @@ case class EqualTo(attribute: String, value: Any) extends Filter { @Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + override private[sql] def toV2 = v2.EqualNullSafe(attToRef(attribute), toV2Value(value)) } /** @@ -74,6 +92,8 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { @Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + override private[sql] def toV2 = v2.GreaterThan(attToRef(attribute), toV2Value(value)) } /** @@ -85,6 +105,9 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { @Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + override private[sql] def toV2 = + v2.GreaterThanOrEqual(attToRef(attribute), toV2Value(value)) } /** @@ -96,6 +119,8 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { @Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + override private[sql] def toV2 = v2.LessThan(attToRef(attribute), toV2Value(value)) } /** @@ -107,6 +132,9 @@ case class LessThan(attribute: String, value: Any) extends Filter { @Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + + override private[sql] def toV2 = + v2.LessThanOrEqual(attToRef(attribute), toV2Value(value)) } /** @@ -134,6 +162,8 @@ case class In(attribute: String, values: Array[Any]) extends Filter { } override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences) + + override private[sql] def toV2 = v2.In(attToRef(attribute), values.map(toV2Value)) } /** @@ -144,6 +174,8 @@ case class In(attribute: String, values: Array[Any]) extends Filter { @Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + + override private[sql] def toV2 = v2.IsNull(attToRef(attribute)) } /** @@ -154,6 +186,8 @@ case class IsNull(attribute: String) extends Filter { @Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + + override private[sql] def toV2 = v2.IsNotNull(attToRef(attribute)) } /** @@ -164,6 +198,8 @@ case class IsNotNull(attribute: String) extends Filter { @Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + + override private[sql] def toV2 = v2.And(left.toV2, right.toV2) } /** @@ -174,6 +210,8 @@ case class And(left: Filter, right: Filter) extends Filter { @Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + + override private[sql] def toV2 = v2.Or(left.toV2, right.toV2) } /** @@ -184,6 +222,8 @@ case class Or(left: Filter, right: Filter) extends Filter { @Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references + + override private[sql] def toV2 = v2.Not(child.toV2) } /** @@ -195,6 +235,8 @@ case class Not(child: Filter) extends Filter { @Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + + override private[sql] def toV2 = v2.StringStartsWith(attToRef(attribute), value) } /** @@ -206,6 +248,8 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { @Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + + override private[sql] def toV2 = v2.StringEndsWith(attToRef(attribute), value) } /** @@ -217,6 +261,8 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { @Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + + override private[sql] def toV2 = v2.StringContains(attToRef(attribute), value) } /** @@ -225,6 +271,8 @@ case class StringContains(attribute: String, value: String) extends Filter { @Evolving case class AlwaysTrue() extends Filter { override def references: Array[String] = Array.empty + + override private[sql] def toV2 = v2.AlwaysTrue } @Evolving @@ -237,6 +285,8 @@ object AlwaysTrue extends AlwaysTrue { @Evolving case class AlwaysFalse() extends Filter { override def references: Array[String] = Array.empty + + override private[sql] def toV2 = v2.AlwaysFalse } @Evolving diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/v2/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/v2/filters.scala new file mode 100644 index 000000000000..dfbc86edadfd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/v2/filters.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.connector.expressions.NamedReference + +/** + * A filter that evaluates to `true` iff the field evaluates to a value + * equal to `value`. + * + * @since 3.0.0 + */ +@Experimental +case class EqualTo(ref: NamedReference, value: Any) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) ++ findReferences(value) +} + +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 3.0.0 + */ +@Experimental +case class EqualNullSafe(ref: NamedReference, value: Any) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) ++ findReferences(value) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to a value + * greater than `value`. + * + * @since 3.0.0 + */ +@Experimental +case class GreaterThan(ref: NamedReference, value: Any) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) ++ findReferences(value) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to a value + * greater than or equal to `value`. + * + * @since 3.0.0 + */ +@Experimental +case class GreaterThanOrEqual(ref: NamedReference, value: Any) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) ++ findReferences(value) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to a value + * less than `value`. + * + * @since 3.0.0 + */ +@Experimental +case class LessThan(ref: NamedReference, value: Any) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) ++ findReferences(value) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to a value + * less than or equal to `value`. + * + * @since 3.0.0 + */ +@Experimental +case class LessThanOrEqual(ref: NamedReference, value: Any) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) ++ findReferences(value) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to one of the values in the array. + * + * @since 3.0.0 + */ +@Experimental +case class In(ref: NamedReference, values: Array[Any]) extends FilterV2 { + override def hashCode(): Int = { + var h = ref.hashCode + values.foreach { v => + h *= 41 + h += v.hashCode() + } + h + } + override def equals(o: Any): Boolean = o match { + case In(a, vs) => + a == ref && vs.length == values.length && vs.zip(values).forall(x => x._1 == x._2) + case _ => false + } + override def toString: String = { + s"In($ref, [${values.mkString(",")}])" + } + + override def references: Array[NamedReference] = Array(ref) ++ values.flatMap(findReferences) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to null. + * + * @since 3.0.0 + */ +@Experimental +case class IsNull(ref: NamedReference) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to a non-null value. + * + * @since 3.0.0 + */ +@Experimental +case class IsNotNull(ref: NamedReference) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) +} + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + * + * @since 3.0.0 + */ +@Experimental +case class And(left: FilterV2, right: FilterV2) extends FilterV2 { + override def references: Array[NamedReference] = left.references ++ right.references +} + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + * + * @since 3.0.0 + */ +@Experimental +case class Or(left: FilterV2, right: FilterV2) extends FilterV2 { + override def references: Array[NamedReference] = left.references ++ right.references +} + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + * + * @since 3.0.0 + */ +@Experimental +case class Not(child: FilterV2) extends FilterV2 { + override def references: Array[NamedReference] = child.references() +} + +/** + * A filter that evaluates to `true` iff the field evaluates to + * a string that starts with `value`. + * + * @since 3.0.0 + */ +@Experimental +case class StringStartsWith(ref: NamedReference, value: String) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to + * a string that ends with `value`. + * + * @since 3.0.0 + */ +@Experimental +case class StringEndsWith(ref: NamedReference, value: String) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) +} + +/** + * A filter that evaluates to `true` iff the field evaluates to + * a string that contains the string `value`. + * + * @since 3.0.0 + */ +@Experimental +case class StringContains(ref: NamedReference, value: String) extends FilterV2 { + override def references: Array[NamedReference] = Array(ref) +} + +/** + * A filter that always evaluates to `true`. + */ +@Experimental +case class AlwaysTrue() extends FilterV2 { + override def references: Array[NamedReference] = Array.empty +} + +@Experimental +object AlwaysTrue extends AlwaysTrue { +} + +/** + * A filter that always evaluates to `false`. + */ +@Experimental +case class AlwaysFalse() extends FilterV2 { + override def references: Array[NamedReference] = Array.empty +} + +@Experimental +object AlwaysFalse extends AlwaysFalse { +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 201860e5135b..5e6dbe26ba22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -27,9 +27,10 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ -import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} +import org.apache.spark.sql.sources.v2.{And, EqualTo, FilterV2, IsNotNull} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -59,8 +60,8 @@ class InMemoryTable( def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) - private val partIndexes = partFieldNames.map(schema.fieldIndex) + private val partFieldNames = partitioning.flatMap(_.references).toSeq + private val partIndexes = partFieldNames.flatMap(_.fieldNames).map(schema.fieldIndex) private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) @@ -107,7 +108,7 @@ class InMemoryTable( this } - override def overwrite(filters: Array[Filter]): WriteBuilder = { + override def overwrite(filters: Array[FilterV2]): WriteBuilder = { assert(writer == Append) writer = new Overwrite(filters) this @@ -145,7 +146,7 @@ class InMemoryTable( } } - private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + private class Overwrite(filters: Array[FilterV2]) extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) dataMap --= deleteKeys @@ -160,7 +161,7 @@ class InMemoryTable( } } - override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { + override def deleteWhere(filters: Array[FilterV2]): Unit = dataMap.synchronized { dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) } } @@ -170,8 +171,8 @@ object InMemoryTable { def filtersToKeys( keys: Iterable[Seq[Any]], - partitionNames: Seq[String], - filters: Array[Filter]): Iterable[Seq[Any]] = { + partitionNames: Seq[NamedReference], + filters: Array[FilterV2]): Iterable[Seq[Any]] = { keys.filter { partValues => filters.flatMap(splitAnd).forall { case EqualTo(attr, value) => @@ -185,8 +186,8 @@ object InMemoryTable { } private def extractValue( - attr: String, - partFieldNames: Seq[String], + attr: NamedReference, + partFieldNames: Seq[NamedReference], partValues: Seq[Any]): Any = { partFieldNames.zipWithIndex.find(_._1 == attr) match { case Some((_, partIndex)) => @@ -196,7 +197,7 @@ object InMemoryTable { } } - private def splitAnd(filter: Filter): Seq[Filter] = { + private def splitAnd(filter: FilterV2): Seq[FilterV2] = { filter match { case And(left, right) => splitAnd(left) ++ splitAnd(right) case _ => filter :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d44cb11e2876..7035f959de3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -433,6 +433,7 @@ object DataSourceStrategy { * is case insensitive. We should change attribute names to match the ones in the schema, * so we do not need to worry about case sensitivity anymore. */ + // TODO (DB Tsai): Handel the case sensitivity of nested columns protected[sql] def normalizeExprs( exprs: Seq[Expression], attributes: Seq[AttributeReference]): Seq[Expression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index 0b5658715377..9a619b8fb689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.spark.sql.sources.{And, Filter} +import org.apache.spark.sql.sources.v2.{And, FilterV2} import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} /** @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} */ trait OrcFiltersBase { - private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { + private[sql] def buildTree(filters: Seq[FilterV2]): Option[FilterV2] = { filters match { case Seq() => None case Seq(filter) => Some(filter) @@ -36,16 +36,6 @@ trait OrcFiltersBase { } } - // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters - // in order to distinguish predicate pushdown for nested columns. - protected[sql] def quoteAttributeNameIfNeeded(name: String) : String = { - if (!name.contains("`") && name.contains(".")) { - s"`$name`" - } else { - name - } - } - /** * Return true if this is a searchable type in ORC. * Both CharType and VarcharType are cleaned at AstBuilder. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 44de8f275fea..3ec11bd5199d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,22 +18,188 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{AlterNamespaceSetProperties, AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeNamespace, DescribeTable, DropNamespace, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RefreshTable, RenameTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowCurrentNamespace, ShowNamespaces, ShowTableProperties, ShowTables} import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability} +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.FilterV2 +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String object DataSourceV2Strategy extends Strategy with PredicateHelper { import DataSourceV2Implicits._ + /** + * Tries to convert an [[Expression]] that can be pushed down to a [[NamedReference]] + */ + private def toNamedRef(e: Expression): Option[NamedReference] = { + def helper(e: Expression): Option[Seq[String]] = e match { + case a: Attribute => + Some(Seq(a.name)) + case s: GetStructField => + helper(s.child).map(_ ++ Seq(s.childSchema(s.ordinal).name)) + case _ => + None + } + helper(e).map(fieldNames => FieldReference(fieldNames)) + } + + private def translateLeafNodeFilterV2(predicate: Expression): Option[FilterV2] = predicate match { + case expressions.EqualTo(e: Expression, Literal(v, t)) => + toNamedRef(e).map(field => sources.v2.EqualTo(field, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), e: Expression) => + toNamedRef(e).map(field => sources.v2.EqualTo(field, convertToScala(v, t))) + + case expressions.EqualNullSafe(e: Expression, Literal(v, t)) => + toNamedRef(e).map(field => sources.v2.EqualNullSafe(field, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), e: Expression) => + toNamedRef(e).map(field => sources.v2.EqualNullSafe(field, convertToScala(v, t))) + + case expressions.GreaterThan(e: Expression, Literal(v, t)) => + toNamedRef(e).map(field => sources.v2.GreaterThan(field, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), e: Expression) => + toNamedRef(e).map(field => sources.v2.LessThan(field, convertToScala(v, t))) + + case expressions.LessThan(e: Expression, Literal(v, t)) => + toNamedRef(e).map(field => sources.v2.LessThan(field, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), e: Expression) => + toNamedRef(e).map(field => sources.v2.GreaterThan(field, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(e: Expression, Literal(v, t)) => + toNamedRef(e).map(field => sources.v2.GreaterThanOrEqual(field, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), e: Expression) => + toNamedRef(e).map(field => sources.v2.LessThanOrEqual(field, convertToScala(v, t))) + + case expressions.LessThanOrEqual(e: Expression, Literal(v, t)) => + toNamedRef(e).map(field => sources.v2.LessThanOrEqual(field, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), e: Expression) => + toNamedRef(e).map(field => sources.v2.GreaterThanOrEqual(field, convertToScala(v, t))) + + case expressions.InSet(e: Expression, set) => + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + toNamedRef(e).map(field => sources.v2.In(field, set.toArray.map(toScala))) + + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(e: Expression, list) if list.forall(_.isInstanceOf[Literal]) => + val hSet = list.map(_.eval(EmptyRow)) + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + toNamedRef(e).map(field => sources.v2.In(field, hSet.toArray.map(toScala))) + + case expressions.IsNull(e: Expression) => + toNamedRef(e).map(field => sources.v2.IsNull(field)) + case expressions.IsNotNull(e: Expression) => + toNamedRef(e).map(field => sources.v2.IsNotNull(field)) + case expressions.StartsWith(e: Expression, Literal(v: UTF8String, StringType)) => + toNamedRef(e).map(field => sources.v2.StringStartsWith(field, v.toString)) + + case expressions.EndsWith(e: Expression, Literal(v: UTF8String, StringType)) => + toNamedRef(e).map(field => sources.v2.StringEndsWith(field, v.toString)) + + case expressions.Contains(e: Expression, Literal(v: UTF8String, StringType)) => + toNamedRef(e).map(field => sources.v2.StringContains(field, v.toString)) + + case expressions.Literal(true, BooleanType) => + Some(sources.v2.AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(sources.v2.AlwaysFalse) + + case _ => None + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[FilterV2]]. + * + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2(predicate: Expression): Option[FilterV2] = { + translateFilterV2WithMapping(predicate, None) + } + /** + * Tries to translate a Catalyst [[Expression]] into data source [[FilterV2]]. + * + * @param predicate The input [[Expression]] to be translated as [[FilterV2]] + * @param translatedFilterToExpr An optional map from leaf node filter expressions to its + * translated [[FilterV2]]. The map is used for rebuilding + * [[Expression]] from [[FilterV2]]. + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2WithMapping( + predicate: Expression, + translatedFilterToExpr: Option[mutable.HashMap[sources.v2.FilterV2, Expression]]) + : Option[sources.v2.FilterV2] = { + predicate match { + case expressions.And(left, right) => + // See SPARK-12218 for detailed discussion + // It is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) + // and we do not understand how to convert trim(b) = 'blah'. + // If we only convert a = 2, we will end up with + // (a = 2) OR (c > 0), which will generate wrong results. + // Pushing one leg of AND down is only safe to do at the top level. + // You can see ParquetFilters' createFilter for more details. + for { + leftFilter <- translateFilterV2WithMapping(left, translatedFilterToExpr) + rightFilter <- translateFilterV2WithMapping(right, translatedFilterToExpr) + } yield sources.v2.And(leftFilter, rightFilter) + + case expressions.Or(left, right) => + for { + leftFilter <- translateFilterV2WithMapping(left, translatedFilterToExpr) + rightFilter <- translateFilterV2WithMapping(right, translatedFilterToExpr) + } yield sources.v2.Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translateFilterV2WithMapping(child, translatedFilterToExpr).map(sources.v2.Not) + + case other => + val filter = translateLeafNodeFilterV2(other) + if (filter.isDefined && translatedFilterToExpr.isDefined) { + translatedFilterToExpr.get(filter.get) = predicate + } + filter + } + } + + protected[sql] def rebuildExpressionFromFilter( + filter: FilterV2, + translatedFilterToExpr: mutable.HashMap[sources.v2.FilterV2, Expression]): Expression = { + filter match { + case sources.v2.And(left, right) => + expressions.And(rebuildExpressionFromFilter(left, translatedFilterToExpr), + rebuildExpressionFromFilter(right, translatedFilterToExpr)) + case sources.v2.Or(left, right) => + expressions.Or(rebuildExpressionFromFilter(left, translatedFilterToExpr), + rebuildExpressionFromFilter(right, translatedFilterToExpr)) + case sources.v2.Not(pred) => + expressions.Not(rebuildExpressionFromFilter(pred, translatedFilterToExpr)) + case other => + translatedFilterToExpr.getOrElse(other, + throw new AnalysisException( + s"Fail to rebuild expression: missing key $filter in `translatedFilterToExpr`")) + } + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => // projection and filters were already pushed down in the optimizer. @@ -143,7 +309,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. val filters = splitConjunctivePredicates(deleteExpr).map { - filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( + filter => translateFilterV2(deleteExpr).getOrElse( throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) }.toArray r.table.asWritable match { @@ -168,7 +334,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { // correctness depends on removing all matching data. val filters = DataSourceStrategy.normalizeExprs(condition.toSeq, output) .flatMap(splitConjunctivePredicates(_).map { - f => DataSourceStrategy.translateFilter(f).getOrElse( + f => translateFilterV2(f).getOrElse( throw new AnalysisException(s"Exec update failed:" + s" cannot translate expression to source filter: $f")) }).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala index afebbfd01db2..3802591fb8a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.catalog.SupportsDelete -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 case class DeleteFromTableExec( table: SupportsDelete, - condition: Array[Filter]) extends V2CommandExec { + condition: Array[FilterV2]) extends V2CommandExec { override protected def run(): Seq[InternalRow] = { table.deleteWhere(condition) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileFormatV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileFormatV2.scala new file mode 100644 index 000000000000..3326659044d5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileFormatV2.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.FilterV2 +import org.apache.spark.sql.types.{DataType, StructType} + + +/** + * Used to read and write data stored in files to/from the [[InternalRow]] format. + */ +trait FileFormatV2 { + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory + + /** + * Returns whether this format supports returning columnar batch or not. + * + * TODO: we should just have different traits for the different formats. + */ + def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = { + false + } + + /** + * Returns concrete column vector class names for each column to be used in a columnar batch + * if this format supports returning columnar batch. + */ + def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { + None + } + + /** + * Returns whether a file with `path` could be split or not. + */ + def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + false + } + + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param dataSchema The global data schema. It can be either specified by the user, or + * reconciled/merged from all underlying data files. If any partition columns + * are contained in the files, they are preserved in this schema. + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be appended to the rows that + * are produced by the iterator. + * @param requiredSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + protected def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[FilterV2], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } + + /** + * Exactly the same as [[buildReader]] except that the reader function returned by this method + * appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]] + * returns. + */ + def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[FilterV2], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val dataReader = buildReader( + sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) + + new (PartitionedFile => Iterator[InternalRow]) with Serializable { + private val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + + // Using lazy val to avoid serialization + private lazy val appendPartitionColumns = + GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + override def apply(file: PartitionedFile): Iterator[InternalRow] = { + // Using local val to avoid per-row lazy val check (pre-mature optimization?...) + val converter = appendPartitionColumns + + // Note that we have to apply the converter even though `file.partitionValues` is empty. + // This is because the converter is also responsible for converting safe `InternalRow`s into + // `UnsafeRow`s. + if (partitionSchema.isEmpty) { + dataReader(file).map { dataRow => + converter(dataRow) + } + } else { + val joinedRow = new JoinedRow() + dataReader(file).map { dataRow => + converter(joinedRow(dataRow, file.partitionValues)) + } + } + } + } + } + + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType): Boolean = true +} + +/** + * The base class file format that is based on text file. + */ +abstract class TextBasedFileFormat extends FileFormatV2 { + private var codecFactory: CompressionCodecFactory = _ + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + if (codecFactory == null) { + codecFactory = new CompressionCodecFactory( + sparkSession.sessionState.newHadoopConfWithOptions(options)) + } + val codec = codecFactory.getCodec(path) + codec == null || codec.isInstanceOf[SplittableCompressionCodec] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 55104a2b21de..436af4540f81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statist import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -137,7 +138,7 @@ abstract class FileScan( StructType(readDataSchema.fields ++ readPartitionSchema.fields) // Returns whether the two given arrays of [[Filter]]s are equivalent. - protected def equivalentFilters(a: Array[Filter], b: Array[Filter]): Boolean = { + protected def equivalentFilters(a: Array[FilterV2], b: Array[FilterV2]): Boolean = { a.sortBy(_.hashCode()).sameElements(b.sortBy(_.hashCode())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 634ecfdf7e1d..add60279df6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -39,14 +39,14 @@ object PushDownUtils extends PredicateHelper { // expressions. For a `And`/`Or` predicate, it is possible that the predicate is partially // pushed down. This map can be used to construct a catalyst filter expression from the // input filter, or a superset(partial push down filter) of the input filter. - val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] - val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] + val translatedFilterToExpr = mutable.HashMap.empty[sources.v2.FilterV2, Expression] + val translatedFilters = mutable.ArrayBuffer.empty[sources.v2.FilterV2] // Catalyst filter expression that can't be translated to data source filters. val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] for (filterExpr <- filters) { - val translated = - DataSourceStrategy.translateFilterWithMapping(filterExpr, Some(translatedFilterToExpr)) + val translated = DataSourceV2Strategy + .translateFilterV2WithMapping(filterExpr, Some(translatedFilterToExpr)) if (translated.isEmpty) { untranslatableExprs += filterExpr } else { @@ -58,11 +58,11 @@ object PushDownUtils extends PredicateHelper { // the data source cannot guarantee the rows returned can pass these filters. // As a result we must return it so Spark can plan an extra filter operator. val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => - DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) + DataSourceV2Strategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } // The filters which are marked as pushed to this data source val pushedFilters = r.pushedFilters().map { filter => - DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) + DataSourceV2Strategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } (pushedFilters, untranslatableExprs ++ postScanFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index bf67e972976b..c33185b0476a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.SupportsWrite import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.sources.{AlwaysTrue, Filter, InsertableRelation} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.sql.sources.v2.{AlwaysTrue, FilterV2} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -59,11 +61,11 @@ case class AppendDataExecV1( */ case class OverwriteByExpressionExecV1( table: SupportsWrite, - deleteWhere: Array[Filter], + deleteWhere: Array[FilterV2], writeOptions: CaseInsensitiveStringMap, plan: LogicalPlan) extends V1FallbackWriters { - private def isTruncate(filters: Array[Filter]): Boolean = { + private def isTruncate(filters: Array[FilterV2]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 7d8a115c126e..f6ad0fa3ef48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingT import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.sources.v2.{AlwaysTrue, FilterV2} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -268,11 +268,11 @@ case class AppendDataExec( */ case class OverwriteByExpressionExec( table: SupportsWrite, - deleteWhere: Array[Filter], + deleteWhere: Array[FilterV2], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { - private def isTruncate(filters: Array[Filter]): Boolean = { + private def isTruncate(filters: Array[FilterV2]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 40784516a6f3..154f4f373e8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -36,7 +37,7 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter]) + pushedFilters: Array[FilterV2]) extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { override def isSplitable(path: Path): Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 8d1d4ec45915..ba5f797fc6d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -22,11 +22,15 @@ import scala.collection.JavaConverters._ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -48,20 +52,22 @@ case class OrcScanBuilder( readDataSchema(), readPartitionSchema(), options, pushedFilters()) } - private var _pushedFilters: Array[Filter] = Array.empty + private var _pushedFilters: Array[FilterV2] = Array.empty - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushFilters(filters: Array[FilterV2]): Array[FilterV2] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(schema, filters).foreach { f => // The pushed filters will be set in `hadoopConf`. After that, we can simply use the // changed `hadoopConf` in executors. OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) } - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + // TODO: Fix me for nested data + val dataTypeMap: Map[NamedReference, DataType] = + schema.map(f => FieldReference(Seq(f.name)) -> f.dataType).toMap _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray } filters } - override def pushedFilters(): Array[Filter] = _pushedFilters + override def pushedFilters(): Array[FilterV2] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index b2fc724057eb..f155737322b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.{PartitionedFile, RecordReader import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -58,7 +58,7 @@ case class ParquetPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory with Logging { + filters: Array[FilterV2]) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) private val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled @@ -137,12 +137,15 @@ case class ParquetPartitionReaderFactory( val parquetSchema = footerFileMetaData.getSchema val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) + /* filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. .flatMap(parquetFilters.createFilter) .reduceOption(FilterApi.and) + */ + None } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 7e6ea41cf0b8..ebf6d0549a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, P import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -38,7 +39,7 @@ case class ParquetScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - pushedFilters: Array[Filter], + pushedFilters: Array[FilterV2], options: CaseInsensitiveStringMap) extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { override def isSplitable(path: Path): Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 87db00077e79..a758905410c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -53,12 +53,13 @@ case class ParquetScanBuilder( new SparkToParquetSchemaConverter(sparkSession.sessionState.conf).convert(schema) val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) - parquetFilters.convertibleFilters(this.filters).toArray + // parquetFilters.convertibleFilters(this.filters).toArray + this.filters } - private var filters: Array[Filter] = Array.empty + private var filters: Array[FilterV2] = Array.empty - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushFilters(filters: Array[FilterV2]): Array[FilterV2] = { this.filters = filters this.filters } @@ -66,7 +67,7 @@ case class ParquetScanBuilder( // Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]]. // It requires the Parquet physical schema to determine whether a filter is convertible. // All filters that can be converted to Parquet are pushed down. - override def pushedFilters(): Array[Filter] = pushedParquetFilters + override def pushedFilters(): Array[FilterV2] = pushedParquetFilters override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 9386ab51d64f..f28ee0e093ed 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -25,8 +25,8 @@ import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; -import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.FilterV2; +import org.apache.spark.sql.sources.v2.GreaterThan; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -46,7 +46,7 @@ static class AdvancedScanBuilder implements ScanBuilder, Scan, SupportsPushDownFilters, SupportsPushDownRequiredColumns { private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - private Filter[] filters = new Filter[0]; + private FilterV2[] filters = new FilterV2[0]; @Override public void pruneColumns(StructType requiredSchema) { @@ -59,31 +59,31 @@ public StructType readSchema() { } @Override - public Filter[] pushFilters(Filter[] filters) { - Filter[] supported = Arrays.stream(filters).filter(f -> { + public FilterV2[] pushFilters(FilterV2[] filters) { + FilterV2[] supported = Arrays.stream(filters).filter(f -> { if (f instanceof GreaterThan) { GreaterThan gt = (GreaterThan) f; - return gt.attribute().equals("i") && gt.value() instanceof Integer; + return gt.ref().describe().equals("i") && gt.value() instanceof Integer; } else { return false; } - }).toArray(Filter[]::new); + }).toArray(FilterV2[]::new); - Filter[] unsupported = Arrays.stream(filters).filter(f -> { + FilterV2[] unsupported = Arrays.stream(filters).filter(f -> { if (f instanceof GreaterThan) { GreaterThan gt = (GreaterThan) f; - return !gt.attribute().equals("i") || !(gt.value() instanceof Integer); + return !gt.ref().describe().equals("i") || !(gt.value() instanceof Integer); } else { return true; } - }).toArray(Filter[]::new); + }).toArray(FilterV2[]::new); this.filters = supported; return unsupported; } @Override - public Filter[] pushedFilters() { + public FilterV2[] pushedFilters() { return filters; } @@ -101,9 +101,9 @@ public Batch toBatch() { public static class AdvancedBatch implements Batch { // Exposed for testing. public StructType requiredSchema; - public Filter[] filters; + public FilterV2[] filters; - AdvancedBatch(StructType requiredSchema, Filter[] filters) { + AdvancedBatch(StructType requiredSchema, FilterV2[] filters) { this.requiredSchema = requiredSchema; this.filters = filters; } @@ -113,10 +113,10 @@ public InputPartition[] planInputPartitions() { List res = new ArrayList<>(); Integer lowerBound = null; - for (Filter filter : filters) { + for (FilterV2 filter : filters) { if (filter instanceof GreaterThan) { GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + if ("i".equals(f.ref().describe()) && f.value() instanceof Integer) { lowerBound = (Integer) f.value(); break; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 55c71c7d02d2..3f08bede235e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.{FilterV2, GreaterThan} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -92,11 +92,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession { checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q2) - assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.filters.flatMap(_.references.map(_.describe())).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } else { val batch = getJavaBatch(q2) - assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } @@ -104,11 +104,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession { checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q3) - assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i")) } else { val batch = getJavaBatch(q3) - assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i")) } @@ -481,7 +481,7 @@ class AdvancedScanBuilder extends ScanBuilder with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns { var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] + var filters = Array.empty[FilterV2] override def pruneColumns(requiredSchema: StructType): Unit = { this.requiredSchema = requiredSchema @@ -489,27 +489,27 @@ class AdvancedScanBuilder extends ScanBuilder override def readSchema(): StructType = requiredSchema - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushFilters(filters: Array[FilterV2]): Array[FilterV2] = { val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true + case GreaterThan(ref, _: Int) if ref.fieldNames().sameElements(Array("i")) => true case _ => false } this.filters = supported unsupported } - override def pushedFilters(): Array[Filter] = filters + override def pushedFilters(): Array[FilterV2] = filters override def build(): Scan = this override def toBatch: Batch = new AdvancedBatch(filters, requiredSchema) } -class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType) extends Batch { +class AdvancedBatch(val filters: Array[FilterV2], val requiredSchema: StructType) extends Batch { override def planInputPartitions(): Array[InputPartition] = { val lowerBound = filters.collectFirst { - case GreaterThan("i", v: Int) => v + case GreaterThan(ref, v: Int) if ref.fieldNames.sameElements(Array("i")) => v } val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index de843ba4375d..3911d99d9e33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -28,10 +28,11 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveM import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} -import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils} +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -262,8 +263,8 @@ class InMemoryTableWithV1Fallback( TableCapability.TRUNCATE).asJava @volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty - private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) - private val partIndexes = partFieldNames.map(schema.fieldIndex(_)) + private val partFieldNames = partitioning.flatMap(_.references).toSeq + private val partIndexes = partFieldNames.map(_.describe()).map(schema.fieldIndex(_)) def getData: Seq[Row] = dataMap.values.flatten.toSeq @@ -285,7 +286,7 @@ class InMemoryTableWithV1Fallback( this } - override def overwrite(filters: Array[Filter]): WriteBuilder = { + override def overwrite(filters: Array[FilterV2]): WriteBuilder = { val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) dataMap --= keys mode = "overwrite" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 528c3474a17c..dab83f6c3603 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -125,8 +125,8 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") } else { assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") +// val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) +// assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") } case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 286bb1e92026..f5239f10dd0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1493,7 +1493,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { // In this test suite, all the simple predicates are convertible here. assert(parquetFilters.convertibleFilters(sourceFilters) === pushedFilters) val pushedParquetFilters = pushedFilters.map { pred => - val maybeFilter = parquetFilters.createFilter(pred) + val maybeFilter: Option[FilterPredicate] = None // parquetFilters.createFilter(pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") maybeFilter.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala index 87d541d2d22b..32b1544e587d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -160,7 +160,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { // 'a is not null and 'a > 1 val filters = scanNodes.head.scan.asInstanceOf[ParquetScan].pushedFilters assert(filters.length == 2) - assert(filters.flatMap(_.references).distinct === Array("a")) + // assert(filters.flatMap(_.references).map(_.name()).distinct === Array("a")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala index 1cb7a2156c3d..9a61652cad3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.FiltersV2Suite.ref /** * Unit test suites for data source filters. @@ -27,26 +28,51 @@ class FiltersSuite extends SparkFunSuite { test("EqualTo references") { assert(EqualTo("a", "1").references.toSeq == Seq("a")) assert(EqualTo("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + + // Testing v1 to v2 filter conversions + assert(EqualTo("a", "1").toV2 == v2.EqualTo(ref("a"), "1")) + assert(EqualTo("a", EqualTo("b", "2")).toV2 == + v2.EqualTo(ref("a"), v2.EqualTo(ref("b"), "2"))) } test("EqualNullSafe references") { assert(EqualNullSafe("a", "1").references.toSeq == Seq("a")) assert(EqualNullSafe("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + + // Testing v1 to v2 filter conversions + assert(EqualNullSafe("a", "1").toV2 == v2.EqualNullSafe(ref("a"), "1")) + assert(EqualNullSafe("a", EqualTo("b", "2")).toV2 == + v2.EqualNullSafe(ref("a"), v2.EqualTo(ref("b"), "2"))) } test("GreaterThan references") { assert(GreaterThan("a", "1").references.toSeq == Seq("a")) assert(GreaterThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + + // Testing v1 to v2 filter conversions + assert(GreaterThan("a", "1").toV2 == v2.GreaterThan(ref("a"), "1")) + assert(GreaterThan("a", EqualTo("b", "2")).toV2 == + v2.GreaterThan(ref("a"), v2.EqualTo(ref("b"), "2"))) } test("GreaterThanOrEqual references") { assert(GreaterThanOrEqual("a", "1").references.toSeq == Seq("a")) assert(GreaterThanOrEqual("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + + // Testing v1 to v2 filter conversions + assert(GreaterThanOrEqual("a", "1").toV2 == v2.GreaterThanOrEqual(ref("a"), "1")) + assert(GreaterThanOrEqual("a", EqualTo("b", "2")).toV2 == + v2.GreaterThanOrEqual(ref("a"), v2.EqualTo(ref("b"), "2"))) } test("LessThan references") { assert(LessThan("a", "1").references.toSeq == Seq("a")) assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) + + // Testing v1 to v2 filter conversions + assert(LessThan("a", "1").toV2 == v2.LessThan(ref("a"), "1")) + assert(LessThan("a", EqualTo("b", "2")).toV2 == + v2.LessThan(ref("a"), v2.EqualTo(ref("b"), "2"))) } test("LessThanOrEqual references") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FiltersV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FiltersV2Suite.scala new file mode 100644 index 000000000000..6c338671d99d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FiltersV2Suite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.sources.v2.FiltersV2Suite.ref + +class FiltersV2Suite extends SparkFunSuite { + + test("References with nested columns") { + assert(EqualTo(ref("a", "B"), "1").references.map(_.describe()).toSeq == Seq("a.B")) + assert(EqualTo(ref("a", "b.c"), "1").references.map(_.describe()).toSeq == Seq("a.`b.c`")) + assert(EqualTo(ref("`a`.b", "c"), "1").references.map(_.describe()).toSeq == Seq("```a``.b`.c")) + } + + test("EqualTo references") { + assert(EqualTo(ref("a"), "1").references.map(_.describe()).toSeq == Seq("a")) + assert(EqualTo(ref("a"), EqualTo(ref("b"), "2")).references.map(_.describe()).toSeq == + Seq("a", "b")) + } + + test("EqualNullSafe references") { + assert(EqualNullSafe(ref("a"), "1").references.map(_.describe()).toSeq == Seq("a")) + assert(EqualNullSafe(ref("a"), EqualTo(ref("b"), "2")).references.map(_.describe()).toSeq == + Seq("a", "b")) + } + + test("GreaterThan references") { + assert(GreaterThan(ref("a"), "1").references.map(_.describe()).toSeq == Seq("a")) + assert(GreaterThan(ref("a"), EqualTo(ref("b"), "2")).references.map(_.describe()).toSeq + == Seq("a", "b")) + } + + test("GreaterThanOrEqual references") { + assert(GreaterThanOrEqual(ref("a"), "1").references.map(_.describe()).toSeq == Seq("a")) + assert(GreaterThanOrEqual(ref("a"), EqualTo(ref("b"), "2")).references.map(_.describe()).toSeq + == Seq("a", "b")) + } + + test("LessThan references") { + assert(LessThan(ref("a"), "1").references.map(_.describe()).toSeq == + Seq("a")) + assert(LessThan(ref("a"), EqualTo(ref("b"), "2")).references.map(_.describe()).toSeq == + Seq("a", "b")) + } + + test("LessThanOrEqual references") { + assert(LessThanOrEqual(ref("a"), "1").references.map(_.describe()).toSeq == + Seq("a")) + assert(LessThanOrEqual(ref("a"), EqualTo(ref("b"), "2")).references.map(_.describe()).toSeq == + Seq("a", "b")) + } + + test("In references") { + assert(In(ref("a"), Array("1")).references.map(_.describe()).toSeq == Seq("a")) + assert(In(ref("a"), Array("1", EqualTo(ref("b"), "2"))).references.map(_.describe()).toSeq + == Seq("a", "b")) + } + + test("IsNull references") { + assert(IsNull(ref("a")).references.map(_.describe()).toSeq + == Seq("a")) + } + + test("IsNotNull references") { + assert(IsNotNull(ref("a")).references.map(_.describe()).toSeq + == Seq("a")) + } + + test("And references") { + assert(And(EqualTo(ref("a"), "1"), EqualTo(ref("b"), "1")).references.map(_.describe()).toSeq == + Seq("a", "b")) + } + + test("Or references") { + assert(Or(EqualTo(ref("a"), "1"), EqualTo(ref("b"), "1")).references.map(_.describe()).toSeq == + Seq("a", "b")) + } + + test("StringStartsWith references") { + assert(StringStartsWith(ref("a"), "str").references.map(_.describe()).toSeq == Seq("a")) + } + + test("StringEndsWith references") { + assert(StringEndsWith(ref("a"), "str").references.map(_.describe()).toSeq == Seq("a")) + } + + test("StringContains references") { + assert(StringContains(ref("a"), "str").references.map(_.describe()).toSeq == Seq("a")) + } +} + +object FiltersV2Suite { + private[sql] def ref(parts: String*): FieldReference = { + new FieldReference(parts) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 995c5ed317de..c991de57728a 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc +import scala.reflect.ClassTag + import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder @@ -24,7 +26,9 @@ import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types._ /** @@ -63,8 +67,15 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Create ORC filter as a SearchArgument instance. */ - def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + def createFilter(schema: StructType, filters: Seq[Filter])(implicit d: DummyImplicit) + : Option[SearchArgument] = { + createFilter(schema, filters.map(_.toV2)) + } + + def createFilter(schema: StructType, filters: Seq[FilterV2]): Option[SearchArgument] = { + // TODO: Fix me for nested data + val dataTypeMap: Map[NamedReference, DataType] = + schema.map(f => FieldReference(Seq(f.name)) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => @@ -77,13 +88,13 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], - filters: Seq[Filter]): Seq[Filter] = { - import org.apache.spark.sql.sources._ + dataTypeMap: Map[NamedReference, DataType], + filters: Seq[FilterV2]): Seq[FilterV2] = { + import org.apache.spark.sql.sources.v2._ def convertibleFiltersHelper( - filter: Filter, - canPartialPushDown: Boolean): Option[Filter] = filter match { + filter: FilterV2, + canPartialPushDown: Boolean): Option[FilterV2] = filter match { // At here, it is not safe to just convert one side and remove the other side // if we do not understand what the parent filters are. // @@ -171,10 +182,10 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], - expression: Filter, + dataTypeMap: Map[NamedReference, DataType], + expression: FilterV2, builder: Builder): Builder = { - import org.apache.spark.sql.sources._ + import org.apache.spark.sql.sources.v2._ expression match { case And(left, right) => @@ -207,64 +218,54 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], - expression: Filter, + dataTypeMap: Map[NamedReference, DataType], + expression: FilterV2, builder: Builder): Option[Builder] = { - def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + def getType(field: NamedReference): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(field)) - import org.apache.spark.sql.sources._ + import org.apache.spark.sql.sources.v2._ // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(field, value) if isSearchableType(dataTypeMap(field)) => + val castedValue = castLiteralValue(value, dataTypeMap(field)) + Some(builder.startAnd().equals(field.name(), getType(field), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(field, value) if isSearchableType(dataTypeMap(field)) => + val castedValue = castLiteralValue(value, dataTypeMap(field)) + Some(builder.startAnd().nullSafeEquals(field.name(), getType(field), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(field, value) if isSearchableType(dataTypeMap(field)) => + val castedValue = castLiteralValue(value, dataTypeMap(field)) + Some(builder.startAnd().lessThan(field.name(), getType(field), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(field, value) if isSearchableType(dataTypeMap(field)) => + val castedValue = castLiteralValue(value, dataTypeMap(field)) + Some(builder.startAnd().lessThanEquals(field.name(), getType(field), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(field, value) if isSearchableType(dataTypeMap(field)) => + val castedValue = castLiteralValue(value, dataTypeMap(field)) + Some(builder.startNot().lessThanEquals(field.name(), getType(field), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(field, value) if isSearchableType(dataTypeMap(field)) => + val castedValue = castLiteralValue(value, dataTypeMap(field)) + Some(builder.startNot().lessThan(field.name(), getType(field), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(field) if isSearchableType(dataTypeMap(field)) => + Some(builder.startAnd().isNull(field.name(), getType(field)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(field) if isSearchableType(dataTypeMap(field)) => + Some(builder.startNot().isNull(field.name(), getType(field)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(field, values) if isSearchableType(dataTypeMap(field)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(field))) + Some(builder.startAnd().in(field.name(), getType(field), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } } } - diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index d09236a93433..4686926f8182 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -57,7 +57,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) + val maybeFilter: Option[SearchArgument] = None // OrcFilters.createFilter(query.schema, pushedFilters) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pushedFilters") checker(maybeFilter.get) diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 948ab44a8c19..ebbd8bf00f41 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -24,7 +24,9 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types._ /** @@ -63,8 +65,15 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Create ORC filter as a SearchArgument instance. */ - def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + def createFilter(schema: StructType, filters: Seq[Filter])(implicit d: DummyImplicit) + : Option[SearchArgument] = { + createFilter(schema, filters.map(_.toV2)) + } + + def createFilter(schema: StructType, filters: Seq[FilterV2]): Option[SearchArgument] = { + // TODO: Fix me for nested data + val dataTypeMap: Map[NamedReference, DataType] = + schema.map(f => FieldReference(Seq(f.name)) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => @@ -77,13 +86,13 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], - filters: Seq[Filter]): Seq[Filter] = { + dataTypeMap: Map[NamedReference, DataType], + filters: Seq[FilterV2]): Seq[FilterV2] = { import org.apache.spark.sql.sources._ def convertibleFiltersHelper( - filter: Filter, - canPartialPushDown: Boolean): Option[Filter] = filter match { + filter: FilterV2, + canPartialPushDown: Boolean): Option[FilterV2] = filter match { // At here, it is not safe to just convert one side and remove the other side // if we do not understand what the parent filters are. // @@ -95,11 +104,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Pushing one side of AND down is only safe to do at the top level or in the child // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate // can be safely removed. - case And(left, right) => + case v2.And(left, right) => val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult)) + case (Some(leftResult), Some(rightResult)) => Some(v2.And(leftResult, rightResult)) case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) case _ => None @@ -116,14 +125,14 @@ private[sql] object OrcFilters extends OrcFiltersBase { // The predicate can be converted as // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) // As per the logical in And predicate, we can push down (a1 OR b1). - case Or(left, right) => + case v2.Or(left, right) => for { lhs <- convertibleFiltersHelper(left, canPartialPushDown) rhs <- convertibleFiltersHelper(right, canPartialPushDown) - } yield Or(lhs, rhs) - case Not(pred) => + } yield v2.Or(lhs, rhs) + case v2.Not(pred) => val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - childResultOptional.map(Not) + childResultOptional.map(v2.Not) case other => for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other } @@ -171,23 +180,23 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], - expression: Filter, + dataTypeMap: Map[NamedReference, DataType], + expression: FilterV2, builder: Builder): Builder = { import org.apache.spark.sql.sources._ expression match { - case And(left, right) => + case v2.And(left, right) => val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) val rhs = buildSearchArgument(dataTypeMap, right, lhs) rhs.end() - case Or(left, right) => + case v2.Or(left, right) => val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) val rhs = buildSearchArgument(dataTypeMap, right, lhs) rhs.end() - case Not(child) => + case v2.Not(child) => buildSearchArgument(dataTypeMap, child, builder.startNot()).end() case other => @@ -207,11 +216,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], - expression: Filter, + dataTypeMap: Map[NamedReference, DataType], + expression: FilterV2, builder: Builder): Option[Builder] = { - def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + def getType(field: NamedReference): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(field)) import org.apache.spark.sql.sources._ @@ -219,48 +228,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case v2.EqualTo(ref, value) if isSearchableType(dataTypeMap(ref)) => + val castedValue = castLiteralValue(value, dataTypeMap(ref)) + Some(builder.startAnd().equals(ref.describe(), getType(ref), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case v2.EqualNullSafe(ref, value) if isSearchableType(dataTypeMap(ref)) => + val castedValue = castLiteralValue(value, dataTypeMap(ref)) + Some(builder.startAnd().nullSafeEquals(ref.describe(), getType(ref), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case v2.LessThan(ref, value) if isSearchableType(dataTypeMap(ref)) => + val castedValue = castLiteralValue(value, dataTypeMap(ref)) + Some(builder.startAnd().lessThan(ref.describe(), getType(ref), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case v2.LessThanOrEqual(ref, value) if isSearchableType(dataTypeMap(ref)) => + val castedValue = castLiteralValue(value, dataTypeMap(ref)) + Some(builder.startAnd().lessThanEquals(ref.describe(), getType(ref), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case v2.GreaterThan(ref, value) if isSearchableType(dataTypeMap(ref)) => + val castedValue = castLiteralValue(value, dataTypeMap(ref)) + Some(builder.startNot().lessThanEquals(ref.describe(), getType(ref), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case v2.GreaterThanOrEqual(ref, value) if isSearchableType(dataTypeMap(ref)) => + val castedValue = castLiteralValue(value, dataTypeMap(ref)) + Some(builder.startNot().lessThan(ref.describe(), getType(ref), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case v2.IsNull(ref) if isSearchableType(dataTypeMap(ref)) => + Some(builder.startAnd().isNull(ref.describe(), getType(ref)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case v2.IsNotNull(ref) if isSearchableType(dataTypeMap(ref)) => + Some(builder.startNot().isNull(ref.describe(), getType(ref)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case v2.In(ref, values) if isSearchableType(dataTypeMap(ref)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(ref))) + Some(builder.startAnd().in(ref.describe(), getType(ref), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index cd1bffb6b7ab..4d954ae587ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -25,10 +25,13 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.execution.datasources.orc.{OrcFilters => DatasourceOrcFilters} import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.FilterV2 import org.apache.spark.sql.types._ /** @@ -69,11 +72,17 @@ private[orc] object OrcFilters extends Logging { method } - def createFilter(schema: StructType, filters: Array[Filter]): Option[SearchArgument] = { + def createFilter(schema: StructType, filters: Seq[Filter])(implicit d: DummyImplicit) + : Option[SearchArgument] = { + createFilter(schema, filters.map(_.toV2)) + } + + def createFilter(schema: StructType, filters: Seq[FilterV2]): Option[SearchArgument] = { if (HiveUtils.isHive23) { DatasourceOrcFilters.createFilter(schema, filters).asInstanceOf[Option[SearchArgument]] } else { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap: Map[NamedReference, DataType] = + schema.map(f => FieldReference(Seq(f.name)) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => @@ -87,13 +96,13 @@ private[orc] object OrcFilters extends Logging { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], - filters: Seq[Filter]): Seq[Filter] = { + dataTypeMap: Map[NamedReference, DataType], + filters: Seq[FilterV2]): Seq[FilterV2] = { import org.apache.spark.sql.sources._ def convertibleFiltersHelper( - filter: Filter, - canPartialPushDown: Boolean): Option[Filter] = filter match { + filter: FilterV2, + canPartialPushDown: Boolean): Option[FilterV2] = filter match { // At here, it is not safe to just convert one side and remove the other side // if we do not understand what the parent filters are. // @@ -105,11 +114,11 @@ private[orc] object OrcFilters extends Logging { // Pushing one side of AND down is only safe to do at the top level or in the child // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate // can be safely removed. - case And(left, right) => + case v2.And(left, right) => val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult)) + case (Some(leftResult), Some(rightResult)) => Some(v2.And(leftResult, rightResult)) case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) case _ => None @@ -126,14 +135,14 @@ private[orc] object OrcFilters extends Logging { // The predicate can be converted as // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) // As per the logical in And predicate, we can push down (a1 OR b1). - case Or(left, right) => + case v2.Or(left, right) => for { lhs <- convertibleFiltersHelper(left, canPartialPushDown) rhs <- convertibleFiltersHelper(right, canPartialPushDown) - } yield Or(lhs, rhs) - case Not(pred) => + } yield v2.Or(lhs, rhs) + case v2.Not(pred) => val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - childResultOptional.map(Not) + childResultOptional.map(v2.Not) case other => for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other } @@ -151,21 +160,21 @@ private[orc] object OrcFilters extends Logging { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], - expression: Filter, + dataTypeMap: Map[NamedReference, DataType], + expression: FilterV2, builder: Builder): Builder = { expression match { - case And(left, right) => + case v2.And(left, right) => val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) val rhs = buildSearchArgument(dataTypeMap, right, lhs) rhs.end() - case Or(left, right) => + case v2.Or(left, right) => val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) val rhs = buildSearchArgument(dataTypeMap, right, lhs) rhs.end() - case Not(child) => + case v2.Not(child) => buildSearchArgument(dataTypeMap, child, builder.startNot()).end() case other => @@ -185,8 +194,8 @@ private[orc] object OrcFilters extends Logging { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], - expression: Filter, + dataTypeMap: Map[NamedReference, DataType], + expression: FilterV2, builder: Builder): Option[Builder] = { def isSearchableType(dataType: DataType): Boolean = dataType match { // Only the values in the Spark types below can be recognized by @@ -207,47 +216,47 @@ private[orc] object OrcFilters extends Logging { // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case v2.EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startAnd() val method = findMethod(bd.getClass, "equals", classOf[String], classOf[Object]) Some(method.invoke(bd, attribute, value.asInstanceOf[AnyRef]).asInstanceOf[Builder].end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case v2.EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startAnd() val method = findMethod(bd.getClass, "nullSafeEquals", classOf[String], classOf[Object]) Some(method.invoke(bd, attribute, value.asInstanceOf[AnyRef]).asInstanceOf[Builder].end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case v2.LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startAnd() val method = findMethod(bd.getClass, "lessThan", classOf[String], classOf[Object]) Some(method.invoke(bd, attribute, value.asInstanceOf[AnyRef]).asInstanceOf[Builder].end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case v2.LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startAnd() val method = findMethod(bd.getClass, "lessThanEquals", classOf[String], classOf[Object]) Some(method.invoke(bd, attribute, value.asInstanceOf[AnyRef]).asInstanceOf[Builder].end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case v2.GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startNot() val method = findMethod(bd.getClass, "lessThanEquals", classOf[String], classOf[Object]) Some(method.invoke(bd, attribute, value.asInstanceOf[AnyRef]).asInstanceOf[Builder].end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case v2.GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startNot() val method = findMethod(bd.getClass, "lessThan", classOf[String], classOf[Object]) Some(method.invoke(bd, attribute, value.asInstanceOf[AnyRef]).asInstanceOf[Builder].end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case v2.IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startAnd() val method = findMethod(bd.getClass, "isNull", classOf[String]) Some(method.invoke(bd, attribute).asInstanceOf[Builder].end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case v2.IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startNot() val method = findMethod(bd.getClass, "isNull", classOf[String]) Some(method.invoke(bd, attribute).asInstanceOf[Builder].end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + case v2.In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => val bd = builder.startAnd() val method = findMethod(bd.getClass, "in", classOf[String], classOf[Array[Object]]) Some(method.invoke(bd, attribute, values.map(_.asInstanceOf[AnyRef]))