Skip to content

Commit 5623c7a

Browse files
committed
Reimplement UnsafeExternalRowSorter in database style
1 parent df00b5c commit 5623c7a

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.io.IOException;
2121
import java.util.function.Supplier;
2222

23-
import scala.collection.AbstractIterator;
2423
import scala.collection.Iterator;
2524
import scala.math.Ordering;
2625

@@ -168,39 +167,40 @@ public void cleanupResources() {
168167
sorter.cleanupResources();
169168
}
170169

171-
public Iterator<UnsafeRow> sort() throws IOException {
170+
public Iterator<InternalRow> sort() throws IOException {
172171
try {
173172
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
174173
if (!sortedIterator.hasNext()) {
175174
// Since we won't ever call next() on an empty iterator, we need to clean up resources
176175
// here in order to prevent memory leaks.
177176
cleanupResources();
178177
}
179-
return new AbstractIterator<UnsafeRow>() {
178+
return new RowIterator() {
180179

181180
private final int numFields = schema.length();
182181
private UnsafeRow row = new UnsafeRow(numFields);
183182

184183
@Override
185-
public boolean hasNext() {
186-
return !isReleased && sortedIterator.hasNext();
187-
}
188-
189-
@Override
190-
public UnsafeRow next() {
184+
public boolean advanceNext() {
191185
try {
192-
sortedIterator.loadNext();
193-
row.pointTo(
194-
sortedIterator.getBaseObject(),
195-
sortedIterator.getBaseOffset(),
196-
sortedIterator.getRecordLength());
197-
if (!hasNext()) {
198-
UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page
199-
row = null; // so that we don't keep references to the base object
200-
cleanupResources();
201-
return copy;
186+
if (!isReleased && sortedIterator.hasNext()) {
187+
sortedIterator.loadNext();
188+
row.pointTo(
189+
sortedIterator.getBaseObject(),
190+
sortedIterator.getBaseOffset(),
191+
sortedIterator.getRecordLength());
192+
// Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug
193+
// when returning the last row from an iterator. For example, in
194+
// [[GroupedIterator]], we still use the last row after traversing the iterator
195+
// in `fetchNextGroupIterator`
196+
if (!sortedIterator.hasNext()) {
197+
row = row.copy(); // so that we don't have dangling pointers to freed page
198+
cleanupResources();
199+
}
200+
return true;
202201
} else {
203-
return row;
202+
row = null; // so that we don't keep references to the base object
203+
return false;
204204
}
205205
} catch (IOException e) {
206206
cleanupResources();
@@ -210,14 +210,18 @@ public UnsafeRow next() {
210210
}
211211
throw new RuntimeException("Exception should have been re-thrown in next()");
212212
}
213-
};
213+
214+
@Override
215+
public UnsafeRow getRow() { return row; }
216+
217+
}.toScala();
214218
} catch (IOException e) {
215219
cleanupResources();
216220
throw e;
217221
}
218222
}
219223

220-
public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
224+
public Iterator<InternalRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
221225
while (inputIterator.hasNext()) {
222226
insertRow(inputIterator.next());
223227
}

0 commit comments

Comments
 (0)