2020import java .io .IOException ;
2121import java .util .LinkedList ;
2222
23+ import scala .runtime .AbstractFunction0 ;
24+ import scala .runtime .BoxedUnit ;
25+
2326import com .google .common .annotations .VisibleForTesting ;
2427import org .slf4j .Logger ;
2528import org .slf4j .LoggerFactory ;
@@ -41,10 +44,7 @@ public final class UnsafeExternalSorter {
4144
4245 private final Logger logger = LoggerFactory .getLogger (UnsafeExternalSorter .class );
4346
44- private static final int PAGE_SIZE = 1 << 27 ; // 128 megabytes
45- @ VisibleForTesting
46- static final int MAX_RECORD_SIZE = PAGE_SIZE - 4 ;
47-
47+ private final long pageSizeBytes ;
4848 private final PrefixComparator prefixComparator ;
4949 private final RecordComparator recordComparator ;
5050 private final int initialSize ;
@@ -91,7 +91,19 @@ public UnsafeExternalSorter(
9191 this .initialSize = initialSize ;
9292 // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
9393 this .fileBufferSizeBytes = (int ) conf .getSizeAsKb ("spark.shuffle.file.buffer" , "32k" ) * 1024 ;
94+ this .pageSizeBytes = conf .getSizeAsBytes ("spark.buffer.pageSize" , "64m" );
9495 initializeForWriting ();
96+
97+ // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
98+ // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
99+ // does not fully consume the sorter's output (e.g. sort followed by limit).
100+ taskContext .addOnCompleteCallback (new AbstractFunction0 <BoxedUnit >() {
101+ @ Override
102+ public BoxedUnit apply () {
103+ freeMemory ();
104+ return null ;
105+ }
106+ });
95107 }
96108
97109 // TODO: metrics tracking + integration with shuffle write metrics
@@ -147,7 +159,11 @@ public void spill() throws IOException {
147159 }
148160
149161 private long getMemoryUsage () {
150- return sorter .getMemoryUsage () + (allocatedPages .size () * (long ) PAGE_SIZE );
162+ long totalPageSize = 0 ;
163+ for (MemoryBlock page : allocatedPages ) {
164+ totalPageSize += page .size ();
165+ }
166+ return sorter .getMemoryUsage () + totalPageSize ;
151167 }
152168
153169 @ VisibleForTesting
@@ -214,23 +230,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
214230 // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
215231 // without using the free space at the end of the current page. We should also do this for
216232 // BytesToBytesMap.
217- if (requiredSpace > PAGE_SIZE ) {
233+ if (requiredSpace > pageSizeBytes ) {
218234 throw new IOException ("Required space " + requiredSpace + " is greater than page size (" +
219- PAGE_SIZE + ")" );
235+ pageSizeBytes + ")" );
220236 } else {
221- final long memoryAcquired = shuffleMemoryManager .tryToAcquire (PAGE_SIZE );
222- if (memoryAcquired < PAGE_SIZE ) {
237+ final long memoryAcquired = shuffleMemoryManager .tryToAcquire (pageSizeBytes );
238+ if (memoryAcquired < pageSizeBytes ) {
223239 shuffleMemoryManager .release (memoryAcquired );
224240 spill ();
225- final long memoryAcquiredAfterSpilling = shuffleMemoryManager .tryToAcquire (PAGE_SIZE );
226- if (memoryAcquiredAfterSpilling != PAGE_SIZE ) {
241+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager .tryToAcquire (pageSizeBytes );
242+ if (memoryAcquiredAfterSpilling != pageSizeBytes ) {
227243 shuffleMemoryManager .release (memoryAcquiredAfterSpilling );
228- throw new IOException ("Unable to acquire " + PAGE_SIZE + " bytes of memory" );
244+ throw new IOException ("Unable to acquire " + pageSizeBytes + " bytes of memory" );
229245 }
230246 }
231- currentPage = memoryManager .allocatePage (PAGE_SIZE );
247+ currentPage = memoryManager .allocatePage (pageSizeBytes );
232248 currentPagePosition = currentPage .getBaseOffset ();
233- freeSpaceInCurrentPage = PAGE_SIZE ;
249+ freeSpaceInCurrentPage = pageSizeBytes ;
234250 allocatedPages .add (currentPage );
235251 }
236252 }
0 commit comments