Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column, SupportsEqNotEq, SupportsLtGt}
import org.apache.parquet.hadoop.metadata.ColumnPath
import org.apache.parquet.io.api.Binary

import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand All @@ -30,6 +31,8 @@ import org.apache.spark.sql.types._
*/
private[parquet] object ParquetFilters {

import ParquetColumns._

case class SetInFilter[T <: Comparable[T]](valueSet: Set[T])
extends UserDefinedPredicate[T] with Serializable {

Expand Down Expand Up @@ -344,3 +347,39 @@ private[parquet] object ParquetFilters {
}
}
}

/**
* Note that, this is a hacky workaround to allow dots in column names. Currently, column APIs
* in Parquet's `FilterApi` only allows dot-separated names so here we resemble those columns
* but only allow single column path that allows dots in the names as we don't currently push
* down filters with nested fields.
*/
private[parquet] object ParquetColumns {
def intColumn(columnPath: String): Column[Integer] with SupportsLtGt = {
new Column[Integer] (ColumnPath.get(columnPath), classOf[Integer]) with SupportsLtGt
}

def longColumn(columnPath: String): Column[java.lang.Long] with SupportsLtGt = {
new Column[java.lang.Long] (
ColumnPath.get(columnPath), classOf[java.lang.Long]) with SupportsLtGt
}

def floatColumn(columnPath: String): Column[java.lang.Float] with SupportsLtGt = {
new Column[java.lang.Float] (
ColumnPath.get(columnPath), classOf[java.lang.Float]) with SupportsLtGt
}

def doubleColumn(columnPath: String): Column[java.lang.Double] with SupportsLtGt = {
new Column[java.lang.Double] (
ColumnPath.get(columnPath), classOf[java.lang.Double]) with SupportsLtGt
}

def booleanColumn(columnPath: String): Column[java.lang.Boolean] with SupportsEqNotEq = {
new Column[java.lang.Boolean] (
ColumnPath.get(columnPath), classOf[java.lang.Boolean]) with SupportsEqNotEq
}

def binaryColumn(columnPath: String): Column[Binary] with SupportsLtGt = {
new Column[Binary] (ColumnPath.get(columnPath), classOf[Binary]) with SupportsLtGt
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,21 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}

test("no filter pushdown for nested field access") {
val table = createTable(
files = Seq("file1" -> 1),
format = classOf[TestFileFormatWithNestedSchema].getName)

checkScan(table.where("a1 = 1"))(_ => ())
// Check `a1` access pushes the predicate.
checkDataFilters(Set(IsNotNull("a1"), EqualTo("a1", 1)))

// TODO: reenable check below once we update to latest master
// checkScan(table.where("a2.c1 = 1"))(_ => ())
// // Check `a2.c1` access does not push the predicate.
// checkDataFilters(Set(IsNotNull("a2")))
}

// Helpers for checking the arguments passed to the FileFormat.

protected val checkPartitionSchema =
Expand Down Expand Up @@ -537,7 +552,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
*/
def createTable(
files: Seq[(String, Int)],
buckets: Int = 0): DataFrame = {
buckets: Int = 0,
format: String = classOf[TestFileFormat].getName): DataFrame = {
val tempDir = Utils.createTempDir()
files.foreach {
case (name, size) =>
Expand All @@ -547,7 +563,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}

val df = spark.read
.format(classOf[TestFileFormat].getName)
.format(format)
.load(tempDir.getCanonicalPath)

if (buckets > 0) {
Expand Down Expand Up @@ -632,6 +648,22 @@ class TestFileFormat extends TextBasedFileFormat {
}
}

/**
* A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing.
* Unlike the one above, this one has a nested schema.
*/
class TestFileFormatWithNestedSchema extends TestFileFormat {
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] =
Some(StructType(Nil)
.add("a1", IntegerType)
.add("a2",
StructType(Nil)
.add("c1", IntegerType)
.add("c2", IntegerType)))
}

class LocalityTestFileSystem extends RawLocalFileSystem {
private val invocations = new AtomicInteger(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.nio.charset.StandardCharsets

import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.FilterApi.{and, gt, lt}
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}
import org.apache.parquet.hadoop.ParquetInputFormat

Expand All @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.parquet.ParquetColumns.{doubleColumn, intColumn}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -542,4 +543,53 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// scalastyle:on nonascii
}
}

test("SPARK-20364: Predicate pushdown for columns with a '.' in them") {
import testImplicits._

Seq(true, false).foreach { vectorized =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
withTempPath { path =>
Seq(Some(1), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` > 0").count() == 1)
}

withTempPath { path =>
Seq(Some(1L), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` >= 1L").count() == 1)
}

withTempPath { path =>
Seq(Some(1.0F), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` < 2.0").count() == 1)
}

withTempPath { path =>
Seq(Some(1.0D), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` <= 1.0D").count() == 1)
}

withTempPath { path =>
Seq(true, false).toDF("col.dots").write.parquet(path.getAbsolutePath)
assert(spark.read.parquet(path.getAbsolutePath).where("`col.dots` == true").count() == 1)
}

withTempPath { path =>
Seq("apple", null).toDF("col.dots").write.parquet(path.getAbsolutePath)
assert(
spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NOT NULL").count() == 1)
}
}
}

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString) {
withTempPath { path =>
Seq("apple", null).toDF("col.dots").write.parquet(path.getAbsolutePath)
// This checks record-by-record filtering in Parquet's filter2.
val num = stripSparkFilter(
spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NULL")).count()
assert(num == 1)
}
}
}
}