From 6b1d5dc8874bba7c707428818123ec63fd7e84f0 Mon Sep 17 00:00:00 2001 From: Ameen Tayyebi Date: Wed, 27 Dec 2017 21:56:13 -0500 Subject: [PATCH] [SPARK-22913][SQL] Improved Hive Partition Pruning Adding support for Timestamp and Fractional column types. The pruning of partitions of these types is being put behind default options that are set to false, as it's not clear which hive metastore implementations support predicates on these types of columns. The AWS Glue Catalog http://docs.aws.amazon.com/glue/latest/dg/populate-data-catalog.html does support filters on timestamp and fractional columns and pushing these filters down to it has significant performance improvements in our use cases. As part of this change the hive pruning suite is renamed (a TODO) and 2 ignored tests are added that will validate the functionality of partition pruning through integration tests. The tests are ignored since the integration test setup uses a Hive client that throws errors when it sees partition column filters on non-integral and non-string columns. Unit tests are added to validate filtering, which are active. --- .../apache/spark/sql/internal/SQLConf.scala | 21 ++ .../spark/sql/hive/client/HiveShim.scala | 28 ++- .../spark/sql/hive/client/FiltersSuite.scala | 52 +++- .../sql/hive/client/HiveClientSuites.scala | 2 +- ...cala => HivePartitionFilteringSuite.scala} | 233 +++++++++++------- 5 files changed, 242 insertions(+), 94 deletions(-) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/client/{HiveClientSuite.scala => HivePartitionFilteringSuite.scala} (55%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f16972e5427e2..48d4e2aa83a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -187,6 +187,21 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING = + buildConf("spark.sql.hive.metastore.partition.pruning.timestamps.enabled") + .internal() + .doc("When true, predicates for columns of type timestamp are pushed to hive metastore.") + .booleanConf + .createWithDefault(false) + + val ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING = + buildConf("spark.sql.hive.metastore.partition.pruning.fractionals.enabled") + .internal() + .doc("When true, predicates for columns of type fractional (double, float, decimal) " + + "are pushed to hive metastore.") + .booleanConf + .createWithDefault(false) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -1222,6 +1237,12 @@ class SQLConf extends Serializable with Logging { def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def pruneTimestampPartitionColumns: Boolean = + getConf(ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING) + + def pruneFractionalPartitionColumns: Boolean = + getConf(ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19cd..3c80bcd2d9e9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -20,16 +20,15 @@ package org.apache.spark.sql.hive.client import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{InvocationTargetException, Method, Modifier} import java.net.URI -import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, Set => JSet} +import java.util.{Locale, ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.util.Try import scala.util.control.NonFatal - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, Function => HiveFunction, FunctionType} +import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, FunctionType, Function => HiveFunction} import org.apache.hadoop.hive.metastore.api.{MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.io.AcidUtils @@ -38,15 +37,16 @@ import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants - import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{FractionalType, IntegralType, StringType, TimestampType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -598,9 +598,17 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } object ExtractableLiteral { + val pruneTimestamps = SQLConf.get.pruneTimestampPartitionColumns + val pruneFractionals = SQLConf.get.pruneFractionalPartitionColumns + def unapply(expr: Expression): Option[String] = expr match { case Literal(value, _: IntegralType) => Some(value.toString) case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) + case Literal(value, _: FractionalType) if pruneFractionals => Some(value.toString) + // Timestamp must be converted to yyyy-mm-dd hh:mm:ss[.fffffffff] format before + // it can be used for partition pruning + case Literal(value: SQLTimestamp, _: TimestampType) if pruneTimestamps => + Some(s"'${DateTimeUtils.timestampToString(value, DateTimeUtils.TimeZoneUTC)}'") case _ => None } } @@ -641,6 +649,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet + if (varcharKeys.nonEmpty) { + logDebug(s"Following table columns will be ignored in " + + s"partition pruning because their type is varchar: $varcharKeys") + } def unapply(attr: Attribute): Option[String] = { if (varcharKeys.contains(attr.name)) { @@ -687,7 +699,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case _ => None } - filters.flatMap(convert).mkString(" and ") + val result = filters.flatMap(convert).mkString(" and ") + logDebug(s"Conversion of $filters for metastore partition pruning resulted in $result") + + result } private def quoteStringLiteral(str: String): String = { @@ -714,7 +729,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { - logDebug(s"Hive metastore filter is '$filter'.") val tryDirectSqlConfVar = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL // We should get this config value from the metaStore. otherwise hit SPARK-18681. // To be compatible with hive-0.12 and hive-0.13, In the future we can achieve this by: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 19765695fbcb4..14ac907dd9f62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.hive.client +import java.sql.Timestamp +import java.time.Instant import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants - import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -60,7 +61,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { "1 = intcol") filterTest("int and string filter", - (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", StringType)) :: Nil, "1 = intcol and \"a\" = strcol") filterTest("skip varchar", @@ -72,9 +73,19 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + filterTest("timestamp partition columns must be mapped to yyyy-mm-dd hh:mm:ss[.fffffffff] format", + (a("timestampcol", TimestampType) === Literal(812505600000000L, TimestampType)) :: Nil, + "timestampcol = '1995-10-01 00:00:00'") + + filterTest("decimal filter", + (Literal(50D) === a("deccol", DecimalType(2, 0))) :: Nil, + "50.0 = deccol") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { - withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true", + SQLConf.ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING.key -> "true", + SQLConf.ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING.key -> "true") { val converted = shim.convertFilters(testTable, filters) if (converted != result) { fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") @@ -100,5 +111,40 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { } } + test("turn on/off HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING.key -> enabled.toString) { + val filters = + (Literal(1.0F) === a("floatcol", FloatType) || + Literal(2.0D) === a("doublecol", DoubleType) || + Literal(BigDecimal(3.0D)) === a("deccol", DecimalType(10, 0))) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "((1.0 = floatcol or 2.0 = doublecol) or 3.0 = deccol)") + } else { + assert(converted.isEmpty) + } + } + } + } + + test("turn on/off HIVE_TIMESTAMP_PARTITION_PRUNING") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + val october23rd = Instant.parse("1984-10-23T00:00:00.00Z") + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING.key -> enabled.toString) { + val filters = (Literal(new Timestamp(october23rd.toEpochMilli)) + === a("tcol", TimestampType)) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "'1984-10-23 00:00:00' = tcol") + } else { + assert(converted.isEmpty) + } + } + } + } + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala index de1be2115b2d8..6fe495b55e780 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala @@ -24,6 +24,6 @@ import org.scalatest.Suite class HiveClientSuites extends Suite with HiveClientVersions { override def nestedSuites: IndexedSeq[Suite] = { // Hive 0.12 does not provide the partition filtering API we call - versions.filterNot(_ == "0.12").map(new HiveClientSuite(_)) + versions.filterNot(_ == "0.12").map(new HivePartitionFilteringSuite(_)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala similarity index 55% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index ce53acef51503..22ef7c29cf7af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -17,24 +17,36 @@ package org.apache.spark.sql.hive.client +import java.time.Instant + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf -import org.scalatest.BeforeAndAfterAll - +import org.apache.spark.sql.Column +import org.scalatest.{BeforeAndAfterAll, Ignore} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EmptyRow, EqualTo, Expression, In, InSet, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types.DataTypes -// TODO: Refactor this to `HivePartitionFilteringSuite` -class HiveClientSuite(version: String) +class HivePartitionFilteringSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname - private val testPartitionCount = 3 * 24 * 4 + private val testPartitionCount = 3 * 24 * 4 * 2 * 4 + private val chunkValues = Seq("aa", "ab", "ba", "bb") + private val dValues = 20170101 to 20170103 + private val hValues = 0 to 23 + private val tValues = + Seq(Instant.parse("2017-12-24T00:00:00.00Z"), Instant.parse("2017-12-25T00:00:00.00Z")) + private val decValues = Seq(BigDecimal(1D), BigDecimal(2D), BigDecimal(3D), BigDecimal(4D)) private def init(tryDirectSql: Boolean): HiveClient = { + val hadoopConf = new Configuration() + hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) + val client = buildClient(hadoopConf) + val storageFormat = CatalogStorageFormat( locationUri = None, inputFormat = None, @@ -43,26 +55,29 @@ class HiveClientSuite(version: String) compressed = false, properties = Map.empty) - val hadoopConf = new Configuration() - hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) - val client = buildClient(hadoopConf) client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + .runSqlHive("CREATE TABLE test (value INT) " + + "PARTITIONED BY (ds INT, h INT, chunk STRING, t TIMESTAMP, d DECIMAL)") val partitions = for { - ds <- 20170101 to 20170103 - h <- 0 to 23 - chunk <- Seq("aa", "ab", "ba", "bb") + ds <- dValues + h <- hValues + chunk <- chunkValues + t <- tValues + d <- decValues } yield CatalogTablePartition(Map( "ds" -> ds.toString, "h" -> h.toString, - "chunk" -> chunk + "chunk" -> chunk, + "t" -> t.getEpochSecond.toString, + "d" -> d.toString ), storageFormat) assert(partitions.size == testPartitionCount) client.createPartitions( "default", "test", partitions, ignoreIfExists = false) + client } @@ -80,19 +95,11 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported - testMetastorePartitionFiltering( - "ds<=>20170101", - 20170101 to 20170103, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + assertNoFilterIsApplied("ds<=>20170101") } test("getPartitionsByFilter: ds=20170101") { - testMetastorePartitionFiltering( - "ds=20170101", - 20170101 to 20170101, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + testMetastorePartitionFiltering("ds=20170101", 20170101 to 20170101) } test("getPartitionsByFilter: ds=(20170101 + 1) and h=0") { @@ -100,93 +107,114 @@ class HiveClientSuite(version: String) // comparisons to non-literal values testMetastorePartitionFiltering( "ds=(20170101 + 1) and h=0", - 20170101 to 20170103, - 0 to 0, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + dValues, 0 to 0) } test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( "chunk='aa'", - 20170101 to 20170103, - 0 to 23, + dValues, hValues, "aa" :: Nil) } test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( "20170101=ds", + 20170101 to 20170101) + } + + test("getPartitionsByFilter: must ignore unsupported expressions") { + testMetastorePartitionFiltering( + "ds is not null and chunk is not null and 20170101=ds and chunk = 'aa'", 20170101 to 20170101, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + hValues, + "aa" :: Nil) + } + + test("getPartitionsByFilter: multiple or single expressions expressions yield the same result") { + testMetastorePartitionFiltering( + "ds is not null and chunk is not null and (20170101=ds) and (chunk = 'aa')", + 20170101 to 20170101, + hValues, + "aa" :: Nil) + + testMetastorePartitionFiltering( + Seq(parseExpression("ds is not null"), + parseExpression("chunk is not null"), + parseExpression("(20170101=ds)"), + EqualTo(AttributeReference("chunk", DataTypes.StringType)(), Literal.apply("aa"))), + 20170101 to 20170101, + hValues, + "aa" :: Nil, + tValues, + decValues) } test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( "ds=20170101 and h=10", 20170101 to 20170101, - 10 to 10, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + 10 to 10) } test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( "ds=20170101 or ds=20170102", - 20170101 to 20170102, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + 20170101 to 20170102) } test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( "ds in (20170102, 20170103)", - 20170102 to 20170103, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + 20170102 to 20170103) } test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", - 20170102 to 20170103, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + Seq(parseExpression("ds in (20170102, 20170103)") match { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) - }) + }), + 20170102 to 20170103, + hValues, + chunkValues, + tValues, + decValues) } test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( "chunk in ('ab', 'ba')", - 20170101 to 20170103, - 0 to 23, + dValues, + hValues, "ab" :: "ba" :: Nil) } test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", - 20170101 to 20170103, - 0 to 23, - "ab" :: "ba" :: Nil, { + Seq(parseExpression("chunk in ('ab', 'ba')") match { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) - }) + }), + dValues, + hValues, + "ab" :: "ba" :: Nil, + tValues, + decValues) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) - val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) + val day1 = (20170101 to 20170101, 8 to 23, chunkValues, tValues, decValues) + val day2 = (20170102 to 20170102, 0 to 7, chunkValues, tValues, decValues) testMetastorePartitionFiltering( "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + val day1 = (20170101 to 20170101, 8 to 23, chunkValues, tValues, decValues) // Day 2 should include all hours because we can't build a filter for h<(7+1) - val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) + val day2 = (20170102 to 20170102, 0 to 23, chunkValues, tValues, decValues) testMetastorePartitionFiltering( "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", day1 :: day2 :: Nil) @@ -194,66 +222,105 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) - val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) + val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba"), tValues, decValues) + val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba"), tValues, decValues) testMetastorePartitionFiltering( "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", day1 :: day2 :: Nil) } - private def testMetastorePartitionFiltering( - filterString: String, - expectedDs: Seq[Int], - expectedH: Seq[Int], - expectedChunks: Seq[String]): Unit = { - testMetastorePartitionFiltering( - filterString, - (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + ignore("TODO: create hive metastore for integration test " + + "that supports timestamp and decimal pruning") { + test("getPartitionsByFilter: t = '2017-12-24T00:00:00.00Z' (timestamp test)") { + testMetastorePartitionFiltering( + "t = '2017-12-24T00:00:00.00Z'", + dValues, + hValues, + chunkValues, + Seq(Instant.parse("2017-12-24T00:00:00.00Z")) + ) + } + + test("getPartitionsByFilter: d = 4.0 (decimal test)") { + testMetastorePartitionFiltering( + "d = 4.0", + dValues, + hValues, + chunkValues, + tValues, + Seq(BigDecimal(4.0D)) + ) + } + } + + private def assertNoFilterIsApplied(expression: String) = { + val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), + Seq(parseExpression(expression))) + + assert(filteredPartitions.size == testPartitionCount) } private def testMetastorePartitionFiltering( - filterString: String, + filters: Seq[Expression], expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], - transform: Expression => Expression): Unit = { + expectedTs: Seq[Instant], + expectedDecs: Seq[BigDecimal]): Unit = { testMetastorePartitionFiltering( - filterString, - (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + filters, + (expectedDs, expectedH, expectedChunks, expectedTs, expectedDecs) :: Nil) } private def testMetastorePartitionFiltering( filterString: String, - expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + expectedDs: Seq[Int] = dValues, + expectedH: Seq[Int] = hValues, + expectedChunks: Seq[String] = chunkValues, + expectedTs: Seq[Instant] = tValues, + expectedDecs: Seq[BigDecimal] = decValues): Unit = { + testMetastorePartitionFiltering(Seq(parseExpression(filterString)), + expectedDs, + expectedH, + expectedChunks, + expectedTs, + expectedDecs) } private def testMetastorePartitionFiltering( filterString: String, - expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], - transform: Expression => Expression): Unit = { - val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq( - transform(parseExpression(filterString)) - )) + expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String], Seq[Instant], Seq[BigDecimal])] + ): Unit = { + testMetastorePartitionFiltering(Seq(parseExpression(filterString)), + expectedPartitionCubes) + } + + private def testMetastorePartitionFiltering( + predicates: Seq[Expression], + expectedPartitionCubes: + Seq[(Seq[Int], Seq[Int], Seq[String], Seq[Instant], Seq[BigDecimal])]): Unit = { + val filteredPartitions = + client.getPartitionsByFilter(client.getTable("default", "test"), predicates) val expectedPartitionCount = expectedPartitionCubes.map { - case (expectedDs, expectedH, expectedChunks) => - expectedDs.size * expectedH.size * expectedChunks.size + case (expectedDs, expectedH, expectedChunks, expectedTs, expectedDecs) => + expectedDs.size * expectedH.size * expectedChunks.size * expectedTs.size * expectedDecs.size }.sum val expectedPartitions = expectedPartitionCubes.map { - case (expectedDs, expectedH, expectedChunks) => + case (expectedDs, expectedH, expectedChunks, expectedTs, expectedDecs) => for { ds <- expectedDs h <- expectedH chunk <- expectedChunks + t <- expectedTs + d <- expectedDecs } yield Set( "ds" -> ds.toString, "h" -> h.toString, - "chunk" -> chunk + "chunk" -> chunk, + "t" -> t.getEpochSecond.toString, + "d" -> d.toString() ) }.reduce(_ ++ _)