Skip to content

Commit 40f9dbb

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-31425][SQL][CORE] UnsafeKVExternalSorter/VariableLengthRowBasedKeyValueBatch should also respect UnsafeAlignedOffset
### What changes were proposed in this pull request? Make `UnsafeKVExternalSorter` / `VariableLengthRowBasedKeyValueBatch ` also respect `UnsafeAlignedOffset` when reading the record and update some out of date comemnts. ### Why are the changes needed? Since `BytesToBytesMap` respects `UnsafeAlignedOffset` when writing the record, `UnsafeKVExternalSorter` should also respect `UnsafeAlignedOffset` when reading the record from `BytesToBytesMap` otherwise it will causes data correctness issue. Unlike `UnsafeKVExternalSorter` may reading records from `BytesToBytesMap`, `VariableLengthRowBasedKeyValueBatch` writes and reads records by itself. Thus, similar to apache#22053 and [comment](apache#22053 (comment)) there, fix for `VariableLengthRowBasedKeyValueBatch` more likely an improvement for the support of SPARC platform. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manually tested `HashAggregationQueryWithControlledFallbackSuite` with `UAO_SIZE=8` to simulate SPARC platform. And tests only pass with this fix. Closes apache#28195 from Ngone51/fix_uao. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent b2e9e17 commit 40f9dbb

File tree

6 files changed

+78
-56
lines changed

6 files changed

