diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java new file mode 100644 index 0000000000000..008799cc77395 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + + +import java.io.IOException; + +import org.apache.spark.unsafe.memory.MemoryBlock; + + +/** + * An memory consumer of TaskMemoryManager, which support spilling. + */ +public abstract class MemoryConsumer { + + private final TaskMemoryManager taskMemoryManager; + private final long pageSize; + private long used; + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { + this.taskMemoryManager = taskMemoryManager; + this.pageSize = pageSize; + this.used = 0; + } + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { + this(taskMemoryManager, taskMemoryManager.pageSizeBytes()); + } + + /** + * Returns the size of used memory in bytes. + */ + long getUsed() { + return used; + } + + /** + * Force spill during building. + * + * For testing. + */ + public void spill() throws IOException { + spill(Long.MAX_VALUE, this); + } + + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + * + * This should be implemented by subclass. + * + * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * + * @param size the amount of memory should be released + * @param trigger the MemoryConsumer that trigger this spilling + * @return the amount of released memory in bytes + * @throws IOException + */ + public abstract long spill(long size, MemoryConsumer trigger) throws IOException; + + /** + * Acquire `size` bytes memory. + * + * If there is not enough memory, throws OutOfMemoryError. + */ + protected void acquireMemory(long size) { + long got = taskMemoryManager.acquireExecutionMemory(size, this); + if (got < size) { + taskMemoryManager.releaseExecutionMemory(got, this); + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + } + used += got; + } + + /** + * Release `size` bytes memory. + */ + protected void releaseMemory(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory(size, this); + } + + /** + * Allocate a memory block with at least `required` bytes. + * + * Throws IOException if there is not enough memory. + * + * @throws OutOfMemoryError + */ + protected MemoryBlock allocatePage(long required) { + MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + freePage(page); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } + used += page.size(); + return page; + } + + /** + * Free a memory block. + */ + protected void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 7b31c90dac666..4230575446d31 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -17,13 +17,18 @@ package org.apache.spark.memory; -import java.util.*; +import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashSet; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; /** * Manages the memory allocated by an individual task. @@ -100,6 +105,12 @@ public class TaskMemoryManager { */ private final boolean inHeap; + /** + * The size of memory granted to each consumer. + */ + @GuardedBy("this") + private final HashSet consumers; + /** * Construct a new TaskMemoryManager. */ @@ -107,23 +118,92 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); this.memoryManager = memoryManager; this.taskAttemptId = taskAttemptId; + this.consumers = new HashSet<>(); } /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call + * spill() of consumers to release more memory. + * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory(long size) { - return memoryManager.acquireExecutionMemory(size, taskAttemptId); + public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + assert(required >= 0); + synchronized (this) { + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId); + + // try to release memory from other consumers first, then we can reduce the frequency of + // spilling, avoid to have too many spilled files. + if (got < required) { + // Call spill() on other consumers to release memory + for (MemoryConsumer c: consumers) { + if (c != null && c != consumer && c.getUsed() > 0) { + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.info("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + if (got >= required) { + break; + } + } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); + } + } + } + } + + // call spill() on itself + if (got < required && consumer != null) { + try { + long released = consumer.spill(required - got, consumer); + if (released > 0) { + logger.info("Task {} released {} from itself ({})", taskAttemptId, + Utils.bytesToString(released), consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + } + } catch (IOException e) { + logger.error("error while calling spill() on " + consumer, e); + throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " + + e.getMessage()); + } + } + + consumers.add(consumer); + logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + return got; + } } /** - * Release N bytes of execution memory. + * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size) { + public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); memoryManager.releaseExecutionMemory(size, taskAttemptId); } + /** + * Dump the memory usage of all consumers. + */ + public void showMemoryUsage() { + logger.info("Memory used in task " + taskAttemptId); + synchronized (this) { + for (MemoryConsumer c: consumers) { + if (c.getUsed() > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed())); + } + } + } + } + + /** + * Return the page size in bytes. + */ public long pageSizeBytes() { return memoryManager.pageSizeBytes(); } @@ -134,42 +214,40 @@ public long pageSizeBytes() { * * Returns `null` if there was not enough memory to allocate the page. */ - public MemoryBlock allocatePage(long size) { + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } + long acquired = acquireExecutionMemory(size, consumer); + if (acquired <= 0) { + return null; + } + final int pageNumber; synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { + releaseExecutionMemory(acquired, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } allocatedPages.set(pageNumber); } - final long acquiredExecutionMemory = acquireExecutionMemory(size); - if (acquiredExecutionMemory != size) { - releaseExecutionMemory(acquiredExecutionMemory); - synchronized (this) { - allocatedPages.clear(pageNumber); - } - return null; - } - final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size); + final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { - logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); } return page; } /** - * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ - public void freePage(MemoryBlock page) { + public void freePage(MemoryBlock page, MemoryConsumer consumer) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; assert(allocatedPages.get(page.pageNumber)); @@ -182,14 +260,14 @@ public void freePage(MemoryBlock page) { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize); + releaseExecutionMemory(pageSize, consumer); } /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. * - * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/ + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/ * @param offsetInPage an offset in this page which incorporates the base offset. In other words, * this should be the value that you would pass as the base offset into an * UNSAFE call (e.g. page.baseOffset() + something). @@ -261,17 +339,17 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { * value can be used to detect memory leaks. */ public long cleanUpAllAllocatedMemory() { - long freedBytes = 0; - for (MemoryBlock page : pageTable) { - if (page != null) { - freedBytes += page.size(); - freePage(page); + synchronized (this) { + Arrays.fill(pageTable, null); + for (MemoryConsumer c: consumers) { + if (c != null && c.getUsed() > 0) { + // In case of failed task, it's normal to see leaked memory + logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + } } + consumers.clear(); } - - freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); - - return freedBytes; + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index f43236f41ae7b..400d8520019b9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -31,15 +31,15 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -58,23 +58,18 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class ShuffleExternalSorter { +final class ShuffleExternalSorter extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - private final int initialSize; private final int numPartitions; - private final int pageSizeBytes; - @VisibleForTesting - final int maxRecordSizeBytes; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; - private long numRecordsInsertedSinceLastSpill = 0; /** Force this sorter to spill when there are this many elements in memory. For testing only */ private final long numElementsForSpillThreshold; @@ -98,8 +93,7 @@ final class ShuffleExternalSorter { // These variables are reset after spilling: @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - private long freeSpaceInCurrentPage = 0; + private long pageCursor = -1; public ShuffleExternalSorter( TaskMemoryManager memoryManager, @@ -108,42 +102,21 @@ public ShuffleExternalSorter( int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) throws IOException { + ShuffleWriteMetrics writeMetrics) { + super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + memoryManager.pageSizeBytes())); this.taskMemoryManager = memoryManager; this.blockManager = blockManager; this.taskContext = taskContext; - this.initialSize = initialSize; - this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); - this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes()); - this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; - initializeForWriting(); - - // preserve first page to ensure that we have at least one page to work with. Otherwise, - // other operators in the same task may starve this sorter (SPARK-9709). - acquireNewPageIfNecessary(pageSizeBytes); - } - - /** - * Allocates new sort data structures. Called when creating the sorter and after each spill. - */ - private void initializeForWriting() throws IOException { - // TODO: move this sizing calculation logic into a static method of sorter: - final long memoryRequested = initialSize * 8L; - final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested); - if (memoryAcquired != memoryRequested) { - taskMemoryManager.releaseExecutionMemory(memoryAcquired); - throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); - } - + acquireMemory(initialSize * 8L); this.inMemSorter = new ShuffleInMemorySorter(initialSize); - numRecordsInsertedSinceLastSpill = 0; + this.peakMemoryUsedBytes = getMemoryUsage(); } /** @@ -242,6 +215,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } + inMemSorter.reset(); + if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -266,9 +241,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { /** * Sort and spill the current records in response to memory pressure. */ - @VisibleForTesting - void spill() throws IOException { - assert(inMemSorter != null); + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) { + return 0L; + } + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -276,13 +254,9 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - - initializeForWriting(); + return spillSize; } private long getMemoryUsage() { @@ -312,18 +286,12 @@ private long freeMemory() { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - taskMemoryManager.freePage(block); memoryFreed += block.size(); - } - if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); + freePage(block); } allocatedPages.clear(); currentPage = null; - currentPagePosition = -1; - freeSpaceInCurrentPage = 0; + pageCursor = 0; return memoryFreed; } @@ -332,16 +300,16 @@ private long freeMemory() { */ public void cleanupResources() { freeMemory(); + if (inMemSorter != null) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(sorterMemoryUsage); + } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); - } } /** @@ -352,16 +320,27 @@ public void cleanupResources() { private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { - logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); - final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray); - if (memoryAcquired < memoryToGrowPointerArray) { - taskMemoryManager.releaseExecutionMemory(memoryAcquired); - spill(); + long used = inMemSorter.getMemoryUsage(); + long needed = used + inMemSorter.getMemoryToExpand(); + try { + acquireMemory(needed); // could trigger spilling + } catch (OutOfMemoryError e) { + // should have trigger spilling + assert(inMemSorter.hasSpaceForAnotherRecord()); + return; + } + // check if spilling is triggered or not + if (inMemSorter.hasSpaceForAnotherRecord()) { + releaseMemory(needed); } else { - inMemSorter.expandPointerArray(); - taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage); + try { + inMemSorter.expandPointerArray(); + releaseMemory(used); + } catch (OutOfMemoryError oom) { + // Just in case that JVM had run out of memory + releaseMemory(needed); + spill(); + } } } } @@ -370,96 +349,46 @@ private void growPointerArrayIfNecessary() throws IOException { * Allocates more memory in order to insert an additional record. This will request additional * memory from the memory manager and spill if the requested memory can not be obtained. * - * @param requiredSpace the required space in the data page, in bytes, including space for storing + * @param required the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records * that exceed the page size are handled via a different code path which uses * special overflow pages). */ - private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { - growPointerArrayIfNecessary(); - if (requiredSpace > freeSpaceInCurrentPage) { - logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, - freeSpaceInCurrentPage); - // TODO: we should track metrics on the amount of space wasted when we roll over to a new page - // without using the free space at the end of the current page. We should also do this for - // BytesToBytesMap. - if (requiredSpace > pageSizeBytes) { - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - pageSizeBytes + ")"); - } else { - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - spill(); - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); - } + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) { + // TODO: try to find space in previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); } } /** * Write a record to the shuffle sorter. */ - public void insertRecord( - Object recordBaseObject, - long recordBaseOffset, - int lengthInBytes, - int partitionId) throws IOException { + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + throws IOException { - if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + // for tests + assert(inMemSorter != null); + if (inMemSorter.numRecords() > numElementsForSpillThreshold) { spill(); } growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. - final int totalSpaceRequired = lengthInBytes + 4; - - // --- Figure out where to insert the new record ---------------------------------------------- - - final MemoryBlock dataPage; - long dataPagePosition; - boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; - if (useOverflowPage) { - long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); - // The record is larger than the page size, so allocate a special overflow page just to hold - // that record. - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - spill(); - overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); - } - } - allocatedPages.add(overflowPage); - dataPage = overflowPage; - dataPagePosition = overflowPage.getBaseOffset(); - } else { - // The record is small enough to fit in a regular data page, but the current page might not - // have enough space to hold it (or no pages have been allocated yet). - acquireNewPageIfNecessary(totalSpaceRequired); - dataPage = currentPage; - dataPagePosition = currentPagePosition; - // Update bookkeeping information - freeSpaceInCurrentPage -= totalSpaceRequired; - currentPagePosition += totalSpaceRequired; - } - final Object dataPageBaseObject = dataPage.getBaseObject(); - - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); - dataPagePosition += 4; - Platform.copyMemory( - recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); - assert(inMemSorter != null); + final int required = length + 4; + acquireNewPageIfNecessary(required); + + assert(currentPage != null); + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + Platform.putInt(base, pageCursor, length); + pageCursor += 4; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); - numRecordsInsertedSinceLastSpill += 1; } /** @@ -475,6 +404,9 @@ public SpillInfo[] closeAndGetSpills() throws IOException { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index a8dee6c6101c1..e630575d1ae19 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -37,33 +37,51 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] pointerArray; + private long[] array; /** * The position in the pointer array where new records can be inserted. */ - private int pointerArrayInsertPosition = 0; + private int pos = 0; public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); - this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(ShuffleSortDataFormat.INSTANCE); + this.array = new long[initialSize]; + this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } - public void expandPointerArray() { - final long[] oldArray = pointerArray; + public int numRecords() { + return pos; + } + + public void reset() { + pos = 0; + } + + private int newLength() { // Guard against overflow: - final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; - pointerArray = new long[newLength]; - System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; + } + + /** + * Returns the memory needed to expand + */ + public long getMemoryToExpand() { + return ((long) (newLength() - array.length)) * 8; + } + + public void expandPointerArray() { + final long[] oldArray = array; + array = new long[newLength()]; + System.arraycopy(oldArray, 0, array, 0, oldArray.length); } public boolean hasSpaceForAnotherRecord() { - return pointerArrayInsertPosition + 1 < pointerArray.length; + return pos < array.length; } public long getMemoryUsage() { - return pointerArray.length * 8L; + return array.length * 8L; } /** @@ -78,15 +96,15 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (pointerArray.length == Integer.MAX_VALUE) { + if (array.length == Integer.MAX_VALUE) { throw new IllegalStateException("Sort pointer array has reached maximum size"); } else { expandPointerArray(); } } - pointerArray[pointerArrayInsertPosition] = + array[pos] = PackedRecordPointer.packPointer(recordPointer, partitionId); - pointerArrayInsertPosition++; + pos++; } /** @@ -118,7 +136,7 @@ public void loadNext() { * Return an iterator over record pointers in sorted order. */ public ShuffleSorterIterator getSortedIterator() { - sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + sorter.sort(array, 0, pos, SORT_COMPARATOR); + return new ShuffleSorterIterator(pos, array); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f6c5c944bd77b..e19b37864293c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -127,12 +127,6 @@ public UnsafeShuffleWriter( open(); } - @VisibleForTesting - public int maxRecordSizeBytes() { - assert(sorter != null); - return sorter.maxRecordSizeBytes; - } - private void updatePeakMemoryUsed() { // sorter can be null if this writer is closed if (sorter != null) { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index f035bdac810bd..e36709c6fc849 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -18,14 +18,20 @@ package org.apache.spark.unsafe.map; import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; -import java.util.List; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -33,7 +39,8 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -54,7 +61,7 @@ * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ -public final class BytesToBytesMap { +public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); @@ -62,27 +69,22 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; - /** - * Special record length that is placed after the last record in a data page. - */ - private static final int END_OF_PAGE_MARKER = -1; - private final TaskMemoryManager taskMemoryManager; /** * A linked list for tracking all allocated data pages so that we can free all of our memory. */ - private final List dataPages = new LinkedList(); + private final LinkedList dataPages = new LinkedList<>(); /** * The data page that will be used to store keys and values for new hashtable entries. When this * page becomes full, a new page will be allocated and this pointer will change to point to that * new page. */ - private MemoryBlock currentDataPage = null; + private MemoryBlock currentPage = null; /** - * Offset into `currentDataPage` that points to the location where new data can be inserted into + * Offset into `currentPage` that points to the location where new data can be inserted into * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; @@ -116,6 +118,11 @@ public final class BytesToBytesMap { // full base addresses in the page table for off-heap mode so that we can reconstruct the full // absolute memory addresses. + /** + * Whether or not the longArray can grow. We will not insert more elements if it's false. + */ + private boolean canGrowArray = true; + /** * A {@link BitSet} used to track location of the map where the key is set. * Size of the bitset should be half of the size of the long array. @@ -164,13 +171,20 @@ public final class BytesToBytesMap { private long peakMemoryUsedBytes = 0L; + private final BlockManager blockManager; + private volatile MapIterator destructiveIterator = null; + private LinkedList spillWriters = new LinkedList<>(); + public BytesToBytesMap( TaskMemoryManager taskMemoryManager, + BlockManager blockManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { + super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; + this.blockManager = blockManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -187,18 +201,13 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); - - // Acquire a new page as soon as we construct the map to ensure that we have at least - // one page to work with. Otherwise, other operators in the same task may starve this - // map (SPARK-9747). - acquireNewPage(); } public BytesToBytesMap( TaskMemoryManager taskMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, initialCapacity, pageSizeBytes, false); } public BytesToBytesMap( @@ -208,6 +217,7 @@ public BytesToBytesMap( boolean enablePerfMetrics) { this( taskMemoryManager, + SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -219,61 +229,153 @@ public BytesToBytesMap( */ public int numElements() { return numElements; } - public static final class BytesToBytesMapIterator implements Iterator { + public final class MapIterator implements Iterator { - private final int numRecords; - private final Iterator dataPagesIterator; + private int numRecords; private final Location loc; private MemoryBlock currentPage = null; - private int currentRecordNumber = 0; + private int recordsInPage = 0; private Object pageBaseObject; private long offsetInPage; // If this iterator destructive or not. When it is true, it frees each page as it moves onto // next one. private boolean destructive = false; - private BytesToBytesMap bmap; + private UnsafeSorterSpillReader reader = null; - private BytesToBytesMapIterator( - int numRecords, Iterator dataPagesIterator, Location loc, - boolean destructive, BytesToBytesMap bmap) { + private MapIterator(int numRecords, Location loc, boolean destructive) { this.numRecords = numRecords; - this.dataPagesIterator = dataPagesIterator; this.loc = loc; this.destructive = destructive; - this.bmap = bmap; - if (dataPagesIterator.hasNext()) { - advanceToNextPage(); + if (destructive) { + destructiveIterator = this; } } private void advanceToNextPage() { - if (destructive && currentPage != null) { - dataPagesIterator.remove(); - this.bmap.taskMemoryManager.freePage(currentPage); + synchronized (this) { + int nextIdx = dataPages.indexOf(currentPage) + 1; + if (destructive && currentPage != null) { + dataPages.remove(currentPage); + freePage(currentPage); + nextIdx --; + } + if (dataPages.size() > nextIdx) { + currentPage = dataPages.get(nextIdx); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); + offsetInPage += 4; + } else { + currentPage = null; + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + try { + reader = spillWriters.getFirst().getReader(blockManager); + recordsInPage = -1; + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + } } - currentPage = dataPagesIterator.next(); - pageBaseObject = currentPage.getBaseObject(); - offsetInPage = currentPage.getBaseOffset(); } @Override public boolean hasNext() { - return currentRecordNumber != numRecords; + if (numRecords == 0) { + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + return numRecords > 0; } @Override public Location next() { - int totalLength = Platform.getInt(pageBaseObject, offsetInPage); - if (totalLength == END_OF_PAGE_MARKER) { + if (recordsInPage == 0) { advanceToNextPage(); - totalLength = Platform.getInt(pageBaseObject, offsetInPage); } - loc.with(currentPage, offsetInPage); - offsetInPage += 4 + totalLength; - currentRecordNumber++; - return loc; + numRecords--; + if (currentPage != null) { + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + loc.with(currentPage, offsetInPage); + offsetInPage += 4 + totalLength; + recordsInPage --; + return loc; + } else { + assert(reader != null); + if (!reader.hasNext()) { + advanceToNextPage(); + } + try { + reader.loadNext(); + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength()); + return loc; + } + } + + public long spill(long numBytes) throws IOException { + synchronized (this) { + if (!destructive || dataPages.size() == 1) { + return 0L; + } + + // TODO: use existing ShuffleWriteMetrics + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + long released = 0L; + while (dataPages.size() > 0) { + MemoryBlock block = dataPages.getLast(); + // The currentPage is used, cannot be released + if (block == currentPage) { + break; + } + + Object base = block.getBaseObject(); + long offset = block.getBaseOffset(); + int numRecords = Platform.getInt(base, offset); + offset += 4; + final UnsafeSorterSpillWriter writer = + new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); + while (numRecords > 0) { + int length = Platform.getInt(base, offset); + writer.write(base, offset + 4, length, 0); + offset += 4 + length; + numRecords--; + } + writer.close(); + spillWriters.add(writer); + + dataPages.removeLast(); + released += block.size(); + freePage(block); + + if (released >= numBytes) { + break; + } + } + + return released; + } } @Override @@ -290,8 +392,8 @@ public void remove() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator iterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + public MapIterator iterator() { + return new MapIterator(numElements, loc, false); } /** @@ -304,8 +406,8 @@ public BytesToBytesMapIterator iterator() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator destructiveIterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); + public MapIterator destructiveIterator() { + return new MapIterator(numElements, loc, true); } /** @@ -314,11 +416,8 @@ public BytesToBytesMapIterator destructiveIterator() { * * This function always return the same {@link Location} instance to avoid object allocation. */ - public Location lookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes) { - safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + public Location lookup(Object keyBase, long keyOffset, int keyLength) { + safeLookup(keyBase, keyOffset, keyLength, loc); return loc; } @@ -327,18 +426,14 @@ public Location lookup( * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes, - Location loc) { + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { assert(bitset != null); assert(longArray != null); if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); int pos = hashcode & mask; int step = 1; while (true) { @@ -354,16 +449,16 @@ public void safeLookup( if ((int) (stored) == hashcode) { // Full hash code matches. Let's compare the keys for equality. loc.with(pos, hashcode, true); - if (loc.getKeyLength() == keyRowLengthBytes) { + if (loc.getKeyLength() == keyLength) { final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedKeyBaseObject = keyAddress.getBaseObject(); - final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final Object storedkeyBase = keyAddress.getBaseObject(); + final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( - keyBaseObject, - keyBaseOffset, - storedKeyBaseObject, - storedKeyBaseOffset, - keyRowLengthBytes + keyBase, + keyOffset, + storedkeyBase, + storedkeyOffset, + keyLength ); if (areEqual) { return; @@ -410,18 +505,18 @@ private void updateAddressesAndSizes(long fullKeyAddress) { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object page, final long offsetInPage) { - long position = offsetInPage; - final int totalLength = Platform.getInt(page, position); + private void updateAddressesAndSizes(final Object base, final long offset) { + long position = offset; + final int totalLength = Platform.getInt(base, position); position += 4; - keyLength = Platform.getInt(page, position); + keyLength = Platform.getInt(base, position); position += 4; valueLength = totalLength - keyLength - 4; - keyMemoryLocation.setObjAndOffset(page, position); + keyMemoryLocation.setObjAndOffset(base, position); position += keyLength; - valueMemoryLocation.setObjAndOffset(page, position); + valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -443,6 +538,19 @@ private Location with(MemoryBlock page, long offsetInPage) { return this; } + /** + * This is only used for spilling + */ + private Location with(Object base, long offset, int length) { + this.isDefined = true; + this.memoryPage = null; + keyLength = Platform.getInt(base, offset); + valueLength = length - 4 - keyLength; + keyMemoryLocation.setObjAndOffset(base, offset + 4); + valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); + return this; + } + /** * Returns the memory page that contains the current record. * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. @@ -517,9 +625,9 @@ public int getValueLength() { * As an example usage, here's the proper way to store a new key: *

*
-     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+     *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -531,113 +639,59 @@ public int getValueLength() {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(
-        Object keyBaseObject,
-        long keyBaseOffset,
-        int keyLengthBytes,
-        Object valueBaseObject,
-        long valueBaseOffset,
-        int valueLengthBytes) {
+    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
+        Object valueBase, long valueOffset, int valueLength) {
       assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLengthBytes % 8 == 0);
-      assert (valueLengthBytes % 8 == 0);
+      assert (keyLength % 8 == 0);
+      assert (valueLength % 8 == 0);
       assert(bitset != null);
       assert(longArray != null);
 
-      if (numElements == MAX_CAPACITY) {
-        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+      if (numElements == MAX_CAPACITY || !canGrowArray) {
+        return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
       // (8 byte key length) (key) (value)
-      final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
-
-      // --- Figure out where to insert the new record ---------------------------------------------
-
-      final MemoryBlock dataPage;
-      final Object dataPageBaseObject;
-      final long dataPageInsertOffset;
-      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
-      if (useOverflowPage) {
-        // The record is larger than the page size, so allocate a special overflow page just to hold
-        // that record.
-        final long overflowPageSize = requiredSize + 8;
-        MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
+      final long recordLength = 8 + keyLength + valueLength;
+      if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
+        if (!acquireNewPage(recordLength + 4L)) {
           return false;
         }
-        dataPages.add(overflowPage);
-        dataPage = overflowPage;
-        dataPageBaseObject = overflowPage.getBaseObject();
-        dataPageInsertOffset = overflowPage.getBaseOffset();
-      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
-        // The record can fit in a data page, but either we have not allocated any pages yet or
-        // the current page does not have enough space.
-        if (currentDataPage != null) {
-          // There wasn't enough space in the current page, so write an end-of-page marker:
-          final Object pageBaseObject = currentDataPage.getBaseObject();
-          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
-          Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
-        }
-        if (!acquireNewPage()) {
-          return false;
-        }
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset();
-      } else {
-        // There is enough space in the current data page.
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
       }
 
       // --- Append the key and value data to the current data page --------------------------------
-
-      long insertCursor = dataPageInsertOffset;
-
-      // Compute all of our offsets up-front:
-      final long recordOffset = insertCursor;
-      insertCursor += 4;
-      final long keyLengthOffset = insertCursor;
-      insertCursor += 4;
-      final long keyDataOffsetInPage = insertCursor;
-      insertCursor += keyLengthBytes;
-      final long valueDataOffsetInPage = insertCursor;
-      insertCursor += valueLengthBytes; // word used to store the value size
-
-      Platform.putInt(dataPageBaseObject, recordOffset,
-        keyLengthBytes + valueLengthBytes + 4);
-      Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
-      // Copy the key
-      Platform.copyMemory(
-        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
-      // Copy the value
-      Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
-        valueDataOffsetInPage, valueLengthBytes);
-
-      // --- Update bookeeping data structures -----------------------------------------------------
-
-      if (useOverflowPage) {
-        // Store the end-of-page marker at the end of the data page
-        Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
-      } else {
-        pageCursor += requiredSize;
-      }
-
+      final Object base = currentPage.getBaseObject();
+      long offset = currentPage.getBaseOffset() + pageCursor;
+      final long recordOffset = offset;
+      Platform.putInt(base, offset, keyLength + valueLength + 4);
+      Platform.putInt(base, offset + 4, keyLength);
+      offset += 8;
+      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
+      offset += keyLength;
+      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+
+      // --- Update bookkeeping data structures -----------------------------------------------------
+      offset = currentPage.getBaseOffset();
+      Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
+      pageCursor += recordLength;
       numElements++;
       bitset.set(pos);
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
-        dataPage, recordOffset);
+        currentPage, recordOffset);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
       isDefined = true;
+
       if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        growAndRehash();
+        try {
+          growAndRehash();
+        } catch (OutOfMemoryError oom) {
+          canGrowArray = false;
+        }
       }
       return true;
     }
