Skip to content

Commit a93fd4b

Browse files
committed
cogengen FromUnsafe
1 parent e57d6b5 commit a93fd4b

File tree

5 files changed

+96
-109
lines changed

5 files changed

+96
-109
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala

Lines changed: 0 additions & 70 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,7 @@ object FromUnsafeProjection {
152152
*/
153153
def apply(fields: Seq[DataType]): Projection = {
154154
create(fields.zipWithIndex.map(x => {
155-
val b = new BoundReference(x._2, x._1, true)
156-
// todo: this is quite slow, maybe remove this whole projection after remove generic getter of
157-
// InternalRow?
158-
b.dataType match {
159-
case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
160-
case _ => b
161-
}
155+
new BoundReference(x._2, x._1, true)
162156
}))
163157
}
164158

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,7 @@ class CodeGenContext {
137137
dataType match {
138138
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
139139
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
140-
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
141-
case StringType => s"$row.update($ordinal, $value.clone())"
140+
case StringType => s"$row.update($ordinal, $value)"
142141
case _ => s"$row.update($ordinal, $value)"
143142
}
144143
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
24-
import org.apache.spark.sql.types.{StringType, StructType, DataType}
24+
import org.apache.spark.sql.types._
2525

2626

2727
/**
@@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
3636
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
3737
in.map(BindReferences.bindReference(_, inputSchema))
3838

39-
private def genUpdater(
39+
private def createCodeForStruct(
4040
ctx: CodeGenContext,
41-
setter: String,
42-
dataType: DataType,
43-
ordinal: Int,
44-
value: String): String = {
45-
dataType match {
46-
case struct: StructType =>
47-
val rowTerm = ctx.freshName("row")
48-
val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
49-
val colTerm = ctx.freshName("col")
50-
s"""
51-
if ($value.isNullAt($i)) {
52-
$rowTerm.setNullAt($i);
53-
} else {
54-
${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
55-
${genUpdater(ctx, rowTerm, dt, i, colTerm)};
56-
}
57-
"""
58-
}.mkString("\n")
59-
s"""
60-
$genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
61-
$updates
62-
$setter.update($ordinal, $rowTerm.copy());
63-
"""
64-
case _ =>
65-
ctx.setColumn(setter, dataType, ordinal, value)
66-
}
41+
input: String,
42+
schema: StructType): GeneratedExpressionCode = {
43+
val tmp = ctx.freshName("tmp")
44+
val output = ctx.freshName("safeRow")
45+
val values = ctx.freshName("values")
46+
val rowClass = classOf[GenericInternalRow].getName
47+
48+
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
49+
val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
50+
s"""
51+
if (!$tmp.isNullAt($i)) {
52+
${converter.code}
53+
$values[$i] = ${converter.primitive};
54+
}
55+
"""
56+
}.mkString("\n")
57+
58+
val code = s"""
59+
final InternalRow $tmp = $input;
60+
final Object[] $values = new Object[${schema.length}];
61+
$fieldWriters
62+
final InternalRow $output = new $rowClass($values);
63+
"""
64+
65+
GeneratedExpressionCode(code, "false", output)
66+
}
67+
68+
private def createCodeForArray(
69+
ctx: CodeGenContext,
70+
input: String,
71+
elementType: DataType): GeneratedExpressionCode = {
72+
val tmp = ctx.freshName("tmp")
73+
val output = ctx.freshName("safeArray")
74+
val values = ctx.freshName("values")
75+
val numElements = ctx.freshName("numElements")
76+
val index = ctx.freshName("index")
77+
val arrayClass = classOf[GenericArrayData].getName
78+
79+
val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
80+
val code = s"""
81+
final ArrayData $tmp = $input;
82+
final int $numElements = $tmp.numElements();
83+
final Object[] $values = new Object[$numElements];
84+
for (int $index = 0; $index < $numElements; $index++) {
85+
if (!$tmp.isNullAt($index)) {
86+
${elementConverter.code}
87+
$values[$index] = ${elementConverter.primitive};
88+
}
89+
}
90+
final ArrayData $output = new $arrayClass($values);
91+
"""
92+
93+
GeneratedExpressionCode(code, "false", output)
94+
}
95+
96+
private def createCodeForMap(
97+
ctx: CodeGenContext,
98+
input: String,
99+
keyType: DataType,
100+
valueType: DataType): GeneratedExpressionCode = {
101+
val tmp = ctx.freshName("tmp")
102+
val output = ctx.freshName("safeMap")
103+
val mapClass = classOf[ArrayBasedMapData].getName
104+
105+
val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
106+
val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType)
107+
val code = s"""
108+
final MapData $tmp = $input;
109+
${keyConverter.code}
110+
${valueConverter.code}
111+
final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive});
112+
"""
113+
114+
GeneratedExpressionCode(code, "false", output)
115+
}
116+
117+
private def convertToSafe(
118+
ctx: CodeGenContext,
119+
input: String,
120+
dataType: DataType): GeneratedExpressionCode = dataType match {
121+
case s: StructType => createCodeForStruct(ctx, input, s)
122+
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
123+
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
124+
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
125+
case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
126+
case _ => GeneratedExpressionCode("", "false", input)
67127
}
68128

69129
protected def create(expressions: Seq[Expression]): Projection = {
@@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
72132
case (NoOp, _) => ""
73133
case (e, i) =>
74134
val evaluationCode = e.gen(ctx)
135+
val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType)
75136
evaluationCode.code +
76137
s"""
77138
if (${evaluationCode.isNull}) {
78139
mutableRow.setNullAt($i);
79140
} else {
80-
${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
141+
${converter.code}
142+
${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)};
81143
}
82144
"""
83145
}

sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode {
112112

113113
override protected def doExecute(): RDD[InternalRow] = {
114114
child.execute().mapPartitions { iter =>
115-
// cache all strings to make sure we have deep copied UTF8String inside incoming
115+
// This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some
116+
// values gotten from the incoming rows.
117+
// we cache all strings here to make sure we have deep copied UTF8String inside incoming
116118
// safe InternalRow.
117119
val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
118120
iter.foreach { row =>

0 commit comments

Comments
 (0)