Skip to content

Commit 9e66a53

Browse files
marmbrusyhuai
authored andcommitted
[SPARK-10993] [SQL] Inital code generated encoder for product types
This PR is a first cut at code generating an encoder that takes a Scala `Product` type and converts it directly into the tungsten binary format. This is done through the addition of a new set of expression that can be used to invoke methods on raw JVM objects, extracting fields and converting the result into the required format. These can then be used directly in an `UnsafeProjection` allowing us to leverage the existing encoding logic. According to some simple benchmarks, this can significantly speed up conversion (~4x). However, replacing CatalystConverters is deferred to a later PR to keep this PR at a reasonable size. ```scala case class SomeInts(a: Int, b: Int, c: Int, d: Int, e: Int) val data = SomeInts(1, 2, 3, 4, 5) val encoder = ProductEncoder[SomeInts] val converter = CatalystTypeConverters.createToCatalystConverter(ScalaReflection.schemaFor[SomeInts].dataType) (1 to 5).foreach {iter => benchmark(s"converter $iter") { var i = 100000000 while (i > 0) { val res = converter(data).asInstanceOf[InternalRow] assert(res.getInt(0) == 1) assert(res.getInt(1) == 2) i -= 1 } } benchmark(s"encoder $iter") { var i = 100000000 while (i > 0) { val res = encoder.toRow(data) assert(res.getInt(0) == 1) assert(res.getInt(1) == 2) i -= 1 } } } ``` Results: ``` [info] converter 1: 7170ms [info] encoder 1: 1888ms [info] converter 2: 6763ms [info] encoder 2: 1824ms [info] converter 3: 6912ms [info] encoder 3: 1802ms [info] converter 4: 7131ms [info] encoder 4: 1798ms [info] converter 5: 7350ms [info] encoder 5: 1912ms ``` Author: Michael Armbrust <[email protected]> Closes #9019 from marmbrus/productEncoder.
1 parent a8226a9 commit 9e66a53

File tree

8 files changed

+910
-2
lines changed

8 files changed