@@ -647,18 +701,26 @@ public boolean putNewKey(
    * Acquire a new page from the memory manager.
    * @return whether there is enough space to allocate the new page.
    */
-  private boolean acquireNewPage() {
-    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
-    if (newPage == null) {
-      logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+  private boolean acquireNewPage(long required) {
+    try {
+      currentPage = allocatePage(required);
+    } catch (OutOfMemoryError e) {
       return false;
     }
-    dataPages.add(newPage);
-    pageCursor = 0;
-    currentDataPage = newPage;
+    dataPages.add(currentPage);
+    Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
+    pageCursor = 4;
     return true;
   }
 
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this && destructiveIterator != null) {
+      return destructiveIterator.spill(size);
+    }
+    return 0L;
+  }
+
   /**
    * Allocate new data structures for this map. When calling this outside of the constructor,
    * make sure to keep references to the old data structures so that you can free them.
@@ -670,6 +732,7 @@ private void allocate(int capacity) {
     // The capacity needs to be divisible by 64 so that our bit set can be sized properly
     capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
     assert (capacity <= MAX_CAPACITY);
+    acquireMemory(capacity * 16);
     longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
     bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
 
@@ -677,6 +740,19 @@ private void allocate(int capacity) {
     this.mask = capacity - 1;
   }
 
+  /**
+   * Free the memory used by longArray.
+   */
+  public void freeArray() {
+    updatePeakMemoryUsed();
+    if (longArray != null) {
+      long used = longArray.memoryBlock().size();
+      longArray = null;
+      releaseMemory(used);
+      bitset = null;
+    }
+  }
+
   /**
    * Free all allocated memory associated with this map, including the storage for keys and values
    * as well as the hash map array itself.
@@ -684,16 +760,23 @@ private void allocate(int capacity) {
    * This method is idempotent and can be called multiple times.
    */
   public void free() {
-    updatePeakMemoryUsed();
-    longArray = null;
-    bitset = null;
+    freeArray();
     Iterator dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
       MemoryBlock dataPage = dataPagesIterator.next();
       dataPagesIterator.remove();
-      taskMemoryManager.freePage(dataPage);
+      freePage(dataPage);
     }
     assert(dataPages.isEmpty());
+
+    while (!spillWriters.isEmpty()) {
+      File file = spillWriters.removeFirst().getFile();
+      if (file != null && file.exists()) {
+        if (!file.delete()) {
+          logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+        }
+      }
+    }
   }
 
   public TaskMemoryManager getTaskMemoryManager() {
@@ -782,7 +865,13 @@ void growAndRehash() {
     final int oldCapacity = (int) oldBitSet.capacity();
 
     // Allocate the new data structures
-    allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+    try {
+      allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+    } catch (OutOfMemoryError oom) {
+      longArray = oldLongArray;
+      bitset = oldBitSet;
+      throw oom;
+    }
 
     // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
     for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
@@ -806,6 +895,7 @@ void growAndRehash() {
         }
       }
     }
