Skip to content

Commit 01ec2cd

Browse files
committed
Merge branch 'master' into mllib-main
2 parents 9420692 + f735884 commit 01ec2cd

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

project/SparkBuild.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ object SparkBuild extends Build {
507507
|import org.apache.spark.sql.catalyst.util._
508508
|import org.apache.spark.sql.execution
509509
|import org.apache.spark.sql.hive._
510-
|import org.apache.spark.sql.hive.TestHive._
510+
|import org.apache.spark.sql.hive.test.TestHive._
511511
|import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin
512512
)
513513

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,31 +70,43 @@ private[hive] object HiveFunctionRegistry
7070
}
7171

7272
def javaClassToDataType(clz: Class[_]): DataType = clz match {
73+
// writable
7374
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
7475
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
7576
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
7677
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
7778
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
79+
case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
7880
case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
7981
case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
8082
case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
8183
case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
8284
case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
85+
case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
86+
87+
// java class
8388
case c: Class[_] if c == classOf[java.lang.String] => StringType
84-
case c: Class[_] if c == java.lang.Short.TYPE => ShortType
85-
case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
86-
case c: Class[_] if c == java.lang.Long.TYPE => LongType
87-
case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
88-
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
89-
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
90-
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
89+
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
90+
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
91+
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
92+
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
9193
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
9294
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
9395
case c: Class[_] if c == classOf[java.lang.Long] => LongType
9496
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
9597
case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
9698
case c: Class[_] if c == classOf[java.lang.Float] => FloatType
9799
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
100+
101+
// primitive type
102+
case c: Class[_] if c == java.lang.Short.TYPE => ShortType
103+
case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
104+
case c: Class[_] if c == java.lang.Long.TYPE => LongType
105+
case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
106+
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
107+
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
108+
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
109+
98110
case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
99111
}
100112
}
@@ -111,11 +123,19 @@ private[hive] trait HiveFunctionFactory {
111123
case i: hadoopIo.IntWritable => i.get
112124
case t: hadoopIo.Text => t.toString
113125
case l: hadoopIo.LongWritable => l.get
114-
case d: hadoopIo.DoubleWritable => d.get()
126+
case d: hadoopIo.DoubleWritable => d.get
115127
case d: hiveIo.DoubleWritable => d.get
116128
case s: hiveIo.ShortWritable => s.get
117-
case b: hadoopIo.BooleanWritable => b.get()
129+
case b: hadoopIo.BooleanWritable => b.get
118130
case b: hiveIo.ByteWritable => b.get
131+
case b: hadoopIo.FloatWritable => b.get
132+
case b: hadoopIo.BytesWritable => {
133+
val bytes = new Array[Byte](b.getLength)
134+
System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
135+
bytes
136+
}
137+
case t: hiveIo.TimestampWritable => t.getTimestamp
138+
case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
119139
case list: java.util.List[_] => list.map(unwrap)
120140
case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
121141
case array: Array[_] => array.map(unwrap).toSeq
@@ -127,6 +147,9 @@ private[hive] trait HiveFunctionFactory {
127147
case p: java.lang.Byte => p
128148
case p: java.lang.Boolean => p
129149
case str: String => str
150+
case p: BigDecimal => p
151+
case p: Array[Byte] => p
152+
case p: java.sql.Timestamp => p
130153
}
131154
}
132155

@@ -252,13 +275,17 @@ private[hive] trait HiveInspectors {
252275

253276
/** Converts native catalyst types to the types expected by Hive */
254277
def wrap(a: Any): AnyRef = a match {
255-
case s: String => new hadoopIo.Text(s)
278+
case s: String => new hadoopIo.Text(s) // TODO why should be Text?
256279
case i: Int => i: java.lang.Integer
257280
case b: Boolean => b: java.lang.Boolean
281+
case f: Float => f: java.lang.Float
258282
case d: Double => d: java.lang.Double
259283
case l: Long => l: java.lang.Long
260284
case l: Short => l: java.lang.Short
261285
case l: Byte => l: java.lang.Byte
286+
case b: BigDecimal => b.bigDecimal
287+
case b: Array[Byte] => b
288+
case t: java.sql.Timestamp => t
262289
case s: Seq[_] => seqAsJavaList(s.map(wrap))
263290
case m: Map[_,_] =>
264291
mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
@@ -280,6 +307,8 @@ private[hive] trait HiveInspectors {
280307
case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
281308
case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
282309
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
310+
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
311+
case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
283312
}
284313

285314
def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
@@ -307,6 +336,14 @@ private[hive] trait HiveInspectors {
307336
case _: JavaShortObjectInspector => ShortType
308337
case _: WritableByteObjectInspector => ByteType
309338
case _: JavaByteObjectInspector => ByteType
339+
case _: WritableFloatObjectInspector => FloatType
340+
case _: JavaFloatObjectInspector => FloatType
341+
case _: WritableBinaryObjectInspector => BinaryType
342+
case _: JavaBinaryObjectInspector => BinaryType
343+
case _: WritableHiveDecimalObjectInspector => DecimalType
344+
case _: JavaHiveDecimalObjectInspector => DecimalType
345+
case _: WritableTimestampObjectInspector => TimestampType
346+
case _: JavaTimestampObjectInspector => TimestampType
310347
}
311348

312349
implicit class typeInfoConversions(dt: DataType) {
@@ -324,6 +361,7 @@ private[hive] trait HiveInspectors {
324361
case ShortType => shortTypeInfo
325362
case StringType => stringTypeInfo
326363
case DecimalType => decimalTypeInfo
364+
case TimestampType => timestampTypeInfo
327365
case NullType => voidTypeInfo
328366
}
329367
}

0 commit comments

Comments
 (0)