Skip to content

Commit a55ad54

Browse files
committed
Implement unhandled filters for Parquet
1 parent 969d566 commit a55ad54

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet
1919

2020
import java.io.Serializable
2121

22+
import scala.collection.mutable.ArrayBuffer
23+
2224
import org.apache.parquet.filter2.predicate.FilterApi._
2325
import org.apache.parquet.filter2.predicate._
2426
import org.apache.parquet.io.api.Binary
@@ -207,12 +209,31 @@ private[sql] object ParquetFilters {
207209
*/
208210
}
209211

212+
/**
213+
* Return referenced columns in [[sources.Filter]].
214+
*/
215+
def referencedColumns(schema: StructType, predicate: sources.Filter): Array[String] = {
216+
val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap
217+
val referencedColumns = ArrayBuffer.empty[String]
218+
def getDataTypeOf(name: String): DataType = {
219+
referencedColumns += name
220+
dataTypeOf(name)
221+
}
222+
createParquetFilter(getDataTypeOf, predicate)
223+
referencedColumns.distinct.toArray
224+
}
225+
210226
/**
211227
* Converts data sources filters to Parquet filter predicates.
212228
*/
213229
def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = {
214230
val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap
231+
createParquetFilter(dataTypeOf, predicate)
232+
}
215233

234+
private def createParquetFilter(
235+
dataTypeOf: String => DataType,
236+
predicate: sources.Filter): Option[FilterPredicate] = {
216237
relaxParquetValidTypeMap
217238

218239
// NOTE:
@@ -265,18 +286,18 @@ private[sql] object ParquetFilters {
265286
// Pushing one side of AND down is only safe to do at the top level.
266287
// You can see ParquetRelation's initializeLocalJobFunc method as an example.
267288
for {
268-
lhsFilter <- createFilter(schema, lhs)
269-
rhsFilter <- createFilter(schema, rhs)
289+
lhsFilter <- createParquetFilter(dataTypeOf, lhs)
290+
rhsFilter <- createParquetFilter(dataTypeOf, rhs)
270291
} yield FilterApi.and(lhsFilter, rhsFilter)
271292

272293
case sources.Or(lhs, rhs) =>
273294
for {
274-
lhsFilter <- createFilter(schema, lhs)
275-
rhsFilter <- createFilter(schema, rhs)
295+
lhsFilter <- createParquetFilter(dataTypeOf, lhs)
296+
rhsFilter <- createParquetFilter(dataTypeOf, rhs)
276297
} yield FilterApi.or(lhsFilter, rhsFilter)
277298

278299
case sources.Not(pred) =>
279-
createFilter(schema, pred).map(FilterApi.not)
300+
createParquetFilter(dataTypeOf, pred).map(FilterApi.not)
280301

281302
case _ => None
282303
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ private[sql] class ParquetRelation(
133133
.map(_.toBoolean)
134134
.getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED))
135135

136+
// When merging schemas is enabled and the column of the given filter does not exist,
137+
// Parquet emits an exception which is an issue of Parquet (PARQUET-389).
138+
private val safeParquetFilterPushDown =
139+
sqlContext.conf.parquetFilterPushDown && !shouldMergeSchemas
140+
136141
private val mergeRespectSummaries =
137142
sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES)
138143

@@ -288,20 +293,23 @@ private[sql] class ParquetRelation(
288293
}
289294
}
290295

296+
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
297+
if (safeParquetFilterPushDown) {
298+
filters.filter(ParquetFilters.createFilter(dataSchema, _).isEmpty)
299+
} else {
300+
filters
301+
}
302+
}
303+
291304
override def buildInternalScan(
292305
requiredColumns: Array[String],
293306
filters: Array[Filter],
294307
inputFiles: Array[FileStatus],
295308
broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
296309
val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA)
297-
val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown
310+
val parquetFilterPushDown = safeParquetFilterPushDown
298311
val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString
299312
val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp
300-
301-
// When merging schemas is enabled and the column of the given filter does not exist,
302-
// Parquet emits an exception which is an issue of Parquet (PARQUET-389).
303-
val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown
304-
305313
// Parquet row group size. We will use this value as the value for
306314
// mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value
307315
// of these flags are smaller than the parquet row group size.
@@ -315,7 +323,7 @@ private[sql] class ParquetRelation(
315323
dataSchema,
316324
parquetBlockSize,
317325
useMetadataCache,
318-
safeParquetFilterPushDown,
326+
parquetFilterPushDown,
319327
assumeBinaryIsString,
320328
assumeInt96IsTimestamp) _
321329

@@ -568,6 +576,15 @@ private[sql] object ParquetRelation extends Logging {
568576
conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName)
569577

570578
// Try to push down filters when filter push-down is enabled.
579+
val safeRequiredColumns = if (parquetFilterPushDown) {
580+
val referencedColumns = filters
581+
// Collects all columns referenced in Parquet filter predicates.
582+
.flatMap(filter => ParquetFilters.referencedColumns(dataSchema, filter))
583+
(requiredColumns ++ referencedColumns).distinct
584+
} else {
585+
requiredColumns
586+
}
587+
571588
if (parquetFilterPushDown) {
572589
filters
573590
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
@@ -579,7 +596,7 @@ private[sql] object ParquetRelation extends Logging {
579596
}
580597

581598
conf.set(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, {
582-
val requestedSchema = StructType(requiredColumns.map(dataSchema(_)))
599+
val requestedSchema = StructType(safeRequiredColumns.map(dataSchema(_)))
583600
CatalystSchemaConverter.checkFieldNames(requestedSchema).json
584601
})
585602

0 commit comments

Comments
 (0)