1717
1818package org .apache .spark .sql .catalyst
1919
20+ import org .apache .spark .sql .catalyst .util .DateTimeUtils
2021import org .apache .spark .unsafe .types .UTF8String
2122import org .apache .spark .util .Utils
2223import org .apache .spark .sql .catalyst .expressions ._
@@ -75,6 +76,242 @@ trait ScalaReflection {
7576 */
7677 private def localTypeOf [T : TypeTag ]: `Type` = typeTag[T ].in(mirror).tpe
7778
79+ /**
80+ * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
81+ * to a native type, an ObjectType is returned. Special handling is also used for Arrays including
82+ * those that hold primitive types.
83+ */
84+ def dataTypeFor (tpe : `Type`): DataType = tpe match {
85+ case t if t <:< definitions.IntTpe => IntegerType
86+ case t if t <:< definitions.LongTpe => LongType
87+ case t if t <:< definitions.DoubleTpe => DoubleType
88+ case t if t <:< definitions.FloatTpe => FloatType
89+ case t if t <:< definitions.ShortTpe => ShortType
90+ case t if t <:< definitions.ByteTpe => ByteType
91+ case t if t <:< definitions.BooleanTpe => BooleanType
92+ case t if t <:< localTypeOf[Array [Byte ]] => BinaryType
93+ case _ =>
94+ val className : String = tpe.erasure.typeSymbol.asClass.fullName
95+ className match {
96+ case " scala.Array" =>
97+ val TypeRef (_, _, Seq (arrayType)) = tpe
98+ val cls = arrayType match {
99+ case t if t <:< definitions.IntTpe => classOf [Array [Int ]]
100+ case t if t <:< definitions.LongTpe => classOf [Array [Long ]]
101+ case t if t <:< definitions.DoubleTpe => classOf [Array [Double ]]
102+ case t if t <:< definitions.FloatTpe => classOf [Array [Float ]]
103+ case t if t <:< definitions.ShortTpe => classOf [Array [Short ]]
104+ case t if t <:< definitions.ByteTpe => classOf [Array [Byte ]]
105+ case t if t <:< definitions.BooleanTpe => classOf [Array [Boolean ]]
106+ case other =>
107+ // There is probably a better way to do this, but I couldn't find it...
108+ val elementType = dataTypeFor(other).asInstanceOf [ObjectType ].cls
109+ java.lang.reflect.Array .newInstance(elementType, 1 ).getClass
110+
111+ }
112+ ObjectType (cls)
113+ case other => ObjectType (Utils .classForName(className))
114+ }
115+ }
116+
117+ /** Returns expressions for extracting all the fields from the given type. */
118+ def extractorsFor [T : TypeTag ](inputObject : Expression ): Seq [Expression ] = {
119+ ScalaReflectionLock .synchronized {
120+ extractorFor(inputObject, typeTag[T ].tpe).asInstanceOf [CreateStruct ].children
121+ }
122+ }
123+
124+ /** Helper for extracting internal fields from a case class. */
125+ protected def extractorFor (
126+ inputObject : Expression ,
127+ tpe : `Type`): Expression = ScalaReflectionLock .synchronized {
128+ if (! inputObject.dataType.isInstanceOf [ObjectType ]) {
129+ inputObject
130+ } else {
131+ tpe match {
132+ case t if t <:< localTypeOf[Option [_]] =>
133+ val TypeRef (_, _, Seq (optType)) = t
134+ optType match {
135+ // For primitive types we must manually unbox the value of the object.
136+ case t if t <:< definitions.IntTpe =>
137+ Invoke (
138+ UnwrapOption (ObjectType (classOf [java.lang.Integer ]), inputObject),
139+ " intValue" ,
140+ IntegerType )
141+ case t if t <:< definitions.LongTpe =>
142+ Invoke (
143+ UnwrapOption (ObjectType (classOf [java.lang.Long ]), inputObject),
144+ " longValue" ,
145+ LongType )
146+ case t if t <:< definitions.DoubleTpe =>
147+ Invoke (
148+ UnwrapOption (ObjectType (classOf [java.lang.Double ]), inputObject),
149+ " doubleValue" ,
150+ DoubleType )
151+ case t if t <:< definitions.FloatTpe =>
152+ Invoke (
153+ UnwrapOption (ObjectType (classOf [java.lang.Float ]), inputObject),
154+ " floatValue" ,
155+ FloatType )
156+ case t if t <:< definitions.ShortTpe =>
157+ Invoke (
158+ UnwrapOption (ObjectType (classOf [java.lang.Short ]), inputObject),
159+ " shortValue" ,
160+ ShortType )
161+ case t if t <:< definitions.ByteTpe =>
162+ Invoke (
163+ UnwrapOption (ObjectType (classOf [java.lang.Byte ]), inputObject),
164+ " byteValue" ,
165+ ByteType )
166+ case t if t <:< definitions.BooleanTpe =>
167+ Invoke (
168+ UnwrapOption (ObjectType (classOf [java.lang.Boolean ]), inputObject),
169+ " booleanValue" ,
170+ BooleanType )
171+
172+ // For non-primitives, we can just extract the object from the Option and then recurse.
173+ case other =>
174+ val className : String = optType.erasure.typeSymbol.asClass.fullName
175+ val classObj = Utils .classForName(className)
176+ val optionObjectType = ObjectType (classObj)
177+
178+ val unwrapped = UnwrapOption (optionObjectType, inputObject)
179+ expressions.If (
180+ IsNull (unwrapped),
181+ expressions.Literal .create(null , schemaFor(optType).dataType),
182+ extractorFor(unwrapped, optType))
183+ }
184+
185+ case t if t <:< localTypeOf[Product ] =>
186+ val formalTypeArgs = t.typeSymbol.asClass.typeParams
187+ val TypeRef (_, _, actualTypeArgs) = t
188+ val constructorSymbol = t.member(nme.CONSTRUCTOR )
189+ val params = if (constructorSymbol.isMethod) {
190+ constructorSymbol.asMethod.paramss
191+ } else {
192+ // Find the primary constructor, and use its parameter ordering.
193+ val primaryConstructorSymbol : Option [Symbol ] =
194+ constructorSymbol.asTerm.alternatives.find(s =>
195+ s.isMethod && s.asMethod.isPrimaryConstructor)
196+
197+ if (primaryConstructorSymbol.isEmpty) {
198+ sys.error(" Internal SQL error: Product object did not have a primary constructor." )
199+ } else {
200+ primaryConstructorSymbol.get.asMethod.paramss
201+ }
202+ }
203+
204+ CreateStruct (params.head.map { p =>
205+ val fieldName = p.name.toString
206+ val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
207+ val fieldValue = Invoke (inputObject, fieldName, dataTypeFor(fieldType))
208+ extractorFor(fieldValue, fieldType)
209+ })
210+
211+ case t if t <:< localTypeOf[Array [_]] =>
212+ val TypeRef (_, _, Seq (elementType)) = t
213+ val elementDataType = dataTypeFor(elementType)
214+ val Schema (dataType, nullable) = schemaFor(elementType)
215+
216+ if (! elementDataType.isInstanceOf [AtomicType ]) {
217+ MapObjects (extractorFor(_, elementType), inputObject, elementDataType)
218+ } else {
219+ NewInstance (
220+ classOf [GenericArrayData ],
221+ inputObject :: Nil ,
222+ dataType = ArrayType (dataType, nullable))
223+ }
224+
225+ case t if t <:< localTypeOf[Seq [_]] =>
226+ val TypeRef (_, _, Seq (elementType)) = t
227+ val elementDataType = dataTypeFor(elementType)
228+ val Schema (dataType, nullable) = schemaFor(elementType)
229+
230+ if (! elementDataType.isInstanceOf [AtomicType ]) {
231+ MapObjects (extractorFor(_, elementType), inputObject, elementDataType)
232+ } else {
233+ NewInstance (
234+ classOf [GenericArrayData ],
235+ inputObject :: Nil ,
236+ dataType = ArrayType (dataType, nullable))
237+ }
238+
239+ case t if t <:< localTypeOf[Map [_, _]] =>
240+ val TypeRef (_, _, Seq (keyType, valueType)) = t
241+ val Schema (keyDataType, _) = schemaFor(keyType)
242+ val Schema (valueDataType, valueNullable) = schemaFor(valueType)
243+
244+ val rawMap = inputObject
245+ val keys =
246+ NewInstance (
247+ classOf [GenericArrayData ],
248+ Invoke (rawMap, " keys" , ObjectType (classOf [scala.collection.GenIterable [_]])) :: Nil ,
249+ dataType = ObjectType (classOf [ArrayData ]))
250+ val values =
251+ NewInstance (
252+ classOf [GenericArrayData ],
253+ Invoke (rawMap, " values" , ObjectType (classOf [scala.collection.GenIterable [_]])) :: Nil ,
254+ dataType = ObjectType (classOf [ArrayData ]))
255+ NewInstance (
256+ classOf [ArrayBasedMapData ],
257+ keys :: values :: Nil ,
258+ dataType = MapType (keyDataType, valueDataType, valueNullable))
259+
260+ case t if t <:< localTypeOf[String ] =>
261+ StaticInvoke (
262+ classOf [UTF8String ],
263+ StringType ,
264+ " fromString" ,
265+ inputObject :: Nil )
266+
267+ case t if t <:< localTypeOf[java.sql.Timestamp ] =>
268+ StaticInvoke (
269+ DateTimeUtils ,
270+ TimestampType ,
271+ " fromJavaTimestamp" ,
272+ inputObject :: Nil )
273+
274+ case t if t <:< localTypeOf[java.sql.Date ] =>
275+ StaticInvoke (
276+ DateTimeUtils ,
277+ DateType ,
278+ " fromJavaDate" ,
279+ inputObject :: Nil )
280+ case t if t <:< localTypeOf[BigDecimal ] =>
281+ StaticInvoke (
282+ Decimal ,
283+ DecimalType .SYSTEM_DEFAULT ,
284+ " apply" ,
285+ inputObject :: Nil )
286+
287+ case t if t <:< localTypeOf[java.math.BigDecimal ] =>
288+ StaticInvoke (
289+ Decimal ,
290+ DecimalType .SYSTEM_DEFAULT ,
291+ " apply" ,
292+ inputObject :: Nil )
293+
294+ case t if t <:< localTypeOf[java.lang.Integer ] =>
295+ Invoke (inputObject, " intValue" , IntegerType )
296+ case t if t <:< localTypeOf[java.lang.Long ] =>
297+ Invoke (inputObject, " longValue" , LongType )
298+ case t if t <:< localTypeOf[java.lang.Double ] =>
299+ Invoke (inputObject, " doubleValue" , DoubleType )
300+ case t if t <:< localTypeOf[java.lang.Float ] =>
301+ Invoke (inputObject, " floatValue" , FloatType )
302+ case t if t <:< localTypeOf[java.lang.Short ] =>
303+ Invoke (inputObject, " shortValue" , ShortType )
304+ case t if t <:< localTypeOf[java.lang.Byte ] =>
305+ Invoke (inputObject, " byteValue" , ByteType )
306+ case t if t <:< localTypeOf[java.lang.Boolean ] =>
307+ Invoke (inputObject, " booleanValue" , BooleanType )
308+
309+ case other =>
310+ throw new UnsupportedOperationException (s " Extractor for type $other is not supported " )
311+ }
312+ }
313+ }
314+
78315 /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
79316 def schemaFor (tpe : `Type`): Schema = ScalaReflectionLock .synchronized {
80317 val className : String = tpe.erasure.typeSymbol.asClass.fullName
@@ -91,7 +328,6 @@ trait ScalaReflection {
91328 case t if t <:< localTypeOf[Option [_]] =>
92329 val TypeRef (_, _, Seq (optType)) = t
93330 Schema (schemaFor(optType).dataType, nullable = true )
94- // Need to decide if we actually need a special type here.
95331 case t if t <:< localTypeOf[Array [Byte ]] => Schema (BinaryType , nullable = true )
96332 case t if t <:< localTypeOf[Array [_]] =>
97333 val TypeRef (_, _, Seq (elementType)) = t
0 commit comments