Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,43 @@ object HiveFunctionRegistry
}

def javaClassToDataType(clz: Class[_]): DataType = clz match {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType

// java class
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == java.lang.Short.TYPE => ShortType
case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
case c: Class[_] if c == java.lang.Long.TYPE => LongType
case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
case c: Class[_] if c == classOf[java.lang.Long] => LongType
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
case c: Class[_] if c == classOf[java.lang.Float] => FloatType
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType

// primitive type
case c: Class[_] if c == java.lang.Short.TYPE => ShortType
case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
case c: Class[_] if c == java.lang.Long.TYPE => LongType
case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType

case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
}
}
Expand All @@ -111,11 +123,19 @@ trait HiveFunctionFactory {
case i: hadoopIo.IntWritable => i.get
case t: hadoopIo.Text => t.toString
case l: hadoopIo.LongWritable => l.get
case d: hadoopIo.DoubleWritable => d.get()
case d: hadoopIo.DoubleWritable => d.get
case d: hiveIo.DoubleWritable => d.get
case s: hiveIo.ShortWritable => s.get
case b: hadoopIo.BooleanWritable => b.get()
case b: hadoopIo.BooleanWritable => b.get
case b: hiveIo.ByteWritable => b.get
case b: hadoopIo.FloatWritable => b.get
case b: hadoopIo.BytesWritable => {
val bytes = new Array[Byte](b.getLength)
System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
bytes
}
case t: hiveIo.TimestampWritable => t.getTimestamp
case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
case list: java.util.List[_] => list.map(unwrap)
case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
case array: Array[_] => array.map(unwrap).toSeq
Expand All @@ -127,6 +147,9 @@ trait HiveFunctionFactory {
case p: java.lang.Byte => p
case p: java.lang.Boolean => p
case str: String => str
case p: BigDecimal => p
case p: Array[Byte] => p
case p: java.sql.Timestamp => p
}
}

Expand Down Expand Up @@ -252,13 +275,17 @@ trait HiveInspectors {

/** Converts native catalyst types to the types expected by Hive */
def wrap(a: Any): AnyRef = a match {
case s: String => new hadoopIo.Text(s)
case s: String => new hadoopIo.Text(s) // TODO why should be Text?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question. Text is what seemed to work for the test cases, but if there is some other type or way to know we should investigate that.

case i: Int => i: java.lang.Integer
case b: Boolean => b: java.lang.Boolean
case f: Float => f: java.lang.Float
case d: Double => d: java.lang.Double
case l: Long => l: java.lang.Long
case l: Short => l: java.lang.Short
case l: Byte => l: java.lang.Byte
case b: BigDecimal => b.bigDecimal
case b: Array[Byte] => b
case t: java.sql.Timestamp => t
case s: Seq[_] => seqAsJavaList(s.map(wrap))
case m: Map[_,_] =>
mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
Expand All @@ -280,6 +307,8 @@ trait HiveInspectors {
case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
}

def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
Expand Down Expand Up @@ -307,6 +336,14 @@ trait HiveInspectors {
case _: JavaShortObjectInspector => ShortType
case _: WritableByteObjectInspector => ByteType
case _: JavaByteObjectInspector => ByteType
case _: WritableFloatObjectInspector => FloatType
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
case _: WritableHiveDecimalObjectInspector => DecimalType
case _: JavaHiveDecimalObjectInspector => DecimalType
case _: WritableTimestampObjectInspector => TimestampType
case _: JavaTimestampObjectInspector => TimestampType
}

implicit class typeInfoConversions(dt: DataType) {
Expand All @@ -324,6 +361,7 @@ trait HiveInspectors {
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
case DecimalType => decimalTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
}
}
Expand Down