Skip to content

Commit c36ca65

Browse files
committed
[SPARK-15351][SQL] RowEncoder should support array as the external type for ArrayType
## What changes were proposed in this pull request? This PR improves `RowEncoder` and `MapObjects`, to support array as the external type for `ArrayType`. The idea is straightforward, we use `Object` as the external input type for `ArrayType`, and determine its type at runtime in `MapObjects`. ## How was this patch tested? new test in `RowEncoderSuite` Author: Wenchen Fan <[email protected]> Closes #13138 from cloud-fan/map-object.
1 parent 122302c commit c36ca65

File tree

5 files changed

+92
-55
lines changed

5 files changed

+92
-55
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ trait Row extends Serializable {
151151
* BinaryType -> byte array
152152
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
153153
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
154-
* StructType -> org.apache.spark.sql.Row (or Product)
154+
* StructType -> org.apache.spark.sql.Row
155155
* }}}
156156
*/
157157
def apply(i: Int): Any = get(i)
@@ -176,7 +176,7 @@ trait Row extends Serializable {
176176
* BinaryType -> byte array
177177
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
178178
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
179-
* StructType -> org.apache.spark.sql.Row (or Product)
179+
* StructType -> org.apache.spark.sql.Row
180180
* }}}
181181
*/
182182
def get(i: Int): Any

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,26 @@ import org.apache.spark.unsafe.types.UTF8String
3232
/**
3333
* A factory for constructing encoders that convert external row to/from the Spark SQL
3434
* internal binary representation.
35+
*
36+
* The following is a mapping between Spark SQL types and its allowed external types:
37+
* {{{
38+
* BooleanType -> java.lang.Boolean
39+
* ByteType -> java.lang.Byte
40+
* ShortType -> java.lang.Short
41+
* IntegerType -> java.lang.Integer
42+
* FloatType -> java.lang.Float
43+
* DoubleType -> java.lang.Double
44+
* StringType -> String
45+
* DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal
46+
*
47+
* DateType -> java.sql.Date
48+
* TimestampType -> java.sql.Timestamp
49+
*
50+
* BinaryType -> byte array
51+
* ArrayType -> scala.collection.Seq or Array
52+
* MapType -> scala.collection.Map
53+
* StructType -> org.apache.spark.sql.Row or Product
54+
* }}}
3555
*/
3656
object RowEncoder {
3757
def apply(schema: StructType): ExpressionEncoder[Row] = {
@@ -166,6 +186,8 @@ object RowEncoder {
166186
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
167187
// as java.lang.Object.
168188
case _: DecimalType => ObjectType(classOf[java.lang.Object])
189+
// In order to support both Array and Seq in external row, we make this as java.lang.Object.
190+
case _: ArrayType => ObjectType(classOf[java.lang.Object])
169191
case _ => externalDataTypeFor(dt)
170192
}
171193

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -376,45 +376,6 @@ case class MapObjects private(
376376
lambdaFunction: Expression,
377377
inputData: Expression) extends Expression with NonSQLExpression {
378378

379-
@tailrec
380-
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
381-
case NullType =>
382-
val nullTypeClassName = NullType.getClass.getName + ".MODULE$"
383-
(i: String) => s".get($i, $nullTypeClassName)"
384-
case IntegerType => (i: String) => s".getInt($i)"
385-
case LongType => (i: String) => s".getLong($i)"
386-
case FloatType => (i: String) => s".getFloat($i)"
387-
case DoubleType => (i: String) => s".getDouble($i)"
388-
case ByteType => (i: String) => s".getByte($i)"
389-
case ShortType => (i: String) => s".getShort($i)"
390-
case BooleanType => (i: String) => s".getBoolean($i)"
391-
case StringType => (i: String) => s".getUTF8String($i)"
392-
case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
393-
case a: ArrayType => (i: String) => s".getArray($i)"
394-
case _: MapType => (i: String) => s".getMap($i)"
395-
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
396-
case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
397-
case DateType => (i: String) => s".getInt($i)"
398-
}
399-
400-
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
401-
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
402-
(".size()", (i: String) => s".apply($i)", false)
403-
case ObjectType(cls) if cls.isArray =>
404-
(".length", (i: String) => s"[$i]", false)
405-
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
406-
(".size()", (i: String) => s".get($i)", false)
407-
case ArrayType(t, _) =>
408-
val (sqlType, primitiveElement) = t match {
409-
case m: MapType => (m, false)
410-
case s: StructType => (s, false)
411-
case s: StringType => (s, false)
412-
case udt: UserDefinedType[_] => (udt.sqlType, false)
413-
case o => (o, true)
414-
}
415-
(".numElements()", itemAccessorMethod(sqlType), primitiveElement)
416-
}
417-
418379
override def nullable: Boolean = true
419380

420381
override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
@@ -425,7 +386,6 @@ case class MapObjects private(
425386
override def dataType: DataType = ArrayType(lambdaFunction.dataType)
426387

427388
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
428-
val javaType = ctx.javaType(dataType)
429389
val elementJavaType = ctx.javaType(loopVar.dataType)
430390
ctx.addMutableState("boolean", loopVar.isNull, "")
431391
ctx.addMutableState(elementJavaType, loopVar.value, "")
@@ -448,27 +408,61 @@ case class MapObjects private(
448408
s"new $convertedType[$dataLength]"
449409
}
450410

451-
val loopNullCheck = if (primitiveElement) {
452-
s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
453-
} else {
454-
s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
411+
// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
412+
// of input collection at runtime for this case.
413+
val seq = ctx.freshName("seq")
414+
val array = ctx.freshName("array")
415+
val determineCollectionType = inputData.dataType match {
416+
case ObjectType(cls) if cls == classOf[Object] =>
417+
val seqClass = classOf[Seq[_]].getName
418+
s"""
419+
$seqClass $seq = null;
420+
$elementJavaType[] $array = null;
421+
if (${genInputData.value}.getClass().isArray()) {
422+
$array = ($elementJavaType[]) ${genInputData.value};
423+
} else {
424+
$seq = ($seqClass) ${genInputData.value};
425+
}
426+
"""
427+
case _ => ""
428+
}
429+
430+
431+
val (getLength, getLoopVar) = inputData.dataType match {
432+
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
433+
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
434+
case ObjectType(cls) if cls.isArray =>
435+
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
436+
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
437+
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
438+
case ArrayType(et, _) =>
439+
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
440+
case ObjectType(cls) if cls == classOf[Object] =>
441+
s"$seq == null ? $array.length : $seq.size()" ->
442+
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
443+
}
444+
445+
val loopNullCheck = inputData.dataType match {
446+
case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
447+
// The element of primitive array will never be null.
448+
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
449+
s"${loopVar.isNull} = false"
450+
case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
455451
}
456452

457453
val code = s"""
458454
${genInputData.code}
455+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
459456

460-
boolean ${ev.isNull} = ${genInputData.value} == null;
461-
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
462-
463-
if (!${ev.isNull}) {
457+
if (!${genInputData.isNull}) {
458+
$determineCollectionType
464459
$convertedType[] $convertedArray = null;
465-
int $dataLength = ${genInputData.value}$lengthFunction;
460+
int $dataLength = $getLength;
466461
$convertedArray = $arrayConstructor;
467462

468463
int $loopIndex = 0;
469464
while ($loopIndex < $dataLength) {
470-
${loopVar.value} =
471-
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
465+
${loopVar.value} = ($elementJavaType) ($getLoopVar);
472466
$loopNullCheck
473467

474468
${genFunction.code}
@@ -481,11 +475,10 @@ case class MapObjects private(
481475
$loopIndex += 1;
482476
}
483477

484-
${ev.isNull} = false;
485478
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
486479
}
487480
"""
488-
ev.copy(code = code)
481+
ev.copy(code = code, isNull = genInputData.isNull)
489482
}
490483
}
491484

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
3737
def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq)
3838
def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq)
3939

