Skip to content

Commit 1bd0d5e

Browse files
committed
[SPARK-23935][SQL] Addressing review comments.
1 parent 56ff20a commit 1bd0d5e

File tree

3 files changed

+62
-30
lines changed

3 files changed

+62
-30
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
*/
6363
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {
6464

65+
public static final int WORD_SIZE = 8;
66+
6567
//////////////////////////////////////////////////////////////////////////////
6668
// Static methods
6769
//////////////////////////////////////////////////////////////////////////////

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,40 @@ class CodegenContext {
764764
""".stripMargin
765765
}
766766

767+
/**
768+
* Generates code creating a [[UnsafeArrayData]]. The generated code executes
769+
* a provided fallback when the size of backing array would exceed the array size limit.
770+
* @param arrayName a name of the array to create
771+
* @param numElements a piece of code representing the number of elements the array should contain
772+
* @param elementSize a size of an element in bytes
773+
* @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
774+
* and getting the backing array as a parameter
775+
* @param fallbackCode a piece of code executed when the array size limit is exceeded
776+
*/
777+
def createUnsafeArrayWithFallback(
778+
arrayName: String,
779+
numElements: String,
780+
elementSize: Int,
781+
bodyCode: String => String,
782+
fallbackCode: String): String = {
783+
val arraySize = freshName("size")
784+
val arrayBytes = freshName("arrayBytes")
785+
s"""
786+
|final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
787+
| $numElements,
788+
| $elementSize);
789+
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
790+
| $fallbackCode
791+
|} else {
792+
| final byte[] $arrayBytes = new byte[(int)$arraySize];
793+
| UnsafeArrayData $arrayName = new UnsafeArrayData();
794+
| Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
795+
| $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
796+
| ${bodyCode(arrayBytes)}
797+
|}
798+
""".stripMargin
799+
}
800+
767801
/**
768802
* Generates code to do null safe execution, i.e. only execute the code when the input is not
769803
* null by adding null check if necessary.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,14 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
195195
values: String,
196196
arrayData: String,
197197
numElements: String): String = {
198-
val byteArraySize = ctx.freshName("byteArraySize")
199-
val data = ctx.freshName("byteArray")
200198
val unsafeRow = ctx.freshName("unsafeRow")
201199
val unsafeArrayData = ctx.freshName("unsafeArrayData")
202200
val structsOffset = ctx.freshName("structsOffset")
203-
val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
204201
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
205202

206203
val baseOffset = Platform.BYTE_ARRAY_OFFSET
207-
val longSize = LongType.defaultSize
208-
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2
204+
val wordSize = UnsafeRow.WORD_SIZE
205+
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
209206
val structSizeAsLong = structSize + "L"
210207
val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
211208
val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
@@ -223,27 +220,26 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
223220
valueAssignment
224221
}
225222

226-
s"""
227-
|final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize});
228-
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
229-
| ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)}
230-
|} else {
231-
| final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize;
232-
| final byte[] $data = new byte[(int)$byteArraySize];
233-
| UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
234-
| Platform.putLong($data, $baseOffset, $numElements);
235-
| $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize);
236-
| UnsafeRow $unsafeRow = new UnsafeRow(2);
237-
| for (int z = 0; z < $numElements; z++) {
238-
| long offset = $structsOffset + z * $structSizeAsLong;
239-
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
240-
| $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
241-
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
242-
| $valueAssignmentChecked
243-
| }
244-
| $arrayData = $unsafeArrayData;
245-
|}
246-
""".stripMargin
223+
val assignmentLoop = (byteArray: String) =>
224+
s"""
225+
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
226+
|UnsafeRow $unsafeRow = new UnsafeRow(2);
227+
|for (int z = 0; z < $numElements; z++) {
228+
| long offset = $structsOffset + z * $structSizeAsLong;
229+
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
230+
| $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
231+
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
232+
| $valueAssignmentChecked
233+
|}
234+
|$arrayData = $unsafeArrayData;
235+
""".stripMargin
236+
237+
ctx.createUnsafeArrayWithFallback(
238+
unsafeArrayData,
239+
numElements,
240+
structSize + wordSize,
241+
assignmentLoop,
242+
genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
247243
}
248244

249245
private def genCodeForAnyElements(
@@ -258,10 +254,10 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
258254

259255
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
260256
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
261-
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
262-
} else {
263-
getValue(values)
264-
}
257+
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
258+
} else {
259+
getValue(values)
260+
}
265261

266262
s"""
267263
|final Object[] $data = new Object[$numElements];

0 commit comments

Comments
 (0)