+    releaseMemory(oldLongArray.memoryBlock().size());
 
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e317ea391c556..49a5a4b13b70d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,39 +17,34 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
 
-import javax.annotation.Nullable;
-
-import scala.runtime.AbstractFunction0;
-import scala.runtime.BoxedUnit;
-
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.TaskCompletionListener;
 import org.apache.spark.util.Utils;
 
 /**
  * External sorter based on {@link UnsafeInMemorySorter}.
  */
-public final class UnsafeExternalSorter {
+public final class UnsafeExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
 
-  private final long pageSizeBytes;
   private final PrefixComparator prefixComparator;
   private final RecordComparator recordComparator;
-  private final int initialSize;
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
@@ -69,14 +64,12 @@ public final class UnsafeExternalSorter {
   private final LinkedList spillWriters = new LinkedList<>();
 
   // These variables are reset after spilling:
-  @Nullable private UnsafeInMemorySorter inMemSorter;
-  // Whether the in-mem sorter is created internally, or passed in from outside.
-  // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
-  private boolean isInMemSorterExternal = false;
+  @Nullable private volatile UnsafeInMemorySorter inMemSorter;
+
   private MemoryBlock currentPage = null;
-  private long currentPagePosition = -1;
-  private long freeSpaceInCurrentPage = 0;
+  private long pageCursor = -1;
   private long peakMemoryUsedBytes = 0;
+  private volatile SpillableIterator readingIterator = null;
 
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
@@ -86,7 +79,7 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes,
-      UnsafeInMemorySorter inMemorySorter) throws IOException {
+      UnsafeInMemorySorter inMemorySorter) {
     return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
   }
@@ -98,7 +91,7 @@ public static UnsafeExternalSorter create(
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
-      long pageSizeBytes) throws IOException {
+      long pageSizeBytes) {
     return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
   }
@@ -111,60 +104,41 @@ private UnsafeExternalSorter(
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes,
-      @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
+      @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+    super(taskMemoryManager, pageSizeBytes);
     this.taskMemoryManager = taskMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.recordComparator = recordComparator;
     this.prefixComparator = prefixComparator;
-    this.initialSize = initialSize;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
     // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.fileBufferSizeBytes = 32 * 1024;
-    this.pageSizeBytes = pageSizeBytes;
+    // TODO: metrics tracking + integration with shuffle write metrics
+    // need to connect the write metrics to task metrics so we count the spill IO somewhere.
     this.writeMetrics = new ShuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
-      initializeForWriting();
-      // Acquire a new page as soon as we construct the sorter to ensure that we have at
-      // least one page to work with. Otherwise, other operators in the same task may starve
-      // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
-      acquireNewPage();
+      this.inMemSorter =
+        new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
+      acquireMemory(inMemSorter.getMemoryUsage());
     } else {
-      this.isInMemSorterExternal = true;
       this.inMemSorter = existingInMemorySorter;
+      // will acquire after free the map
     }
+    this.peakMemoryUsedBytes = getMemoryUsage();
 
     // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
     // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
     // does not fully consume the sorter's output (e.g. sort followed by limit).
-    taskContext.addOnCompleteCallback(new AbstractFunction0() {
-      @Override
-      public BoxedUnit apply() {
-        cleanupResources();
-        return null;
+    taskContext.addTaskCompletionListener(
+      new TaskCompletionListener() {
+        @Override
+        public void onTaskCompletion(TaskContext context) {
+          cleanupResources();
+        }
       }
-    });
-  }
-
-  // TODO: metrics tracking + integration with shuffle write metrics
-  // need to connect the write metrics to task metrics so we count the spill IO somewhere.
-
-  /**
-   * Allocates new sort data structures. Called when creating the sorter and after each spill.
-   */
-  private void initializeForWriting() throws IOException {
-    // Note: Do not track memory for the pointer array for now because of SPARK-10474.
-    // In more detail, in TungstenAggregate we only reserve a page, but when we fall back to
-    // sort-based aggregation we try to acquire a page AND a pointer array, which inevitably
-    // fails if all other memory is already occupied. It should be safe to not track the array
-    // because its memory footprint is frequently much smaller than that of a page. This is a
-    // temporary hack that we should address in 1.6.0.
-    // TODO: track the pointer array memory!
-    this.writeMetrics = new ShuffleWriteMetrics();
-    this.inMemSorter =
-      new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
-    this.isInMemSorterExternal = false;
+    );
   }
 
   /**
@@ -173,14 +147,27 @@ private void initializeForWriting() throws IOException {
    */
   @VisibleForTesting
   public void closeCurrentPage() {
-    freeSpaceInCurrentPage = 0;
+    if (currentPage != null) {
+      pageCursor = currentPage.getBaseOffset() + currentPage.size();
+    }
   }
 
   /**
    * Sort and spill the current records in response to memory pressure.
    */
-  public void spill() throws IOException {
-    assert(inMemSorter != null);
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this) {
+      if (readingIterator != null) {
+        return readingIterator.spill();
+      }
+      return 0L;
+    }
+
+    if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
+      return 0L;
+    }
+
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -202,6 +189,8 @@ public void spill() throws IOException {
         spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
       }
       spillWriter.close();
