Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.connector.catalog
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.internal.SQLConf

/**
* Conversion helpers for working with v2 [[CatalogPlugin]].
Expand Down Expand Up @@ -132,4 +134,10 @@ private[sql] object CatalogV2Implicits {
part
}
}

private lazy val catalystSqlParser = new CatalystSqlParser(SQLConf.get)

def parseColumnPath(name: String): Seq[String] = {
catalystSqlParser.parseMultipartIdentifier(name)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2049,6 +2049,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val NESTED_PREDICATE_PUSHDOWN_ENABLED =
buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled")
.internal()
.doc("When true, Spark tries to push down predicates for nested columns and or names " +
"containing `dots` to data sources. Currently, Parquet implements both optimizations " +
"while ORC only supports predicates for names containing `dots`. The other data sources" +
"don't support this feature yet.")
.version("3.0.0")
.booleanConf
.createWithDefault(true)

val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
.internal()
Expand Down Expand Up @@ -3035,6 +3046,8 @@ class SQLConf extends Serializable with Logging {

def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)

def nestedPredicatePushdownEnabled: Boolean = getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED)

def serializerNestedSchemaPruningEnabled: Boolean =
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.sources

import org.apache.spark.annotation.{Evolving, Stable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath

////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines all the filters that we can push down to the data sources.
Expand All @@ -32,6 +33,10 @@ import org.apache.spark.annotation.{Evolving, Stable}
sealed abstract class Filter {
/**
* List of columns that are referenced by this filter.
*
* Note that, each element in `references` represents a column; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
*
* @since 2.1.0
*/
def references: Array[String]
Expand All @@ -40,12 +45,32 @@ sealed abstract class Filter {
case f: Filter => f.references
case _ => Array.empty
}

/**
* List of columns that are referenced by this filter.
*
* @return each element is a column name as an array of string multi-identifier
* @since 3.0.0
*/
def v2references: Array[Array[String]] = {
this.references.map(parseColumnPath(_).toArray)
}

/**
* If any of the references of this filter contains nested column
*/
private[sql] def containsNestedColumn: Boolean = {
this.v2references.exists(_.length > 1)
}
}

