Skip to content

Commit 20e6d74

Browse files
authored
[SPARK-20364] Parquet predicate pushdown on columns with dots return empty results (apache-spark-on-k8s#170)
1 parent 035cb7f commit 20e6d74

File tree

3 files changed

+125
-4
lines changed

3 files changed

+125
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.sql.execution.datasources.parquet
1919

2020
import org.apache.parquet.filter2.predicate._
21-
import org.apache.parquet.filter2.predicate.FilterApi._
21+
import org.apache.parquet.filter2.predicate.Operators.{Column, SupportsEqNotEq, SupportsLtGt}
22+
import org.apache.parquet.hadoop.metadata.ColumnPath
2223
import org.apache.parquet.io.api.Binary
2324

2425
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -30,6 +31,8 @@ import org.apache.spark.sql.types._
3031
*/
3132
private[parquet] object ParquetFilters {
3233

34+
import ParquetColumns._
35+
3336
case class SetInFilter[T <: Comparable[T]](valueSet: Set[T])
3437
extends UserDefinedPredicate[T] with Serializable {
3538

@@ -344,3 +347,39 @@ private[parquet] object ParquetFilters {
344347
}
345348
}
346349
}
350+
351+
/**
352+
* Note that, this is a hacky workaround to allow dots in column names. Currently, column APIs
353+
* in Parquet's `FilterApi` only allows dot-separated names so here we resemble those columns
354+
* but only allow single column path that allows dots in the names as we don't currently push
355+
* down filters with nested fields.
356+
*/
357+
private[parquet] object ParquetColumns {
358+
def intColumn(columnPath: String): Column[Integer] with SupportsLtGt = {
359+
new Column[Integer] (ColumnPath.get(columnPath), classOf[Integer]) with SupportsLtGt
360+
}
361+
362+
def longColumn(columnPath: String): Column[java.lang.Long] with SupportsLtGt = {
363+
new Column[java.lang.Long] (
364+
ColumnPath.get(columnPath), classOf[java.lang.Long]) with SupportsLtGt
365+
}
366+
367+
def floatColumn(columnPath: String): Column[java.lang.Float] with SupportsLtGt = {
368+
new Column[java.lang.Float] (
369+
ColumnPath.get(columnPath), classOf[java.lang.Float]) with SupportsLtGt
370+
}
371+
372+
def doubleColumn(columnPath: String): Column[java.lang.Double] with SupportsLtGt = {
373+
new Column[java.lang.Double] (
374+
ColumnPath.get(columnPath), classOf[java.lang.Double]) with SupportsLtGt
375+
}
376+
377+
def booleanColumn(columnPath: String): Column[java.lang.Boolean] with SupportsEqNotEq = {
378+
new Column[java.lang.Boolean] (
379+
ColumnPath.get(columnPath), classOf[java.lang.Boolean]) with SupportsEqNotEq
380+
}
381+
382+
def binaryColumn(columnPath: String): Column[Binary] with SupportsLtGt = {
383+
new Column[Binary] (ColumnPath.get(columnPath), classOf[Binary]) with SupportsLtGt
384+
}
385+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,21 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
487487
}
488488
}
489489

490+
test("no filter pushdown for nested field access") {
491+
val table = createTable(
492+
files = Seq("file1" -> 1),
493+
format = classOf[TestFileFormatWithNestedSchema].getName)
494+
495+
checkScan(table.where("a1 = 1"))(_ => ())
496+
// Check `a1` access pushes the predicate.
497+
checkDataFilters(Set(IsNotNull("a1"), EqualTo("a1", 1)))
498+
499+
// TODO: reenable check below once we update to latest master
500+
// checkScan(table.where("a2.c1 = 1"))(_ => ())
501+
// // Check `a2.c1` access does not push the predicate.
502+
// checkDataFilters(Set(IsNotNull("a2")))
503+
}
504+
490505
// Helpers for checking the arguments passed to the FileFormat.
491506

492507
protected val checkPartitionSchema =
@@ -537,7 +552,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
537552
*/
538553
def createTable(
539554
files: Seq[(String, Int)],
540-
buckets: Int = 0): DataFrame = {
555+
buckets: Int = 0,
556+
format: String = classOf[TestFileFormat].getName): DataFrame = {
541557
val tempDir = Utils.createTempDir()
542558
files.foreach {
543559
case (name, size) =>
@@ -547,7 +563,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
547563
}
548564

549565
val df = spark.read
550-
.format(classOf[TestFileFormat].getName)
566+
.format(format)
551567
.load(tempDir.getCanonicalPath)
552568

553569
if (buckets > 0) {
@@ -632,6 +648,22 @@ class TestFileFormat extends TextBasedFileFormat {
632648
}
633649
}
634650

651+
/**
652+
* A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing.
653+
* Unlike the one above, this one has a nested schema.
654+
*/
655+
class TestFileFormatWithNestedSchema extends TestFileFormat {
656+
override def inferSchema(
657+
sparkSession: SparkSession,
658+
options: Map[String, String],
659+
files: Seq[FileStatus]): Option[StructType] =
660+
Some(StructType(Nil)
661+
.add("a1", IntegerType)
662+
.add("a2",
663+
StructType(Nil)
664+
.add("c1", IntegerType)
665+
.add("c2", IntegerType)))
666+
}
635667

636668
class LocalityTestFileSystem extends RawLocalFileSystem {
637669
private val invocations = new AtomicInteger(0)

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
2020
import java.nio.charset.StandardCharsets
2121

2222
import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
23-
import org.apache.parquet.filter2.predicate.FilterApi._
23+
import org.apache.parquet.filter2.predicate.FilterApi.{and, gt, lt}
2424
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}
2525
import org.apache.parquet.hadoop.ParquetInputFormat
2626

@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
3131
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation}
32+
import org.apache.spark.sql.execution.datasources.parquet.ParquetColumns.{doubleColumn, intColumn}
3233
import org.apache.spark.sql.functions._
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.test.SharedSQLContext
@@ -542,4 +543,53 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
542543
// scalastyle:on nonascii
543544
}
544545
}
546+
547+
test("SPARK-20364: Predicate pushdown for columns with a '.' in them") {
548+
import testImplicits._
549+
550+
Seq(true, false).foreach { vectorized =>
551+
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
552+
withTempPath { path =>
553+
Seq(Some(1), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
554+
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` > 0").count() == 1)
555+
}
556+
557+
withTempPath { path =>
558+
Seq(Some(1L), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
559+
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` >= 1L").count() == 1)
560+
}
561+
562+
withTempPath { path =>
563+
Seq(Some(1.0F), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
564+
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` < 2.0").count() == 1)
565+
}
566+
567+
withTempPath { path =>
568+
Seq(Some(1.0D), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
569+
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` <= 1.0D").count() == 1)
570+
}
571+
572+
withTempPath { path =>
573+
Seq(true, false).toDF("col.dots").write.parquet(path.getAbsolutePath)
574+
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` == true").count() == 1)
575+
}
576+
577+
withTempPath { path =>
578+
Seq("apple", null).toDF("col.dots").write.parquet(path.getAbsolutePath)
579+
assert(
580+
spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NOT NULL").count() == 1)
581+
}
582+
}
583+
}
584+
585+
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString) {
586+
withTempPath { path =>
587+
Seq("apple", null).toDF("col.dots").write.parquet(path.getAbsolutePath)
588+
// This checks record-by-record filtering in Parquet's filter2.
589+
val num = stripSparkFilter(
590+
spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NULL")).count()
591+
assert(num == 1)
592+
}
593+
}
594+
}
545595
}

0 commit comments

Comments
 (0)