diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 405e52946415..fbba002f1f80 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -255,11 +255,18 @@ private MapIterator(int numRecords, Location loc, boolean destructive) { } private void advanceToNextPage() { + // SPARK-26265: We will first lock this `MapIterator` and then `TaskMemoryManager` when going + // to free a memory page by calling `freePage`. At the same time, it is possibly that another + // memory consumer first locks `TaskMemoryManager` and then this `MapIterator` when it + // acquires memory and causes spilling on this `MapIterator`. To avoid deadlock here, we keep + // reference to the page to free and free it after releasing the lock of `MapIterator`. + MemoryBlock pageToFree = null; + synchronized (this) { int nextIdx = dataPages.indexOf(currentPage) + 1; if (destructive && currentPage != null) { dataPages.remove(currentPage); - freePage(currentPage); + pageToFree = currentPage; nextIdx --; } if (dataPages.size() > nextIdx) { @@ -283,6 +290,9 @@ private void advanceToNextPage() { } } } + if (pageToFree != null) { + freePage(pageToFree); + } } @Override @@ -329,52 +339,50 @@ public Location next() { } } - public long spill(long numBytes) throws IOException { - synchronized (this) { - if (!destructive || dataPages.size() == 1) { - return 0L; - } + public synchronized long spill(long numBytes) throws IOException { + if (!destructive || dataPages.size() == 1) { + return 0L; + } - updatePeakMemoryUsed(); + updatePeakMemoryUsed(); - // TODO: use existing ShuffleWriteMetrics - ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + // TODO: use existing ShuffleWriteMetrics + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); - long released = 0L; - while (dataPages.size() > 0) { - MemoryBlock block = dataPages.getLast(); - // The currentPage is used, cannot be released - if (block == currentPage) { - break; - } + long released = 0L; + while (dataPages.size() > 0) { + MemoryBlock block = dataPages.getLast(); + // The currentPage is used, cannot be released + if (block == currentPage) { + break; + } - Object base = block.getBaseObject(); - long offset = block.getBaseOffset(); - int numRecords = UnsafeAlignedOffset.getSize(base, offset); - int uaoSize = UnsafeAlignedOffset.getUaoSize(); - offset += uaoSize; - final UnsafeSorterSpillWriter writer = - new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); - while (numRecords > 0) { - int length = UnsafeAlignedOffset.getSize(base, offset); - writer.write(base, offset + uaoSize, length, 0); - offset += uaoSize + length + 8; - numRecords--; - } - writer.close(); - spillWriters.add(writer); + Object base = block.getBaseObject(); + long offset = block.getBaseOffset(); + int numRecords = UnsafeAlignedOffset.getSize(base, offset); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + offset += uaoSize; + final UnsafeSorterSpillWriter writer = + new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); + while (numRecords > 0) { + int length = UnsafeAlignedOffset.getSize(base, offset); + writer.write(base, offset + uaoSize, length, 0); + offset += uaoSize + length + 8; + numRecords--; + } + writer.close(); + spillWriters.add(writer); - dataPages.removeLast(); - released += block.size(); - freePage(block); + dataPages.removeLast(); + released += block.size(); + freePage(block); - if (released >= numBytes) { - break; - } + if (released >= numBytes) { + break; } - - return released; } + + return released; } @Override diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index 0bbaea6b834b..6aa577d1bf79 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -38,12 +38,12 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { return used; } - void use(long size) { + public void use(long size) { long got = taskMemoryManager.acquireExecutionMemory(size, this); used += got; } - void free(long size) { + public void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index aa29232e73e1..089dda7568a7 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -33,6 +33,8 @@ import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; @@ -667,4 +669,49 @@ public void testPeakMemoryUsed() { } } + @Test + public void avoidDeadlock() throws InterruptedException { + memoryManager.limit(PAGE_SIZE_BYTES); + MemoryMode mode = useOffHeapMemoryAllocator() ? MemoryMode.OFF_HEAP: MemoryMode.ON_HEAP; + TestMemoryConsumer c1 = new TestMemoryConsumer(taskMemoryManager, mode); + BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024); + + Thread thread = new Thread(() -> { + int i = 0; + long used = 0; + while (i < 10) { + c1.use(10000000); + used += 10000000; + i++; + } + c1.free(used); + }); + + try { + int i; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + } + + // Starts to require memory at another memory consumer. + thread.start(); + + BytesToBytesMap.MapIterator iter = map.destructiveIterator(); + for (i = 0; i < 1024; i++) { + iter.next(); + } + assertFalse(iter.hasNext()); + } finally { + map.free(); + thread.join(); + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + } + }