Skip to content

Commit 67d0688

Browse files
vlyubinmarmbrus
authored andcommitted
[SQL] [SPARK-6620] Speed up toDF() and rdd() functions by constructing converters in ScalaReflection
cc marmbrus Author: Volodymyr Lyubinets <[email protected]> Closes #5279 from vlyubin/speedup and squashes the following commits: e75a387 [Volodymyr Lyubinets] Changes to ScalaUDF 11a20ec [Volodymyr Lyubinets] Avoid creating a tuple c327bc9 [Volodymyr Lyubinets] Moved the only remaining function from DataTypeConversions to DateUtils dec6802 [Volodymyr Lyubinets] Addresed review feedback 74301fa [Volodymyr Lyubinets] Addressed review comments afa3aa5 [Volodymyr Lyubinets] Minor refactoring, added license, removed debug output 881dc60 [Volodymyr Lyubinets] Moved to a separate module; addressed review comments; one extra place of usage; changed behaviour for Java 8cad6e2 [Volodymyr Lyubinets] Addressed review commments 41b2aa9 [Volodymyr Lyubinets] Creating converters for ScalaReflection stuff, and more
1 parent 23d5f88 commit 67d0688

File tree

17 files changed

+929
-461
lines changed

17 files changed

+929
-461
lines changed

mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2525
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2626

