Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {

val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled

object ExtractAttribute {
def unapply(expr: Expression): Option[Attribute] = {
expr match {
case attr: Attribute => Some(attr)
case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child)
case _ => None
}
}
}

def convert(expr: Expression): Option[String] = expr match {
case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced =>
case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values))
if useAdvanced =>
Some(convertInToOr(name, values))

case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced =>
case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values))
if useAdvanced =>
Some(convertInToOr(name, values))

case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) =>
case op @ SpecialBinaryComparison(
ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) =>
Some(s"$name ${op.symbol} $value")

case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) =>
case op @ SpecialBinaryComparison(
ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) =>
Some(s"$value ${op.symbol} $name")

case And(expr1, expr2) if useAdvanced =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.LongType

// TODO: Refactor this to `HivePartitionFilteringSuite`
class HiveClientSuite(version: String)
extends HiveVersionSuite(version) with BeforeAndAfterAll {
import CatalystSqlParser._

private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname

Expand All @@ -46,8 +46,7 @@ class HiveClientSuite(version: String)
val hadoopConf = new Configuration()
hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql)
val client = buildClient(hadoopConf)
client
.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")
client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")

val partitions =
for {
Expand All @@ -66,6 +65,15 @@ class HiveClientSuite(version: String)
client
}

private def attr(name: String): Attribute = {
client.getTable("default", "test").partitionSchema.fields
.find(field => field.name.equals(name)) match {
case Some(field) => AttributeReference(field.name, field.dataType)()
case None =>
fail(s"Illegal name of partition attribute: $name")
}
}

override def beforeAll() {
super.beforeAll()
client = init(true)
Expand All @@ -74,23 +82,23 @@ class HiveClientSuite(version: String)
test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
val client = init(false)
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(parseExpression("ds=20170101")))
Seq(attr("ds") === 20170101))

assert(filteredPartitions.size == testPartitionCount)
}

test("getPartitionsByFilter: ds<=>20170101") {
// Should return all partitions where <=> is not supported
testMetastorePartitionFiltering(
"ds<=>20170101",
attr("ds") <=> 20170101,
20170101 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds=20170101") {
testMetastorePartitionFiltering(
"ds=20170101",
attr("ds") === 20170101,
20170101 to 20170101,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
Expand All @@ -100,55 +108,83 @@ class HiveClientSuite(version: String)
// Should return all partitions where h=0 because getPartitionsByFilter does not support
// comparisons to non-literal values
testMetastorePartitionFiltering(
"ds=(20170101 + 1) and h=0",
attr("ds") === (Literal(20170101) + 1) && attr("h") === 0,
20170101 to 20170103,
0 to 0,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: chunk='aa'") {
testMetastorePartitionFiltering(
"chunk='aa'",
attr("chunk") === "aa",
20170101 to 20170103,
0 to 23,
"aa" :: Nil)
}

test("getPartitionsByFilter: 20170101=ds") {
testMetastorePartitionFiltering(
"20170101=ds",
Literal(20170101) === attr("ds"),
20170101 to 20170101,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds=20170101 and h=10") {
testMetastorePartitionFiltering(
"ds=20170101 and h=10",
attr("ds") === 20170101 && attr("h") === 10,
20170101 to 20170101,
10 to 10,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") {
testMetastorePartitionFiltering(
attr("ds").cast(LongType) === 20170101L && attr("h") === 10,
20170101 to 20170101,
10 to 10,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds=20170101 or ds=20170102") {
testMetastorePartitionFiltering(
"ds=20170101 or ds=20170102",
attr("ds") === 20170101 || attr("ds") === 20170102,
20170101 to 20170102,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") {
testMetastorePartitionFiltering(
"ds in (20170102, 20170103)",
attr("ds").in(20170102, 20170103),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") {
testMetastorePartitionFiltering(
attr("ds").cast(LongType).in(20170102L, 20170103L),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") {
testMetastorePartitionFiltering(
"ds in (20170102, 20170103)",
attr("ds").in(20170102, 20170103),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
case expr @ In(v, list) if expr.inSetConvertible =>
InSet(v, list.map(_.eval(EmptyRow)).toSet)
})
}

test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)")
{
testMetastorePartitionFiltering(
attr("ds").cast(LongType).in(20170102L, 20170103L),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
Expand All @@ -159,15 +195,15 @@ class HiveClientSuite(version: String)

test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") {
testMetastorePartitionFiltering(
"chunk in ('ab', 'ba')",
attr("chunk").in("ab", "ba"),
20170101 to 20170103,
0 to 23,
"ab" :: "ba" :: Nil)
}

test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") {
testMetastorePartitionFiltering(
"chunk in ('ab', 'ba')",
attr("chunk").in("ab", "ba"),
20170101 to 20170103,
0 to 23,
"ab" :: "ba" :: Nil, {
Expand All @@ -179,26 +215,24 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb"))
testMetastorePartitionFiltering(
"(ds=20170101 and h>=8) or (ds=20170102 and h<8)",
day1 :: day2 :: Nil)
testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) ||
(attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil)
}

test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
// Day 2 should include all hours because we can't build a filter for h<(7+1)
val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb"))
testMetastorePartitionFiltering(
"(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))",
day1 :: day2 :: Nil)
testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) ||
(attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil)
}

test("getPartitionsByFilter: " +
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba"))
val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba"))
testMetastorePartitionFiltering(
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))",
testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") &&
((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)),
day1 :: day2 :: Nil)
}

Expand All @@ -207,41 +241,41 @@ class HiveClientSuite(version: String)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedDs: Seq[Int],
expectedH: Seq[Int],
expectedChunks: Seq[String]): Unit = {
testMetastorePartitionFiltering(
filterString,
filterExpr,
(expectedDs, expectedH, expectedChunks) :: Nil,
identity)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedDs: Seq[Int],
expectedH: Seq[Int],
expectedChunks: Seq[String],
transform: Expression => Expression): Unit = {
testMetastorePartitionFiltering(
filterString,
filterExpr,
(expectedDs, expectedH, expectedChunks) :: Nil,
identity)
transform)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = {
testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity)
testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])],
transform: Expression => Expression): Unit = {
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(
transform(parseExpression(filterString))
transform(filterExpr)
))

val expectedPartitionCount = expectedPartitionCubes.map {
Expand Down