+78
-56
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,20 @@ public class UnsafeAlignedOffset {
2828

2929
private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8;
3030

31+
private static int TEST_UAO_SIZE = 0;
32+
33+
// used for test only
34+
public static void setUaoSize(int size) {
35+
assert size == 0 || size == 4 || size == 8;
36+
TEST_UAO_SIZE = size;
37+
}
38+
3139
public static int getUaoSize() {
32-
return UAO_SIZE;
40+
return TEST_UAO_SIZE == 0 ? UAO_SIZE : TEST_UAO_SIZE;
3341
}
3442

3543
public static int getSize(Object object, long offset) {
36-
switch (UAO_SIZE) {
44+
switch (getUaoSize()) {
3745
case 4:
3846
return Platform.getInt(object, offset);
3947
case 8:
@@ -46,7 +54,7 @@ public static int getSize(Object object, long offset) {
4654
}
4755

4856
public static void putSize(Object object, long offset, int value) {
49-
switch (UAO_SIZE) {
57+
switch (getUaoSize()) {
5058
case 4:
5159
Platform.putInt(object, offset, value);
5260
break;

core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
* probably be using sorting instead of hashing for better cache locality.
5555
*
5656
* The key and values under the hood are stored together, in the following format:
57-
* Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
58-
* Bytes 4 to 8: len(k)
59-
* Bytes 8 to 8 + len(k): key data
60-
* Bytes 8 + len(k) to 8 + len(k) + len(v): value data
61-
* Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
57+
* First uaoSize bytes: len(k) (key length in bytes) + len(v) (value length in bytes) + uaoSize
58+
* Next uaoSize bytes: len(k)
59+
* Next len(k) bytes: key data
60+
* Next len(v) bytes: value data
61+
* Last 8 bytes: pointer to next pair
6262
*
63-
* This means that the first four bytes store the entire record (key + value) length. This format
63+
* It means first uaoSize bytes store the entire record (key + value + uaoSize) length. This format
6464
* is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
6565
* so we can pass records from this map directly into the sorter to sort records in place.
6666
*/
@@ -706,7 +706,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff
706706
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
707707
// the key address instead of storing the absolute address of the value, the key and value
708708
// must be stored in the same memory page.
709-
// (8 byte key length) (key) (value) (8 byte pointer to next value)
709+
// (total length) (key length) (key) (value) (8 byte pointer to next value)
710710
int uaoSize = UnsafeAlignedOffset.getUaoSize();
711711
final long recordLength = (2L * uaoSize) + klen + vlen + 8;
712712
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ public void expandPointerArray(LongArray newArray) {
235235

236236
/**
237237
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
238-
* stored as a 4-byte integer, followed by the record's bytes.
238+
* stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes.
239239
*
240240
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
241241
* @param keyPrefix a user-defined key prefix

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
import org.apache.spark.memory.TaskMemoryManager;
2020
import org.apache.spark.sql.types.*;
2121
import org.apache.spark.unsafe.Platform;
22+
import org.apache.spark.unsafe.UnsafeAlignedOffset;
2223

2324
/**
2425
* An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths.
2526
*
26-
* The format for each record looks like this:
27+
* The format for each record looks like this (in case of uaoSize = 4):
2728
* [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen]
2829
* [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
2930
* [8 bytes pointer to next]
@@ -41,18 +42,19 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
4142
@Override
4243
public UnsafeRow appendRow(Object kbase, long koff, int klen,
4344
Object vbase, long voff, int vlen) {
44-
final long recordLength = 8L + klen + vlen + 8;
45+
int uaoSize = UnsafeAlignedOffset.getUaoSize();
46+
final long recordLength = 2 * uaoSize + klen + vlen + 8L;
4547
// if run out of max supported rows or page size, return null
4648
if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
4749
return null;
4850
}
4951

5052
long offset = page.getBaseOffset() + pageCursor;
5153
final long recordOffset = offset;
52-
Platform.putInt(base, offset, klen + vlen + 4);
53-
Platform.putInt(base, offset + 4, klen);
54+
UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
55+
UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);
5456

55-
offset += 8;
57+
offset += 2 * uaoSize;
5658
Platform.copyMemory(kbase, koff, base, offset, klen);
5759
offset += klen;
5860
Platform.copyMemory(vbase, voff, base, offset, vlen);
@@ -61,11 +63,11 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen,
6163

6264
pageCursor += recordLength;
6365

64-
keyOffsets[numRows] = recordOffset + 8;
66+
keyOffsets[numRows] = recordOffset + 2 * uaoSize;
6567

6668
keyRowId = numRows;
67-
keyRow.pointTo(base, recordOffset + 8, klen);
68-
valueRow.pointTo(base, recordOffset + 8 + klen, vlen);
69+
keyRow.pointTo(base, recordOffset + 2 * uaoSize, klen);
70+
valueRow.pointTo(base, recordOffset + 2 * uaoSize + klen, vlen);
6971
numRows++;
7072
return valueRow;
7173
}
@@ -79,7 +81,7 @@ public UnsafeRow getKeyRow(int rowId) {
7981
assert(rowId < numRows);
8082
if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
8183
long offset = keyOffsets[rowId];
82-
int klen = Platform.getInt(base, offset - 4);
84+
int klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize());
8385
keyRow.pointTo(base, offset, klen);
8486
// set keyRowId so we can check if desired row is cached
8587
keyRowId = rowId;
@@ -99,9 +101,10 @@ public UnsafeRow getValueFromKey(int rowId) {
99101
getKeyRow(rowId);
100102
}
101103
assert(rowId >= 0);
104+
int uaoSize = UnsafeAlignedOffset.getUaoSize();
102105
long offset = keyRow.getBaseOffset();
103106
int klen = keyRow.getSizeInBytes();
104-
int vlen = Platform.getInt(base, offset - 8) - klen - 4;
107+
int vlen = UnsafeAlignedOffset.getSize(base, offset - uaoSize * 2) - klen - uaoSize;
105108
valueRow.pointTo(base, offset + klen, vlen);
106109
return valueRow;
107110
}
@@ -141,14 +144,15 @@ public boolean next() {
141144
return false;
142145
}
143146

144-
totalLength = Platform.getInt(base, offsetInPage) - 4;
145-
currentklen = Platform.getInt(base, offsetInPage + 4);
147+
int uaoSize = UnsafeAlignedOffset.getUaoSize();
148+
totalLength = UnsafeAlignedOffset.getSize(base, offsetInPage) - uaoSize;
149+
currentklen = UnsafeAlignedOffset.getSize(base, offsetInPage + uaoSize);
146150
currentvlen = totalLength - currentklen;
147151

148-
key.pointTo(base, offsetInPage + 8, currentklen);
149-
value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen);
152+
key.pointTo(base, offsetInPage + 2 * uaoSize, currentklen);
153+
value.pointTo(base, offsetInPage + 2 * uaoSize + currentklen, currentvlen);
150154

151-
offsetInPage += 8 + totalLength + 8;
155+
offsetInPage += 2 * uaoSize + totalLength + 8;
152156
recordsInPage -= 1;
153157
return true;
154158
}

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.apache.spark.storage.BlockManager;
3636
import org.apache.spark.unsafe.KVIterator;
3737
import org.apache.spark.unsafe.Platform;
38+
import org.apache.spark.unsafe.UnsafeAlignedOffset;
3839
import org.apache.spark.unsafe.array.LongArray;
3940
import org.apache.spark.unsafe.map.BytesToBytesMap;
4041
import org.apache.spark.unsafe.memory.MemoryBlock;
@@ -141,9 +142,10 @@ public UnsafeKVExternalSorter(
141142

142143
// Get encoded memory address
143144
// baseObject + baseOffset point to the beginning of the key data in the map, but that
144-
// the KV-pair's length data is stored in the word immediately before that address
145+
// the KV-pair's length data is stored at 2 * uaoSize bytes immediately before that address
145146
MemoryBlock page = loc.getMemoryPage();
146-
long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
147+
long address = taskMemoryManager.encodePageNumberAndOffset(page,
148+
baseOffset - 2 * UnsafeAlignedOffset.getUaoSize());
147149

148150
// Compute prefix
149151
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
@@ -262,10 +264,11 @@ public int compare(
262264
Object baseObj2,
263265
long baseOff2,
264266
int baseLen2) {
267+
int uaoSize = UnsafeAlignedOffset.getUaoSize();
265268
// Note that since ordering doesn't need the total length of the record, we just pass 0
266269
// into the row.
267-
row1.pointTo(baseObj1, baseOff1 + 4, 0);
268-
row2.pointTo(baseObj2, baseOff2 + 4, 0);
270+
row1.pointTo(baseObj1, baseOff1 + uaoSize, 0);
271+
row2.pointTo(baseObj2, baseOff2 + uaoSize, 0);
269272
return ordering.compare(row1, row2);
270273
}
271274
}
@@ -289,11 +292,12 @@ public boolean next() throws IOException {
289292
long recordOffset = underlying.getBaseOffset();
290293
int recordLen = underlying.getRecordLength();
291294

292-
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
295+
// Note that recordLen = keyLen + valueLen + uaoSize (for the keyLen itself)
296+
int uaoSize = UnsafeAlignedOffset.getUaoSize();
293297
int keyLen = Platform.getInt(baseObj, recordOffset);
294-
int valueLen = recordLen - keyLen - 4;
295-
key.pointTo(baseObj, recordOffset + 4, keyLen);
296-
value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen);
298+
int valueLen = recordLen - keyLen - uaoSize;
299+
key.pointTo(baseObj, recordOffset + uaoSize, keyLen);
300+
value.pointTo(baseObj, recordOffset + uaoSize + keyLen, valueLen);
297301

298302
return true;
299303
} else {

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.test.SQLTestUtils
3333
import org.apache.spark.sql.types._
34+
import org.apache.spark.unsafe.UnsafeAlignedOffset
3435

3536

3637
class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
@@ -1055,30 +1056,35 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu
10551056
Seq("true", "false").foreach { enableTwoLevelMaps =>
10561057
withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key ->
10571058
enableTwoLevelMaps) {
1058-
(1 to 3).foreach { fallbackStartsAt =>
1059-
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
1060-
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
1061-
// Create a new df to make sure its physical operator picks up
1062-
// spark.sql.TungstenAggregate.testFallbackStartsAt.
1063-
// todo: remove it?
1064-
val newActual = Dataset.ofRows(spark, actual.logicalPlan)
1065-
1066-
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
1067-
case Some(errorMessage) =>
1068-
val newErrorMessage =
1069-
s"""
1070-
|The following aggregation query failed when using HashAggregate with
1071-
|controlled fallback (it falls back to bytes to bytes map once it has processed
1072-
|${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has
1073-
|processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
1074-
|
1075-
|$errorMessage
1076-
""".stripMargin
1077-
1078-
fail(newErrorMessage)
1079-
case None => // Success
1059+
Seq(4, 8).foreach { uaoSize =>
1060+
UnsafeAlignedOffset.setUaoSize(uaoSize)
1061+
(1 to 3).foreach { fallbackStartsAt =>
1062+
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
1063+
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
1064+
// Create a new df to make sure its physical operator picks up
1065+
// spark.sql.TungstenAggregate.testFallbackStartsAt.
1066+
// todo: remove it?
1067+
val newActual = Dataset.ofRows(spark, actual.logicalPlan)
1068+
1069+
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
1070+
case Some(errorMessage) =>
1071+
val newErrorMessage =
1072+
s"""
1073+
|The following aggregation query failed when using HashAggregate with
1074+
|controlled fallback (it falls back to bytes to bytes map once it has
1075+
|processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation
1076+
|once it has processed $fallbackStartsAt input rows).
1077+
|The query is ${actual.queryExecution}
1078+
|$errorMessage
1079+
""".stripMargin
1080+
1081+
fail(newErrorMessage)
1082+
case None => // Success
1083+
}
10801084
}
10811085
}
1086+
// reset static uaoSize to avoid affect other tests
1087+
UnsafeAlignedOffset.setUaoSize(0)
10821088
}
10831089
}
10841090
}

0 commit comments

Comments
 (0)