-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-10993] [SQL] Inital code generated encoder for product types #9019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst | ||
|
|
||
| import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
@@ -75,6 +76,242 @@ trait ScalaReflection { | |
| */ | ||
| private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe | ||
|
|
||
| /** | ||
| * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping | ||
| * to a native type, an ObjectType is returned. Special handling is also used for Arrays including | ||
| * those that hold primitive types. | ||
| */ | ||
| def dataTypeFor(tpe: `Type`): DataType = tpe match { | ||
| case t if t <:< definitions.IntTpe => IntegerType | ||
| case t if t <:< definitions.LongTpe => LongType | ||
| case t if t <:< definitions.DoubleTpe => DoubleType | ||
| case t if t <:< definitions.FloatTpe => FloatType | ||
| case t if t <:< definitions.ShortTpe => ShortType | ||
| case t if t <:< definitions.ByteTpe => ByteType | ||
| case t if t <:< definitions.BooleanTpe => BooleanType | ||
| case t if t <:< localTypeOf[Array[Byte]] => BinaryType | ||
| case _ => | ||
| val className: String = tpe.erasure.typeSymbol.asClass.fullName | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, probably too late to change this now :) |
||
| className match { | ||
| case "scala.Array" => | ||
| val TypeRef(_, _, Seq(arrayType)) = tpe | ||
| val cls = arrayType match { | ||
| case t if t <:< definitions.IntTpe => classOf[Array[Int]] | ||
| case t if t <:< definitions.LongTpe => classOf[Array[Long]] | ||
| case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] | ||
| case t if t <:< definitions.FloatTpe => classOf[Array[Float]] | ||
| case t if t <:< definitions.ShortTpe => classOf[Array[Short]] | ||
| case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] | ||
| case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] | ||
| case other => | ||
| // There is probably a better way to do this, but I couldn't find it... | ||
| val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls | ||
| java.lang.reflect.Array.newInstance(elementType, 1).getClass | ||
|
|
||
| } | ||
| ObjectType(cls) | ||
| case other => ObjectType(Utils.classForName(className)) | ||
| } | ||
| } | ||
|
|
||
| /** Returns expressions for extracting all the fields from the given type. */ | ||
| def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = { | ||
| ScalaReflectionLock.synchronized { | ||
| extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children | ||
| } | ||
| } | ||
|
|
||
| /** Helper for extracting internal fields from a case class. */ | ||
| protected def extractorFor( | ||
| inputObject: Expression, | ||
| tpe: `Type`): Expression = ScalaReflectionLock.synchronized { | ||
| if (!inputObject.dataType.isInstanceOf[ObjectType]) { | ||
| inputObject | ||
| } else { | ||
| tpe match { | ||
| case t if t <:< localTypeOf[Option[_]] => | ||
| val TypeRef(_, _, Seq(optType)) = t | ||
| optType match { | ||
| // For primitive types we must manually unbox the value of the object. | ||
| case t if t <:< definitions.IntTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), | ||
| "intValue", | ||
| IntegerType) | ||
| case t if t <:< definitions.LongTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), | ||
| "longValue", | ||
| LongType) | ||
| case t if t <:< definitions.DoubleTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), | ||
| "doubleValue", | ||
| DoubleType) | ||
| case t if t <:< definitions.FloatTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), | ||
| "floatValue", | ||
| FloatType) | ||
| case t if t <:< definitions.ShortTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), | ||
| "shortValue", | ||
| ShortType) | ||
| case t if t <:< definitions.ByteTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), | ||
| "byteValue", | ||
| ByteType) | ||
| case t if t <:< definitions.BooleanTpe => | ||
| Invoke( | ||
| UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), | ||
| "booleanValue", | ||
| BooleanType) | ||
|
|
||
| // For non-primitives, we can just extract the object from the Option and then recurse. | ||
| case other => | ||
| val className: String = optType.erasure.typeSymbol.asClass.fullName | ||
| val classObj = Utils.classForName(className) | ||
| val optionObjectType = ObjectType(classObj) | ||
|
|
||
| val unwrapped = UnwrapOption(optionObjectType, inputObject) | ||
| expressions.If( | ||
| IsNull(unwrapped), | ||
| expressions.Literal.create(null, schemaFor(optType).dataType), | ||
| extractorFor(unwrapped, optType)) | ||
| } | ||
|
|
||
| case t if t <:< localTypeOf[Product] => | ||
| val formalTypeArgs = t.typeSymbol.asClass.typeParams | ||
| val TypeRef(_, _, actualTypeArgs) = t | ||
| val constructorSymbol = t.member(nme.CONSTRUCTOR) | ||
| val params = if (constructorSymbol.isMethod) { | ||
| constructorSymbol.asMethod.paramss | ||
| } else { | ||
| // Find the primary constructor, and use its parameter ordering. | ||
| val primaryConstructorSymbol: Option[Symbol] = | ||
| constructorSymbol.asTerm.alternatives.find(s => | ||
| s.isMethod && s.asMethod.isPrimaryConstructor) | ||
|
|
||
| if (primaryConstructorSymbol.isEmpty) { | ||
| sys.error("Internal SQL error: Product object did not have a primary constructor.") | ||
| } else { | ||
| primaryConstructorSymbol.get.asMethod.paramss | ||
| } | ||
| } | ||
|
|
||
| CreateStruct(params.head.map { p => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe explain why we need to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, this was copied from code that I coped from a stack overflow article a long time ago. |
||
| val fieldName = p.name.toString | ||
| val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) | ||
| val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) | ||
| extractorFor(fieldValue, fieldType) | ||
| }) | ||
|
|
||
| case t if t <:< localTypeOf[Array[_]] => | ||
| val TypeRef(_, _, Seq(elementType)) = t | ||
| val elementDataType = dataTypeFor(elementType) | ||
| val Schema(dataType, nullable) = schemaFor(elementType) | ||
|
|
||
| if (!elementDataType.isInstanceOf[AtomicType]) { | ||
| MapObjects(extractorFor(_, elementType), inputObject, elementDataType) | ||
| } else { | ||
| NewInstance( | ||
| classOf[GenericArrayData], | ||
| inputObject :: Nil, | ||
| dataType = ArrayType(dataType, nullable)) | ||
| } | ||
|
|
||
| case t if t <:< localTypeOf[Seq[_]] => | ||
| val TypeRef(_, _, Seq(elementType)) = t | ||
| val elementDataType = dataTypeFor(elementType) | ||
| val Schema(dataType, nullable) = schemaFor(elementType) | ||
|
|
||
| if (!elementDataType.isInstanceOf[AtomicType]) { | ||
| MapObjects(extractorFor(_, elementType), inputObject, elementDataType) | ||
| } else { | ||
| NewInstance( | ||
| classOf[GenericArrayData], | ||
| inputObject :: Nil, | ||
| dataType = ArrayType(dataType, nullable)) | ||
| } | ||
|
|
||
| case t if t <:< localTypeOf[Map[_, _]] => | ||
| val TypeRef(_, _, Seq(keyType, valueType)) = t | ||
| val Schema(keyDataType, _) = schemaFor(keyType) | ||
| val Schema(valueDataType, valueNullable) = schemaFor(valueType) | ||
|
|
||
| val rawMap = inputObject | ||
| val keys = | ||
| NewInstance( | ||
| classOf[GenericArrayData], | ||
| Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, | ||
| dataType = ObjectType(classOf[ArrayData])) | ||
| val values = | ||
| NewInstance( | ||
| classOf[GenericArrayData], | ||
| Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, | ||
| dataType = ObjectType(classOf[ArrayData])) | ||
| NewInstance( | ||
| classOf[ArrayBasedMapData], | ||
| keys :: values :: Nil, | ||
| dataType = MapType(keyDataType, valueDataType, valueNullable)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to use |
||
|
|
||
| case t if t <:< localTypeOf[String] => | ||
| StaticInvoke( | ||
| classOf[UTF8String], | ||
| StringType, | ||
| "fromString", | ||
| inputObject :: Nil) | ||
|
|
||
| case t if t <:< localTypeOf[java.sql.Timestamp] => | ||
| StaticInvoke( | ||
| DateTimeUtils, | ||
| TimestampType, | ||
| "fromJavaTimestamp", | ||
| inputObject :: Nil) | ||
|
|
||
| case t if t <:< localTypeOf[java.sql.Date] => | ||
| StaticInvoke( | ||
| DateTimeUtils, | ||
| DateType, | ||
| "fromJavaDate", | ||
| inputObject :: Nil) | ||
| case t if t <:< localTypeOf[BigDecimal] => | ||
| StaticInvoke( | ||
| Decimal, | ||
| DecimalType.SYSTEM_DEFAULT, | ||
| "apply", | ||
| inputObject :: Nil) | ||
|
|
||
| case t if t <:< localTypeOf[java.math.BigDecimal] => | ||
| StaticInvoke( | ||
| Decimal, | ||
| DecimalType.SYSTEM_DEFAULT, | ||
| "apply", | ||
| inputObject :: Nil) | ||
|
|
||
| case t if t <:< localTypeOf[java.lang.Integer] => | ||
| Invoke(inputObject, "intValue", IntegerType) | ||
| case t if t <:< localTypeOf[java.lang.Long] => | ||
| Invoke(inputObject, "longValue", LongType) | ||
| case t if t <:< localTypeOf[java.lang.Double] => | ||
| Invoke(inputObject, "doubleValue", DoubleType) | ||
| case t if t <:< localTypeOf[java.lang.Float] => | ||
| Invoke(inputObject, "floatValue", FloatType) | ||
| case t if t <:< localTypeOf[java.lang.Short] => | ||
| Invoke(inputObject, "shortValue", ShortType) | ||
| case t if t <:< localTypeOf[java.lang.Byte] => | ||
| Invoke(inputObject, "byteValue", ByteType) | ||
| case t if t <:< localTypeOf[java.lang.Boolean] => | ||
| Invoke(inputObject, "booleanValue", BooleanType) | ||
|
|
||
| case other => | ||
| throw new UnsupportedOperationException(s"Extractor for type $other is not supported") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ | ||
| def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { | ||
| val className: String = tpe.erasure.typeSymbol.asClass.fullName | ||
|
|
@@ -91,7 +328,6 @@ trait ScalaReflection { | |
| case t if t <:< localTypeOf[Option[_]] => | ||
| val TypeRef(_, _, Seq(optType)) = t | ||
| Schema(schemaFor(optType).dataType, nullable = true) | ||
| // Need to decide if we actually need a special type here. | ||
| case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) | ||
| case t if t <:< localTypeOf[Array[_]] => | ||
| val TypeRef(_, _, Seq(elementType)) = t | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.encoders | ||
|
|
||
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. | ||
| * | ||
| * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking | ||
| * and reuse internal buffers to improve performance. | ||
| */ | ||
| trait Encoder[T] { | ||
| /** Returns the schema of encoding this type of object as a Row. */ | ||
| def schema: StructType | ||
|
|
||
| /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ | ||
| def clsTag: ClassTag[T] | ||
|
|
||
| /** | ||
| * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to | ||
| * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should | ||
| * copy the result before making another call if required. | ||
| */ | ||
| def toRow(t: T): InternalRow | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.encoders | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection | ||
|
|
||
| import scala.reflect.ClassTag | ||
| import scala.reflect.runtime.universe.{typeTag, TypeTag} | ||
|
|
||
| import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} | ||
| import org.apache.spark.sql.types.{ObjectType, StructType} | ||
|
|
||
| /** | ||
| * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL | ||
| * internal binary representation. | ||
| */ | ||
| object ProductEncoder { | ||
| def apply[T <: Product : TypeTag]: Encoder[T] = { | ||
| // We convert the not-serializable TypeTag into StructType and ClassTag. | ||
| val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] | ||
| val mirror = typeTag[T].mirror | ||
| val cls = mirror.runtimeClass(typeTag[T].tpe) | ||
|
|
||
| val inputObject = BoundReference(0, ObjectType(cls), nullable = true) | ||
| val extractExpressions = ScalaReflection.extractorsFor[T](inputObject) | ||
| new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A generic encoder for JVM objects. | ||
| * | ||
| * @param schema The schema after converting `T` to a Spark SQL row. | ||
| * @param extractExpressions A set of expressions, one for each top-level field that can be used to | ||
| * extract the values from a raw object. | ||
| * @param clsTag A classtag for `T`. | ||
| */ | ||
| case class ClassEncoder[T]( | ||
| schema: StructType, | ||
| extractExpressions: Seq[Expression], | ||
| clsTag: ClassTag[T]) | ||
| extends Encoder[T] { | ||
|
|
||
| private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) | ||
| private val inputRow = new GenericMutableRow(1) | ||
|
|
||
| override def toRow(t: T): InternalRow = { | ||
| inputRow(0) = t | ||
| extractProjection(inputRow) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference between this method and
def schemaFor(tpe:Type): Schema?