Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 120 additions & 21 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.metastore.MetaStoreUtils
import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.Partition
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive}
import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
import org.apache.hadoop.hive.serde2.Serializer
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
import org.apache.hadoop.hive.serde2.objectinspector._
Expand All @@ -35,9 +37,13 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.{SparkHiveHadoopWriter, TaskContext, SparkException}

/* Implicits */
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.control._
/* java library */
import java.util.ArrayList


/**
* The Hive table scan operator. Column and partition pruning are both handled.
Expand Down Expand Up @@ -70,7 +76,20 @@ case class HiveTableScan(

@transient
val hadoopReader = new HadoopTableReader(relation.tableDesc, sc)

/**
* attempt to retrieve a list of column ids that is required
* TODO: require more work ...
* val hadoopReader = {
* val cNames: ArrayList[Integer] = getNeededColumnIDs()
* if (cNames.size() != 0) {
* ColumnProjectionUtils.appendReadColumnIDs(sc.hiveconf,cNames)
* } else {
* ColumnProjectionUtils.setFullyReadColumns(sc.hiveconf)
* }
* new HadoopTableReader(relation.tableDesc, sc)
* }
*/

/**
* The hive object inspector for this table, which can be used to extract values from the
* serialized row representation.
Expand All @@ -79,6 +98,21 @@ case class HiveTableScan(
lazy val objectInspector =
relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector]

/** attempt to retrieve a list of column ids that is required ... used by ColumnProjectionUtils */
private def getNeededColumnIDs() : ArrayList[Integer] = {
val names: ArrayList[Integer] = new ArrayList[Integer]()
var i = 0
val len = relation.attributes.length
while(i < len) {
for(a <- attributes) {
if(a.name == relation.attributes(i).name) {
names.add(i)
}
}
i += 1
}
names
}
/**
* Functions that extract the requested attributes from the hive output. Partitioned values are
* casted from string to its declared data type.
Expand All @@ -104,6 +138,7 @@ case class HiveTableScan(
}
}
}


private def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
Expand All @@ -123,40 +158,104 @@ case class HiveTableScan(
* @return Partitions that are involved in the query plan.
*/
private[hive] def prunePartitions(partitions: Seq[HivePartition]) = {
boundPruningPred match {
case None => partitions
case Some(shouldKeep) => partitions.filter { part =>
val dataTypes = relation.partitionKeys.map(_.dataType)
val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield {
castFromString(value, dataType)
/** mutable row implementation to avoid creating row instance at
* each iteration inside the while loop.
*/
var row = new GenericMutableRow(relation.partitionKeys.length)
if (boundPruningPred == None) {
partitions
} else {
val shouldKeep:Expression = boundPruningPred.get
val partitionSize = partitions.length
var index = 0
var filterPartition = mutable.ListBuffer[HivePartition]()
while (index < partitionSize) {
val part = partitions(index)
var i = 0
var len = relation.partitionKeys.length
var castedValues = mutable.ListBuffer[Any]()
val iter: Iterator[String] = part.getValues.iterator
while (i < len) {
castedValues += castFromString(iter.next,relation.partitionKeys(i).dataType)
i += 1
}

// Only partitioned values are needed here, since the predicate has already been bound to
// partition key attribute references.
val row = new GenericRow(castedValues.toArray)
shouldKeep.eval(row).asInstanceOf[Boolean]
i = 0
len = castedValues.length
// castedValues represents columns in the row.
while (i < len) {
val n = castedValues(i)
if (n.isInstanceOf[String]) {
if (n.asInstanceOf[String].toLowerCase == "null") {
row.setNullAt(i)
} else {
row.setString(i,n.asInstanceOf[String])
}
}
else {
row.update(i,n)
}
i += 1
}
if (shouldKeep.eval(row).asInstanceOf[Boolean]) {
filterPartition += part
}
index += 1
}
filterPartition
}
}
/**
* A custom Iterator class passed to mapPartitions() at execute() method.
*/
class MyIterator(iter: Iterator[_],mutableRow: GenericMutableRow) extends Iterator[Row] {

def execute() = {
inputRdd.map { row =>
def hasNext = iter.hasNext
def next = {
val row = iter.next()
val values = row match {
case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
attributeFunctions.map(_(deserializedRow, partitionKeys))
case deserializedRow: AnyRef =>
attributeFunctions.map(_(deserializedRow, Array.empty))
}
var i = 0
val len = values.length
while ( i < len ) {
val n = values(i)
if(n.isInstanceOf[String]){
if(n.asInstanceOf[String].toLowerCase == "null"){
mutableRow.setNullAt(i)
}else{
mutableRow.setString(i,n.asInstanceOf[String])
}
}
else if(n.isInstanceOf[HiveVarchar]){
mutableRow.update(i,n.asInstanceOf[HiveVarchar].getValue)
}
else if(n.isInstanceOf[HiveDecimal]){
mutableRow.update(i,BigDecimal(n.asInstanceOf[HiveDecimal].bigDecimalValue))
}
else{
mutableRow.update(i,n)
}
i += 1
}
buildRow(values.map {
case n: String if n.toLowerCase == "null" => null
case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
BigDecimal(decimal.bigDecimalValue)
case other => other
})
mutableRow
}
}


def execute() = {
/**
* mutableRow is GenericMutableRow type and only created once per partition.
*/
inputRdd.mapPartitions((iter: Iterator[_]) => {
var mutableRow = new GenericMutableRow(attributes.length)
new MyIterator(iter,mutableRow)
})

}
def output = attributes
}

Expand Down