diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala index 821fb22112f87..ca5bb1a1c6971 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.metadata.Partition import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ @@ -35,9 +37,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.{SparkHiveHadoopWriter, TaskContext, SparkException} - /* Implicits */ import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.control._ +/* java library */ +import java.util.ArrayList + /** * The Hive table scan operator. Column and partition pruning are both handled. @@ -70,7 +76,20 @@ case class HiveTableScan( @transient val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) - + /** + * attempt to retrieve a list of column ids that is required + * TODO: require more work ... + * val hadoopReader = { + * val cNames: ArrayList[Integer] = getNeededColumnIDs() + * if (cNames.size() != 0) { + * ColumnProjectionUtils.appendReadColumnIDs(sc.hiveconf,cNames) + * } else { + * ColumnProjectionUtils.setFullyReadColumns(sc.hiveconf) + * } + * new HadoopTableReader(relation.tableDesc, sc) + * } + */ + /** * The hive object inspector for this table, which can be used to extract values from the * serialized row representation. @@ -79,6 +98,21 @@ case class HiveTableScan( lazy val objectInspector = relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] + /** attempt to retrieve a list of column ids that is required ... used by ColumnProjectionUtils */ + private def getNeededColumnIDs() : ArrayList[Integer] = { + val names: ArrayList[Integer] = new ArrayList[Integer]() + var i = 0 + val len = relation.attributes.length + while(i < len) { + for(a <- attributes) { + if(a.name == relation.attributes(i).name) { + names.add(i) + } + } + i += 1 + } + names + } /** * Functions that extract the requested attributes from the hive output. Partitioned values are * casted from string to its declared data type. @@ -104,6 +138,7 @@ case class HiveTableScan( } } } + private def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -123,40 +158,104 @@ case class HiveTableScan( * @return Partitions that are involved in the query plan. */ private[hive] def prunePartitions(partitions: Seq[HivePartition]) = { - boundPruningPred match { - case None => partitions - case Some(shouldKeep) => partitions.filter { part => - val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) + /** mutable row implementation to avoid creating row instance at + * each iteration inside the while loop. + */ + var row = new GenericMutableRow(relation.partitionKeys.length) + if (boundPruningPred == None) { + partitions + } else { + val shouldKeep:Expression = boundPruningPred.get + val partitionSize = partitions.length + var index = 0 + var filterPartition = mutable.ListBuffer[HivePartition]() + while (index < partitionSize) { + val part = partitions(index) + var i = 0 + var len = relation.partitionKeys.length + var castedValues = mutable.ListBuffer[Any]() + val iter: Iterator[String] = part.getValues.iterator + while (i < len) { + castedValues += castFromString(iter.next,relation.partitionKeys(i).dataType) + i += 1 } - // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = new GenericRow(castedValues.toArray) - shouldKeep.eval(row).asInstanceOf[Boolean] + i = 0 + len = castedValues.length + // castedValues represents columns in the row. + while (i < len) { + val n = castedValues(i) + if (n.isInstanceOf[String]) { + if (n.asInstanceOf[String].toLowerCase == "null") { + row.setNullAt(i) + } else { + row.setString(i,n.asInstanceOf[String]) + } + } + else { + row.update(i,n) + } + i += 1 + } + if (shouldKeep.eval(row).asInstanceOf[Boolean]) { + filterPartition += part + } + index += 1 } + filterPartition } } + /** + * A custom Iterator class passed to mapPartitions() at execute() method. + */ + class MyIterator(iter: Iterator[_],mutableRow: GenericMutableRow) extends Iterator[Row] { - def execute() = { - inputRdd.map { row => + def hasNext = iter.hasNext + def next = { + val row = iter.next() val values = row match { case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) => attributeFunctions.map(_(deserializedRow, partitionKeys)) case deserializedRow: AnyRef => attributeFunctions.map(_(deserializedRow, Array.empty)) + } + var i = 0 + val len = values.length + while ( i < len ) { + val n = values(i) + if(n.isInstanceOf[String]){ + if(n.asInstanceOf[String].toLowerCase == "null"){ + mutableRow.setNullAt(i) + }else{ + mutableRow.setString(i,n.asInstanceOf[String]) + } + } + else if(n.isInstanceOf[HiveVarchar]){ + mutableRow.update(i,n.asInstanceOf[HiveVarchar].getValue) + } + else if(n.isInstanceOf[HiveDecimal]){ + mutableRow.update(i,BigDecimal(n.asInstanceOf[HiveDecimal].bigDecimalValue)) + } + else{ + mutableRow.update(i,n) + } + i += 1 } - buildRow(values.map { - case n: String if n.toLowerCase == "null" => null - case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue - case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal => - BigDecimal(decimal.bigDecimalValue) - case other => other - }) + mutableRow } } - + + def execute() = { + /** + * mutableRow is GenericMutableRow type and only created once per partition. + */ + inputRdd.mapPartitions((iter: Iterator[_]) => { + var mutableRow = new GenericMutableRow(attributes.length) + new MyIterator(iter,mutableRow) + }) + + } def output = attributes }