From a392ef52d7e64d137e48f1fb236506fd6aad2109 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Fri, 24 Nov 2023 15:49:13 +0100 Subject: [PATCH 1/2] Don't push down row group filters that overflow --- .../datasources/parquet/ParquetFilters.scala | 5 +- .../parquet/ParquetFilterSuite.scala | 71 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 20 ++++++ 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index c1d02ba5a227d..1ed5261dbad45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -613,7 +613,10 @@ class ParquetFilters( value == null || (nameToParquetField(name).fieldType match { case ParquetBooleanType => value.isInstanceOf[JBoolean] case ParquetIntegerType if value.isInstanceOf[Period] => true - case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value match { + case v: Number => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } case ParquetLongType => value.isInstanceOf[JLong] || value.isInstanceOf[Duration] case ParquetFloatType => value.isInstanceOf[JFloat] case ParquetDoubleType => value.isInstanceOf[JDouble] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4ed5297ff4ead..d515b30e3de3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.lang.{Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -906,6 +907,76 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } + test("don't push down filters that would result in overflows") { + val schema = StructType(Seq( + StructField("cbyte", ByteType), + StructField("cshort", ShortType), + StructField("cint", IntegerType) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + val parquetFilters = createParquetFilters(parquetSchema) + + for { + column <- Seq("cbyte", "cshort", "cint") + value <- Seq(JLong.MAX_VALUE, JLong.MIN_VALUE): Seq[JLong] + } { + val filters = Seq( + sources.LessThan(column, value), + sources.LessThanOrEqual(column, value), + sources.GreaterThan(column, value), + sources.GreaterThanOrEqual(column, value), + sources.EqualTo(column, value), + sources.EqualNullSafe(column, value), + sources.Not(sources.EqualTo(column, value)), + sources.In(column, Array(value)) + ) + for (filter <- filters) { + assert(parquetFilters.createFilter(filter).isEmpty, + s"Row group filter $filter shouldn't be pushed down.") + } + } + } + + test("don't push down filters when value type doesn't match column type") { + val schema = StructType(Seq( + StructField("cbyte", ByteType), + StructField("cshort", ShortType), + StructField("cint", IntegerType), + StructField("clong", LongType), + StructField("cfloat", FloatType), + StructField("cdouble", DoubleType), + StructField("cboolean", BooleanType), + StructField("cstring", StringType), + StructField("cdate", DateType), + StructField("ctimestamp", TimestampType), + StructField("cbinary", BinaryType), + StructField("cdecimal", DecimalType(10, 0)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + val parquetFilters = createParquetFilters(parquetSchema) + + val filters = Seq( + sources.LessThan("cbyte", "1"), + sources.LessThan("cshort", "1"), + sources.LessThan("cint", "1"), + sources.LessThan("clong", "1"), + sources.LessThan("cfloat", 1.0D), + sources.LessThan("cdouble", 1.0F), + sources.LessThan("cboolean", "true"), + sources.LessThan("cstring", 1), + sources.LessThan("cdate", Timestamp.valueOf("2018-01-01 00:00:00")), + sources.LessThan("ctimestamp", Date.valueOf("2018-01-01")), + sources.LessThan("cbinary", 1), + sources.LessThan("cdecimal", 1234) + ) + for (filter <- filters) { + assert(parquetFilters.createFilter(filter).isEmpty, + s"Row group filter $filter shouldn't be pushed down.") + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index dc8a89c12ca61..43103db522bac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1095,6 +1095,26 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } + test("row group skipping doesn't overflow when reading into larger type") { + withTempPath { path => + Seq(0).toDF("a").write.parquet(path.toString) + // The vectorized and non-vectorized readers will produce different exceptions, we don't need + // to test both as this covers row group skipping. + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + // Reading integer 'a' as a long isn't supported. Check that an exception is raised instead + // of incorrectly skipping the single row group and producing incorrect results. + val exception = intercept[SparkException] { + spark.read + .schema("a LONG") + .parquet(path.toString) + .where(s"a < ${Long.MaxValue}") + .collect() + } + assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + } + } + } + test("SPARK-36825, SPARK-36852: create table with ANSI intervals") { withTable("tbl") { sql("create table tbl (c1 interval day, c2 interval year to month) using parquet") From 1f88c4f44c4eaa95dbf1fe0461fed39ccb5afed1 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Mon, 27 Nov 2023 14:05:54 +0100 Subject: [PATCH 2/2] Don't filter integer columns with non-integer values (float/double, decimal) --- .../datasources/parquet/ParquetFilters.scala | 7 ++++-- .../parquet/ParquetFilterSuite.scala | 24 +++++++++---------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 1ed5261dbad45..fee20d1b86ca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets.UTF_8 import java.sql.{Date, Timestamp} @@ -614,7 +614,10 @@ class ParquetFilters( case ParquetBooleanType => value.isInstanceOf[JBoolean] case ParquetIntegerType if value.isInstanceOf[Period] => true case ParquetByteType | ParquetShortType | ParquetIntegerType => value match { - case v: Number => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type Int. + // We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue case _ => false } case ParquetLongType => value.isInstanceOf[JLong] || value.isInstanceOf[Duration] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index d515b30e3de3a..d7ed5c4d35426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File -import java.lang.{Long => JLong} +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -919,7 +919,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared for { column <- Seq("cbyte", "cshort", "cint") - value <- Seq(JLong.MAX_VALUE, JLong.MIN_VALUE): Seq[JLong] + value <- Seq(JLong.MAX_VALUE, JLong.MIN_VALUE).map(JLong.valueOf) } { val filters = Seq( sources.LessThan(column, value), @@ -958,18 +958,18 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val parquetFilters = createParquetFilters(parquetSchema) val filters = Seq( - sources.LessThan("cbyte", "1"), - sources.LessThan("cshort", "1"), - sources.LessThan("cint", "1"), - sources.LessThan("clong", "1"), - sources.LessThan("cfloat", 1.0D), - sources.LessThan("cdouble", 1.0F), - sources.LessThan("cboolean", "true"), - sources.LessThan("cstring", 1), + sources.LessThan("cbyte", String.valueOf("1")), + sources.LessThan("cshort", JBigDecimal.valueOf(1)), + sources.LessThan("cint", JFloat.valueOf(JFloat.NaN)), + sources.LessThan("clong", String.valueOf("1")), + sources.LessThan("cfloat", JDouble.valueOf(1.0D)), + sources.LessThan("cdouble", JFloat.valueOf(1.0F)), + sources.LessThan("cboolean", String.valueOf("true")), + sources.LessThan("cstring", Integer.valueOf(1)), sources.LessThan("cdate", Timestamp.valueOf("2018-01-01 00:00:00")), sources.LessThan("ctimestamp", Date.valueOf("2018-01-01")), - sources.LessThan("cbinary", 1), - sources.LessThan("cdecimal", 1234) + sources.LessThan("cbinary", Integer.valueOf(1)), + sources.LessThan("cdecimal", Integer.valueOf(1234)) ) for (filter <- filters) { assert(parquetFilters.createFilter(filter).isEmpty,