Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -75,6 +76,242 @@ trait ScalaReflection {
*/
private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe

/**
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
* to a native type, an ObjectType is returned. Special handling is also used for Arrays including
* those that hold primitive types.
*/
def dataTypeFor(tpe: `Type`): DataType = tpe match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between this method and def schemaFor(tpe:Type): Schema?

case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case _ =>
val className: String = tpe.erasure.typeSymbol.asClass.fullName
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, probably too late to change this now :)

className match {
case "scala.Array" =>
val TypeRef(_, _, Seq(arrayType)) = tpe
val cls = arrayType match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
case other =>
// There is probably a better way to do this, but I couldn't find it...
val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
java.lang.reflect.Array.newInstance(elementType, 1).getClass

}
ObjectType(cls)
case other => ObjectType(Utils.classForName(className))
}
}

/** Returns expressions for extracting all the fields from the given type. */
def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
ScalaReflectionLock.synchronized {
extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children
}
}

/** Helper for extracting internal fields from a case class. */
protected def extractorFor(
inputObject: Expression,
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
optType match {
// For primitive types we must manually unbox the value of the object.
case t if t <:< definitions.IntTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
"intValue",
IntegerType)
case t if t <:< definitions.LongTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
"longValue",
LongType)
case t if t <:< definitions.DoubleTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
"doubleValue",
DoubleType)
case t if t <:< definitions.FloatTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
"floatValue",
FloatType)
case t if t <:< definitions.ShortTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
"shortValue",
ShortType)
case t if t <:< definitions.ByteTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
"byteValue",
ByteType)
case t if t <:< definitions.BooleanTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
"booleanValue",
BooleanType)

// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className: String = optType.erasure.typeSymbol.asClass.fullName
val classObj = Utils.classForName(className)
val optionObjectType = ObjectType(classObj)

val unwrapped = UnwrapOption(optionObjectType, inputObject)
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, schemaFor(optType).dataType),
extractorFor(unwrapped, optType))
}

case t if t <:< localTypeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val constructorSymbol = t.member(nme.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramss
} else {
// Find the primary constructor, and use its parameter ordering.
val primaryConstructorSymbol: Option[Symbol] =
constructorSymbol.asTerm.alternatives.find(s =>
s.isMethod && s.asMethod.isPrimaryConstructor)

if (primaryConstructorSymbol.isEmpty) {
sys.error("Internal SQL error: Product object did not have a primary constructor.")
} else {
primaryConstructorSymbol.get.asMethod.paramss
}
}

CreateStruct(params.head.map { p =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain why we need to use params.head?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, this was copied from code that I coped from a stack overflow article a long time ago.

val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
extractorFor(fieldValue, fieldType)
})

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val elementDataType = dataTypeFor(elementType)
val Schema(dataType, nullable) = schemaFor(elementType)

if (!elementDataType.isInstanceOf[AtomicType]) {
MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
} else {
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = ArrayType(dataType, nullable))
}

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val elementDataType = dataTypeFor(elementType)
val Schema(dataType, nullable) = schemaFor(elementType)

if (!elementDataType.isInstanceOf[AtomicType]) {
MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
} else {
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = ArrayType(dataType, nullable))
}

case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(keyDataType, _) = schemaFor(keyType)
val Schema(valueDataType, valueNullable) = schemaFor(valueType)

val rawMap = inputObject
val keys =
NewInstance(
classOf[GenericArrayData],
Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
dataType = ObjectType(classOf[ArrayData]))
val values =
NewInstance(
classOf[GenericArrayData],
Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
dataType = ObjectType(classOf[ArrayData]))
NewInstance(
classOf[ArrayBasedMapData],
keys :: values :: Nil,
dataType = MapType(keyDataType, valueDataType, valueNullable))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use MapObjects for non-primitive types?


case t if t <:< localTypeOf[String] =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)

case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)

case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
DateType,
"fromJavaDate",
inputObject :: Nil)
case t if t <:< localTypeOf[BigDecimal] =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)

case t if t <:< localTypeOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)

case t if t <:< localTypeOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case t if t <:< localTypeOf[java.lang.Long] =>
Invoke(inputObject, "longValue", LongType)
case t if t <:< localTypeOf[java.lang.Double] =>
Invoke(inputObject, "doubleValue", DoubleType)
case t if t <:< localTypeOf[java.lang.Float] =>
Invoke(inputObject, "floatValue", FloatType)
case t if t <:< localTypeOf[java.lang.Short] =>
Invoke(inputObject, "shortValue", ShortType)
case t if t <:< localTypeOf[java.lang.Byte] =>
Invoke(inputObject, "byteValue", ByteType)
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)

case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
}
}
}

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
Expand All @@ -91,7 +328,6 @@ trait ScalaReflection {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.encoders

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType

/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
*
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
* and reuse internal buffers to improve performance.
*/
trait Encoder[T] {
/** Returns the schema of encoding this type of object as a Row. */
def schema: StructType

/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
*/
def toRow(t: T): InternalRow
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.encoders

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
import org.apache.spark.sql.types.{ObjectType, StructType}

/**
* A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
* internal binary representation.
*/
object ProductEncoder {
def apply[T <: Product : TypeTag]: Encoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(typeTag[T].tpe)

val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls))
}
}

/**
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
* @param extractExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object.
* @param clsTag A classtag for `T`.
*/
case class ClassEncoder[T](
schema: StructType,
extractExpressions: Seq[Expression],
clsTag: ClassTag[T])
extends Encoder[T] {

private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private val inputRow = new GenericMutableRow(1)

override def toRow(t: T): InternalRow = {
inputRow(0) = t
extractProjection(inputRow)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class CodeGenContext {
case _: MapType => "MapData"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
case _ => "Object"
}

Expand Down Expand Up @@ -395,7 +397,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
evaluator.setDebuggingInformation(false, true, false)
evaluator.setDebuggingInformation(true, true, false)
withLineNums
})

Expand Down
Loading