+
+      inMemSorter.reset();
     }
 
     final long spillSize = freeMemory();
@@ -210,7 +199,7 @@ public void spill() throws IOException {
     // written to disk. This also counts the space needed to store the sorter's pointer array.
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
 
-    initializeForWriting();
+    return spillSize;
   }
 
   /**
@@ -246,7 +235,7 @@ public int getNumberOfAllocatedPages() {
   }
 
   /**
-   * Free this sorter's in-memory data structures, including its data pages and pointer array.
+   * Free this sorter's data pages.
    *
    * @return the number of bytes freed.
    */
@@ -254,14 +243,12 @@ private long freeMemory() {
     updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      taskMemoryManager.freePage(block);
       memoryFreed += block.size();
+      freePage(block);
     }
-    // TODO: track in-memory sorter memory usage (SPARK-10474)
     allocatedPages.clear();
     currentPage = null;
-    currentPagePosition = -1;
-    freeSpaceInCurrentPage = 0;
+    pageCursor = 0;
     return memoryFreed;
   }
 
@@ -283,8 +270,15 @@ private void deleteSpillFiles() {
    * Frees this sorter's in-memory data structures and cleans up its spill files.
    */
   public void cleanupResources() {
-    deleteSpillFiles();
-    freeMemory();
+    synchronized (this) {
+      deleteSpillFiles();
+      freeMemory();
+      if (inMemSorter != null) {
+        long used = inMemSorter.getMemoryUsage();
+        inMemSorter = null;
+        releaseMemory(used);
+      }
+    }
   }
 
   /**
@@ -295,8 +289,28 @@ public void cleanupResources() {
   private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
-      // TODO: track the pointer array memory! (SPARK-10474)
-      inMemSorter.expandPointerArray();
+      long used = inMemSorter.getMemoryUsage();
+      long needed = used + inMemSorter.getMemoryToExpand();
+      try {
+        acquireMemory(needed);  // could trigger spilling
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        assert(inMemSorter.hasSpaceForAnotherRecord());
+        return;
+      }
+      // check if spilling is triggered or not
+      if (inMemSorter.hasSpaceForAnotherRecord()) {
+        releaseMemory(needed);
+      } else {
+        try {
+          inMemSorter.expandPointerArray();
+          releaseMemory(used);
+        } catch (OutOfMemoryError oom) {
+          // Just in case that JVM had run out of memory
+          releaseMemory(needed);
+          spill();
+        }
+      }
     }
   }
 
@@ -304,101 +318,38 @@ private void growPointerArrayIfNecessary() throws IOException {
    * Allocates more memory in order to insert an additional record. This will request additional
    * memory from the memory manager and spill if the requested memory can not be obtained.
    *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   * @param required the required space in the data page, in bytes, including space for storing
    *                      the record size. This must be less than or equal to the page size (records
    *                      that exceed the page size are handled via a different code path which uses
    *                      special overflow pages).
    */
-  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
-    assert (requiredSpace <= pageSizeBytes);
-    if (requiredSpace > freeSpaceInCurrentPage) {
-      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
-        freeSpaceInCurrentPage);
-      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
-      // without using the free space at the end of the current page. We should also do this for
-      // BytesToBytesMap.
-      if (requiredSpace > pageSizeBytes) {
-        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          pageSizeBytes + ")");
-      } else {
-        acquireNewPage();
-      }
+  private void acquireNewPageIfNecessary(int required) {
+    if (currentPage == null ||
+      pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
+      // TODO: try to find space on previous pages
+      currentPage = allocatePage(required);
+      pageCursor = currentPage.getBaseOffset();
+      allocatedPages.add(currentPage);
     }
   }
 
