Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
val partitions = metastoreRelation.hiveQlPartitions.map { p =>
// We're converting the entire table into ParquetRelation, so predicates to Hive metastore
// are empty.
val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
Expand Down Expand Up @@ -644,32 +646,6 @@ private[hive] case class MetastoreRelation
new Table(tTable)
}

@transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
tPartition.setValues(p.values)

val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tPartition.setSd(sd)
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))

sd.setLocation(p.storage.location)
sd.setInputFormat(p.storage.inputFormat)
sd.setOutputFormat(p.storage.outputFormat)

val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
sd.setSerdeInfo(serdeInfo)
serdeInfo.setSerializationLib(p.storage.serde)

val serdeParameters = new java.util.HashMap[String, String]()
serdeInfo.setParameters(serdeParameters)
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }

new Partition(hiveQlTable, tPartition)
}

@transient override lazy val statistics: Statistics = Statistics(
sizeInBytes = {
val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
Expand All @@ -690,6 +666,34 @@ private[hive] case class MetastoreRelation
}
)

def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
table.getPartitions(predicates).map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
tPartition.setValues(p.values)

val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tPartition.setSd(sd)
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))

sd.setLocation(p.storage.location)
sd.setInputFormat(p.storage.inputFormat)
sd.setOutputFormat(p.storage.outputFormat)

val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
sd.setSerdeInfo(serdeInfo)
serdeInfo.setSerializationLib(p.storage.serde)

val serdeParameters = new java.util.HashMap[String, String]()
serdeInfo.setParameters(serdeParameters)
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }

new Partition(hiveQlTable, tPartition)
}
}

/** Only compare database and tablename, not alias. */
override def sameResult(plan: LogicalPlan): Boolean = {
plan match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.reflect.ClassTag

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}

val partitions = relation.hiveQlPartitions.filter { part =>
val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
Expand Down Expand Up @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies {
projectList,
otherPredicates,
identity[Seq[Expression]],
HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil
HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.PrintStream
import java.util.{Map => JMap}

import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression

private[hive] case class HiveDatabase(
name: String,
Expand Down Expand Up @@ -71,7 +72,12 @@ private[hive] case class HiveTable(

def isPartitioned: Boolean = partitionColumns.nonEmpty

def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this)
def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
predicates match {
case Nil => client.getAllPartitions(this)
case _ => client.getPartitionsByFilter(this, predicates)
}
}

// Hive does not support backticks when passing names to the client.
def qualifiedName: String = s"$database.$name"
Expand Down Expand Up @@ -132,6 +138,9 @@ private[hive] trait ClientInterface {
/** Returns all partitions for the given table. */
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]

/** Returns partitions filtered by predicates for the given table. */
def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]

/** Loads a static partition into an existing table. */
def loadPartition(
loadPath: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,24 @@

package org.apache.spark.sql.hive.client

import java.io.{BufferedReader, InputStreamReader, File, PrintStream}
import java.net.URI
import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
import java.io.{File, PrintStream}
import java.util.{Map => JMap}
import javax.annotation.concurrent.GuardedBy

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.util.CircularBuffer

import scala.collection.JavaConversions._
import scala.language.reflectiveCalls

import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
import org.apache.hadoop.hive.metastore.{TableType => HTableType}
import org.apache.hadoop.hive.metastore.api
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.{Driver, metadata}

import org.apache.spark.Logging
import org.apache.spark.sql.execution.QueryExecutionException
Expand Down Expand Up @@ -316,6 +313,13 @@ private[hive] class ClientWrapper(
shim.getAllPartitions(client, qlTable).map(toHivePartition)
}

override def getPartitionsByFilter(
hTable: HiveTable,
predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
}

override def listTables(dbName: String): Seq[String] = withHiveState {
client.getAllTables(dbName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
import org.apache.spark.sql.types.{StringType, IntegralType}

/**
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
Expand Down Expand Up @@ -61,6 +66,8 @@ private[client] sealed abstract class Shim {

def getAllPartitions(hive: Hive, table: Table): Seq[Partition]

def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor

def getDriverResults(driver: Driver): Seq[String]
Expand Down Expand Up @@ -109,7 +116,7 @@ private[client] sealed abstract class Shim {

}

private[client] class Shim_v0_12 extends Shim {
private[client] class Shim_v0_12 extends Shim with Logging {

private lazy val startMethod =
findStaticMethod(
Expand Down Expand Up @@ -196,6 +203,17 @@ private[client] class Shim_v0_12 extends Shim {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
// See HIVE-4888.
logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
"Please use Hive 0.13 or higher.")
getAllPartitions(hive, table)
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]

Expand Down Expand Up @@ -267,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
classOf[Hive],
"getAllPartitionsOf",
classOf[Table])
private lazy val getPartitionsByFilterMethod =
findMethod(
classOf[Hive],
"getPartitionsByFilter",
classOf[Table],
classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
Expand All @@ -288,6 +312,48 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = predicates.flatMap { expr =>
expr match {
case op @ BinaryComparison(lhs, rhs) => {
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
}
case _ => None
}
}.mkString(" and ")

val partitions =
if (filter.isEmpty) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
} else {
logDebug(s"Hive metastore filter is '$filter'.")
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
}

partitions.toSeq
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[hive]
case class HiveTableScan(
requestedAttributes: Seq[Attribute],
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
partitionPruningPred: Seq[Expression])(
@transient val context: HiveContext)
extends LeafNode {

Expand All @@ -56,7 +56,7 @@ case class HiveTableScan(

// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
private[this] val boundPruningPred = partitionPruningPred.map { pred =>
private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
require(
pred.dataType == BooleanType,
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
Expand Down Expand Up @@ -133,7 +133,8 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
hadoopReader.makeRDDForPartitionedTable(
prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
}

override def output: Seq[Attribute] = attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.hive.client
import java.io.File

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -151,6 +153,12 @@ class VersionsSuite extends SparkFunSuite with Logging {
client.getAllPartitions(client.getTable("default", "src_part"))
}

test(s"$version: getPartitionsByFilter") {
client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
Literal(1))))
}

test(s"$version: loadPartition") {
client.loadPartition(
emptyDir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
} else {
Seq.empty
}
Expand Down