Skip to content

Commit 1a79f0e

Browse files
cloud-fanDavies Liu
authored andcommitted
[SPARK-8635] [SQL] improve performance of CatalystTypeConverters
In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance. Author: Wenchen Fan <[email protected]> Closes #7018 from cloud-fan/converter and squashes the following commits: 8b16630 [Wenchen Fan] another fix 326c82c [Wenchen Fan] optimize type converter
1 parent 4036011 commit 1a79f0e

File tree

8 files changed

+48
-33
lines changed

8 files changed

+48
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ object CatalystTypeConverters {
5252
}
5353
}
5454

55+
private def isWholePrimitive(dt: DataType): Boolean = dt match {
56+
case dt if isPrimitive(dt) => true
57+
case ArrayType(elementType, _) => isWholePrimitive(elementType)
58+
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
59+
case _ => false
60+
}
61+
5562
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
5663
val converter = dataType match {
5764
case udt: UserDefinedType[_] => UDTConverter(udt)
@@ -148,6 +155,8 @@ object CatalystTypeConverters {
148155

149156
private[this] val elementConverter = getConverterForType(elementType)
150157

158+
private[this] val isNoChange = isWholePrimitive(elementType)
159+
151160
override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
152161
scalaValue match {
153162
case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
@@ -166,8 +175,10 @@ object CatalystTypeConverters {
166175
override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
167176
if (catalystValue == null) {
168177
null
178+
} else if (isNoChange) {
179+
catalystValue
169180
} else {
170-
catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
181+
catalystValue.map(elementConverter.toScala)
171182
}
172183
}
173184

@@ -183,6 +194,8 @@ object CatalystTypeConverters {
183194
private[this] val keyConverter = getConverterForType(keyType)
184195
private[this] val valueConverter = getConverterForType(valueType)
185196

197+
private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType)
198+
186199
override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
187200
case m: Map[_, _] =>
188201
m.map { case (k, v) =>
@@ -203,6 +216,8 @@ object CatalystTypeConverters {
203216
override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
204217
if (catalystValue == null) {
205218
null
219+
} else if (isNoChange) {
220+
catalystValue
206221
} else {
207222
catalystValue.map { case (k, v) =>
208223
keyConverter.toScala(k) -> valueConverter.toScala(v)
@@ -258,24 +273,22 @@ object CatalystTypeConverters {
258273
toScala(row(column).asInstanceOf[InternalRow])
259274
}
260275

261-
private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
276+
private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
262277
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
263278
case str: String => UTF8String.fromString(str)
264279
case utf8: UTF8String => utf8
265280
}
266-
override def toScala(catalystValue: Any): String = catalystValue match {
267-
case null => null
268-
case str: String => str
269-
case utf8: UTF8String => utf8.toString()
270-
}
281+
override def toScala(catalystValue: UTF8String): String =
282+
if (catalystValue == null) null else catalystValue.toString
271283
override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString
272284
}
273285

274286
private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
275287
override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
276288
override def toScala(catalystValue: Any): Date =
277289
if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
278-
override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column))
290+
override def toScalaImpl(row: InternalRow, column: Int): Date =
291+
DateTimeUtils.toJavaDate(row.getInt(column))
279292
}
280293

281294
private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
@@ -285,7 +298,7 @@ object CatalystTypeConverters {
285298
if (catalystValue == null) null
286299
else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
287300
override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
288-
toScala(row.getLong(column))
301+
DateTimeUtils.toJavaTimestamp(row.getLong(column))
289302
}
290303

291304
private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
@@ -296,10 +309,7 @@ object CatalystTypeConverters {
296309
}
297310
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
298311
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
299-
row.get(column) match {
300-
case d: JavaBigDecimal => d
301-
case d: Decimal => d.toJavaBigDecimal
302-
}
312+
row.get(column).asInstanceOf[Decimal].toJavaBigDecimal
303313
}
304314

305315
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
@@ -362,6 +372,19 @@ object CatalystTypeConverters {
362372
}
363373
}
364374

375+
/**
376+
* Creates a converter function that will convert Catalyst types to Scala type.
377+
* Typical use case would be converting a collection of rows that have the same schema. You will
378+
* call this function once to get a converter, and apply it to every row.
379+
*/
380+
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
381+
if (isPrimitive(dataType)) {
382+
identity
383+
} else {
384+
getConverterForType(dataType).toScala
385+
}
386+
}
387+
365388
/**
366389
* Converts Scala objects to Catalyst rows / types.
367390
*
@@ -389,15 +412,6 @@ object CatalystTypeConverters {
389412
* produced by createToScalaConverter.
390413
*/
391414
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
392-
getConverterForType(dataType).toScala(catalystValue)
393-
}
394-
395-
/**
396-
* Creates a converter function that will convert Catalyst types to Scala type.
397-
* Typical use case would be converting a collection of rows that have the same schema. You will
398-
* call this function once to get a converter, and apply it to every row.
399-
*/
400-
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
401-
getConverterForType(dataType).toScala
415+
createToScalaConverter(dataType)(catalystValue)
402416
}
403417
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst
2120
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2221
import org.apache.spark.sql.types.DataType
2322

@@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
3938
(1 to 22).map { x =>
4039
val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
4140
val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _)
42-
lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
41+
val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
4342
val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _)
4443
4544
s"""case $x =>

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1418,12 +1418,14 @@ class DataFrame private[sql](
14181418
lazy val rdd: RDD[Row] = {
14191419
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
14201420
val schema = this.schema
1421-
queryExecution.executedPlan.execute().mapPartitions { rows =>
1421+
internalRowRdd.mapPartitions { rows =>
14221422
val converter = CatalystTypeConverters.createToScalaConverter(schema)
14231423
rows.map(converter(_).asInstanceOf[Row])
14241424
}
14251425
}
14261426

1427+
private[sql] def internalRowRdd = queryExecution.executedPlan.execute()
1428+
14271429
/**
14281430
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
14291431
* @group rdd

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging {
9090
(name, originalSchema.fields(index).dataType)
9191
}
9292

93-
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
93+
val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)(
9494
seqOp = (counts, row) => {
9595
var i = 0
9696
while (i < numCols) {

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging {
8181
s"with dataType ${data.get.dataType} not supported.")
8282
}
8383
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
84-
df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
84+
df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)(
8585
seqOp = (counter, row) => {
8686
counter.add(row.getDouble(0), row.getDouble(1))
8787
},

sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
313313
output: Seq[Attribute],
314314
rdd: RDD[Row]): RDD[InternalRow] = {
315315
if (relation.relation.needConversion) {
316-
execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
316+
execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
317317
} else {
318318
rdd.map(_.asInstanceOf[InternalRow])
319319
}

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
154154
writerContainer.driverSideSetup()
155155

156156
try {
157-
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
157+
df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
158158
writerContainer.commitJob()
159159
relation.refresh()
160160
} catch { case cause: Throwable =>
@@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
220220
writerContainer.driverSideSetup()
221221

222222
try {
223-
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
223+
df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
224224
writerContainer.commitJob()
225225
relation.refresh()
226226
} catch { case cause: Throwable =>

sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ case class AllDataTypesScan(
8888
UTF8String.fromString(s"varchar_$i"),
8989
Seq(i, i + 1),
9090
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
91-
Map(i -> i.toString),
91+
Map(i -> UTF8String.fromString(i.toString)),
9292
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
93-
Row(i, i.toString),
93+
Row(i, UTF8String.fromString(i.toString)),
9494
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
9595
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
9696
}

0 commit comments

Comments
 (0)