@@ -24,12 +24,14 @@ import java.util.{Map => JavaMap}
2424import javax .annotation .Nullable
2525
2626import scala .language .existentials
27+ import scala .reflect .ClassTag
2728
2829import org .apache .spark .sql .Row
2930import org .apache .spark .sql .catalyst .expressions ._
3031import org .apache .spark .sql .catalyst .util .DateTimeUtils
3132import org .apache .spark .sql .types ._
3233import org .apache .spark .unsafe .types .UTF8String
34+ import org .apache .spark .util .Utils
3335
3436/**
3537 * Functions to convert Scala types to Catalyst types and vice versa.
@@ -39,6 +41,8 @@ object CatalystTypeConverters {
3941 // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
4042 import scala .collection .Map
4143
44+ lazy val universe = ScalaReflection .universe
45+
4246 private def isPrimitive (dataType : DataType ): Boolean = {
4347 dataType match {
4448 case BooleanType => true
@@ -454,4 +458,166 @@ object CatalystTypeConverters {
454458 def convertToScala (catalystValue : Any , dataType : DataType ): Any = {
455459 createToScalaConverter(dataType)(catalystValue)
456460 }
461+
462+ /**
463+ * Like createToScalaConverter(DataType), creates a function that converts a Catalyst object to a
464+ * Scala object; however, in this case, the Scala object is an instance of a subtype of Product
465+ * (e.g. a case class).
466+ *
467+ * If the given Scala type is not compatible with the given structType, this method ultimately
468+ * throws a ClassCastException when the converter is invoked.
469+ *
470+ * Typical use case would be converting a collection of rows that have the same schema. You will
471+ * call this function once to get a converter, and apply it to every row.
472+ */
473+ private [sql] def createToProductConverter [T <: Product ](
474+ structType : StructType )(implicit classTag : ClassTag [T ]): InternalRow => T = {
475+
476+ // Use ScalaReflectionLock, to avoid reflection thread safety issues in 2.10.
477+ // https://issues.scala-lang.org/browse/SI-6240
478+ // http://docs.scala-lang.org/overviews/reflection/thread-safety.html
479+ ScalaReflectionLock .synchronized { createToProductConverter(classTag, structType) }
480+ }
481+
482+ private [sql] def createToProductConverter [T <: Product ](
483+ classTag : ClassTag [T ], structType : StructType ): InternalRow => T = {
484+
485+ import universe ._
486+
487+ val constructorMirror = {
488+ val mirror = runtimeMirror(Utils .getContextOrSparkClassLoader)
489+ val classSymbol = mirror.classSymbol(classTag.runtimeClass)
490+ val classMirror = mirror.reflectClass(classSymbol)
491+ val constructorSymbol = {
492+ // Adapted from ScalaReflection to find primary constructor.
493+ // https://issues.apache.org/jira/browse/SPARK-4791
494+ val symbol = classSymbol.toType.declaration(nme.CONSTRUCTOR )
495+ if (symbol.isMethod) {
496+ symbol.asMethod
497+ } else {
498+ val candidateSymbol =
499+ symbol.asTerm.alternatives.find { s => s.isMethod && s.asMethod.isPrimaryConstructor }
500+ if (candidateSymbol.isDefined) {
501+ candidateSymbol.get.asMethod
502+ } else {
503+ throw new IllegalArgumentException (s " No primary constructor for ${symbol.name}" )
504+ }
505+ }
506+ }
507+ classMirror.reflectConstructor(constructorSymbol)
508+ }
509+
510+ val params = constructorMirror.symbol.paramss.head.toSeq
511+ val paramTypes = params.map { _.asTerm.typeSignature }
512+ val fields = structType.fields
513+ val dataTypes = fields.map { _.dataType }
514+ val converters : Seq [Any => Any ] =
515+ paramTypes.zip(dataTypes).map { case (pt, dt) => createToScalaConverter(pt, dt) }
516+
517+ (row : InternalRow ) => if (row == null ) {
518+ null .asInstanceOf [T ]
519+ } else {
520+ val convertedArgs =
521+ converters.zip(row.toSeq(dataTypes)).map { case (converter, arg) => converter(arg) }
522+ try {
523+ constructorMirror.apply(convertedArgs : _* ).asInstanceOf [T ]
524+ } catch {
525+ case e : IllegalArgumentException => // argument type mismatch
526+ val message =
527+ s """ |Error constructing ${classTag.runtimeClass.getName}: ${e.getMessage};
528+ |paramTypes: ${paramTypes}, dataTypes: ${dataTypes},
529+ |convertedArgs: ${convertedArgs}""" .stripMargin.replace(" \n " , " " )
530+ throw new ClassCastException (message)
531+ }
532+ }
533+ }
534+
535+ /**
536+ * Like createToScalaConverter(DataType), but with a Scala type hint.
537+ *
538+ * Please keep in sync with createToScalaConverter(DataType) and ScalaReflection.schemaFor[T].
539+ */
540+ private [sql] def createToScalaConverter (
541+ universeType : universe.Type , dataType : DataType ): Any => Any = {
542+
543+ import universe ._
544+
545+ (universeType, dataType) match {
546+ case (t, dt) if t <:< typeOf[Option [_]] =>
547+ val TypeRef (_, _, Seq (elementType)) = t
548+ val converter : Any => Any = createToScalaConverter(elementType, dt)
549+ (catalystValue : Any ) => Option (converter(catalystValue))
550+
551+ case (t, udt : UserDefinedType [_]) =>
552+ (catalystValue : Any ) => if (catalystValue == null ) null else udt.deserialize(catalystValue)
553+
554+ case (t, bt : BinaryType ) => identity
555+
556+ case (t, at : ArrayType ) if t <:< typeOf[Array [_]] =>
557+ throw new UnsupportedOperationException (" Array[_] is not supported; try using Seq instead." )
558+
559+ case (t, at : ArrayType ) if t <:< typeOf[Seq [_]] =>
560+ val TypeRef (_, _, Seq (elementType)) = t
561+ val converter : Any => Any = createToScalaConverter(elementType, at.elementType)
562+ (catalystValue : Any ) => catalystValue match {
563+ case arrayData : ArrayData => arrayData.toArray[Any ](at.elementType).map(converter).toSeq
564+ case o => o
565+ }
566+
567+ case (t, mt : MapType ) if t <:< typeOf[Map [_, _]] =>
568+ val TypeRef (_, _, Seq (keyType, valueType)) = t
569+ val keyConverter : Any => Any = createToScalaConverter(keyType, mt.keyType)
570+ val valueConverter : Any => Any = createToScalaConverter(valueType, mt.valueType)
571+ (catalystValue : Any ) => catalystValue match {
572+ case mapData : MapData =>
573+ val keys = mapData.keyArray().toArray[Any ](mt.keyType)
574+ val values = mapData.valueArray().toArray[Any ](mt.valueType)
575+ keys.map(keyConverter).zip(values.map(valueConverter)).toMap
576+ case o => o
577+ }
578+
579+ case (t, st : StructType ) if t <:< typeOf[Product ] =>
580+ val className = t.erasure.typeSymbol.asClass.fullName
581+ val classTag = if (Utils .classIsLoadable(className)) {
582+ scala.reflect.ClassTag (Utils .classForName(className))
583+ } else {
584+ throw new IllegalArgumentException (s " $className is not loadable " )
585+ }
586+ createToProductConverter(classTag, st).asInstanceOf [Any => Any ]
587+
588+ case (t, StringType ) if t <:< typeOf[String ] =>
589+ (catalystValue : Any ) => catalystValue match {
590+ case utf8 : UTF8String => utf8.toString
591+ case o => o
592+ }
593+
594+ case (t, DateType ) if t <:< typeOf[Date ] =>
595+ (catalystValue : Any ) => catalystValue match {
596+ case i : Int => DateTimeUtils .toJavaDate(i)
597+ case o => o
598+ }
599+
600+ case (t, TimestampType ) if t <:< typeOf[Timestamp ] =>
601+ (catalystValue : Any ) => catalystValue match {
602+ case x : Long => DateTimeUtils .toJavaTimestamp(x)
603+ case o => o
604+ }
605+
606+ case (t, _ : DecimalType ) if t <:< typeOf[BigDecimal ] =>
607+ (catalystValue : Any ) => catalystValue match {
608+ case d : Decimal => d.toBigDecimal
609+ case o => o
610+ }
611+
612+ case (t, _ : DecimalType ) if t <:< typeOf[java.math.BigDecimal ] =>
613+ (catalystValue : Any ) => catalystValue match {
614+ case d : Decimal => d.toJavaBigDecimal
615+ case o => o
616+ }
617+
618+ // Pass non-string primitives through. (Strings are converted from UTF8Strings above.)
619+ // For everything else, hope for the best.
620+ case (t, o) => identity
621+ }
622+ }
457623}
0 commit comments