Skip to content

Commit 326c82c

Browse files
committed
optimize type converter
1 parent c337844 commit 326c82c

File tree

5 files changed

+50
-32
lines changed

5 files changed

+50
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ object CatalystTypeConverters {
5252
}
5353
}
5454

55+
private def isWholePrimitive(dt: DataType): Boolean = dt match {
56+
case dt if isPrimitive(dt) => true
57+
case ArrayType(elementType, _) => isWholePrimitive(elementType)
58+
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
59+
case _ => false
60+
}
61+
5562
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
5663
val converter = dataType match {
5764
case udt: UserDefinedType[_] => UDTConverter(udt)
@@ -148,6 +155,8 @@ object CatalystTypeConverters {
148155

149156
private[this] val elementConverter = getConverterForType(elementType)
150157

158+
private[this] val isNoChange = isWholePrimitive(elementType)
159+
151160
override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
152161
scalaValue match {
153162
case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
@@ -166,8 +175,10 @@ object CatalystTypeConverters {
166175
override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
167176
if (catalystValue == null) {
168177
null
178+
} else if (isNoChange) {
179+
catalystValue
169180
} else {
170-
catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
181+
catalystValue.map(elementConverter.toScala)
171182
}
172183
}
173184

@@ -183,6 +194,8 @@ object CatalystTypeConverters {
183194
private[this] val keyConverter = getConverterForType(keyType)
184195
private[this] val valueConverter = getConverterForType(valueType)
185196

197+
private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType)
198+
186199
override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
187200
case m: Map[_, _] =>
188201
m.map { case (k, v) =>
@@ -203,6 +216,8 @@ object CatalystTypeConverters {
203216
override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
204217
if (catalystValue == null) {
205218
null
219+
} else if (isNoChange) {
220+
catalystValue
206221
} else {
207222
catalystValue.map { case (k, v) =>
208223
keyConverter.toScala(k) -> valueConverter.toScala(v)
@@ -258,24 +273,22 @@ object CatalystTypeConverters {
258273
toScala(row(column).asInstanceOf[InternalRow])
259274
}
260275

261-
private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
276+
private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
262277
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
263278
case str: String => UTF8String.fromString(str)
264279
case utf8: UTF8String => utf8
265280
}
266-
override def toScala(catalystValue: Any): String = catalystValue match {
267-
case null => null
268-
case str: String => str
269-
case utf8: UTF8String => utf8.toString()
270-
}
281+
override def toScala(catalystValue: UTF8String): String =
282+
if (catalystValue == null) null else catalystValue.toString
271283
override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString
272284
}
273285

274286
private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
275287
override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
276288
override def toScala(catalystValue: Any): Date =
277289
if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
278-
override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column))
290+
override def toScalaImpl(row: InternalRow, column: Int): Date =
291+
DateTimeUtils.toJavaDate(row.getInt(column))
279292
}
280293

281294
private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
@@ -285,7 +298,7 @@ object CatalystTypeConverters {
285298
if (catalystValue == null) null
286299
else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
287300
override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
288-
toScala(row.getLong(column))
301+
DateTimeUtils.toJavaTimestamp(row.getLong(column))
289302
}
290303

291304
private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
@@ -296,10 +309,7 @@ object CatalystTypeConverters {
296309
}
297310
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
298311
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
299-
row.get(column) match {
300-
case d: JavaBigDecimal => d
301-
case d: Decimal => d.toJavaBigDecimal
302-
}
312+
row.get(column).asInstanceOf[Decimal].toJavaBigDecimal
303313
}
304314

305315
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
@@ -362,6 +372,19 @@ object CatalystTypeConverters {
362372
}
363373
}
364374

375+
/**
376+
* Creates a converter function that will convert Catalyst types to Scala type.
377+
* Typical use case would be converting a collection of rows that have the same schema. You will
378+
* call this function once to get a converter, and apply it to every row.
379+
*/
380+
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
381+
if (isPrimitive(dataType)) {
382+
identity
383+
} else {
384+
getConverterForType(dataType).toScala
385+
}
386+
}
387+
365388
/**
366389
* Converts Scala objects to Catalyst rows / types.
367390
*
@@ -389,15 +412,6 @@ object CatalystTypeConverters {
389412
* produced by createToScalaConverter.
390413
*/
391414
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
392-
getConverterForType(dataType).toScala(catalystValue)
393-
}
394-
395-
/**
396-
* Creates a converter function that will convert Catalyst types to Scala type.
397-
* Typical use case would be converting a collection of rows that have the same schema. You will
398-
* call this function once to get a converter, and apply it to every row.
399-
*/
400-
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
401-
getConverterForType(dataType).toScala
415+
createToScalaConverter(dataType)(catalystValue)
402416
}
403417
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst
2120
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2221
import org.apache.spark.sql.types.DataType
2322

@@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
3938
(1 to 22).map { x =>
4039
val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
4140
val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _)
42-
lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
41+
val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
4342
val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _)
4443
4544
s"""case $x =>

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable.{Map => MutableMap}
2222
import org.apache.spark.Logging
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
25+
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2526
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
2627
import org.apache.spark.sql.{Column, DataFrame}
2728

@@ -110,13 +111,17 @@ private[sql] object FrequentItems extends Logging {
110111
baseCounts
111112
}
112113
)
113-
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
114-
val resultRow = InternalRow(justItems : _*)
114+
115115
// append frequent Items to the column name for easy debugging
116116
val outputCols = colInfo.map { v =>
117117
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
118118
}
119-
val schema = StructType(outputCols).toAttributes
120-
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
119+
val schema = StructType(outputCols)
120+
121+
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
122+
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
123+
val resultRow = converter(InternalRow(justItems : _*)).asInstanceOf[InternalRow]
124+
125+
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, Seq(resultRow)))
121126
}
122127
}

sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
313313
output: Seq[Attribute],
314314
rdd: RDD[Row]): RDD[InternalRow] = {
315315
if (relation.relation.needConversion) {
316-
execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
316+
execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
317317
} else {
318318
rdd.map(_.asInstanceOf[InternalRow])
319319
}

sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ case class AllDataTypesScan(
8888
UTF8String.fromString(s"varchar_$i"),
8989
Seq(i, i + 1),
9090
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
91-
Map(i -> i.toString),
91+
Map(i -> UTF8String.fromString(i.toString)),
9292
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
93-
Row(i, i.toString),
93+
Row(i, UTF8String.fromString(i.toString)),
9494
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
9595
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
9696
}

0 commit comments

Comments
 (0)