+910
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 237 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2021
import org.apache.spark.unsafe.types.UTF8String
2122
import org.apache.spark.util.Utils
2223
import org.apache.spark.sql.catalyst.expressions._
@@ -75,6 +76,242 @@ trait ScalaReflection {
7576
*/
7677
private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
7778

79+
/**
80+
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
81+
* to a native type, an ObjectType is returned. Special handling is also used for Arrays including
82+
* those that hold primitive types.
83+
*/
84+
def dataTypeFor(tpe: `Type`): DataType = tpe match {
85+
case t if t <:< definitions.IntTpe => IntegerType
86+
case t if t <:< definitions.LongTpe => LongType
87+
case t if t <:< definitions.DoubleTpe => DoubleType
88+
case t if t <:< definitions.FloatTpe => FloatType
89+
case t if t <:< definitions.ShortTpe => ShortType
90+
case t if t <:< definitions.ByteTpe => ByteType
91+
case t if t <:< definitions.BooleanTpe => BooleanType
92+
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
93+
case _ =>
94+
val className: String = tpe.erasure.typeSymbol.asClass.fullName
95+
className match {
96+
case "scala.Array" =>
97+
val TypeRef(_, _, Seq(arrayType)) = tpe
98+
val cls = arrayType match {
99+
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
100+
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
101+
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
102+
case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
103+
case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
104+
case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
105+
case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
106+
case other =>
107+
// There is probably a better way to do this, but I couldn't find it...
108+
val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
109+
java.lang.reflect.Array.newInstance(elementType, 1).getClass
110+
111+
}
112+
ObjectType(cls)
113+
case other => ObjectType(Utils.classForName(className))
114+
}
115+
}
116+
117+
/** Returns expressions for extracting all the fields from the given type. */
118+
def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
119+
ScalaReflectionLock.synchronized {
120+
extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children
121+
}
122+
}
123+
124+
/** Helper for extracting internal fields from a case class. */
125+
protected def extractorFor(
126+
inputObject: Expression,
127+
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
128+
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
129+
inputObject
130+
} else {
131+
tpe match {
132+
case t if t <:< localTypeOf[Option[_]] =>
133+
val TypeRef(_, _, Seq(optType)) = t
134+
optType match {
135+
// For primitive types we must manually unbox the value of the object.
136+
case t if t <:< definitions.IntTpe =>
137+
Invoke(
138+
UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
139+
"intValue",
140+
IntegerType)
141+
case t if t <:< definitions.LongTpe =>
142+
Invoke(
143+
UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
144+
"longValue",
145+
LongType)
146+
case t if t <:< definitions.DoubleTpe =>
147+
Invoke(
148+
UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
149+
"doubleValue",
150+
DoubleType)
151+
case t if t <:< definitions.FloatTpe =>
152+
Invoke(
153+
UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
154+
"floatValue",
155+
FloatType)
156+
case t if t <:< definitions.ShortTpe =>
157+
Invoke(
158+
UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
159+
"shortValue",
160+
ShortType)
161+
case t if t <:< definitions.ByteTpe =>
162+
Invoke(
163+
UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
164+
"byteValue",
165+
ByteType)
166+
case t if t <:< definitions.BooleanTpe =>
167+
Invoke(
168+
UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
169+
"booleanValue",
170+
BooleanType)
171+
172+
// For non-primitives, we can just extract the object from the Option and then recurse.
173+
case other =>
174+
val className: String = optType.erasure.typeSymbol.asClass.fullName
175+
val classObj = Utils.classForName(className)
176+
val optionObjectType = ObjectType(classObj)
177+
178+
val unwrapped = UnwrapOption(optionObjectType, inputObject)
179+
expressions.If(
180+
IsNull(unwrapped),
181+
expressions.Literal.create(null, schemaFor(optType).dataType),
182+
extractorFor(unwrapped, optType))
183+
}
184+
185+
case t if t <:< localTypeOf[Product] =>
186+
val formalTypeArgs = t.typeSymbol.asClass.typeParams
187+
val TypeRef(_, _, actualTypeArgs) = t
188+
val constructorSymbol = t.member(nme.CONSTRUCTOR)
189+
val params = if (constructorSymbol.isMethod) {
190+
constructorSymbol.asMethod.paramss
191+
} else {
192+
// Find the primary constructor, and use its parameter ordering.
193+
val primaryConstructorSymbol: Option[Symbol] =
194+
constructorSymbol.asTerm.alternatives.find(s =>
195+
s.isMethod && s.asMethod.isPrimaryConstructor)
196+
197+
if (primaryConstructorSymbol.isEmpty) {
198+
sys.error("Internal SQL error: Product object did not have a primary constructor.")
199+
} else {
200+
primaryConstructorSymbol.get.asMethod.paramss
201+
}
202+
}
203+
204+
CreateStruct(params.head.map { p =>
205+
val fieldName = p.name.toString
206+
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
207+
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
208+
extractorFor(fieldValue, fieldType)
209+
})
210+
211+
case t if t <:< localTypeOf[Array[_]] =>
212+
val TypeRef(_, _, Seq(elementType)) = t
213+
val elementDataType = dataTypeFor(elementType)
214+
val Schema(dataType, nullable) = schemaFor(elementType)
215+
216+
if (!elementDataType.isInstanceOf[AtomicType]) {
217+
MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
218+
} else {
219+
NewInstance(
220+
classOf[GenericArrayData],
221+
inputObject :: Nil,
222+
dataType = ArrayType(dataType, nullable))
223+
}
224+
225+
case t if t <:< localTypeOf[Seq[_]] =>
226+
val TypeRef(_, _, Seq(elementType)) = t
227+
val elementDataType = dataTypeFor(elementType)
228+
val Schema(dataType, nullable) = schemaFor(elementType)
229+
230+
if (!elementDataType.isInstanceOf[AtomicType]) {
231+
MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
232+
} else {
233+
NewInstance(
234+
classOf[GenericArrayData],
235+
inputObject :: Nil,
236+
dataType = ArrayType(dataType, nullable))
237+
}
238+
239+
case t if t <:< localTypeOf[Map[_, _]] =>
240+
val TypeRef(_, _, Seq(keyType, valueType)) = t
241+
val Schema(keyDataType, _) = schemaFor(keyType)
242+
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
243+
244+
val rawMap = inputObject
245+
val keys =
246+
NewInstance(
247+
classOf[GenericArrayData],
248+
Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
249+
dataType = ObjectType(classOf[ArrayData]))
250+
val values =
251+
NewInstance(
252+
classOf[GenericArrayData],
253+
Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
254+
dataType = ObjectType(classOf[ArrayData]))
255+
NewInstance(
256+
classOf[ArrayBasedMapData],
257+
keys :: values :: Nil,
258+
dataType = MapType(keyDataType, valueDataType, valueNullable))
259+
260+
case t if t <:< localTypeOf[String] =>
261+
StaticInvoke(
262+
classOf[UTF8String],
263+
StringType,
264+
"fromString",
265+
inputObject :: Nil)
266+
267+
case t if t <:< localTypeOf[java.sql.Timestamp] =>
268+
StaticInvoke(
269+
DateTimeUtils,
270+
TimestampType,
271+
"fromJavaTimestamp",
272+
inputObject :: Nil)
273+
274+
case t if t <:< localTypeOf[java.sql.Date] =>
275+
StaticInvoke(
276+
DateTimeUtils,
277+
DateType,
278+
"fromJavaDate",
279+
inputObject :: Nil)
280+
case t if t <:< localTypeOf[BigDecimal] =>
281+
StaticInvoke(
282+
Decimal,
283+
DecimalType.SYSTEM_DEFAULT,
284+
"apply",
285+
inputObject :: Nil)
286+
287+
case t if t <:< localTypeOf[java.math.BigDecimal] =>
288+
StaticInvoke(
289+
Decimal,
290+
DecimalType.SYSTEM_DEFAULT,
291+
"apply",
292+
inputObject :: Nil)
293+
294+
case t if t <:< localTypeOf[java.lang.Integer] =>
295+
Invoke(inputObject, "intValue", IntegerType)
296+
case t if t <:< localTypeOf[java.lang.Long] =>
297+
Invoke(inputObject, "longValue", LongType)
298+
case t if t <:< localTypeOf[java.lang.Double] =>
299+
Invoke(inputObject, "doubleValue", DoubleType)
300+
case t if t <:< localTypeOf[java.lang.Float] =>
301+
Invoke(inputObject, "floatValue", FloatType)
302+
case t if t <:< localTypeOf[java.lang.Short] =>
303+
Invoke(inputObject, "shortValue", ShortType)
304+
case t if t <:< localTypeOf[java.lang.Byte] =>
305+
Invoke(inputObject, "byteValue", ByteType)
306+
case t if t <:< localTypeOf[java.lang.Boolean] =>
307+
Invoke(inputObject, "booleanValue", BooleanType)
308+
309+
case other =>
310+
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
311+
}
312+
}
313+
}
314+
78315
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
79316
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
80317
val className: String = tpe.erasure.typeSymbol.asClass.fullName
@@ -91,7 +328,6 @@ trait ScalaReflection {
91328
case t if t <:< localTypeOf[Option[_]] =>
92329
val TypeRef(_, _, Seq(optType)) = t
93330
Schema(schemaFor(optType).dataType, nullable = true)
94-
// Need to decide if we actually need a special type here.
95331
case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
96332
case t if t <:< localTypeOf[Array[_]] =>
97333
val TypeRef(_, _, Seq(elementType)) = t
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.encoders
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.types.StructType
24+
25+
/**
26+
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
27+
*
28+
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
29+
* and reuse internal buffers to improve performance.
30+
*/
31+
trait Encoder[T] {
32+
/** Returns the schema of encoding this type of object as a Row. */
33+
def schema: StructType
34+
35+
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
36+
def clsTag: ClassTag[T]
37+
38+
/**
39+
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
40+
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
41+
* copy the result before making another call if required.
42+
*/
43+
def toRow(t: T): InternalRow
44+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.encoders
19+
20+
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
22+
23+
import scala.reflect.ClassTag
24+
import scala.reflect.runtime.universe.{typeTag, TypeTag}
25+
26+
import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
27+
import org.apache.spark.sql.types.{ObjectType, StructType}
28+
29+
/**
30+
* A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
31+
* internal binary representation.
32+
*/
33+
object ProductEncoder {
34+
def apply[T <: Product : TypeTag]: Encoder[T] = {
35+
// We convert the not-serializable TypeTag into StructType and ClassTag.
36+
val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
37+
val mirror = typeTag[T].mirror
38+
val cls = mirror.runtimeClass(typeTag[T].tpe)
39+
40+
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
41+
val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
42+
new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls))
43+
}
44+
}
45+
46+
/**
47+
* A generic encoder for JVM objects.
48+
*
49+
* @param schema The schema after converting `T` to a Spark SQL row.
50+
* @param extractExpressions A set of expressions, one for each top-level field that can be used to
51+
* extract the values from a raw object.
52+
* @param clsTag A classtag for `T`.
53+
*/
54+
case class ClassEncoder[T](
55+
schema: StructType,
56+
extractExpressions: Seq[Expression],
57+
clsTag: ClassTag[T])
58+
extends Encoder[T] {
59+
60+
private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
61+
private val inputRow = new GenericMutableRow(1)
62+
63+
override def toRow(t: T): InternalRow = {
64+
inputRow(0) = t
65+
extractProjection(inputRow)
66+
}
67+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ class CodeGenContext {
177177
case _: MapType => "MapData"
178178
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
179179
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
180+
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
181+
case ObjectType(cls) => cls.getName
180182
case _ => "Object"
181183
}
182184

@@ -395,7 +397,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
395397

396398
logDebug({
397399
// Only add extra debugging info to byte code when we are going to print the source code.
398-
evaluator.setDebuggingInformation(false, true, false)
400+
evaluator.setDebuggingInformation(true, true, false)
399401
withLineNums
400402
})
401403

0 commit comments

Comments
 (0)