diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f16972e5427e2..48d4e2aa83a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -187,6 +187,21 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING = + buildConf("spark.sql.hive.metastore.partition.pruning.timestamps.enabled") + .internal() + .doc("When true, predicates for columns of type timestamp are pushed to hive metastore.") + .booleanConf + .createWithDefault(false) + + val ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING = + buildConf("spark.sql.hive.metastore.partition.pruning.fractionals.enabled") + .internal() + .doc("When true, predicates for columns of type fractional (double, float, decimal) " + + "are pushed to hive metastore.") + .booleanConf + .createWithDefault(false) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -1222,6 +1237,12 @@ class SQLConf extends Serializable with Logging { def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def pruneTimestampPartitionColumns: Boolean = + getConf(ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING) + + def pruneFractionalPartitionColumns: Boolean = + getConf(ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19cd..3c80bcd2d9e9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -20,16 +20,15 @@ package org.apache.spark.sql.hive.client import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{InvocationTargetException, Method, Modifier} import java.net.URI -import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, Set => JSet} +import java.util.{Locale, ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.util.Try import scala.util.control.NonFatal - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, Function => HiveFunction, FunctionType} +import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, FunctionType, Function => HiveFunction} import org.apache.hadoop.hive.metastore.api.{MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.io.AcidUtils @@ -38,15 +37,16 @@ import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants - import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{FractionalType, IntegralType, StringType, TimestampType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -598,9 +598,17 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } object ExtractableLiteral { + val pruneTimestamps = SQLConf.get.pruneTimestampPartitionColumns + val pruneFractionals = SQLConf.get.pruneFractionalPartitionColumns + def unapply(expr: Expression): Option[String] = expr match { case Literal(value, _: IntegralType) => Some(value.toString) case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) + case Literal(value, _: FractionalType) if pruneFractionals => Some(value.toString) + // Timestamp must be converted to yyyy-mm-dd hh:mm:ss[.fffffffff] format before + // it can be used for partition pruning + case Literal(value: SQLTimestamp, _: TimestampType) if pruneTimestamps => + Some(s"'${DateTimeUtils.timestampToString(value, DateTimeUtils.TimeZoneUTC)}'") case _ => None } } @@ -641,6 +649,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet + if (varcharKeys.nonEmpty) { + logDebug(s"Following table columns will be ignored in " + + s"partition pruning because their type is varchar: $varcharKeys") + } def unapply(attr: Attribute): Option[String] = { if (varcharKeys.contains(attr.name)) { @@ -687,7 +699,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case _ => None } - filters.flatMap(convert).mkString(" and ") + val result = filters.flatMap(convert).mkString(" and ") + logDebug(s"Conversion of $filters for metastore partition pruning resulted in $result") + + result } private def quoteStringLiteral(str: String): String = { @@ -714,7 +729,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { - logDebug(s"Hive metastore filter is '$filter'.") val tryDirectSqlConfVar = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL // We should get this config value from the metaStore. otherwise hit SPARK-18681. // To be compatible with hive-0.12 and hive-0.13, In the future we can achieve this by: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 19765695fbcb4..14ac907dd9f62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.hive.client +import java.sql.Timestamp +import java.time.Instant import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants - import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -60,7 +61,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { "1 = intcol") filterTest("int and string filter", - (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", StringType)) :: Nil, "1 = intcol and \"a\" = strcol") filterTest("skip varchar", @@ -72,9 +73,19 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + filterTest("timestamp partition columns must be mapped to yyyy-mm-dd hh:mm:ss[.fffffffff] format", + (a("timestampcol", TimestampType) === Literal(812505600000000L, TimestampType)) :: Nil, + "timestampcol = '1995-10-01 00:00:00'") + + filterTest("decimal filter", + (Literal(50D) === a("deccol", DecimalType(2, 0))) :: Nil, + "50.0 = deccol") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { - withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true", + SQLConf.ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING.key -> "true", + SQLConf.ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING.key -> "true") { val converted = shim.convertFilters(testTable, filters) if (converted != result) { fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") @@ -100,5 +111,40 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { } } + test("turn on/off HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ENABLE_HIVE_FRACTIONAL_TYPES_PARTITION_PRUNING.key -> enabled.toString) { + val filters = + (Literal(1.0F) === a("floatcol", FloatType) || + Literal(2.0D) === a("doublecol", DoubleType) || + Literal(BigDecimal(3.0D)) === a("deccol", DecimalType(10, 0))) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "((1.0 = floatcol or 2.0 = doublecol) or 3.0 = deccol)") + } else { + assert(converted.isEmpty) + } + } + } + } + + test("turn on/off HIVE_TIMESTAMP_PARTITION_PRUNING") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + val october23rd = Instant.parse("1984-10-23T00:00:00.00Z") + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ENABLE_HIVE_TIMESTAMP_TYPE_PARTITION_PRUNING.key -> enabled.toString) { + val filters = (Literal(new Timestamp(october23rd.toEpochMilli)) + === a("tcol", TimestampType)) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "'1984-10-23 00:00:00' = tcol") + } else { + assert(converted.isEmpty) + } + } + } + } + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala index de1be2115b2d8..6fe495b55e780 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala @@ -24,6 +24,6 @@ import org.scalatest.Suite class HiveClientSuites extends Suite with HiveClientVersions { override def nestedSuites: IndexedSeq[Suite] = { // Hive 0.12 does not provide the partition filtering API we call - versions.filterNot(_ == "0.12").map(new HiveClientSuite(_)) + versions.filterNot(_ == "0.12").map(new HivePartitionFilteringSuite(_)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala similarity index 55% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index ce53acef51503..22ef7c29cf7af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -17,24 +17,36 @@ package org.apache.spark.sql.hive.client +import java.time.Instant + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf -import org.scalatest.BeforeAndAfterAll - +import org.apache.spark.sql.Column +import org.scalatest.{BeforeAndAfterAll, Ignore} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EmptyRow, EqualTo, Expression, In, InSet, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types.DataTypes -// TODO: Refactor this to `HivePartitionFilteringSuite` -class HiveClientSuite(version: String) +class HivePartitionFilteringSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname - private val testPartitionCount = 3 * 24 * 4 + private val testPartitionCount = 3 * 24 * 4 * 2 * 4 + private val chunkValues = Seq("aa", "ab", "ba", "bb") + private val dValues = 20170101 to 20170103 + private val hValues = 0 to 23 + private val tValues = + Seq(Instant.parse("2017-12-24T00:00:00.00Z"), Instant.parse("2017-12-25T00:00:00.00Z")) + private val decValues = Seq(BigDecimal(1D), BigDecimal(2D), BigDecimal(3D), BigDecimal(4D)) private def init(tryDirectSql: Boolean): HiveClient = { + val hadoopConf = new Configuration() + hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) + val client = buildClient(hadoopConf) + val storageFormat = CatalogStorageFormat( locationUri = None, inputFormat = None, @@ -43,26 +55,29 @@ class HiveClientSuite(version: String) compressed = false, properties = Map.empty) - val hadoopConf = new Configuration() - hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) - val client = buildClient(hadoopConf) client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + .runSqlHive("CREATE TABLE test (value INT) " + + "PARTITIONED BY (ds INT, h INT, chunk STRING, t TIMESTAMP, d DECIMAL)") val partitions = for { - ds <- 20170101 to 20170103 - h <- 0 to 23 - chunk <- Seq("aa", "ab", "ba", "bb") + ds <- dValues + h <- hValues + chunk <- chunkValues + t <- tValues + d <- decValues } yield CatalogTablePartition(Map( "ds" -> ds.toString, "h" -> h.toString, - "chunk" -> chunk + "chunk" -> chunk, + "t" -> t.getEpochSecond.toString, + "d" -> d.toString ), storageFormat) assert(partitions.size == testPartitionCount) client.createPartitions( "default", "test", partitions, ignoreIfExists = false) + client } @@ -80,19 +95,11 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported - testMetastorePartitionFiltering( - "ds<=>20170101", - 20170101 to 20170103, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + assertNoFilterIsApplied("ds<=>20170101") } test("getPartitionsByFilter: ds=20170101") { - testMetastorePartitionFiltering( - "ds=20170101", - 20170101 to 20170101, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + testMetastorePartitionFiltering("ds=20170101", 20170101 to 20170101) } test("getPartitionsByFilter: ds=(20170101 + 1) and h=0") { @@ -100,93 +107,114 @@ class HiveClientSuite(version: String) // comparisons to non-literal values testMetastorePartitionFiltering( "ds=(20170101 + 1) and h=0", - 20170101 to 20170103, - 0 to 0, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + dValues, 0 to 0) } test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( "chunk='aa'", - 20170101 to 20170103, - 0 to 23, + dValues, hValues, "aa" :: Nil) } test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( "20170101=ds", + 20170101 to 20170101) + } + + test("getPartitionsByFilter: must ignore unsupported expressions") { + testMetastorePartitionFiltering( + "ds is not null and chunk is not null and 20170101=ds and chunk = 'aa'", 20170101 to 20170101, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + hValues, + "aa" :: Nil) + } + + test("getPartitionsByFilter: multiple or single expressions expressions yield the same result") { + testMetastorePartitionFiltering( + "ds is not null and chunk is not null and (20170101=ds) and (chunk = 'aa')", + 20170101 to 20170101, + hValues, + "aa" :: Nil) + + testMetastorePartitionFiltering( + Seq(parseExpression("ds is not null"), + parseExpression("chunk is not null"), + parseExpression("(20170101=ds)"), + EqualTo(AttributeReference("chunk", DataTypes.StringType)(), Literal.apply("aa"))), + 20170101 to 20170101, + hValues, + "aa" :: Nil, + tValues, + decValues) } test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( "ds=20170101 and h=10", 20170101 to 20170101, - 10 to 10, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + 10 to 10) } test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( "ds=20170101 or ds=20170102", - 20170101 to 20170102, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + 20170101 to 20170102) } test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( "ds in (20170102, 20170103)", - 20170102 to 20170103, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil) + 20170102 to 20170103) } test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", - 20170102 to 20170103, - 0 to 23, - "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + Seq(parseExpression("ds in (20170102, 20170103)") match { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) - }) + }), + 20170102 to 20170103, + hValues, + chunkValues, + tValues, + decValues) } test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( "chunk in ('ab', 'ba')", - 20170101 to 20170103, - 0 to 23, + dValues, + hValues, "ab" :: "ba" :: Nil) } test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", - 20170101 to 20170103, - 0 to 23, - "ab" :: "ba" :: Nil, { + Seq(parseExpression("chunk in ('ab', 'ba')") match { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) - }) + }), + dValues, + hValues, + "ab" :: "ba" :: Nil, + tValues, + decValues) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) - val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) + val day1 = (20170101 to 20170101, 8 to 23, chunkValues, tValues, decValues) + val day2 = (20170102 to 20170102, 0 to 7, chunkValues, tValues, decValues) testMetastorePartitionFiltering( "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + val day1 = (20170101 to 20170101, 8 to 23, chunkValues, tValues, decValues) // Day 2 should include all hours because we can't build a filter for h<(7+1) - val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) + val day2 = (20170102 to 20170102, 0 to 23, chunkValues, tValues, decValues) testMetastorePartitionFiltering( "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", day1 :: day2 :: Nil) @@ -194,66 +222,105 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) - val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) + val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba"), tValues, decValues) + val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba"), tValues, decValues) testMetastorePartitionFiltering( "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", day1 :: day2 :: Nil) } - private def testMetastorePartitionFiltering( - filterString: String, - expectedDs: Seq[Int], - expectedH: Seq[Int], - expectedChunks: Seq[String]): Unit = { - testMetastorePartitionFiltering( - filterString, - (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + ignore("TODO: create hive metastore for integration test " + + "that supports timestamp and decimal pruning") { + test("getPartitionsByFilter: t = '2017-12-24T00:00:00.00Z' (timestamp test)") { + testMetastorePartitionFiltering( + "t = '2017-12-24T00:00:00.00Z'", + dValues, + hValues, + chunkValues, + Seq(Instant.parse("2017-12-24T00:00:00.00Z")) + ) + } + + test("getPartitionsByFilter: d = 4.0 (decimal test)") { + testMetastorePartitionFiltering( + "d = 4.0", + dValues, + hValues, + chunkValues, + tValues, + Seq(BigDecimal(4.0D)) + ) + } + } + + private def assertNoFilterIsApplied(expression: String) = { + val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), + Seq(parseExpression(expression))) + + assert(filteredPartitions.size == testPartitionCount) } private def testMetastorePartitionFiltering( - filterString: String, + filters: Seq[Expression], expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], - transform: Expression => Expression): Unit = { + expectedTs: Seq[Instant], + expectedDecs: Seq[BigDecimal]): Unit = { testMetastorePartitionFiltering( - filterString, - (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + filters, + (expectedDs, expectedH, expectedChunks, expectedTs, expectedDecs) :: Nil) } private def testMetastorePartitionFiltering( filterString: String, - expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + expectedDs: Seq[Int] = dValues, + expectedH: Seq[Int] = hValues, + expectedChunks: Seq[String] = chunkValues, + expectedTs: Seq[Instant] = tValues, + expectedDecs: Seq[BigDecimal] = decValues): Unit = { + testMetastorePartitionFiltering(Seq(parseExpression(filterString)), + expectedDs, + expectedH, + expectedChunks, + expectedTs, + expectedDecs) } private def testMetastorePartitionFiltering( filterString: String, - expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], - transform: Expression => Expression): Unit = { - val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq( - transform(parseExpression(filterString)) - )) + expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String], Seq[Instant], Seq[BigDecimal])] + ): Unit = { + testMetastorePartitionFiltering(Seq(parseExpression(filterString)), + expectedPartitionCubes) + } + + private def testMetastorePartitionFiltering( + predicates: Seq[Expression], + expectedPartitionCubes: + Seq[(Seq[Int], Seq[Int], Seq[String], Seq[Instant], Seq[BigDecimal])]): Unit = { + val filteredPartitions = + client.getPartitionsByFilter(client.getTable("default", "test"), predicates) val expectedPartitionCount = expectedPartitionCubes.map { - case (expectedDs, expectedH, expectedChunks) => - expectedDs.size * expectedH.size * expectedChunks.size + case (expectedDs, expectedH, expectedChunks, expectedTs, expectedDecs) => + expectedDs.size * expectedH.size * expectedChunks.size * expectedTs.size * expectedDecs.size }.sum val expectedPartitions = expectedPartitionCubes.map { - case (expectedDs, expectedH, expectedChunks) => + case (expectedDs, expectedH, expectedChunks, expectedTs, expectedDecs) => for { ds <- expectedDs h <- expectedH chunk <- expectedChunks + t <- expectedTs + d <- expectedDecs } yield Set( "ds" -> ds.toString, "h" -> h.toString, - "chunk" -> chunk + "chunk" -> chunk, + "t" -> t.getEpochSecond.toString, + "d" -> d.toString() ) }.reduce(_ ++ _)