2727
@BeanInfo
28-
case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
29-
/** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
30-
def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
31-
}
28+
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
3229

3330
class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
3431
import org.apache.spark.ml.feature.RegexTokenizerSuite._
@@ -46,14 +43,14 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
4643
.setOutputCol("tokens")
4744

4845
val dataset0 = sqlContext.createDataFrame(Seq(
49-
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
50-
TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
46+
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")),
47+
TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct"))
5148
))
5249
testRegexTokenizer(tokenizer, dataset0)
5350

5451
val dataset1 = sqlContext.createDataFrame(Seq(
55-
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
56-
TokenizerTestData("Te,st. punct", Seq("punct"))
52+
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")),
53+
TokenizerTestData("Te,st. punct", Array("punct"))
5754
))
5855

5956
tokenizer.setMinTokenLength(3)
@@ -64,8 +61,8 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
6461
.setGaps(true)
6562
.setMinTokenLength(0)
6663
val dataset2 = sqlContext.createDataFrame(Seq(
67-
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
68-
TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
64+
TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")),
65+
TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct"))
6966
))
7067
testRegexTokenizer(tokenizer, dataset2)
7168
}
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst
19+
20+
import java.util.{Map => JavaMap}
21+
22+
import scala.collection.mutable.HashMap
23+
24+
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.types._
26+
27+
/**
28+
* Functions to convert Scala types to Catalyst types and vice versa.
29+
*/
30+
object CatalystTypeConverters {
31+
// The Predef.Map is scala.collection.immutable.Map.
32+
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
33+
import scala.collection.Map
34+
35+
/**
36+
* Converts Scala objects to catalyst rows / types. This method is slow, and for batch
37+
* conversion you should be using converter produced by createToCatalystConverter.
38+
* Note: This is always called after schemaFor has been called.
39+
* This ordering is important for UDT registration.
40+
*/
41+
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
42+
// Check UDT first since UDTs can override other types
43+
case (obj, udt: UserDefinedType[_]) =>
44+
udt.serialize(obj)
45+
46+
case (o: Option[_], _) =>
47+
o.map(convertToCatalyst(_, dataType)).orNull
48+
49+
case (s: Seq[_], arrayType: ArrayType) =>
50+
s.map(convertToCatalyst(_, arrayType.elementType))
51+
52+
case (s: Array[_], arrayType: ArrayType) =>
53+
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
54+
55+
case (m: Map[_, _], mapType: MapType) =>
56+
m.map { case (k, v) =>
57+
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
58+
}
59+
60+
case (jmap: JavaMap[_, _], mapType: MapType) =>
61+
val iter = jmap.entrySet.iterator
62+
var listOfEntries: List[(Any, Any)] = List()
63+
while (iter.hasNext) {
64+
val entry = iter.next()
65+
listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType),
66+
convertToCatalyst(entry.getValue, mapType.valueType))
67+
}
68+
listOfEntries.toMap
69+
70+
case (p: Product, structType: StructType) =>
71+
val ar = new Array[Any](structType.size)
72+
val iter = p.productIterator
73+
var idx = 0
74+
while (idx < structType.size) {
75+
ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType)
76+
idx += 1
77+
}
78+
new GenericRowWithSchema(ar, structType)
79+
80+
case (d: BigDecimal, _) =>
81+
Decimal(d)
82+
83+
case (d: java.math.BigDecimal, _) =>
84+
Decimal(d)
85+
86+
case (d: java.sql.Date, _) =>
87+
DateUtils.fromJavaDate(d)
88+
89+
case (r: Row, structType: StructType) =>
90+
val converters = structType.fields.map {
91+
f => (item: Any) => convertToCatalyst(item, f.dataType)
92+
}
93+
convertRowWithConverters(r, structType, converters)
94+
95+
case (other, _) =>
96+
other
97+
}
98+
99+
/**
100+
* Creates a converter function that will convert Scala objects to the specified catalyst type.
101+
* Typical use case would be converting a collection of rows that have the same schema. You will
102+
* call this function once to get a converter, and apply it to every row.
103+
*/
104+
private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
105+
def extractOption(item: Any): Any = item match {
106+
case opt: Option[_] => opt.orNull
107+
case other => other
108+
}
109+
110+
dataType match {
111+
// Check UDT first since UDTs can override other types
112+
case udt: UserDefinedType[_] =>
113+
(item) => extractOption(item) match {
114+
case null => null
115+
case other => udt.serialize(other)
116+
}
117+
118+
case arrayType: ArrayType =>
119+
val elementConverter = createToCatalystConverter(arrayType.elementType)
120+
(item: Any) => {
121+
extractOption(item) match {
122+
case a: Array[_] => a.toSeq.map(elementConverter)
123+
case s: Seq[_] => s.map(elementConverter)
124+
case null => null
125+
}
126+
}
127+
128+
case mapType: MapType =>
129+
val keyConverter = createToCatalystConverter(mapType.keyType)
130+
val valueConverter = createToCatalystConverter(mapType.valueType)
131+
(item: Any) => {
132+
extractOption(item) match {
133+
case m: Map[_, _] =>
134+
m.map { case (k, v) =>
135+
keyConverter(k) -> valueConverter(v)
136+
}
137+
138+
case jmap: JavaMap[_, _] =>
139+
val iter = jmap.entrySet.iterator
140+
val convertedMap: HashMap[Any, Any] = HashMap()
141+
while (iter.hasNext) {
142+
val entry = iter.next()
143+
convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue)
144+
}
145+
convertedMap
146+
147+
case null => null
148+
}
149+
}
150+
151+
case structType: StructType =>
152+
val converters = structType.fields.map(f => createToCatalystConverter(f.dataType))
153+
(item: Any) => {
154+
extractOption(item) match {
155+
case r: Row =>
156+
convertRowWithConverters(r, structType, converters)
157+
158+
case p: Product =>
159+
val ar = new Array[Any](structType.size)
160+
val iter = p.productIterator
161+
var idx = 0
162+
while (idx < structType.size) {
163+
ar(idx) = converters(idx)(iter.next())
164+
idx += 1
165+
}
166+
new GenericRowWithSchema(ar, structType)
167+
168+
case null =>
169+
null
170+
}
171+
}
172+
173+
case dateType: DateType => (item: Any) => extractOption(item) match {
174+
case d: java.sql.Date => DateUtils.fromJavaDate(d)
175+
case other => other
176+
}
177+
178+
case _ =>
179+
(item: Any) => extractOption(item) match {
180+
case d: BigDecimal => Decimal(d)
181+
case d: java.math.BigDecimal => Decimal(d)
182+
case other => other
183+
}
184+
}
185+
}
186+
187+
/**
188+
* Converts Catalyst types used internally in rows to standard Scala types
189+
* This method is slow, and for batch conversion you should be using converter
190+
* produced by createToScalaConverter.
191+
*/
192+
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
193+
// Check UDT first since UDTs can override other types
194+
case (d, udt: UserDefinedType[_]) =>
195+
udt.deserialize(d)
196+
197+
case (s: Seq[_], arrayType: ArrayType) =>
198+
s.map(convertToScala(_, arrayType.elementType))
199+
200+
case (m: Map[_, _], mapType: MapType) =>
201+
m.map { case (k, v) =>
202+
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
203+
}
204+
205+
case (r: Row, s: StructType) =>
206+
convertRowToScala(r, s)
207+
208+
case (d: Decimal, _: DecimalType) =>
209+
d.toJavaBigDecimal
210+
211+
case (i: Int, DateType) =>
212+
DateUtils.toJavaDate(i)
213+
214+
case (other, _) =>
215+
other
216+
}
217+
218+
/**
219+
* Creates a converter function that will convert Catalyst types to Scala type.
220+
* Typical use case would be converting a collection of rows that have the same schema. You will
221+
* call this function once to get a converter, and apply it to every row.
222+
*/
223+
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match {
224+
// Check UDT first since UDTs can override other types
225+
case udt: UserDefinedType[_] =>
226+
(item: Any) => if (item == null) null else udt.deserialize(item)
227+
228+
case arrayType: ArrayType =>
229+
val elementConverter = createToScalaConverter(arrayType.elementType)
230+
(item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter)
231+
232+
case mapType: MapType =>
233+
val keyConverter = createToScalaConverter(mapType.keyType)
234+
val valueConverter = createToScalaConverter(mapType.valueType)
235+
(item: Any) => if (item == null) {
236+
null
237+
} else {
238+
item.asInstanceOf[Map[_, _]].map { case (k, v) =>
239+
keyConverter(k) -> valueConverter(v)
240+
}
241+
}
242+
243+
case s: StructType =>
244+
val converters = s.fields.map(f => createToScalaConverter(f.dataType))
245+
(item: Any) => {
246+
if (item == null) {
247+
null
248+
} else {
249+
convertRowWithConverters(item.asInstanceOf[Row], s, converters)
250+
}
251+
}
252+
253+
case _: DecimalType =>
254+
(item: Any) => item match {
255+
case d: Decimal => d.toJavaBigDecimal
256+
case other => other
257+
}
258+
259+
case DateType =>
260+
(item: Any) => item match {
261+
case i: Int => DateUtils.toJavaDate(i)
262+
case other => other
263+
}
264+
265+
case other =>
266+
(item: Any) => item
267+
}
268+
269+
def convertRowToScala(r: Row, schema: StructType): Row = {
270+
val ar = new Array[Any](r.size)
271+
var idx = 0
272+
while (idx < r.size) {
273+
ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType)
274+
idx += 1
275+
}
276+
new GenericRowWithSchema(ar, schema)
277+
}
278+
279+
/**
280+
* Converts a row by applying the provided set of converter functions. It is used for both
281+
* toScala and toCatalyst conversions.
282+
*/
283+
private[sql] def convertRowWithConverters(
284+
row: Row,
285+
schema: StructType,
286+
converters: Array[Any => Any]): Row = {
287+
val ar = new Array[Any](row.size)
288+
var idx = 0
289+
while (idx < row.size) {
290+
ar(idx) = converters(idx)(row(idx))
291+
idx += 1
292+
}
293+
new GenericRowWithSchema(ar, schema)
294+
}
295+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -46,61 +46,6 @@ trait ScalaReflection {
4646

4747
case class Schema(dataType: DataType, nullable: Boolean)
4848

49-
/**
50-
* Converts Scala objects to catalyst rows / types.
51-
* Note: This is always called after schemaFor has been called.
52-
* This ordering is important for UDT registration.
53-
*/
54-
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
55-
// Check UDT first since UDTs can override other types
56-
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
57-
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
58-
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
59-
case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) {
60-
s.toSeq
61-
} else {
62-
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
63-
}
64-
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
65-
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
66-
}
67-
case (p: Product, structType: StructType) =>
68-
new GenericRow(
69-
p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
70-
convertToCatalyst(elem, field.dataType)
71-
}.toArray)
72-
case (d: BigDecimal, _) => Decimal(d)
73-
case (d: java.math.BigDecimal, _) => Decimal(d)
74-
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
75-
case (r: Row, structType: StructType) =>
76-
new GenericRow(
77-
r.toSeq.zip(structType.fields).map { case (elem, field) =>
78-
convertToCatalyst(elem, field.dataType)
79-
}.toArray)
80-
case (other, _) => other
81-
}
82-
83-
/** Converts Catalyst types used internally in rows to standard Scala types */
84-
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
85-
// Check UDT first since UDTs can override other types
86-
case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
87-
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
88-
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
89-
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
90-
}
91-
case (r: Row, s: StructType) => convertRowToScala(r, s)
92-
case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal
93-
case (i: Int, DateType) => DateUtils.toJavaDate(i)
94-
case (other, _) => other
95-
}
96-
97-
def convertRowToScala(r: Row, schema: StructType): Row = {
98-
// TODO: This is very slow!!!
99-
new GenericRowWithSchema(
100-
r.toSeq.zip(schema.fields.map(_.dataType))
101-
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema)
102-
}
103-
10449
/** Returns a Sequence of attributes for the given case class type. */
10550
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
10651
case Schema(s: StructType, _) =>

0 commit comments

Comments
 (0)