Skip to content

Commit 408b384

Browse files
Cheolsoo Parkmarmbrus
authored andcommitted
[SPARK-6910] [SQL] Support for pushing predicates down to metastore for partition pruning
This PR supersedes my old one apache#6921. Since my patch has changed quite a bit, I am opening a new PR to make it easier to review. The changes include- * Implement `toMetastoreFilter()` function in `HiveShim` that takes `Seq[Expression]` and converts them into a filter string for Hive metastore. * This functions matches all the `AttributeReference` + `BinaryComparisonOp` + `Integral/StringType` patterns in `Seq[Expression]` and fold them into a string. * Change `hiveQlPartitions` field in `MetastoreRelation` to `getHiveQlPartitions()` function that takes a filter string parameter. * Call `getHiveQlPartitions()` in `HiveTableScan` with a filter string. But there are some cases in which predicate pushdown is disabled- Case | Predicate pushdown ------- | ----------------------------- Hive integral and string types | Yes Hive varchar type | No Hive 0.13 and newer | Yes Hive 0.12 and older | No convertMetastoreParquet=false | Yes convertMetastoreParquet=true | No In case of `convertMetastoreParquet=true`, predicates are not pushed down because this conversion happens in an `Analyzer` rule (`HiveMetastoreCatalog.ParquetConversions`). At this point, `HiveTableScan` hasn't run, so predicates are not available. But reading the source code, I think it is intentional to convert the entire Hive table w/ all the partitions into `ParquetRelation` because then `ParquetRelation` can be cached and reused for any query against that table. Please correct me if I am wrong. cc marmbrus Author: Cheolsoo Park <[email protected]> Closes apache#7216 from piaozhexiu/SPARK-6910-2 and squashes the following commits: aa1490f [Cheolsoo Park] Fix ordering of imports c212c4d [Cheolsoo Park] Incorporate review comments 5e93f9d [Cheolsoo Park] Predicate pushdown into Hive metastore
1 parent b7bcbe2 commit 408b384

File tree

9 files changed

+137
-44
lines changed

9 files changed

+137
-44
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
301301
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
302302
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
303303
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
304-
val partitions = metastoreRelation.hiveQlPartitions.map { p =>
304+
// We're converting the entire table into ParquetRelation, so predicates to Hive metastore
305+
// are empty.
306+
val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
305307
val location = p.getLocation
306308
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
307309
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
@@ -644,32 +646,6 @@ private[hive] case class MetastoreRelation
644646
new Table(tTable)
645647
}
646648

647-
@transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p =>
648-
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
649-
tPartition.setDbName(databaseName)
650-
tPartition.setTableName(tableName)
651-
tPartition.setValues(p.values)
652-
653-
val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
654-
tPartition.setSd(sd)
655-
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
656-
657-
sd.setLocation(p.storage.location)
658-
sd.setInputFormat(p.storage.inputFormat)
659-
sd.setOutputFormat(p.storage.outputFormat)
660-
661-
val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
662-
sd.setSerdeInfo(serdeInfo)
663-
serdeInfo.setSerializationLib(p.storage.serde)
664-
665-
val serdeParameters = new java.util.HashMap[String, String]()
666-
serdeInfo.setParameters(serdeParameters)
667-
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
668-
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
669-
670-
new Partition(hiveQlTable, tPartition)
671-
}
672-
673649
@transient override lazy val statistics: Statistics = Statistics(
674650
sizeInBytes = {
675651
val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
@@ -690,6 +666,34 @@ private[hive] case class MetastoreRelation
690666
}
691667
)
692668