-  /**
-   * Acquire a new page from the memory manager.
-   *
-   * If there is not enough space to allocate the new page, spill all existing ones
-   * and try again. If there is still not enough space, report error to the caller.
-   */
-  private void acquireNewPage() throws IOException {
-    currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-    if (currentPage == null) {
-      spill();
-      currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-      if (currentPage == null) {
-        throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-      }
-    }
-    currentPagePosition = currentPage.getBaseOffset();
-    freeSpaceInCurrentPage = pageSizeBytes;
-    allocatedPages.add(currentPage);
-  }
-
   /**
    * Write a record to the sorter.
    */
-  public void insertRecord(
-      Object recordBaseObject,
-      long recordBaseOffset,
-      int lengthInBytes,
-      long prefix) throws IOException {
+  public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+    throws IOException {
 
     growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
-    final int totalSpaceRequired = lengthInBytes + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      if (overflowPage == null) {
-        spill();
-        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    // --- Insert the record ----------------------------------------------------------------------
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
-    dataPagePosition += 4;
-    Platform.copyMemory(
-      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+    final int required = length + 4;
+    acquireNewPageIfNecessary(required);
+
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, length);
+    pageCursor += 4;
+    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+    pageCursor += length;
     assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
   }
@@ -411,59 +362,24 @@ public void insertRecord(
    *
    * record length = key length + value length + 4
    */
-  public void insertKVRecord(
-      Object keyBaseObj, long keyOffset, int keyLen,
-      Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
+  public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
+      Object valueBase, long valueOffset, int valueLen, long prefix)
+    throws IOException {
 
     growPointerArrayIfNecessary();
-    final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      if (overflowPage == null) {
-        spill();
-        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    // --- Insert the record ----------------------------------------------------------------------
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4);
-    dataPagePosition += 4;
-
-    Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen);
-    dataPagePosition += 4;
-
-    Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen);
-    dataPagePosition += keyLen;
-
-    Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen);
+    final int required = keyLen + valueLen + 4 + 4;
+    acquireNewPageIfNecessary(required);
+
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, keyLen + valueLen + 4);
+    pageCursor += 4;
+    Platform.putInt(base, pageCursor, keyLen);
+    pageCursor += 4;
+    Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
+    pageCursor += keyLen;
+    Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
+    pageCursor += valueLen;
 
     assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
@@ -475,10 +391,10 @@ public void insertKVRecord(
    */
   public UnsafeSorterIterator getSortedIterator() throws IOException {
     assert(inMemSorter != null);
-    final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
-    int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+    readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+    int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {
-      return inMemoryIterator;
+      return readingIterator;
     } else {
       final UnsafeSorterSpillMerger spillMerger =
         new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
@@ -486,9 +402,113 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
         spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
       }
       spillWriters.clear();
-      spillMerger.addSpillIfNotEmpty(inMemoryIterator);
+      spillMerger.addSpillIfNotEmpty(readingIterator);
 
       return spillMerger.getSortedIterator();
     }
   }
+
+  /**
+   * An UnsafeSorterIterator that support spilling.
+   */
+  class SpillableIterator extends UnsafeSorterIterator {
+    private UnsafeSorterIterator upstream;
+    private UnsafeSorterIterator nextUpstream = null;
+    private MemoryBlock lastPage = null;
+    private boolean loaded = false;
+    private int numRecords = 0;
+
+    public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+      this.upstream = inMemIterator;
+      this.numRecords = inMemIterator.numRecordsLeft();
+    }
+
+    public long spill() throws IOException {
+      synchronized (this) {
+        if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null
+          && numRecords > 0)) {
+          return 0L;
+        }
+
+        UnsafeInMemorySorter.SortedIterator inMemIterator =
+          ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
+
+        final UnsafeSorterSpillWriter spillWriter =
+          new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
+        while (inMemIterator.hasNext()) {
+          inMemIterator.loadNext();
+          final Object baseObject = inMemIterator.getBaseObject();
+          final long baseOffset = inMemIterator.getBaseOffset();
+          final int recordLength = inMemIterator.getRecordLength();
+          spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
+        }
+        spillWriter.close();
+        spillWriters.add(spillWriter);
+        nextUpstream = spillWriter.getReader(blockManager);
+
+        long released = 0L;
+        synchronized (UnsafeExternalSorter.this) {
+          // release the pages except the one that is used
+          for (MemoryBlock page : allocatedPages) {
+            if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
+              released += page.size();
+              freePage(page);
+            } else {
+              lastPage = page;
+            }
+          }
+          allocatedPages.clear();
+        }
+        return released;
+      }
+    }
+
+    @Override
+    public boolean hasNext() {
+      return numRecords > 0;
+    }
+
+    @Override
+    public void loadNext() throws IOException {
+      synchronized (this) {
+        loaded = true;
+        if (nextUpstream != null) {
+          // Just consumed the last record from in memory iterator
+          if (lastPage != null) {
+            freePage(lastPage);
+            lastPage = null;
+          }
+          upstream = nextUpstream;
+          nextUpstream = null;
+
+          assert(inMemSorter != null);
+          long used = inMemSorter.getMemoryUsage();
+          inMemSorter = null;
+          releaseMemory(used);
+        }
+        numRecords--;
+        upstream.loadNext();
+      }
+    }
+
+    @Override
+    public Object getBaseObject() {
+      return upstream.getBaseObject();
+    }
+
+    @Override
+    public long getBaseOffset() {
+      return upstream.getBaseOffset();
+    }
+
+    @Override
+    public int getRecordLength() {
+      return upstream.getRecordLength();
+    }
+
+    @Override
+    public long getKeyPrefix() {
+      return upstream.getKeyPrefix();
+    }
+  }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 5aad72c374c37..1480f0681ed9c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -70,12 +70,12 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
    */
-  private long[] pointerArray;
+  private long[] array;
 
   /**
    * The position in the sort buffer where new records can be inserted.
    */
