|
18 | 18 | package org.apache.spark.sql.execution.datasources.parquet |
19 | 19 |
|
20 | 20 | import java.io.File |
| 21 | +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} |
21 | 22 | import java.math.{BigDecimal => JBigDecimal} |
22 | 23 | import java.nio.charset.StandardCharsets |
23 | 24 | import java.sql.{Date, Timestamp} |
@@ -902,6 +903,76 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared |
902 | 903 | } |
903 | 904 | } |
904 | 905 |
|
| 906 | + test("don't push down filters that would result in overflows") { |
| 907 | + val schema = StructType(Seq( |
| 908 | + StructField("cbyte", ByteType), |
| 909 | + StructField("cshort", ShortType), |
| 910 | + StructField("cint", IntegerType) |
| 911 | + )) |
| 912 | + |
| 913 | + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) |
| 914 | + val parquetFilters = createParquetFilters(parquetSchema) |
| 915 | + |
| 916 | + for { |
| 917 | + column <- Seq("cbyte", "cshort", "cint") |
| 918 | + value <- Seq(JLong.MAX_VALUE, JLong.MIN_VALUE).map(JLong.valueOf) |
| 919 | + } { |
| 920 | + val filters = Seq( |
| 921 | + sources.LessThan(column, value), |
| 922 | + sources.LessThanOrEqual(column, value), |
| 923 | + sources.GreaterThan(column, value), |
| 924 | + sources.GreaterThanOrEqual(column, value), |
| 925 | + sources.EqualTo(column, value), |
| 926 | + sources.EqualNullSafe(column, value), |
| 927 | + sources.Not(sources.EqualTo(column, value)), |
| 928 | + sources.In(column, Array(value)) |
| 929 | + ) |
| 930 | + for (filter <- filters) { |
| 931 | + assert(parquetFilters.createFilter(filter).isEmpty, |
| 932 | + s"Row group filter $filter shouldn't be pushed down.") |
| 933 | + } |
| 934 | + } |
| 935 | + } |
| 936 | + |
| 937 | + test("don't push down filters when value type doesn't match column type") { |
| 938 | + val schema = StructType(Seq( |
| 939 | + StructField("cbyte", ByteType), |
| 940 | + StructField("cshort", ShortType), |
| 941 | + StructField("cint", IntegerType), |
| 942 | + StructField("clong", LongType), |
| 943 | + StructField("cfloat", FloatType), |
| 944 | + StructField("cdouble", DoubleType), |
| 945 | + StructField("cboolean", BooleanType), |
| 946 | + StructField("cstring", StringType), |
| 947 | + StructField("cdate", DateType), |
| 948 | + StructField("ctimestamp", TimestampType), |
| 949 | + StructField("cbinary", BinaryType), |
| 950 | + StructField("cdecimal", DecimalType(10, 0)) |
| 951 | + )) |
| 952 | + |
| 953 | + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) |
| 954 | + val parquetFilters = createParquetFilters(parquetSchema) |
| 955 | + |
| 956 | + val filters = Seq( |
| 957 | + sources.LessThan("cbyte", String.valueOf("1")), |
| 958 | + sources.LessThan("cshort", JBigDecimal.valueOf(1)), |
| 959 | + sources.LessThan("cint", JFloat.valueOf(JFloat.NaN)), |
| 960 | + sources.LessThan("clong", String.valueOf("1")), |
| 961 | + sources.LessThan("cfloat", JDouble.valueOf(1.0D)), |
| 962 | + sources.LessThan("cdouble", JFloat.valueOf(1.0F)), |
| 963 | + sources.LessThan("cboolean", String.valueOf("true")), |
| 964 | + sources.LessThan("cstring", Integer.valueOf(1)), |
| 965 | + sources.LessThan("cdate", Timestamp.valueOf("2018-01-01 00:00:00")), |
| 966 | + sources.LessThan("ctimestamp", Date.valueOf("2018-01-01")), |
| 967 | + sources.LessThan("cbinary", Integer.valueOf(1)), |
| 968 | + sources.LessThan("cdecimal", Integer.valueOf(1234)) |
| 969 | + ) |
| 970 | + for (filter <- filters) { |
| 971 | + assert(parquetFilters.createFilter(filter).isEmpty, |
| 972 | + s"Row group filter $filter shouldn't be pushed down.") |
| 973 | + } |
| 974 | + } |
| 975 | + |
905 | 976 | test("SPARK-6554: don't push down predicates which reference partition columns") { |
906 | 977 | import testImplicits._ |
907 | 978 |
|
|
0 commit comments