40+
def this(seqOrArray: Any) = this(seqOrArray match {
41+
case seq: Seq[Any] => seq
42+
case array: Array[_] => array.toSeq
43+
})
44+
4045
override def copy(): ArrayData = new GenericArrayData(array.clone())
4146

4247
override def numElements(): Int = array.length

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,23 @@ class RowEncoderSuite extends SparkFunSuite {
185185
assert(encoder.serializer.head.nullable == false)
186186
}
187187

188+
test("RowEncoder should support array as the external type for ArrayType") {
189+
val schema = new StructType()
190+
.add("array", ArrayType(IntegerType))
191+
.add("nestedArray", ArrayType(ArrayType(StringType)))
192+
.add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType))))
193+
val encoder = RowEncoder(schema)
194+
val input = Row(
195+
Array(1, 2, null),
196+
Array(Array("abc", null), null),
197+
Array(Seq(Array(0L, null), null), null))
198+
val row = encoder.toRow(input)
199+
val convertedBack = encoder.fromRow(row)
200+
assert(convertedBack.getSeq(0) == Seq(1, 2, null))
201+
assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null))
202+
assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
203+
}
204+
188205
private def encodeDecodeTest(schema: StructType): Unit = {
189206
test(s"encode/decode: ${schema.simpleString}") {
190207
val encoder = RowEncoder(schema)

0 commit comments

Comments
 (0)