-  private int pointerArrayInsertPosition = 0;
+  private int pos = 0;
 
   public UnsafeInMemorySorter(
       final TaskMemoryManager memoryManager,
@@ -83,37 +83,43 @@ public UnsafeInMemorySorter(
       final PrefixComparator prefixComparator,
       int initialSize) {
     assert (initialSize > 0);
-    this.pointerArray = new long[initialSize * 2];
+    this.array = new long[initialSize * 2];
     this.memoryManager = memoryManager;
     this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
     this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
   }
 
+  public void reset() {
+    pos = 0;
+  }
+
   /**
    * @return the number of records that have been inserted into this sorter.
    */
   public int numRecords() {
-    return pointerArrayInsertPosition / 2;
+    return pos / 2;
   }
 
-  public long getMemoryUsage() {
-    return pointerArray.length * 8L;
+  private int newLength() {
+    return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
+  }
+
+  public long getMemoryToExpand() {
+    return (long) (newLength() - array.length) * 8L;
   }
 
-  static long getMemoryRequirementsForPointerArray(long numEntries) {
-    return numEntries * 2L * 8L;
+  public long getMemoryUsage() {
+    return array.length * 8L;
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pointerArrayInsertPosition + 2 < pointerArray.length;
+    return pos + 2 <= array.length;
   }
 
   public void expandPointerArray() {
-    final long[] oldArray = pointerArray;
-    // Guard against overflow:
-    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
-    pointerArray = new long[newLength];
-    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+    final long[] oldArray = array;
+    array = new long[newLength()];
+    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
   }
 
   /**
@@ -127,10 +133,10 @@ public void insertRecord(long recordPointer, long keyPrefix) {
     if (!hasSpaceForAnotherRecord()) {
       expandPointerArray();
     }
-    pointerArray[pointerArrayInsertPosition] = recordPointer;
-    pointerArrayInsertPosition++;
-    pointerArray[pointerArrayInsertPosition] = keyPrefix;
-    pointerArrayInsertPosition++;
+    array[pos] = recordPointer;
+    pos++;
+    array[pos] = keyPrefix;
+    pos++;
   }
 
   public static final class SortedIterator extends UnsafeSorterIterator {
@@ -153,11 +159,25 @@ private SortedIterator(
       this.sortBuffer = sortBuffer;
     }
 
+    public SortedIterator clone () {
+      SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+      iter.position = position;
+      iter.baseObject = baseObject;
+      iter.baseOffset = baseOffset;
+      iter.keyPrefix = keyPrefix;
+      iter.recordLength = recordLength;
+      return iter;
+    }
+
     @Override
     public boolean hasNext() {
       return position < sortBufferInsertPosition;
     }
 
+    public int numRecordsLeft() {
+      return (sortBufferInsertPosition - position) / 2;
+    }
+
     @Override
     public void loadNext() {
       // This pointer points to a 4-byte record length, followed by the record's bytes
@@ -187,7 +207,7 @@ public void loadNext() {
    * {@code next()} will return the same mutable object.
    */
   public SortedIterator getSortedIterator() {
-    sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
-    return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+    sorter.sort(array, 0, pos / 2, sortComparator);
+    return new SortedIterator(memoryManager, pos, array);
   }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 501dfe77d13cb..039e940a357ea 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,18 +20,18 @@
 import java.io.*;
 
 import com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 /**
  * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
  * of the file format).
  */
-final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
   private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
 
   private final File file;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index e59a84ff8d118..234e21140a1dd 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -35,7 +35,7 @@
  *
  *   [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
  */
-final class UnsafeSorterSpillWriter {
+public final class UnsafeSorterSpillWriter {
 
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
 
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index 6c9a71c3855b0..b0cf2696a397f 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import com.google.common.annotations.VisibleForTesting
 
+import org.apache.spark.util.Utils
 import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
 import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -215,8 +216,12 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte
   final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
     val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
     if (curMem < numBytes) {
-      throw new SparkException(
-        s"Internal error: release called on $numBytes bytes but task only has $curMem")
+      if (Utils.isTesting) {
+        throw new SparkException(
+          s"Internal error: release called on $numBytes bytes but task only has $curMem")
+      } else {
+        logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem")
+      }
     }
     if (executionMemoryForTask.contains(taskAttemptId)) {
       executionMemoryForTask(taskAttemptId) -= numBytes
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index a76891acf0baf..9e002621a6909 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
       // Claim up to double our current memory from the shuffle memory pool
       val amountToRequest = 2 * currentMemory - myMemoryThreshold
-      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
+      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null)
       myMemoryThreshold += granted
       // If we were granted too little memory to grow further (either tryToAcquire returned 0,
       // or we already had more memory than myMemoryThreshold), spill the current collection
@@ -107,7 +107,7 @@ private[spark] trait Spillable[C] extends Logging {
    */
   def releaseMemory(): Unit = {
     // The amount we requested does not include the initial memory tracking threshold
-    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
+    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null)
     myMemoryThreshold = initialMemoryThreshold
   }
 
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index f381db0c62653..dab7b0592cb4e 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.memory;
 
+import java.io.IOException;
+
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -25,19 +27,40 @@
 
 public class TaskMemoryManagerSuite {
 
+  class TestMemoryConsumer extends MemoryConsumer {
+    TestMemoryConsumer(TaskMemoryManager memoryManager) {
+      super(memoryManager);
+    }
+
+    @Override
+    public long spill(long size, MemoryConsumer trigger) throws IOException {
+      long used = getUsed();
+      releaseMemory(used);
+      return used;
+    }
+
+    void use(long size) {
+      acquireMemory(size);
+    }
+
+    void free(long size) {
+      releaseMemory(size);
+    }
+  }
+
   @Test
   public void leakedPageMemoryIsDetected() {
     final TaskMemoryManager manager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
-    manager.allocatePage(4096);  // leak memory
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    manager.allocatePage(4096, null);  // leak memory
     Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
   }
 
   @Test
   public void encodePageNumberAndOffsetOffHeap() {
     final TaskMemoryManager manager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
-    final MemoryBlock dataPage = manager.allocatePage(256);
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256, null);
     // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
     // encode. This test exercises that corner-case:
     final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
@@ -49,11 +72,53 @@ public void encodePageNumberAndOffsetOffHeap() {
   @Test
   public void encodePageNumberAndOffsetOnHeap() {
     final TaskMemoryManager manager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
-    final MemoryBlock dataPage = manager.allocatePage(256);
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256, null);
     final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
     Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
     Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
   }
 
+  @Test
+  public void cooperativeSpilling() {
+    final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
+    memoryManager.limit(100);
+    final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0);
+
+    TestMemoryConsumer c1 = new TestMemoryConsumer(manager);
+    TestMemoryConsumer c2 = new TestMemoryConsumer(manager);
+    c1.use(100);
+    assert(c1.getUsed() == 100);
+    c2.use(100);
+    assert(c2.getUsed() == 100);
+    assert(c1.getUsed() == 0);  // spilled
+    c1.use(100);
+    assert(c1.getUsed() == 100);
+    assert(c2.getUsed() == 0);  // spilled
+
+    c1.use(50);
+    assert(c1.getUsed() == 50);  // spilled
+    assert(c2.getUsed() == 0);
+    c2.use(50);
+    assert(c1.getUsed() == 50);
+    assert(c2.getUsed() == 50);
+
+    c1.use(100);
+    assert(c1.getUsed() == 100);
+    assert(c2.getUsed() == 0);  // spilled
+
+    c1.free(20);
+    assert(c1.getUsed() == 80);
+    c2.use(10);
+    assert(c1.getUsed() == 80);
+    assert(c2.getUsed() == 10);
+    c2.use(100);
+    assert(c2.getUsed() == 100);
+    assert(c1.getUsed() == 0);  // spilled
+
+    c1.free(0);
+    c2.free(100);
+    assert(manager.cleanUpAllAllocatedMemory() == 0);
+  }
+
 }
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 7fb2f92ca80e8..9a43f1f3a9235 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -17,25 +17,29 @@
 
 package org.apache.spark.shuffle.sort;
 
-import org.apache.spark.shuffle.sort.PackedRecordPointer;
+import java.io.IOException;
+
 import org.junit.Test;
-import static org.junit.Assert.*;
 
 import org.apache.spark.SparkConf;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 
 public class PackedRecordPointerSuite {
 
   @Test
-  public void heap() {
+  public void heap() throws IOException {
     final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
+      new TaskMemoryManager(new TestMemoryManager(conf), 0);
+    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
+    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -49,12 +53,12 @@ public void heap() {
   }
 
   @Test
-  public void offHeap() {
+  public void offHeap() throws IOException {
     final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
+      new TaskMemoryManager(new TestMemoryManager(conf), 0);
+    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
+    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 5049a5306ff21..2293b1bbc113e 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -26,7 +26,7 @@
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.memory.TaskMemoryManager;
 
@@ -60,8 +60,8 @@ public void testBasicSorting() throws Exception {
     };
     final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+      new TaskMemoryManager(new TestMemoryManager(conf), 0);
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
     final Object baseObject = dataPage.getBaseObject();
     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
     final HashPartitioner hashPartitioner = new HashPartitioner(4);
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index d65926949c036..4763395d7d401 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -54,13 +54,14 @@
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 public class UnsafeShuffleWriterSuite {
 
   static final int NUM_PARTITITONS = 4;
+  TestMemoryManager memoryManager;
   TaskMemoryManager taskMemoryManager;
   final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
   File mergedOutputFile;
@@ -106,10 +107,11 @@ public void setUp() throws IOException {
     partitionSizesInMergedFile = null;
     spillFilesCreated.clear();
     conf = new SparkConf()
-      .set("spark.buffer.pageSize", "128m")
+      .set("spark.buffer.pageSize", "1m")
       .set("spark.unsafe.offHeap", "false");
     taskMetrics = new TaskMetrics();
-    taskMemoryManager =  new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
+    memoryManager = new TestMemoryManager(conf);
+    taskMemoryManager =  new TaskMemoryManager(memoryManager, 0);
 
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
@@ -344,9 +346,7 @@ private void testMergingSpills(
     }
     assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
 
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
+    assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
     assertSpillFilesWereCleanedUp();
     ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
     assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
@@ -398,20 +398,14 @@ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
 
   @Test
   public void writeEnoughDataToTriggerSpill() throws Exception {
-    taskMemoryManager = spy(taskMemoryManager);
-    doCallRealMethod() // initialize sort buffer
-      .doCallRealMethod() // allocate initial data page
-      .doReturn(0L) // deny request to allocate new page
-      .doCallRealMethod() // grant new sort buffer and data page
-      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
+    memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES);
     final UnsafeShuffleWriter writer = createWriter(false);
     final ArrayList> dataToWrite = new ArrayList>();
-    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
-    for (int i = 0; i < 128 + 1; i++) {
+    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10];
+    for (int i = 0; i < 10 + 1; i++) {
       dataToWrite.add(new Tuple2(i, bigByteArray));
     }
     writer.write(dataToWrite.iterator());
-    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -426,19 +420,13 @@ public void writeEnoughDataToTriggerSpill() throws Exception {
 
   @Test
   public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
-    taskMemoryManager = spy(taskMemoryManager);
-    doCallRealMethod() // initialize sort buffer
-      .doCallRealMethod() // allocate initial data page
-      .doReturn(0L) // deny request to allocate new page
-      .doCallRealMethod() // grant new sort buffer and data page
-      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
+    memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
     final UnsafeShuffleWriter writer = createWriter(false);
     final ArrayList> dataToWrite = new ArrayList<>();
     for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
       dataToWrite.add(new Tuple2(i, i));
     }
     writer.write(dataToWrite.iterator());
-    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -473,11 +461,11 @@ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
     final ArrayList> dataToWrite = new ArrayList>();
     dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1])));
     // We should be able to write a record that's right _at_ the max record size
-    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
+    final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4];
     new Random(42).nextBytes(atMaxRecordSize);
     dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize)));
     // Inserting a record that's larger than the max record size
-    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
+    final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()];
     new Random(42).nextBytes(exceedsMaxRecordSize);
     dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
     writer.write(dataToWrite.iterator());
@@ -524,7 +512,7 @@ public void testPeakMemoryUsed() throws Exception {
       for (int i = 0; i < numRecordsPerPage * 10; i++) {
         writer.insertRecordIntoSorter(new Tuple2(1, 1));
         newPeakMemory = writer.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0 && i != 0) {
+        if (i % numRecordsPerPage == 0) {
           // The first page is allocated in constructor, another page will be allocated after
           // every numRecordsPerPage records (peak memory should change).
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 6e52496cf933b..92bd45e5fa241 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -17,40 +17,117 @@
 
 package org.apache.spark.unsafe.map;
 
-import java.lang.Exception;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.nio.ByteBuffer;
 import java.util.*;
 
-import org.apache.spark.memory.TaskMemoryManager;
-import org.junit.*;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.junit.Assert.*;
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import org.apache.spark.SparkConf;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.util.Utils;
+
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.when;
 
 
 public abstract class AbstractBytesToBytesMapSuite {
 
   private final Random rand = new Random(42);
 
-  private GrantEverythingMemoryManager memoryManager;
+  private TestMemoryManager memoryManager;
   private TaskMemoryManager taskMemoryManager;
   private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
+  final LinkedList spillFilesCreated = new LinkedList();
+  File tempDir;
+
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+
+  private static final class CompressStream extends AbstractFunction1 {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      return stream;
+    }
+  }
+
   @Before
   public void setup() {
     memoryManager =
-      new GrantEverythingMemoryManager(
+      new TestMemoryManager(
         new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+
+    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
+    spillFilesCreated.clear();
+    MockitoAnnotations.initMocks(this);
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() {
+      @Override
+      public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable {
+        TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+        File file = File.createTempFile("spillFile", ".spill", tempDir);
+        spillFilesCreated.add(file);
+        return Tuple2$.MODULE$.apply(blockId, file);
+      }
+    });
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+      .then(returnsSecondArg());
   }
 
   @After
   public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    tempDir = null;
+
     Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
     if (taskMemoryManager != null) {
       long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
@@ -415,9 +492,8 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
 
   @Test
   public void failureToAllocateFirstPage() {
-    memoryManager.markExecutionAsOutOfMemory();
+    memoryManager.limit(1024);  // longArray
     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
-    memoryManager.markExecutionAsOutOfMemory();
     try {
       final long[] emptyArray = new long[0];
       final BytesToBytesMap.Location loc =
@@ -439,7 +515,7 @@ public void failureToGrow() {
       int i;
       for (i = 0; i < 127; i++) {
         if (i > 0) {
-          memoryManager.markExecutionAsOutOfMemory();
+          memoryManager.limit(0);
         }
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
@@ -456,6 +532,44 @@ public void failureToGrow() {
     }
   }
 
+  @Test
+  public void spillInIterator() throws IOException {
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
+    try {
+      int i;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
+        loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+      }
+      BytesToBytesMap.MapIterator iter = map.iterator();
+      for (i = 0; i < 100; i++) {
+        iter.next();
+      }
+      // Non-destructive iterator is not spillable
+      Assert.assertEquals(0, iter.spill(1024L * 10));
+      for (i = 100; i < 1024; i++) {
+        iter.next();
+      }
+
+      BytesToBytesMap.MapIterator iter2 = map.destructiveIterator();
+      for (i = 0; i < 100; i++) {
+        iter2.next();
+      }
+      Assert.assertTrue(iter2.spill(1024) >= 1024);
+      for (i = 100; i < 1024; i++) {
+        iter2.next();
+      }
+      assertFalse(iter2.hasNext());
+    } finally {
+      map.free();
+      for (File spillFile : spillFilesCreated) {
+        assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+          spillFile.exists());
+      }
+    }
+  }
+
   @Test
   public void initialCapacityBoundsChecking() {
     try {
@@ -500,7 +614,7 @@ public void testPeakMemoryUsed() {
           Platform.LONG_ARRAY_OFFSET,
           8);
         newPeakMemory = map.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0 && i > 0) {
+        if (i % numRecordsPerPage == 0) {
           // We allocated a new page for this record, so peak memory should change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
         } else {
@@ -519,11 +633,4 @@ public void testPeakMemoryUsed() {
     }
   }
 
-  @Test
-  public void testAcquirePageInConstructor() {
-    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
-    assertEquals(1, map.getNumDataPages());
-    map.free();
-  }
-
 }
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 94d50b94fde3f..cfead0e5924b8 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -36,28 +36,29 @@
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-import static org.hamcrest.Matchers.greaterThanOrEqualTo;
-import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsSecondArg;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
 
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
 public class UnsafeExternalSorterSuite {
 
   final LinkedList spillFilesCreated = new LinkedList();
-  final GrantEverythingMemoryManager memoryManager =
-    new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+  final TestMemoryManager memoryManager =
+    new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
   final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   // Use integer comparison for comparing prefixes (which are partition ids, in this case)
   final PrefixComparator prefixComparator = new PrefixComparator() {
@@ -86,7 +87,7 @@ public int compare(
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
 
 
-  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
+  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
 
   private static final class CompressStream extends AbstractFunction1 {
     @Override
@@ -233,7 +234,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception {
       insertNumber(sorter, numRecords - i);
     }
     assertEquals(1, sorter.getNumberOfAllocatedPages());
-    memoryManager.markExecutionAsOutOfMemory();
+    memoryManager.markExecutionAsOutOfMemoryOnce();
     // The insertion of this record should trigger a spill:
     insertNumber(sorter, 0);
     // Ensure that spill files were created
@@ -311,6 +312,62 @@ public void sortingRecordsThatExceedPageSize() throws Exception {
     assertSpillFilesWereCleanedUp();
   }
 
+  @Test
+  public void forcedSpillingWithReadIterator() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
+    long[] record = new long[100];
+    int recordSize = record.length * 8;
+    int n = (int) pageSizeBytes / recordSize * 3;
+    for (int i = 0; i < n; i++) {
+      record[0] = (long) i;
+      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+    }
+    assert(sorter.getNumberOfAllocatedPages() >= 2);
+    UnsafeExternalSorter.SpillableIterator iter =
+      (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
+    int lastv = 0;
+    for (int i = 0; i < n / 3; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+      lastv = i;
+    }
+    assert(iter.spill() > 0);
+    assert(iter.spill() == 0);
+    assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv);
+    for (int i = n / 3; i < n; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+    }
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void forcedSpillingWithNotReadIterator() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
+    long[] record = new long[100];
+    int recordSize = record.length * 8;
+    int n = (int) pageSizeBytes / recordSize * 3;
+    for (int i = 0; i < n; i++) {
+      record[0] = (long) i;
+      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+    }
+    assert(sorter.getNumberOfAllocatedPages() >= 2);
+    UnsafeExternalSorter.SpillableIterator iter =
+      (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
+    assert(iter.spill() > 0);
+    assert(iter.spill() == 0);
+    for (int i = 0; i < n; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+    }
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
   @Test
   public void testPeakMemoryUsed() throws Exception {
     final long recordLengthBytes = 8;
@@ -334,7 +391,7 @@ public void testPeakMemoryUsed() throws Exception {
         insertNumber(sorter, i);
         newPeakMemory = sorter.getPeakMemoryUsedBytes();
         // The first page is pre-allocated on instantiation
-        if (i % numRecordsPerPage == 0 && i > 0) {
+        if (i % numRecordsPerPage == 0) {
           // We allocated a new page for this record, so peak memory should change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
         } else {
@@ -358,21 +415,5 @@ public void testPeakMemoryUsed() throws Exception {
     }
   }
 
-  @Test
-  public void testReservePageOnInstantiation() throws Exception {
-    final UnsafeExternalSorter sorter = newSorter();
-    try {
-      assertEquals(1, sorter.getNumberOfAllocatedPages());
-      // Inserting a new record doesn't allocate more memory since we already have a page
-      long peakMemory = sorter.getPeakMemoryUsedBytes();
-      insertNumber(sorter, 100);
-      assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
-      assertEquals(1, sorter.getNumberOfAllocatedPages());
-    } finally {
-      sorter.cleanupResources();
-      assertSpillFilesWereCleanedUp();
-    }
-  }
-
 }
 
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index d5de56a0512f9..642f6585f8a15 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -20,17 +20,19 @@
 import java.util.Arrays;
 
 import org.junit.Test;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.mock;
 
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.isIn;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
 
 public class UnsafeInMemorySorterSuite {
 
@@ -44,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset,
   public void testSortingEmptyInput() {
     final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
       new TaskMemoryManager(
-        new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
+        new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
       mock(RecordComparator.class),
       mock(PrefixComparator.class),
       100);
@@ -66,8 +68,8 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
       "Mango"
     };
     final TaskMemoryManager memoryManager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
     final Object baseObject = dataPage.getBaseObject();
     // Write the records into the data page:
     long position = dataPage.getBaseOffset();
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 0242cbc9244a8..203dab934ca1f 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // cause is preserved
     val thrownDueToTaskFailure = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocatePage(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128, null)
         throw new Exception("intentional task failure")
         iter
       }.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // If the task succeeded but memory was leaked, then the task should fail due to that leak
     val thrownDueToMemoryLeak = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocatePage(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128, null)
         iter
       }.count()
     }
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 1265087743a98..4a9479cf490fb 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -145,20 +145,20 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val manager = createMemoryManager(1000L)
     val taskMemoryManager = new TaskMemoryManager(manager, 0)
 
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
-    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
-    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
-    assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
 
-    taskMemoryManager.releaseExecutionMemory(500L)
-    assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
-    assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
+    taskMemoryManager.releaseExecutionMemory(500L, null)
+    assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L)
 
     taskMemoryManager.cleanUpAllAllocatedMemory()
-    assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
   }
 
   test("two tasks requesting full execution memory") {
@@ -168,15 +168,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // Have both tasks request 500 bytes, then wait until both requests have been granted:
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 500L)
     assert(Await.result(t2Result1, futureTimeout) === 500L)
 
     // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
     // both now at 1 / N
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t1Result2, 200.millis) === 0L)
     assert(Await.result(t2Result2, 200.millis) === 0L)
   }
@@ -188,15 +188,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // Have both tasks request 250 bytes, then wait until both requests have been granted:
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 250L)
     assert(Await.result(t2Result1, futureTimeout) === 250L)
 
     // Have both tasks each request 500 bytes more.
     // We should only grant 250 bytes to each of them on this second request
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t1Result2, futureTimeout) === 250L)
     assert(Await.result(t2Result2, futureTimeout) === 250L)
   }
@@ -208,17 +208,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 1000L)
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
     // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
     // to make sure the other thread blocks for some time otherwise.
     Thread.sleep(300)
-    t1MemManager.releaseExecutionMemory(250L)
+    t1MemManager.releaseExecutionMemory(250L, null)
     // The memory freed from t1 should now be granted to t2.
     assert(Await.result(t2Result1, futureTimeout) === 250L)
     // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) }
     assert(Await.result(t2Result2, 200.millis) === 0L)
   }
 
@@ -229,18 +229,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 1000L)
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
     // to make sure the other thread blocks for some time otherwise.
     Thread.sleep(300)
     // t1 releases all of its memory, so t2 should be able to grab all of the memory
     t1MemManager.cleanUpAllAllocatedMemory()
     assert(Await.result(t2Result1, futureTimeout) === 500L)
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t2Result2, futureTimeout) === 500L)
-    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t2Result3, 200.millis) === 0L)
   }
 