669+
def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
670+
table.getPartitions(predicates).map { p =>
671+
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
672+
tPartition.setDbName(databaseName)
673+
tPartition.setTableName(tableName)
674+
tPartition.setValues(p.values)
675+
676+
val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
677+
tPartition.setSd(sd)
678+
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
679+
680+
sd.setLocation(p.storage.location)
681+
sd.setInputFormat(p.storage.inputFormat)
682+
sd.setOutputFormat(p.storage.outputFormat)
683+
684+
val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
685+
sd.setSerdeInfo(serdeInfo)
686+
serdeInfo.setSerializationLib(p.storage.serde)
687+
688+
val serdeParameters = new java.util.HashMap[String, String]()
689+
serdeInfo.setParameters(serdeParameters)
690+
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
691+
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
692+
693+
new Partition(hiveQlTable, tPartition)
694+
}
695+
}
696+
693697
/** Only compare database and tablename, not alias. */
694698
override def sameResult(plan: LogicalPlan): Boolean = {
695699
plan match {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.reflect.ClassTag
2727

2828
import com.esotericsoftware.kryo.Kryo
2929
import com.esotericsoftware.kryo.io.{Input, Output}
30+
3031
import org.apache.hadoop.conf.Configuration
3132
import org.apache.hadoop.fs.Path
3233
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ private[hive] trait HiveStrategies {
125125
InterpretedPredicate.create(castedPredicate)
126126
}
127127

128-
val partitions = relation.hiveQlPartitions.filter { part =>
128+
val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
129129
val partitionValues = part.getValues
130130
var i = 0
131131
while (i < partitionValues.size()) {
@@ -213,7 +213,7 @@ private[hive] trait HiveStrategies {
213213
projectList,
214214
otherPredicates,
215215
identity[Seq[Expression]],
216-
HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil
216+
HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
217217
case _ =>
218218
Nil
219219
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.PrintStream
2121
import java.util.{Map => JMap}
2222

2323
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
24+
import org.apache.spark.sql.catalyst.expressions.Expression
2425

2526
private[hive] case class HiveDatabase(
2627
name: String,
@@ -71,7 +72,12 @@ private[hive] case class HiveTable(
7172

7273
def isPartitioned: Boolean = partitionColumns.nonEmpty
7374

74-
def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this)
75+
def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
76+
predicates match {
77+
case Nil => client.getAllPartitions(this)
78+
case _ => client.getPartitionsByFilter(this, predicates)
79+
}
80+
}
7581

7682
// Hive does not support backticks when passing names to the client.
7783
def qualifiedName: String = s"$database.$name"
@@ -132,6 +138,9 @@ private[hive] trait ClientInterface {
132138
/** Returns all partitions for the given table. */
133139
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]
134140

141+
/** Returns partitions filtered by predicates for the given table. */
142+
def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]
143+
135144
/** Loads a static partition into an existing table. */
136145
def loadPartition(
137146
loadPath: String,

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,24 @@
1717

1818
package org.apache.spark.sql.hive.client
1919

20-
import java.io.{BufferedReader, InputStreamReader, File, PrintStream}
21-
import java.net.URI
22-
import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
20+
import java.io.{File, PrintStream}
21+
import java.util.{Map => JMap}
2322
import javax.annotation.concurrent.GuardedBy
2423

24+
import org.apache.spark.sql.catalyst.expressions.Expression
2525
import org.apache.spark.util.CircularBuffer
2626

2727
import scala.collection.JavaConversions._
2828
import scala.language.reflectiveCalls
2929

3030
import org.apache.hadoop.fs.Path
31-
import org.apache.hadoop.hive.metastore.api.Database
3231
import org.apache.hadoop.hive.conf.HiveConf
32+
import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
3333
import org.apache.hadoop.hive.metastore.{TableType => HTableType}
34-
import org.apache.hadoop.hive.metastore.api
35-
import org.apache.hadoop.hive.metastore.api.FieldSchema
36-
import org.apache.hadoop.hive.ql.metadata
3734
import org.apache.hadoop.hive.ql.metadata.Hive
38-
import org.apache.hadoop.hive.ql.session.SessionState
3935
import org.apache.hadoop.hive.ql.processors._
40-
import org.apache.hadoop.hive.ql.Driver
36+
import org.apache.hadoop.hive.ql.session.SessionState
37+
import org.apache.hadoop.hive.ql.{Driver, metadata}
4138

4239
import org.apache.spark.Logging
4340
import org.apache.spark.sql.execution.QueryExecutionException
@@ -316,6 +313,13 @@ private[hive] class ClientWrapper(
316313
shim.getAllPartitions(client, qlTable).map(toHivePartition)
317314
}
318315

316+
override def getPartitionsByFilter(
317+
hTable: HiveTable,
318+
predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
319+
val qlTable = toQlTable(hTable)
320+
shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
321+
}
322+
319323
override def listTables(dbName: String): Seq[String] = withHiveState {
320324
client.getAllTables(dbName)
321325
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver
3131
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
3232
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
3333
import org.apache.hadoop.hive.ql.session.SessionState
34+
import org.apache.hadoop.hive.serde.serdeConstants
35+
36+
import org.apache.spark.Logging
37+
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
38+
import org.apache.spark.sql.types.{StringType, IntegralType}
3439

3540
/**
3641
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
@@ -61,6 +66,8 @@ private[client] sealed abstract class Shim {
6166

6267
def getAllPartitions(hive: Hive, table: Table): Seq[Partition]
6368

69+
def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]
70+
6471
def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor
6572

6673
def getDriverResults(driver: Driver): Seq[String]
@@ -109,7 +116,7 @@ private[client] sealed abstract class Shim {
109116

110117
}
111118

112-
private[client] class Shim_v0_12 extends Shim {
119+
private[client] class Shim_v0_12 extends Shim with Logging {
113120

114121
private lazy val startMethod =
115122
findStaticMethod(
@@ -196,6 +203,17 @@ private[client] class Shim_v0_12 extends Shim {
196203
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
197204
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
198205

206+
override def getPartitionsByFilter(
207+
hive: Hive,
208+
table: Table,
209+
predicates: Seq[Expression]): Seq[Partition] = {
210+
// getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
211+
// See HIVE-4888.
212+
logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
213+
"Please use Hive 0.13 or higher.")
214+
getAllPartitions(hive, table)
215+
}
216+
199217
override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
200218
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
201219

@@ -267,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
267285
classOf[Hive],
268286
"getAllPartitionsOf",
269287
classOf[Table])
288+
private lazy val getPartitionsByFilterMethod =
289+
findMethod(
290+
classOf[Hive],
291+
"getPartitionsByFilter",
292+
classOf[Table],
293+
classOf[String])
270294
private lazy val getCommandProcessorMethod =
271295
findStaticMethod(
272296
classOf[CommandProcessorFactory],
@@ -288,6 +312,48 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
288312
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
289313
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
290314

315+
override def getPartitionsByFilter(
316+
hive: Hive,
317+
table: Table,
318+
predicates: Seq[Expression]): Seq[Partition] = {
319+
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
320+
val varcharKeys = table.getPartitionKeys
321+
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
322+
.map(col => col.getName).toSet
323+
324+
// Hive getPartitionsByFilter() takes a string that represents partition
325+
// predicates like "str_key=\"value\" and int_key=1 ..."
326+
val filter = predicates.flatMap { expr =>
327+
expr match {
328+
case op @ BinaryComparison(lhs, rhs) => {
329+
lhs match {
330+
case AttributeReference(_, _, _, _) => {
331+
rhs.dataType match {
332+
case _: IntegralType =>
333+
Some(lhs.prettyString + op.symbol + rhs.prettyString)
334+
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
335+
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
336+
case _ => None
337+
}
338+
}
339+
case _ => None
340+
}
341+
}
342+
case _ => None
343+
}
344+
}.mkString(" and ")
345+
346+
val partitions =
347+
if (filter.isEmpty) {
348+
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
349+
} else {
350+
logDebug(s"Hive metastore filter is '$filter'.")
351+
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
352+
}
353+
354+
partitions.toSeq
355+
}
356+
291357
override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
292358
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
293359

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ private[hive]
4444
case class HiveTableScan(
4545
requestedAttributes: Seq[Attribute],
4646
relation: MetastoreRelation,
47-
partitionPruningPred: Option[Expression])(
47+
partitionPruningPred: Seq[Expression])(
4848
@transient val context: HiveContext)
4949
extends LeafNode {
5050

@@ -56,7 +56,7 @@ case class HiveTableScan(
5656

5757
// Bind all partition key attribute references in the partition pruning predicate for later
5858
// evaluation.
59-
private[this] val boundPruningPred = partitionPruningPred.map { pred =>
59+
private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
6060
require(
6161
pred.dataType == BooleanType,
6262
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
@@ -133,7 +133,8 @@ case class HiveTableScan(
133133
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
134134
hadoopReader.makeRDDForTable(relation.hiveQlTable)
135135
} else {
136-
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
136+
hadoopReader.makeRDDForPartitionedTable(
137+
prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
137138
}
138139

139140
override def output: Seq[Attribute] = attributes

sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ package org.apache.spark.sql.hive.client
2020
import java.io.File
2121

2222
import org.apache.spark.{Logging, SparkFunSuite}
23+
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
2324
import org.apache.spark.sql.catalyst.util.quietly
25+
import org.apache.spark.sql.types.IntegerType
2426
import org.apache.spark.util.Utils
2527

2628
/**
@@ -151,6 +153,12 @@ class VersionsSuite extends SparkFunSuite with Logging {
151153
client.getAllPartitions(client.getTable("default", "src_part"))
152154
}
153155

156+
test(s"$version: getPartitionsByFilter") {
157+
client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
158+
AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
159+
Literal(1))))
160+
}
161+
154162
test(s"$version: loadPartition") {
155163
client.loadPartition(
156164
emptyDir,

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
151151
case p @ HiveTableScan(columns, relation, _) =>
152152
val columnNames = columns.map(_.name)
153153
val partValues = if (relation.table.isPartitioned) {
154-
p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
154+
p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
155155
} else {
156156
Seq.empty
157157
}

0 commit comments

Comments
 (0)