Skip to content

Commit d13ac55

Browse files
committed
Hacky approach to copying of UnsafeRows for sort followed by limit.
1 parent 845bea3 commit d13ac55

File tree

6 files changed

+104
-44
lines changed

6 files changed

+104
-44
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,11 @@ public UnsafeFixedWidthAggregationMap(
120120
this.bufferPool = new ObjectPool(initialCapacity);
121121

122122
InternalRow initRow = initProjection.apply(emptyRow);
123-
this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
123+
int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
124+
this.emptyBuffer = new byte[emptyBufferSize];
124125
int writtenLength = bufferConverter.writeRow(
125-
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
126+
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
127+
bufferPool);
126128
assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
127129
// re-use the empty buffer only when there is no object saved in pool.
128130
reuseEmptyBuffer = bufferPool.size() == 0;
@@ -142,6 +144,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
142144
groupingKey,
143145
groupingKeyConversionScratchSpace,
144146
PlatformDependent.BYTE_ARRAY_OFFSET,
147+
groupingKeySize,
145148
keyPool);
146149
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
147150

@@ -157,7 +160,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
157160
// There is some objects referenced by emptyBuffer, so generate a new one
158161
InternalRow initRow = initProjection.apply(emptyRow);
159162
bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
160-
bufferPool);
163+
groupingKeySize, bufferPool);
161164
}
162165
loc.putNewKey(
163166
groupingKeyConversionScratchSpace,
@@ -175,6 +178,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
175178
address.getBaseObject(),
176179
address.getBaseOffset(),
177180
bufferConverter.numFields(),
181+
loc.getValueLength(),
178182
bufferPool
179183
);
180184
return currentBuffer;
@@ -214,12 +218,14 @@ public MapEntry next() {
214218
keyAddress.getBaseObject(),
215219
keyAddress.getBaseOffset(),
216220
keyConverter.numFields(),
221+
loc.getKeyLength(),
217222
keyPool
218223
);
219224
entry.value.pointTo(
220225
valueAddress.getBaseObject(),
221226
valueAddress.getBaseOffset(),
222227
bufferConverter.numFields(),
228+
loc.getValueLength(),
223229
bufferPool
224230
);
225231
return entry;

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,15 @@ public final class UnsafeRow extends MutableRow {
6363

6464
public Object getBaseObject() { return baseObject; }
6565
public long getBaseOffset() { return baseOffset; }
66+
public int getSizeInBytes() { return sizeInBytes; }
6667
public ObjectPool getPool() { return pool; }
6768

6869
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
6970
private int numFields;
7071

72+
/** The size of this row's backing data, in bytes) */
73+
private int sizeInBytes;
74+
7175
public int length() { return numFields; }
7276

7377
/** The width of the null tracking bit set, in bytes */
@@ -95,14 +99,17 @@ public UnsafeRow() { }
9599
* @param baseObject the base object
96100
* @param baseOffset the offset within the base object
97101
* @param numFields the number of fields in this row
102+
* @param sizeInBytes the size of this row's backing data, in bytes
98103
* @param pool the object pool to hold arbitrary objects
99104
*/
100-
public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
105+
public void pointTo(
106+
Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
101107
assert numFields >= 0 : "numFields should >= 0";
102108
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
103109
this.baseObject = baseObject;
104110
this.baseOffset = baseOffset;
105111
this.numFields = numFields;
112+
this.sizeInBytes = sizeInBytes;
106113
this.pool = pool;
107114
}
108115

@@ -338,7 +345,23 @@ public double getDouble(int i) {
338345

339346
@Override
340347
public InternalRow copy() {
341-
throw new UnsupportedOperationException();
348+
if (pool != null) {
349+
throw new UnsupportedOperationException(
350+
"Copy is not supported for UnsafeRows that use object pools");
351+
} else {
352+
UnsafeRow rowCopy = new UnsafeRow();
353+
final byte[] rowDataCopy = new byte[sizeInBytes];
354+
PlatformDependent.copyMemory(
355+
baseObject,
356+
baseOffset,
357+
rowDataCopy,
358+
PlatformDependent.BYTE_ARRAY_OFFSET,
359+
sizeInBytes
360+
);
361+
rowCopy.pointTo(
362+
rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
363+
return rowCopy;
364+
}
342365
}
343366

344367
@Override

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ final class UnsafeExternalRowSorter {
5454
private final StructType schema;
5555
private final UnsafeRowConverter rowConverter;
5656
private final PrefixComputer prefixComputer;
57-
private final ObjectPool objPool = new ObjectPool(128);
5857
private final UnsafeExternalSorter sorter;
5958
private byte[] rowConversionBuffer = new byte[1024 * 8];
6059

@@ -77,7 +76,7 @@ public UnsafeExternalRowSorter(
7776
sparkEnv.shuffleMemoryManager(),
7877
sparkEnv.blockManager(),
7978
taskContext,
80-
new RowComparator(ordering, schema.length(), objPool),
79+
new RowComparator(ordering, schema.length(), null),
8180
prefixComparator,
8281
4096,
8382
sparkEnv.conf()
@@ -100,7 +99,7 @@ void insertRow(InternalRow row) throws IOException {
10099
rowConversionBuffer = new byte[sizeRequirement];
101100
}
102101
final int bytesWritten = rowConverter.writeRow(
103-
row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool);
102+
row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null);
104103
assert (bytesWritten == sizeRequirement);
105104
final long prefix = prefixComputer.computePrefix(row);
106105
sorter.insertRecord(
@@ -143,31 +142,18 @@ public boolean hasNext() {
143142
return sortedIterator.hasNext();
144143
}
145144

146-
/**
147-
* Called prior to returning this iterator's last row. This copies the row's data into an
148-
* on-heap byte array so that the pointer to the row data will not be dangling after the
149-
* sorter's memory pages are freed.
150-
*/
151-
private void detachRowFromPage(UnsafeRow row, int rowLength) {
152-
final byte[] rowDataCopy = new byte[rowLength];
153-
PlatformDependent.copyMemory(
154-
row.getBaseObject(),
155-
row.getBaseOffset(),
156-
rowDataCopy,
157-
PlatformDependent.BYTE_ARRAY_OFFSET,
158-
rowLength
159-
);
160-
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, row.getPool());
161-
}
162-
163145
@Override
164146
public InternalRow next() {
165147
try {
166148
sortedIterator.loadNext();
167149
row.pointTo(
168-
sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool);
150+
sortedIterator.getBaseObject(),
151+
sortedIterator.getBaseOffset(),
152+
numFields,
153+
sortedIterator.getRecordLength(),
154+
null);
169155
if (!hasNext()) {
170-
detachRowFromPage(row, sortedIterator.getRecordLength());
156+
row.copy(); // so that we don't have dangling pointers to freed page
171157
cleanupResources();
172158
}
173159
return row;
@@ -198,7 +184,7 @@ public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IO
198184
* Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise.
199185
*/
200186
public static boolean supportsSchema(StructType schema) {
201-
// TODO: add spilling note.
187+
// TODO: add spilling note to explain why we do this for now:
202188
for (StructField field : schema.fields()) {
203189
if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) {
204190
return false;
@@ -222,8 +208,8 @@ public RowComparator(Ordering<InternalRow> ordering, int numFields, ObjectPool o
222208

223209
@Override
224210
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
225-
row1.pointTo(baseObj1, baseOff1, numFields, objPool);
226-
row2.pointTo(baseObj2, baseOff2, numFields, objPool);
211+
row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool);
212+
row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool);
227213
return ordering.compare(row1, row2);
228214
}
229215
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
7070
* @param row the row to convert
7171
* @param baseObject the base object of the destination address
7272
* @param baseOffset the base offset of the destination address
73+
* @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
7374
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
7475
*/
75-
def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
76-
unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
76+
def writeRow(
77+
row: InternalRow,
78+
baseObject: Object,
79+
baseOffset: Long,
80+
rowLengthInBytes: Int,
81+
pool: ObjectPool): Int = {
82+
unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool)
7783

7884
if (writers.length > 0) {
7985
// zero-out the bitset

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
4444
val sizeRequired: Int = converter.getSizeRequirement(row)
4545
assert(sizeRequired === 8 + (3 * 8))
4646
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
47-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
47+
val numBytesWritten =
48+
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
4849
assert(numBytesWritten === sizeRequired)
4950

5051
val unsafeRow = new UnsafeRow()
51-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
52+
unsafeRow.pointTo(
53+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
5254
assert(unsafeRow.getLong(0) === 0)
5355
assert(unsafeRow.getLong(1) === 1)
5456
assert(unsafeRow.getInt(2) === 2)
@@ -73,12 +75,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
7375
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
7476
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
7577
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
76-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
78+
val numBytesWritten = converter.writeRow(
79+
row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
7780
assert(numBytesWritten === sizeRequired)
7881

7982
val unsafeRow = new UnsafeRow()
8083
val pool = new ObjectPool(10)
81-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
84+
unsafeRow.pointTo(
85+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
8286
assert(unsafeRow.getLong(0) === 0)
8387
assert(unsafeRow.getString(1) === "Hello")
8488
assert(unsafeRow.get(2) === "World".getBytes)
@@ -111,12 +115,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
111115
val sizeRequired: Int = converter.getSizeRequirement(row)
112116
assert(sizeRequired === 8 + (8 * 3))
113117
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
114-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
118+
val numBytesWritten =
119+
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool)
115120
assert(numBytesWritten === sizeRequired)
116121
assert(pool.size === 2)
117122

118123
val unsafeRow = new UnsafeRow()
119-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
124+
unsafeRow.pointTo(
125+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
120126
assert(unsafeRow.getLong(0) === 0)
121127
assert(unsafeRow.get(1) === Decimal(1))
122128
assert(unsafeRow.get(2) === Array(2))
@@ -142,11 +148,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
142148
assert(sizeRequired === 8 + (8 * 4) +
143149
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
144150
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
145-
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
151+
val numBytesWritten =
152+
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
146153
assert(numBytesWritten === sizeRequired)
147154

148155
val unsafeRow = new UnsafeRow()
149-
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
156+
unsafeRow.pointTo(
157+
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
150158
assert(unsafeRow.getLong(0) === 0)
151159
assert(unsafeRow.getString(1) === "Hello")
152160
// Date is represented as Int in unsafeRow
@@ -190,12 +198,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
190198
val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
191199
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
192200
val numBytesWritten = converter.writeRow(
193-
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
201+
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
202+
sizeRequired, null)
194203
assert(numBytesWritten === sizeRequired)
195204

196205
val createdFromNull = new UnsafeRow()
197206
createdFromNull.pointTo(
198-
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
207+
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
208+
sizeRequired, null)
199209
for (i <- 0 to fieldTypes.length - 1) {
200210
assert(createdFromNull.isNullAt(i))
201211
}
@@ -233,10 +243,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
233243
val pool = new ObjectPool(1)
234244
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
235245
converter.writeRow(
236-
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
246+
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
247+
sizeRequired, pool)
237248
val setToNullAfterCreation = new UnsafeRow()
238249
setToNullAfterCreation.pointTo(
239-
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
250+
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
251+
sizeRequired, pool)
240252

241253
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
242254
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,33 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
3636
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
3737
}
3838

39+
ignore("sort followed by limit should not leak memory") {
40+
// TODO: this test is going to fail until we implement a proper iterator interface
41+
// with a close() method.
42+
TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
43+
checkAnswer(
44+
(1 to 100).map(v => Tuple1(v)).toDF("a"),
45+
(child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
46+
(child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
47+
sortAnswers = false
48+
)
49+
}
50+
51+
test("sort followed by limit") {
52+
TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
53+
try {
54+
checkAnswer(
55+
(1 to 100).map(v => Tuple1(v)).toDF("a"),
56+
(child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
57+
(child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
58+
sortAnswers = false
59+
)
60+
} finally {
61+
TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
62+
63+
}
64+
}
65+
3966
// Test sorting on different data types
4067
for (
4168
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)

0 commit comments

Comments
 (0)