@@ -251,13 +251,13 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
     val futureTimeout: Duration = 20.seconds
 
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 700L)
 
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) }
     assert(Await.result(t2Result1, futureTimeout) === 300L)
 
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) }
     assert(Await.result(t1Result2, 200.millis) === 0L)
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
similarity index 71%
rename from core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
rename to core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
index fe102d8aeb2a5..77e43554ee27c 100644
--- a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
+++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
@@ -22,16 +22,22 @@ import scala.collection.mutable
 import org.apache.spark.SparkConf
 import org.apache.spark.storage.{BlockStatus, BlockId}
 
-class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
+class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
   private[memory] override def doAcquireExecutionMemory(
       numBytes: Long,
       evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
-    if (oom) {
-      oom = false
+    if (oomOnce) {
+      oomOnce = false
       0
-    } else {
+    } else if (available >= numBytes) {
       _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
+      available -= numBytes
       numBytes
+    } else {
+      _executionMemoryUsed += available
+      val grant = available
+      available = 0
+      grant
     }
   }
   override def acquireStorageMemory(
@@ -42,13 +48,23 @@ class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf,
       blockId: BlockId,
       numBytes: Long,
       evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
-  override def releaseStorageMemory(numBytes: Long): Unit = { }
+  override def releaseExecutionMemory(numBytes: Long): Unit = {
+    available += numBytes
+    _executionMemoryUsed -= numBytes
+  }
+  override def releaseStorageMemory(numBytes: Long): Unit = {}
   override def maxExecutionMemory: Long = Long.MaxValue
   override def maxStorageMemory: Long = Long.MaxValue
 
-  private var oom = false
+  private var oomOnce = false
+  private var available = Long.MaxValue
 
-  def markExecutionAsOutOfMemory(): Unit = {
-    oom = true
+  def markExecutionAsOutOfMemoryOnce(): Unit = {
+    oomOnce = true
   }
+
+  def limit(avail: Long): Unit = {
+    available = avail
+  }
+
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 810c74fd2fb96..f7063d1e5c829 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -96,15 +96,10 @@ void insertRow(UnsafeRow row) throws IOException {
     );
     numRowsInserted++;
     if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
-      spill();
+      sorter.spill();
     }
   }
 
-  @VisibleForTesting
-  void spill() throws IOException {
-    sorter.spill();
-  }
-
   /**
    * Return the peak memory used so far, in bytes.
    */
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 82c645df284de..889f97003450c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -165,7 +165,7 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo
   public KVIterator iterator() {
     return new KVIterator() {
 
-      private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator =
+      private final BytesToBytesMap.MapIterator mapLocationIterator =
         map.destructiveIterator();
       private final UnsafeRow key = new UnsafeRow();
       private final UnsafeRow value = new UnsafeRow();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 46301f0042954..845f2ae6859b7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -17,13 +17,13 @@
 
 package org.apache.spark.sql.execution;
 
-import java.io.IOException;
-
 import javax.annotation.Nullable;
+import java.io.IOException;
 
 import com.google.common.annotations.VisibleForTesting;
 
 import org.apache.spark.TaskContext;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -33,7 +33,6 @@
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.collection.unsafe.sort.*;
 
 /**
@@ -84,18 +83,16 @@ public UnsafeKVExternalSorter(
         /* initialSize */ 4096,
         pageSizeBytes);
     } else {
-      // Insert the records into the in-memory sorter.
-      // We will use the number of elements in the map as the initialSize of the
-      // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize,
-      // we will use 1 as its initial size if the map is empty.
-      // TODO: track pointer array memory used by this in-memory sorter! (SPARK-10474)
+      // The memory needed for UnsafeInMemorySorter should be less than longArray in map.
+      map.freeArray();
+      // The memory used by UnsafeInMemorySorter will be counted later (end of this block)
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
         taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
 
       // We cannot use the destructive iterator here because we are reusing the existing memory
       // pages in BytesToBytesMap to hold records during sorting.
       // The only new memory we are allocating is the pointer/prefix array.
-      BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+      BytesToBytesMap.MapIterator iter = map.iterator();
       final int numKeyFields = keySchema.size();
       UnsafeRow row = new UnsafeRow();
       while (iter.hasNext()) {
@@ -117,7 +114,7 @@ public UnsafeKVExternalSorter(
       }
 
       sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
-        taskContext.taskMemoryManager(),
+        taskMemoryManager,
         blockManager,
         taskContext,
         new KVComparator(ordering, keySchema.length()),
@@ -128,6 +125,8 @@ public UnsafeKVExternalSorter(
 
       sorter.spill();
       map.free();
+      // counting the memory used UnsafeInMemorySorter
+      taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter);
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index dbf4863b767bf..a38623623a441 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -24,7 +24,7 @@ import scala.util.{Try, Random}
 import org.scalatest.Matchers
 
 import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite}
-import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
 import org.apache.spark.sql.test.SharedSQLContext
@@ -48,7 +48,7 @@ class UnsafeFixedWidthAggregationMapSuite
   private def emptyAggregationBuffer: InternalRow = InternalRow(0)
   private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
 
-  private var memoryManager: GrantEverythingMemoryManager = null
+  private var memoryManager: TestMemoryManager = null
   private var taskMemoryManager: TaskMemoryManager = null
 
   def testWithMemoryLeakDetection(name: String)(f: => Unit) {
@@ -62,7 +62,7 @@ class UnsafeFixedWidthAggregationMapSuite
 
     test(name) {
       val conf = new SparkConf().set("spark.unsafe.offHeap", "false")
-      memoryManager = new GrantEverythingMemoryManager(conf)
+      memoryManager = new TestMemoryManager(conf)
       taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
 
       TaskContext.setTaskContext(new TaskContextImpl(
@@ -193,10 +193,6 @@ class UnsafeFixedWidthAggregationMapSuite
     // Convert the map into a sorter
     val sorter = map.destructAndCreateExternalSorter()
 
-    withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
-      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
-    }
-
     // Add more keys to the sorter and make sure the results come out sorted.
     val additionalKeys = randomStrings(1024)
     val keyConverter = UnsafeProjection.create(groupKeySchema)
@@ -208,7 +204,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
@@ -251,7 +247,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
@@ -294,16 +290,12 @@ class UnsafeFixedWidthAggregationMapSuite
     // Convert the map into a sorter. Right now, it contains one record.
     val sorter = map.destructAndCreateExternalSorter()
 
-    withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
-      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
-    }
-
     // Add more keys to the sorter and make sure the results come out sorted.
     (1 to 4096).foreach { i =>
       sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
@@ -342,7 +334,7 @@ class UnsafeFixedWidthAggregationMapSuite
       buf.setInt(0, str.length)
     }
     // Simulate running out of space
-    memoryManager.markExecutionAsOutOfMemory()
+    memoryManager.limit(0)
     val str = rand.nextString(1024)
     val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
     assert(buf == null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 13dc1754c9ff0..7b80963ec8708 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import scala.util.Random
 
 import org.apache.spark._
-import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
@@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       pageSize: Long,
       spill: Boolean): Unit = {
     val memoryManager =
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
     val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
     TaskContext.setTaskContext(new TaskContextImpl(
       stageId = 0,
@@ -128,7 +128,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
       // 1% chance we will spill
       if (rand.nextDouble() < 0.01 && spill) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
deleted file mode 100644
index 475037bd45379..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark._
-import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.SharedSQLContext
-
-class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
-
-  test("memory acquired on construction") {
-    val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0)
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
-    TaskContext.setTaskContext(taskContext)
-
-    // Assert that a page is allocated before processing starts
-    var iter: TungstenAggregationIterator = null
-    try {
-      val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
-        () => new InterpretedMutableProjection(expr, schema)
-      }
-      val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
-      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
-        0, Seq.empty, newMutableProjection, Seq.empty, None,
-        dummyAccum, dummyAccum, dummyAccum, dummyAccum)
-      val numPages = iter.getHashMap.getNumDataPages
-      assert(numPages === 1)
-    } finally {
-      // Clean up
-      if (iter != null) {
-        iter.free()
-      }
-      TaskContext.unset()
-    }
-  }
-}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index ebe90d9e63d83..09847cec9c4ca 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -23,6 +23,8 @@
 import java.util.LinkedList;
 import java.util.Map;
 
+import org.apache.spark.unsafe.Platform;
+
 /**
  * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
  */
@@ -45,9 +47,6 @@ private boolean shouldPool(long size) {
 
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
-    if (size % 8 != 0) {
-      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
-    }
     if (shouldPool(size)) {
       synchronized (this) {
         final LinkedList> pool = bufferPoolsBySize.get(size);
@@ -64,8 +63,8 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
         }
       }
     }
-    long[] array = new long[(int) (size / 8)];
-    return MemoryBlock.fromLongArray(array);
+    long[] array = new long[(int) ((size + 7) / 8)];
+    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
   }
 
   @Override
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index cda7826c8c99b..98ce711176e43 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -26,9 +26,6 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
 
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
-    if (size % 8 != 0) {
-      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
-    }
     long address = Platform.allocateMemory(size);
     return new MemoryBlock(null, address, size);
   }