/**
* A filter that evaluates to `true` iff the attribute evaluates to a value
* A filter that evaluates to `true` iff the column evaluates to a value
* equal to `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand All @@ -58,6 +83,9 @@ case class EqualTo(attribute: String, value: Any) extends Filter {
* in that it returns `true` (rather than NULL) if both inputs are NULL, and `false`
* (rather than NULL) if one of the input is NULL and the other is not NULL.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.5.0
*/
@Stable
Expand All @@ -69,6 +97,9 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* greater than `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand All @@ -80,6 +111,9 @@ case class GreaterThan(attribute: String, value: Any) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* greater than or equal to `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand All @@ -91,6 +125,9 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* less than `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand All @@ -102,6 +139,9 @@ case class LessThan(attribute: String, value: Any) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to a value
* less than or equal to `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand All @@ -112,6 +152,9 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand Down Expand Up @@ -139,6 +182,9 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to null.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand All @@ -149,6 +195,9 @@ case class IsNull(attribute: String) extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to a non-null value.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.0
*/
@Stable
Expand Down Expand Up @@ -190,6 +239,9 @@ case class Not(child: Filter) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to
* a string that starts with `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.1
*/
@Stable
Expand All @@ -201,6 +253,9 @@ case class StringStartsWith(attribute: String, value: String) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to
* a string that ends with `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.1
*/
@Stable
Expand All @@ -212,6 +267,9 @@ case class StringEndsWith(attribute: String, value: String) extends Filter {
* A filter that evaluates to `true` iff the attribute evaluates to
* a string that contains the string `value`.
*
* @param attribute of the column to be evaluated; `dots` are used as separators
* for nested columns. If any part of the names contains `dots`,
* it is quoted to avoid confusion.
* @since 1.3.1
*/
@Stable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,19 @@ object DataSourceStrategy {
*/
object PushableColumn {
def unapply(e: Expression): Option[String] = {
def helper(e: Expression) = e match {
case a: Attribute => Some(a.name)
val nestedPredicatePushdownEnabled = SQLConf.get.nestedPredicatePushdownEnabled
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
def helper(e: Expression): Option[Seq[String]] = e match {
case a: Attribute =>
if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
Some(Seq(a.name))
} else {
None
}
case s: GetStructField if nestedPredicatePushdownEnabled =>
helper(s.child).map(_ :+ s.childSchema(s.ordinal).name)
case _ => None
}
helper(e)
helper(e).map(_.quoted)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter
import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.SparkFilterApi._
import org.apache.parquet.io.api.Binary
import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator}
import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type}
import org.apache.parquet.schema.OriginalType._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
Expand All @@ -49,15 +49,35 @@ class ParquetFilters(
pushDownInFilterThreshold: Int,
caseSensitive: Boolean) {
// A map which contains parquet field name and data type, if predicate push down applies.
private val nameToParquetField : Map[String, ParquetField] = {
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
val primitiveFields =
schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
f.getName -> ParquetField(f.getName,
ParquetSchemaType(f.getOriginalType,
f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
//
// Each key in `nameToParquetField` represents a column; `dots` are used as separators for
// nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
// See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
// Recursively traverse the parquet schema to get primitive fields that can be pushed-down.
// `parentFieldNames` is used to keep track of the current nested level when traversing.
def getPrimitiveFields(
fields: Seq[Type],
parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = {
fields.flatMap {
case p: PrimitiveType =>
Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName,
fieldType = ParquetSchemaType(p.getOriginalType,
p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)))
// Note that when g is a `Struct`, `g.getOriginalType` is `null`.
// When g is a `Map`, `g.getOriginalType` is `MAP`.
// When g is a `List`, `g.getOriginalType` is `LIST`.
case g: GroupType if g.getOriginalType == null =>
getPrimitiveFields(g.getFields.asScala, parentFieldNames :+ g.getName)
// Parquet only supports push-down for primitive types; as a result, Map and List types
// are removed.
case _ => None
}
}

val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map { field =>
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
(field.fieldNames.toSeq.quoted, field)
}
if (caseSensitive) {
primitiveFields.toMap
Expand All @@ -75,13 +95,13 @@ class ParquetFilters(
}

/**
* Holds a single field information stored in the underlying parquet file.
* Holds a single primitive field information stored in the underlying parquet file.
*
* @param fieldName field name in parquet file
* @param fieldNames a field name as an array of string multi-identifier in parquet file
* @param fieldType field type related info in parquet file
*/
private case class ParquetField(
fieldName: String,
private case class ParquetPrimitiveField(
fieldNames: Array[String],
fieldType: ParquetSchemaType)

private case class ParquetSchemaType(
Expand Down Expand Up @@ -472,13 +492,8 @@ class ParquetFilters(
case _ => false
}

// Parquet does not allow dots in the column name because dots are used as a column path
// delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates
// with missing columns. The incorrect results could be got from Parquet when we push down
// filters for the column having dots in the names. Thus, we do not push down such filters.
// See SPARK-20364.
private def canMakeFilterOn(name: String, value: Any): Boolean = {
nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
}

/**
Expand Down Expand Up @@ -509,38 +524,38 @@ class ParquetFilters(
predicate match {
case sources.IsNull(name) if canMakeFilterOn(name, null) =>
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), null))
.map(_(nameToParquetField(name).fieldNames, null))
case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), null))
.map(_(nameToParquetField(name).fieldNames, null))

case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))
case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))

case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))
case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))

case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
makeLt.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))
case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
makeLtEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))

case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
makeGt.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))
case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
makeGtEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), value))
.map(_(nameToParquetField(name).fieldNames, value))

case sources.And(lhs, rhs) =>
// At here, it is not safe to just convert one side and remove the other side
Expand Down Expand Up @@ -591,13 +606,13 @@ class ParquetFilters(
&& values.distinct.length <= pushDownInFilterThreshold =>
values.distinct.flatMap { v =>
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(Array(nameToParquetField(name).fieldName), v))
.map(_(nameToParquetField(name).fieldNames, v))
}.reduceLeftOption(FilterApi.or)

case sources.StringStartsWith(name, prefix)
if pushDownStartWith && canMakeFilterOn(name, prefix) =>
Option(prefix).map { v =>
FilterApi.userDefined(binaryColumn(Array(nameToParquetField(name).fieldName)),
FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldNames),
new UserDefinedPredicate[Binary] with Serializable {
private val strToBinary = Binary.fromReusedByteArray(v.getBytes)
private val size = strToBinary.length
Expand Down
Loading