From 269b15f40fceb9c628f9c9d57cdf042e05e07ac9 Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Wed, 22 Oct 2014 15:51:20 -0700 Subject: [PATCH] make wrap consistent with InsertIntoHiveTable.wrapperFor --- .../sql/hive/orc/OrcTableOperations.scala | 3 +- .../apache/spark/sql/hive/orc/package.scala | 46 ++++++++++--------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala index 793142567512..4b2a8511e8e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala @@ -188,10 +188,11 @@ private[sql] case class InsertIntoOrcTable( val fieldOIs = standardOI .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray val outputData = new Array[Any](fieldOIs.length) + val wrappers = fieldOIs.map(HadoopTypeConverter.wrapperFor) iter.map { row => var i = 0 while (i < row.length) { - outputData(i) = HadoopTypeConverter.wrap((row(i), fieldOIs(i))) + outputData(i) = wrappers(i)(row(i)) i += 1 } orcSerde.serialize(outputData, standardOI) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index 14fe51b92d4b..81eea055e8b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -45,32 +45,36 @@ package object orc { // TypeConverter for InsertIntoOrcTable object HadoopTypeConverter extends HiveInspectors { - def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) + def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (obj, _) => - obj + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + + case _ => + identity[Any] } }