From 8391185995956632a29c6ae69697238e501d5175 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sat, 28 Nov 2015 11:58:23 +0800 Subject: [PATCH 01/18] init commit --- .../apache/spark/memory/MemoryConsumer.java | 20 +++- .../spark/memory/TaskMemoryManager.java | 7 ++ .../collection/ExternalAppendOnlyMap.scala | 57 +++++++--- .../util/collection/ExternalSorter.scala | 101 +++++++++++++++--- .../spark/util/collection/Spillable.scala | 48 +++++---- 5 files changed, 182 insertions(+), 51 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 36138cc9a297c..bd5ffe2e00f9f 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -45,7 +45,7 @@ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { /** * Returns the size of used memory in bytes. */ - long getUsed() { + protected long getUsed() { return used; } @@ -130,4 +130,22 @@ protected void freePage(MemoryBlock page) { used -= page.size(); taskMemoryManager.freePage(page, this); } + + /** + * Allocates a heap memory of `size`. + */ + public long allocateHeapExecutionMemory(long size) { + long granted = + taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this); + used += granted; + return granted; + } + + /** + * Release N bytes of heap memory. + */ + public void freeHeapExecutionMemory(long size) { + taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this); + used -= size; + } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d31eb449eb82e..6c5c105f4db2b 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -387,4 +387,11 @@ public long cleanUpAllAllocatedMemory() { public long getMemoryConsumptionForThisTask() { return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); } + + /** + * Returns Tungsten memory mode + */ + public MemoryMode getTungstenMemoryMode(){ + return tungstenMemoryMode; + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f6d81ee5bf05e..535137616ba3b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -28,7 +28,6 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator @@ -59,10 +58,10 @@ class ExternalAppendOnlyMap[K, V, C]( serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager, context: TaskContext = TaskContext.get()) - extends Iterable[(K, C)] + extends Spillable[SizeTracker](context.taskMemoryManager()) with Serializable with Logging - with Spillable[SizeTracker] { + with Iterable[(K, C)] { if (context == null) { throw new IllegalStateException( @@ -79,8 +78,6 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() - private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf @@ -115,6 +112,8 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + private var memoryOrDiskIterator: Iterator[(K, C)] = null + /** * Number of files this map has spilled so far. * Exposed for testing. @@ -177,9 +176,19 @@ class ExternalAppendOnlyMap[K, V, C]( } /** - * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. + * Spill in-memory map to a temporary file on disk. */ - override protected[this] def spill(collection: SizeTracker): Unit = { + override protected[this] def spill(collection: SizeTracker): Boolean = { + var spillIterator: Iterator[(K, C)] = null + if (collection == null) { + // Spill is called by TaskMemoryManager when there is not enough memory for the task. + assert(memoryOrDiskIterator != null) + spillIterator = memoryOrDiskIterator + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + } else { + spillIterator = currentMap.destructiveSortedIterator(keyComparator) + } val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -200,9 +209,8 @@ class ExternalAppendOnlyMap[K, V, C]( var success = false try { - val it = currentMap.destructiveSortedIterator(keyComparator) - while (it.hasNext) { - val kv = it.next() + while (spillIterator.hasNext) { + val kv = spillIterator.next() writer.write(kv._1, kv._2) objectsWritten += 1 @@ -235,9 +243,29 @@ class ExternalAppendOnlyMap[K, V, C]( } } - spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) + val diskMapIterator = new DiskMapIterator(file, blockId, batchSizes) + if (collection == null) { + memoryOrDiskIterator = diskMapIterator + } else { + spilledMaps.append(diskMapIterator) + } + true } + /** + * Returns a destructive iterator for iterating over the entries of this map. + * If this iterator is forced spill to disk to release memory when there is not enough memory, + * it returns pairs from an on-disk map. + */ + def destructiveIterator(mapIterator: Iterator[(K, C)]): Iterator[(K, C)] = { + memoryOrDiskIterator = mapIterator + new Iterator[(K, C)] { + + override def hasNext = memoryOrDiskIterator.hasNext + + override def next() = memoryOrDiskIterator.next() + } + } /** * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. @@ -248,7 +276,8 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap()) + destructiveIterator( + CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())) } else { new ExternalIterator() } @@ -270,8 +299,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]]( - currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap()) + private val sortedMap = destructiveIterator(CompletionIterator[(K, C), Iterator[(K, C)]]( + currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2440139ac95e9..ab0058b350708 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,7 +26,6 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import org.apache.spark._ -import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} @@ -92,10 +91,8 @@ private[spark] class ExternalSorter[K, V, C]( partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) - extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] { - - override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) + with Logging { private val conf = SparkEnv.get.conf @@ -136,6 +133,10 @@ private[spark] class ExternalSorter[K, V, C]( private var _peakMemoryUsedBytes: Long = 0L def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private var isShuffleSort: Boolean = true + var forceSpillFile: Option[SpilledFile] = None + private var memoryOrDiskIterator: Iterator[((Int, K), C)] = null + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -228,12 +229,54 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Spill our in-memory collection to a sorted file that we can merge later. - * We add this file into `spilledFiles` to find it later. + * Spill our in-memory collection to a sorted file. * * @param collection whichever collection we're using (map or buffer) */ - override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]) + : Boolean = { + var spillIterator: WritablePartitionedIterator = null + if (collection == null) { + // Spill is called by TaskMemoryManager when there is not enough memory for the task. + if (isShuffleSort) { + false + } else { + assert(memoryOrDiskIterator != null) + val it = memoryOrDiskIterator + spillIterator = new WritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null + + def writeNext(writer: DiskBlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + forceSpillFile = Some(spillMemoryToDisk(spillIterator)) + val spillReader = new SpillReader(forceSpillFile.get) + memoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p => + val iterator = spillReader.readNextPartition() + iterator.map(cur => ((p, cur._1), cur._2)) + } + true + } + } else { + spillIterator = collection.destructiveSortedWritablePartitionedIterator(comparator) + val spillFile = spillMemoryToDisk(spillIterator) + spills.append(spillFile) + true + } + } + + /** + * Spill contents of in-memory iterator to a temporary file on disk. + */ + private def spillMemoryToDisk(memoryIterator: WritablePartitionedIterator): SpilledFile = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. @@ -270,12 +313,11 @@ private[spark] class ExternalSorter[K, V, C]( var success = false try { - val it = collection.destructiveSortedWritablePartitionedIterator(comparator) - while (it.hasNext) { - val partitionId = it.nextPartition() + while (memoryIterator.hasNext) { + val partitionId = memoryIterator.nextPartition() require(partitionId >= 0 && partitionId < numPartitions, s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") - it.writeNext(writer) + memoryIterator.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -307,7 +349,7 @@ private[spark] class ExternalSorter[K, V, C]( } } - spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition) } /** @@ -598,6 +640,25 @@ private[spark] class ExternalSorter[K, V, C]( } } + /** + * Returns a destructive iterator for iterating over the entries of this map. + * If this iterator is forced spill to disk to release memory when there is not enough memory, + * it returns pairs from an on-disk map. + */ + def destructiveIterator(memoryIterator: Iterator[((Int, K), C)]): Iterator[((Int, K), C)] = { + if (isShuffleSort) { + memoryIterator + } else { + memoryOrDiskIterator = memoryIterator + new Iterator[((Int, K), C)] { + + override def hasNext = memoryOrDiskIterator.hasNext + + override def next() = memoryOrDiskIterator.next() + } + } + } + /** * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its @@ -617,21 +678,26 @@ private[spark] class ExternalSorter[K, V, C]( // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { // The user hasn't requested sorted keys, so only sort by partition ID, not key - groupByPartition(collection.partitionedDestructiveSortedIterator(None)) + groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None))) } else { // We do need to sort by both partition ID and key - groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) + groupByPartition(destructiveIterator( + collection.partitionedDestructiveSortedIterator(Some(keyComparator)))) } } else { // Merge spilled and in-memory data - merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) + merge(spills, destructiveIterator( + collection.partitionedDestructiveSortedIterator(comparator))) } } /** * Return an iterator over all the data written to this object, aggregated by our aggregator. */ - def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + def iterator: Iterator[Product2[K, C]] = { + isShuffleSort = false + partitionedIterator.flatMap(pair => pair._2) + } /** * Write all the data added into this ExternalSorter into a file in the disk store. This is @@ -691,6 +757,7 @@ private[spark] class ExternalSorter[K, V, C]( buffer = null // So that the memory can be garbage-collected spills.foreach(s => s.file.delete()) spills.clear() + forceSpillFile.foreach(_.file.delete()) releaseMemory() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 3a48af82b1dae..282834f045977 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,20 +17,21 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} import org.apache.spark.{Logging, SparkEnv} /** * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ -private[spark] trait Spillable[C] extends Logging { +private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) + extends MemoryConsumer(taskMemoryManager) with Logging { /** * Spills the current in-memory collection to disk, and releases the memory. * * @param collection collection to spill to disk */ - protected def spill(collection: C): Unit + protected def spill(collection: C): Boolean // Number of elements read from input since last spill protected def elementsRead: Long = _elementsRead @@ -39,22 +40,13 @@ private[spark] trait Spillable[C] extends Logging { // It's used for checking spilling frequency protected def addElementsRead(): Unit = { _elementsRead += 1 } - // Memory manager that can be used to acquire/release memory - protected[this] def taskMemoryManager: TaskMemoryManager - - // Initial threshold for the size of a collection before we start tracking its memory usage - // For testing only - private[this] val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) - // Force this collection to spill when there are this many elements in memory // For testing only private[this] val numElementsForceSpillThreshold: Long = SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) // Threshold for this collection's size in bytes before we start tracking its memory usage - // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 - private[this] var myMemoryThreshold = initialMemoryThreshold + private[this] var myMemoryThreshold = 0L // Number of elements read from input since last spill private[this] var _elementsRead = 0L @@ -78,8 +70,7 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = - taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) + val granted = allocateHeapExecutionMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -98,6 +89,27 @@ private[spark] trait Spillable[C] extends Logging { shouldSpill } + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + */ + override def spill(size: Long, trigger: MemoryConsumer): Long = { + if (trigger != this && taskMemoryManager.getTungstenMemoryMode == MemoryMode.ON_HEAP) { + val isSpilled = spill(null.asInstanceOf[C]) + if (!isSpilled) { + 0L + } else { + _elementsRead = 0 + val freeMemory = myMemoryThreshold + _memoryBytesSpilled += freeMemory + releaseMemory() + freeMemory + } + } else { + 0L + } + } + /** * @return number of bytes spilled in total */ @@ -107,10 +119,8 @@ private[spark] trait Spillable[C] extends Logging { * Release our memory back to the execution pool so that other tasks can grab it. */ def releaseMemory(): Unit = { - // The amount we requested does not include the initial memory tracking threshold - taskMemoryManager.releaseExecutionMemory( - myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) - myMemoryThreshold = initialMemoryThreshold + freeHeapExecutionMemory(myMemoryThreshold) + myMemoryThreshold = 0L } /** From 34f24410be3f6e566900178baf90f122010846a4 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sat, 28 Nov 2015 17:48:24 +0800 Subject: [PATCH 02/18] fix minor bug --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ab0058b350708..3712778d03575 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -161,7 +161,7 @@ private[spark] class ExternalSorter[K, V, C]( // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. - private[this] case class SpilledFile( + private[collection] case class SpilledFile( file: File, blockId: BlockId, serializerBatchSizes: Array[Long], From bedac89e3afdc8013fbf7dc8df4902291c56dd25 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 24 Jan 2016 17:13:06 +0800 Subject: [PATCH 03/18] fix some bugs --- .../collection/ExternalAppendOnlyMap.scala | 19 +++++++++++++------ .../util/collection/ExternalSorter.scala | 16 ++++++++++++---- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 535137616ba3b..13df6088b9bf1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -246,6 +246,7 @@ class ExternalAppendOnlyMap[K, V, C]( val diskMapIterator = new DiskMapIterator(file, blockId, batchSizes) if (collection == null) { memoryOrDiskIterator = diskMapIterator + currentMap = null } else { spilledMaps.append(diskMapIterator) } @@ -276,16 +277,18 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - destructiveIterator( - CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())) + CompletionIterator[(K, C), Iterator[(K, C)]]( + destructiveIterator(currentMap.iterator), freeCurrentMap()) } else { new ExternalIterator() } } private def freeCurrentMap(): Unit = { - currentMap = null // So that the memory can be garbage-collected - releaseMemory() + if (currentMap != null) { + currentMap = null // So that the memory can be garbage-collected + releaseMemory() + } } /** @@ -299,8 +302,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = destructiveIterator(CompletionIterator[(K, C), Iterator[(K, C)]]( - currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())) + private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -561,6 +564,10 @@ class ExternalAppendOnlyMap[K, V, C]( /** Convenience function to hash the given (K, C) pair by the key. */ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) + + override def toString(): String = { + return this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) + } } private[spark] object ExternalAppendOnlyMap { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3712778d03575..cf866ff7c2ec4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -262,7 +262,9 @@ private[spark] class ExternalSorter[K, V, C]( memoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p => val iterator = spillReader.readNextPartition() iterator.map(cur => ((p, cur._1), cur._2)) - } + } + map = null + buffer =null true } } else { @@ -753,12 +755,18 @@ private[spark] class ExternalSorter[K, V, C]( } def stop(): Unit = { - map = null // So that the memory can be garbage-collected - buffer = null // So that the memory can be garbage-collected spills.foreach(s => s.file.delete()) spills.clear() forceSpillFile.foreach(_.file.delete()) - releaseMemory() + if (map != null || buffer != null) { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected + releaseMemory() + } + } + + override def toString(): String = { + return this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) } /** From b5616414af8fff78f96b320cfbe3bf368d6f756c Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 24 Jan 2016 17:16:38 +0800 Subject: [PATCH 04/18] fix minor style --- .../apache/spark/util/collection/ExternalAppendOnlyMap.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 13df6088b9bf1..ed293cab5e675 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -277,8 +277,8 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]]( - destructiveIterator(currentMap.iterator), freeCurrentMap()) + CompletionIterator[(K, C), Iterator[(K, C)]]( + destructiveIterator(currentMap.iterator), freeCurrentMap()) } else { new ExternalIterator() } From 16ca87bfc66e1d8ddfc1067a6bf97b6875343d61 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Sun, 24 Jan 2016 17:37:59 +0800 Subject: [PATCH 05/18] fix minor style --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index cf866ff7c2ec4..be3d392614243 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -264,7 +264,7 @@ private[spark] class ExternalSorter[K, V, C]( iterator.map(cur => ((p, cur._1), cur._2)) } map = null - buffer =null + buffer = null true } } else { From b9e7071e56f7d542ce5826b56fd21071456d967a Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Mon, 18 Apr 2016 17:45:11 +0800 Subject: [PATCH 06/18] update code --- .../collection/ExternalAppendOnlyMap.scala | 64 +++++++------ .../util/collection/ExternalSorter.scala | 93 ++++++++++--------- .../spark/util/collection/Spillable.scala | 10 +- 3 files changed, 91 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index ed293cab5e675..92655b6d74dde 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -112,7 +112,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() - private var memoryOrDiskIterator: Iterator[(K, C)] = null + private var inMemoryOrDiskIterator: Iterator[(K, C)] = null /** * Number of files this map has spilled so far. @@ -176,19 +176,34 @@ class ExternalAppendOnlyMap[K, V, C]( } /** - * Spill in-memory map to a temporary file on disk. + * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - override protected[this] def spill(collection: SizeTracker): Boolean = { - var spillIterator: Iterator[(K, C)] = null - if (collection == null) { - // Spill is called by TaskMemoryManager when there is not enough memory for the task. - assert(memoryOrDiskIterator != null) - spillIterator = memoryOrDiskIterator - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - } else { - spillIterator = currentMap.destructiveSortedIterator(keyComparator) - } + override protected[this] def spill(collection: SizeTracker): Unit = { + val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator) + val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator) + spilledMaps.append(diskMapIterator) + } + + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + override protected[this] def forceSpill(): Boolean = { + assert(inMemoryOrDiskIterator != null) + val inMemoryIterator = inMemoryOrDiskIterator + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator) + inMemoryOrDiskIterator = diskMapIterator + currentMap = null + true + } + + /** + * Spill the in-memory Iterator to a temporary file on disk. + */ + private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)]) + : DiskMapIterator = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -209,8 +224,8 @@ class ExternalAppendOnlyMap[K, V, C]( var success = false try { - while (spillIterator.hasNext) { - val kv = spillIterator.next() + while (inMemoryIterator.hasNext) { + val kv = inMemoryIterator.next() writer.write(kv._1, kv._2) objectsWritten += 1 @@ -243,14 +258,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } - val diskMapIterator = new DiskMapIterator(file, blockId, batchSizes) - if (collection == null) { - memoryOrDiskIterator = diskMapIterator - currentMap = null - } else { - spilledMaps.append(diskMapIterator) - } - true + new DiskMapIterator(file, blockId, batchSizes) } /** @@ -258,13 +266,13 @@ class ExternalAppendOnlyMap[K, V, C]( * If this iterator is forced spill to disk to release memory when there is not enough memory, * it returns pairs from an on-disk map. */ - def destructiveIterator(mapIterator: Iterator[(K, C)]): Iterator[(K, C)] = { - memoryOrDiskIterator = mapIterator + def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { + inMemoryOrDiskIterator = inMemoryIterator new Iterator[(K, C)] { - override def hasNext = memoryOrDiskIterator.hasNext + override def hasNext = inMemoryOrDiskIterator.hasNext - override def next() = memoryOrDiskIterator.next() + override def next() = inMemoryOrDiskIterator.next() } } /** @@ -566,7 +574,7 @@ class ExternalAppendOnlyMap[K, V, C]( private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) override def toString(): String = { - return this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) + this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index be3d392614243..95f8658575762 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -135,7 +135,7 @@ private[spark] class ExternalSorter[K, V, C]( private var isShuffleSort: Boolean = true var forceSpillFile: Option[SpilledFile] = None - private var memoryOrDiskIterator: Iterator[((Int, K), C)] = null + private var inMemoryOrDiskIterator: Iterator[((Int, K), C)] = null // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -229,48 +229,48 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Spill our in-memory collection to a sorted file. - * + * Spill our in-memory collection to a sorted file that we can merge later. + * We add this file into `spilledFiles` to find it later. * @param collection whichever collection we're using (map or buffer) */ - override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]) - : Boolean = { - var spillIterator: WritablePartitionedIterator = null - if (collection == null) { - // Spill is called by TaskMemoryManager when there is not enough memory for the task. - if (isShuffleSort) { - false - } else { - assert(memoryOrDiskIterator != null) - val it = memoryOrDiskIterator - spillIterator = new WritablePartitionedIterator { - private[this] var cur = if (it.hasNext) it.next() else null - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - writer.write(cur._1._2, cur._2) - cur = if (it.hasNext) it.next() else null - } - - def hasNext(): Boolean = cur != null + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { + val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator) + val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) + spills.append(spillFile) + } - def nextPartition(): Int = cur._1._1 - } - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - forceSpillFile = Some(spillMemoryToDisk(spillIterator)) - val spillReader = new SpillReader(forceSpillFile.get) - memoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p => - val iterator = spillReader.readNextPartition() - iterator.map(cur => ((p, cur._1), cur._2)) + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + override protected[this] def forceSpill(): Boolean = { + if (isShuffleSort) { + false + } else { + assert(inMemoryOrDiskIterator != null) + val it = inMemoryOrDiskIterator + val inMemoryIterator = new WritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null + + def writeNext(writer: DiskBlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null } - map = null - buffer = null - true + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 } - } else { - spillIterator = collection.destructiveSortedWritablePartitionedIterator(comparator) - val spillFile = spillMemoryToDisk(spillIterator) - spills.append(spillFile) + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + forceSpillFile = Some(spillMemoryIteratorToDisk(inMemoryIterator)) + val spillReader = new SpillReader(forceSpillFile.get) + inMemoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p => + val iterator = spillReader.readNextPartition() + iterator.map(cur => ((p, cur._1), cur._2)) + } + map = null + buffer = null true } } @@ -278,7 +278,8 @@ private[spark] class ExternalSorter[K, V, C]( /** * Spill contents of in-memory iterator to a temporary file on disk. */ - private def spillMemoryToDisk(memoryIterator: WritablePartitionedIterator): SpilledFile = { + private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator) + : SpilledFile = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. @@ -315,11 +316,11 @@ private[spark] class ExternalSorter[K, V, C]( var success = false try { - while (memoryIterator.hasNext) { - val partitionId = memoryIterator.nextPartition() + while (inMemoryIterator.hasNext) { + val partitionId = inMemoryIterator.nextPartition() require(partitionId >= 0 && partitionId < numPartitions, s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") - memoryIterator.writeNext(writer) + inMemoryIterator.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -651,12 +652,12 @@ private[spark] class ExternalSorter[K, V, C]( if (isShuffleSort) { memoryIterator } else { - memoryOrDiskIterator = memoryIterator + inMemoryOrDiskIterator = memoryIterator new Iterator[((Int, K), C)] { - override def hasNext = memoryOrDiskIterator.hasNext + override def hasNext = inMemoryOrDiskIterator.hasNext - override def next() = memoryOrDiskIterator.next() + override def next() = inMemoryOrDiskIterator.next() } } } @@ -766,7 +767,7 @@ private[spark] class ExternalSorter[K, V, C]( } override def toString(): String = { - return this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) + this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 282834f045977..08ebe96ff12c1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -31,7 +31,13 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * * @param collection collection to spill to disk */ - protected def spill(collection: C): Boolean + protected def spill(collection: C): Unit + + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + protected def forceSpill(): Boolean // Number of elements read from input since last spill protected def elementsRead: Long = _elementsRead @@ -95,7 +101,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) */ override def spill(size: Long, trigger: MemoryConsumer): Long = { if (trigger != this && taskMemoryManager.getTungstenMemoryMode == MemoryMode.ON_HEAP) { - val isSpilled = spill(null.asInstanceOf[C]) + val isSpilled = forceSpill() if (!isSpilled) { 0L } else { From 49acacc4e30a04c97a8ed9c361d5326132655a76 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Mon, 18 Apr 2016 20:12:09 +0800 Subject: [PATCH 07/18] Merge branch 'apache-master' into SPARK-4452-2 --- .../collection/ExternalAppendOnlyMap.scala | 71 +- .../util/collection/ExternalSorter.scala | 112 ++- .../spark/util/collection/Spillable.scala | 52 +- .../spark/sql/execution/SparkSqlParser.scala | 792 ------------------ 4 files changed, 187 insertions(+), 840 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 95351e98261d7..2b19031e5861d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -61,10 +61,10 @@ class ExternalAppendOnlyMap[K, V, C]( blockManager: BlockManager = SparkEnv.get.blockManager, context: TaskContext = TaskContext.get(), serializerManager: SerializerManager = SparkEnv.get.serializerManager) - extends Iterable[(K, C)] + extends Spillable[SizeTracker](context.taskMemoryManager()) with Serializable with Logging - with Spillable[SizeTracker] { + with Iterable[(K, C)] { if (context == null) { throw new IllegalStateException( @@ -81,8 +81,6 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() - private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf @@ -117,6 +115,8 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + private var inMemoryOrDiskIterator: Iterator[(K, C)] = null + /** * Number of files this map has spilled so far. * Exposed for testing. @@ -182,6 +182,31 @@ class ExternalAppendOnlyMap[K, V, C]( * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ override protected[this] def spill(collection: SizeTracker): Unit = { + val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator) + val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator) + spilledMaps.append(diskMapIterator) + } + + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + override protected[this] def forceSpill(): Boolean = { + assert(inMemoryOrDiskIterator != null) + val inMemoryIterator = inMemoryOrDiskIterator + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator) + inMemoryOrDiskIterator = diskMapIterator + currentMap = null + true + } + + /** + * Spill the in-memory Iterator to a temporary file on disk. + */ + private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)]) + : DiskMapIterator = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -202,9 +227,8 @@ class ExternalAppendOnlyMap[K, V, C]( var success = false try { - val it = currentMap.destructiveSortedIterator(keyComparator) - while (it.hasNext) { - val kv = it.next() + while (inMemoryIterator.hasNext) { + val kv = inMemoryIterator.next() writer.write(kv._1, kv._2) objectsWritten += 1 @@ -237,9 +261,23 @@ class ExternalAppendOnlyMap[K, V, C]( } } - spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) + new DiskMapIterator(file, blockId, batchSizes) } + /** + * Returns a destructive iterator for iterating over the entries of this map. + * If this iterator is forced spill to disk to release memory when there is not enough memory, + * it returns pairs from an on-disk map. + */ + def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { + inMemoryOrDiskIterator = inMemoryIterator + new Iterator[(K, C)] { + + override def hasNext = inMemoryOrDiskIterator.hasNext + + override def next() = inMemoryOrDiskIterator.next() + } + } /** * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. @@ -250,15 +288,18 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap()) + CompletionIterator[(K, C), Iterator[(K, C)]]( + destructiveIterator(currentMap.iterator), freeCurrentMap()) } else { new ExternalIterator() } } private def freeCurrentMap(): Unit = { - currentMap = null // So that the memory can be garbage-collected - releaseMemory() + if (currentMap != null) { + currentMap = null // So that the memory can be garbage-collected + releaseMemory() + } } /** @@ -272,8 +313,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]]( - currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap()) + private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -534,6 +575,10 @@ class ExternalAppendOnlyMap[K, V, C]( /** Convenience function to hash the given (K, C) pair by the key. */ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) + + override def toString(): String = { + this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) + } } private[spark] object ExternalAppendOnlyMap { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 916053f42d072..b01b1dc638469 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -93,10 +93,8 @@ private[spark] class ExternalSorter[K, V, C]( partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Serializer = SparkEnv.get.serializer) - extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] { - - override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) + with Logging { private val conf = SparkEnv.get.conf @@ -137,6 +135,10 @@ private[spark] class ExternalSorter[K, V, C]( private var _peakMemoryUsedBytes: Long = 0L def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private var isShuffleSort: Boolean = true + var forceSpillFile: Option[SpilledFile] = None + private var inMemoryOrDiskIterator: Iterator[((Int, K), C)] = null + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -161,7 +163,7 @@ private[spark] class ExternalSorter[K, V, C]( // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. - private[this] case class SpilledFile( + private[collection] case class SpilledFile( file: File, blockId: BlockId, serializerBatchSizes: Array[Long], @@ -235,6 +237,52 @@ private[spark] class ExternalSorter[K, V, C]( * @param collection whichever collection we're using (map or buffer) */ override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { + val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator) + val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) + spills.append(spillFile) + } + + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + override protected[this] def forceSpill(): Boolean = { + if (isShuffleSort) { + false + } else { + assert(inMemoryOrDiskIterator != null) + val it = inMemoryOrDiskIterator + val inMemoryIterator = new WritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null + + def writeNext(writer: DiskBlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + forceSpillFile = Some(spillMemoryIteratorToDisk(inMemoryIterator)) + val spillReader = new SpillReader(forceSpillFile.get) + inMemoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p => + val iterator = spillReader.readNextPartition() + iterator.map(cur => ((p, cur._1), cur._2)) + } + map = null + buffer = null + true + } + } + + /** + * Spill contents of in-memory iterator to a temporary file on disk. + */ + private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator) + : SpilledFile = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. @@ -271,12 +319,11 @@ private[spark] class ExternalSorter[K, V, C]( var success = false try { - val it = collection.destructiveSortedWritablePartitionedIterator(comparator) - while (it.hasNext) { - val partitionId = it.nextPartition() + while (inMemoryIterator.hasNext) { + val partitionId = inMemoryIterator.nextPartition() require(partitionId >= 0 && partitionId < numPartitions, s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") - it.writeNext(writer) + inMemoryIterator.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -308,7 +355,7 @@ private[spark] class ExternalSorter[K, V, C]( } } - spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition) } /** @@ -599,6 +646,25 @@ private[spark] class ExternalSorter[K, V, C]( } } + /** + * Returns a destructive iterator for iterating over the entries of this map. + * If this iterator is forced spill to disk to release memory when there is not enough memory, + * it returns pairs from an on-disk map. + */ + def destructiveIterator(memoryIterator: Iterator[((Int, K), C)]): Iterator[((Int, K), C)] = { + if (isShuffleSort) { + memoryIterator + } else { + inMemoryOrDiskIterator = memoryIterator + new Iterator[((Int, K), C)] { + + override def hasNext = inMemoryOrDiskIterator.hasNext + + override def next() = inMemoryOrDiskIterator.next() + } + } + } + /** * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its @@ -618,21 +684,26 @@ private[spark] class ExternalSorter[K, V, C]( // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { // The user hasn't requested sorted keys, so only sort by partition ID, not key - groupByPartition(collection.partitionedDestructiveSortedIterator(None)) + groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None))) } else { // We do need to sort by both partition ID and key - groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) + groupByPartition(destructiveIterator( + collection.partitionedDestructiveSortedIterator(Some(keyComparator)))) } } else { // Merge spilled and in-memory data - merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) + merge(spills, destructiveIterator( + collection.partitionedDestructiveSortedIterator(comparator))) } } /** * Return an iterator over all the data written to this object, aggregated by our aggregator. */ - def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + def iterator: Iterator[Product2[K, C]] = { + isShuffleSort = false + partitionedIterator.flatMap(pair => pair._2) + } /** * Write all the data added into this ExternalSorter into a file in the disk store. This is @@ -689,11 +760,18 @@ private[spark] class ExternalSorter[K, V, C]( } def stop(): Unit = { - map = null // So that the memory can be garbage-collected - buffer = null // So that the memory can be garbage-collected spills.foreach(s => s.file.delete()) spills.clear() - releaseMemory() + forceSpillFile.foreach(_.file.delete()) + if (map != null || buffer != null) { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected + releaseMemory() + } + } + + override def toString(): String = { + this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 25ca2037bbac6..be2077d753526 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,13 +19,14 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ -private[spark] trait Spillable[C] extends Logging { +private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) + extends MemoryConsumer(taskMemoryManager) with Logging { /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -33,6 +34,12 @@ private[spark] trait Spillable[C] extends Logging { */ protected def spill(collection: C): Unit + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + protected def forceSpill(): Boolean + // Number of elements read from input since last spill protected def elementsRead: Long = _elementsRead @@ -40,22 +47,13 @@ private[spark] trait Spillable[C] extends Logging { // It's used for checking spilling frequency protected def addElementsRead(): Unit = { _elementsRead += 1 } - // Memory manager that can be used to acquire/release memory - protected[this] def taskMemoryManager: TaskMemoryManager - - // Initial threshold for the size of a collection before we start tracking its memory usage - // For testing only - private[this] val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) - // Force this collection to spill when there are this many elements in memory // For testing only private[this] val numElementsForceSpillThreshold: Long = SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) // Threshold for this collection's size in bytes before we start tracking its memory usage - // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 - private[this] var myMemoryThreshold = initialMemoryThreshold + private[this] var myMemoryThreshold = 0L // Number of elements read from input since last spill private[this] var _elementsRead = 0L @@ -79,8 +77,7 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = - taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) + val granted = allocateHeapExecutionMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -99,6 +96,27 @@ private[spark] trait Spillable[C] extends Logging { shouldSpill } + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + */ + override def spill(size: Long, trigger: MemoryConsumer): Long = { + if (trigger != this && taskMemoryManager.getTungstenMemoryMode == MemoryMode.ON_HEAP) { + val isSpilled = forceSpill() + if (!isSpilled) { + 0L + } else { + _elementsRead = 0 + val freeMemory = myMemoryThreshold + _memoryBytesSpilled += freeMemory + releaseMemory() + freeMemory + } + } else { + 0L + } + } + /** * @return number of bytes spilled in total */ @@ -108,10 +126,8 @@ private[spark] trait Spillable[C] extends Logging { * Release our memory back to the execution pool so that other tasks can grab it. */ def releaseMemory(): Unit = { - // The amount we requested does not include the initial memory tracking threshold - taskMemoryManager.releaseExecutionMemory( - myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) - myMemoryThreshold = initialMemoryThreshold + freeHeapExecutionMemory(myMemoryThreshold) + myMemoryThreshold = 0L } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala deleted file mode 100644 index 8ed6ed21d0170..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ /dev/null @@ -1,792 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} -import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} -import org.apache.spark.sql.execution.datasources._ - -/** - * Concrete parser for Spark SQL statements. - */ -object SparkSqlParser extends AbstractSqlParser{ - val astBuilder = new SparkSqlAstBuilder -} - -/** - * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. - */ -class SparkSqlAstBuilder extends AstBuilder { - import org.apache.spark.sql.catalyst.parser.ParserUtils._ - - /** - * Create a [[SetCommand]] logical plan. - * - * Note that we assume that everything after the SET keyword is assumed to be a part of the - * key-value pair. The split between key and value is made by searching for the first `=` - * character in the raw string. - */ - override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { - // Construct the command. - val raw = remainder(ctx.SET.getSymbol) - val keyValueSeparatorIndex = raw.indexOf('=') - if (keyValueSeparatorIndex >= 0) { - val key = raw.substring(0, keyValueSeparatorIndex).trim - val value = raw.substring(keyValueSeparatorIndex + 1).trim - SetCommand(Some(key -> Option(value))) - } else if (raw.nonEmpty) { - SetCommand(Some(raw.trim -> None)) - } else { - SetCommand(None) - } - } - - /** - * Create a [[SetDatabaseCommand]] logical plan. - */ - override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { - SetDatabaseCommand(ctx.db.getText) - } - - /** - * Create a [[ShowTablesCommand]] logical plan. - * Example SQL : - * {{{ - * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; - * }}} - */ - override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { - ShowTablesCommand( - Option(ctx.db).map(_.getText), - Option(ctx.pattern).map(string)) - } - - /** - * Create a [[ShowDatabasesCommand]] logical plan. - * Example SQL: - * {{{ - * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; - * }}} - */ - override def visitShowDatabases(ctx: ShowDatabasesContext): LogicalPlan = withOrigin(ctx) { - ShowDatabasesCommand(Option(ctx.pattern).map(string)) - } - - /** - * A command for users to list the properties for a table. If propertyKey is specified, the value - * for the propertyKey is returned. If propertyKey is not specified, all the keys and their - * corresponding values are returned. - * The syntax of using this command in SQL is: - * {{{ - * SHOW TBLPROPERTIES table_name[('propertyKey')]; - * }}} - */ - override def visitShowTblProperties( - ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { - ShowTablePropertiesCommand( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.key).map(visitTablePropertyKey)) - } - - /** - * Create a [[RefreshTable]] logical plan. - */ - override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { - RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) - } - - /** - * Create a [[CacheTableCommand]] logical plan. - */ - override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { - val query = Option(ctx.query).map(plan) - CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) - } - - /** - * Create an [[UncacheTableCommand]] logical plan. - */ - override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { - UncacheTableCommand(ctx.identifier.getText) - } - - /** - * Create a [[ClearCacheCommand]] logical plan. - */ - override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { - ClearCacheCommand - } - - /** - * Create an [[ExplainCommand]] logical plan. - */ - override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { - val options = ctx.explainOption.asScala - if (options.exists(_.FORMATTED != null)) { - logWarning("Unsupported operation: EXPLAIN FORMATTED option") - } - - // Create the explain comment. - val statement = plan(ctx.statement) - if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = options.exists(_.EXTENDED != null), - codegen = options.exists(_.CODEGEN != null)) - } else { - ExplainCommand(OneRowRelation) - } - } - - /** - * Determine if a plan should be explained at all. - */ - protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { - case _: datasources.DescribeCommand => false - case _ => true - } - - /** - * Create a [[DescribeCommand]] logical plan. - */ - override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { - // FORMATTED and columns are not supported. Return null and let the parser decide what to do - // with this (create an exception or pass it on to a different system). - if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { - null - } else { - datasources.DescribeCommand( - visitTableIdentifier(ctx.tableIdentifier), - ctx.EXTENDED != null) - } - } - - /** - * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). - */ - type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) - - /** - * Validate a create table statement and return the [[TableIdentifier]]. - */ - override def visitCreateTableHeader( - ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { - val temporary = ctx.TEMPORARY != null - val ifNotExists = ctx.EXISTS != null - assert(!temporary || !ifNotExists, - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", - ctx) - (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) - } - - /** - * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. - * - * TODO add bucketing and partitioning. - */ - override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { - val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - if (external) { - throw new ParseException("Unsupported operation: EXTERNAL option", ctx) - } - val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) - val provider = ctx.tableProvider.qualifiedName.getText - - if (ctx.query != null) { - // Get the backing query. - val query = plan(ctx.query) - - // Determine the storage mode. - val mode = if (ifNotExists) { - SaveMode.Ignore - } else if (temp) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) - } else { - val struct = Option(ctx.colTypeList).map(createStructType) - CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) - } - } - - /** - * Convert a table property list into a key-value map. - */ - override def visitTablePropertyList( - ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { - ctx.tableProperty.asScala.map { property => - val key = visitTablePropertyKey(property.key) - val value = Option(property.value).map(string).orNull - key -> value - }.toMap - } - - /** - * A table property key can either be String or a collection of dot separated elements. This - * function extracts the property key based on whether its a string literal or a table property - * identifier. - */ - override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { - if (key.STRING != null) { - string(key.STRING) - } else { - key.getText - } - } - - /** - * Create a [[CreateDatabase]] command. - * - * For example: - * {{{ - * CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] - * [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)] - * }}} - */ - override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) { - CreateDatabase( - ctx.identifier.getText, - ctx.EXISTS != null, - Option(ctx.locationSpec).map(visitLocationSpec), - Option(ctx.comment).map(string), - Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) - } - - /** - * Create an [[AlterDatabaseProperties]] command. - * - * For example: - * {{{ - * ALTER (DATABASE|SCHEMA) database SET DBPROPERTIES (property_name=property_value, ...); - * }}} - */ - override def visitSetDatabaseProperties( - ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterDatabaseProperties( - ctx.identifier.getText, - visitTablePropertyList(ctx.tablePropertyList)) - } - - /** - * Create a [[DropDatabase]] command. - * - * For example: - * {{{ - * DROP (DATABASE|SCHEMA) [IF EXISTS] database [RESTRICT|CASCADE]; - * }}} - */ - override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) { - DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) - } - - /** - * Create a [[DescribeDatabase]] command. - * - * For example: - * {{{ - * DESCRIBE DATABASE [EXTENDED] database; - * }}} - */ - override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) { - DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null) - } - - /** - * Create a [[CreateFunction]] command. - * - * For example: - * {{{ - * CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name - * [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; - * }}} - */ - override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { - val resources = ctx.resource.asScala.map { resource => - val resourceType = resource.identifier.getText.toLowerCase - resourceType match { - case "jar" | "file" | "archive" => - resourceType -> string(resource.STRING) - case other => - throw new ParseException(s"Resource Type '$resourceType' is not supported.", ctx) - } - } - - // Extract database, name & alias. - val (database, function) = visitFunctionName(ctx.qualifiedName) - CreateFunction( - database, - function, - string(ctx.className), - resources, - ctx.TEMPORARY != null) - } - - /** - * Create a [[DropFunction]] command. - * - * For example: - * {{{ - * DROP [TEMPORARY] FUNCTION [IF EXISTS] function; - * }}} - */ - override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) { - val (database, function) = visitFunctionName(ctx.qualifiedName) - DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null) - } - - /** - * Create a function database (optional) and name pair. - */ - private def visitFunctionName(ctx: QualifiedNameContext): (Option[String], String) = { - ctx.identifier().asScala.map(_.getText) match { - case Seq(db, fn) => (Option(db), fn) - case Seq(fn) => (None, fn) - case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) - } - } - - /** - * Create a [[DropTable]] command. - */ - override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.PURGE != null) { - throw new ParseException("Unsupported operation: PURGE option", ctx) - } - if (ctx.REPLICATION != null) { - throw new ParseException("Unsupported operation: REPLICATION clause", ctx) - } - DropTable( - visitTableIdentifier(ctx.tableIdentifier), - ctx.EXISTS != null, - ctx.VIEW != null) - } - - /** - * Create a [[AlterTableRename]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 RENAME TO table2; - * ALTER VIEW view1 RENAME TO view2; - * }}} - */ - override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableRename( - visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to), - ctx.VIEW != null) - } - - /** - * Create an [[AlterTableSetProperties]] command. - * - * For example: - * {{{ - * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); - * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); - * }}} - */ - override def visitSetTableProperties( - ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetProperties( - visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList), - ctx.VIEW != null) - } - - /** - * Create an [[AlterTableUnsetProperties]] command. - * - * For example: - * {{{ - * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - * }}} - */ - override def visitUnsetTableProperties( - ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterTableUnsetProperties( - visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, - ctx.EXISTS != null, - ctx.VIEW != null) - } - - /** - * Create an [[AlterTableSerDeProperties]] command. - * - * For example: - * {{{ - * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; - * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; - * }}} - */ - override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) { - AlterTableSerDeProperties( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.STRING).map(string), - Option(ctx.tablePropertyList).map(visitTablePropertyList), - // TODO a partition spec is allowed to have optional values. This is currently violated. - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) - } - - // TODO: don't even bother parsing alter table commands related to bucketing and skewing - - override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException( - "Operation not allowed: ALTER TABLE ... CLUSTERED BY ... INTO N BUCKETS") - } - - override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT CLUSTERED") - } - - override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SORTED") - } - - override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... SKEWED BY ...") - } - - override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SKEWED") - } - - override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException( - "Operation not allowed: ALTER TABLE ... NOT STORED AS DIRECTORIES") - } - - override def visitSetTableSkewLocations( - ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException( - "Operation not allowed: ALTER TABLE ... SET SKEWED LOCATION ...") - } - - /** - * Create an [[AlterTableAddPartition]] command. - * - * For example: - * {{{ - * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] - * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec - * }}} - * - * ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning - * is associated with physical tables - */ - override def visitAddTablePartition( - ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) { - throw new AnalysisException(s"Operation not allowed: partitioned views") - } - // Create partition spec to location mapping. - val specsAndLocs = if (ctx.partitionSpec.isEmpty) { - ctx.partitionSpecLocation.asScala.map { - splCtx => - val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) - val location = Option(splCtx.locationSpec).map(visitLocationSpec) - spec -> location - } - } else { - // Alter View: the location clauses are not allowed. - ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None) - } - AlterTableAddPartition( - visitTableIdentifier(ctx.tableIdentifier), - specsAndLocs, - ctx.EXISTS != null) - } - - /** - * Create an [[AlterTableExchangePartition]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 EXCHANGE PARTITION spec WITH TABLE table2; - * }}} - */ - override def visitExchangeTablePartition( - ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException( - "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...") - } - - /** - * Create an [[AlterTableRenamePartition]] command - * - * For example: - * {{{ - * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; - * }}} - */ - override def visitRenameTablePartition( - ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableRenamePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.from), - visitNonOptionalPartitionSpec(ctx.to)) - } - - /** - * Create an [[AlterTableDropPartition]] command - * - * For example: - * {{{ - * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; - * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; - * }}} - * - * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning - * is associated with physical tables - */ - override def visitDropTablePartitions( - ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) { - throw new AnalysisException(s"Operation not allowed: partitioned views") - } - if (ctx.PURGE != null) { - throw new AnalysisException(s"Operation not allowed: PURGE") - } - AlterTableDropPartition( - visitTableIdentifier(ctx.tableIdentifier), - ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), - ctx.EXISTS != null) - } - - /** - * Create an [[AlterTableArchivePartition]] command - * - * For example: - * {{{ - * ALTER TABLE table ARCHIVE PARTITION spec; - * }}} - */ - override def visitArchiveTablePartition( - ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException( - "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...") - } - - /** - * Create an [[AlterTableUnarchivePartition]] command - * - * For example: - * {{{ - * ALTER TABLE table UNARCHIVE PARTITION spec; - * }}} - */ - override def visitUnarchiveTablePartition( - ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException( - "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...") - } - - /** - * Create an [[AlterTableSetFileFormat]] command - * - * For example: - * {{{ - * ALTER TABLE table [PARTITION spec] SET FILEFORMAT file_format; - * }}} - */ - override def visitSetTableFileFormat( - ctx: SetTableFileFormatContext): LogicalPlan = withOrigin(ctx) { - // AlterTableSetFileFormat currently takes both a GenericFileFormat and a - // TableFileFormatContext. This is a bit weird because it should only take one. It also should - // use a CatalogFileFormat instead of either a String or a Sequence of Strings. We will address - // this in a follow-up PR. - val (fileFormat, genericFormat) = ctx.fileFormat match { - case s: GenericFileFormatContext => - (Seq.empty[String], Option(s.identifier.getText)) - case s: TableFileFormatContext => - val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq - (elements.map(string), None) - } - AlterTableSetFileFormat( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - fileFormat, - genericFormat)( - command(ctx)) - } - - /** - * Create an [[AlterTableSetLocation]] command - * - * For example: - * {{{ - * ALTER TABLE table [PARTITION spec] SET LOCATION "loc"; - * }}} - */ - override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetLocation( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - visitLocationSpec(ctx.locationSpec)) - } - - /** - * Create an [[AlterTableTouch]] command - * - * For example: - * {{{ - * ALTER TABLE table TOUCH [PARTITION spec]; - * }}} - */ - override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...") - } - - /** - * Create an [[AlterTableCompact]] command - * - * For example: - * {{{ - * ALTER TABLE table [PARTITION spec] COMPACT 'compaction_type'; - * }}} - */ - override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...") - } - - /** - * Create an [[AlterTableMerge]] command - * - * For example: - * {{{ - * ALTER TABLE table [PARTITION spec] CONCATENATE; - * }}} - */ - override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) { - throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE") - } - - /** - * Create an [[AlterTableChangeCol]] command - * - * For example: - * {{{ - * ALTER TABLE tableIdentifier [PARTITION spec] - * CHANGE [COLUMN] col_old_name col_new_name column_type [COMMENT col_comment] - * [FIRST|AFTER column_name] [CASCADE|RESTRICT]; - * }}} - */ - override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) { - val col = visitColType(ctx.colType()) - val comment = if (col.metadata.contains("comment")) { - Option(col.metadata.getString("comment")) - } else { - None - } - - AlterTableChangeCol( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - ctx.oldName.getText, - // We could also pass in a struct field - seems easier. - col.name, - col.dataType, - comment, - Option(ctx.after).map(_.getText), - // Note that Restrict and Cascade are mutually exclusive. - ctx.RESTRICT != null, - ctx.CASCADE != null)( - command(ctx)) - } - - /** - * Create an [[AlterTableAddCol]] command - * - * For example: - * {{{ - * ALTER TABLE tableIdentifier [PARTITION spec] - * ADD COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] - * }}} - */ - override def visitAddColumns(ctx: AddColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableAddCol( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - createStructType(ctx.colTypeList), - // Note that Restrict and Cascade are mutually exclusive. - ctx.RESTRICT != null, - ctx.CASCADE != null)( - command(ctx)) - } - - /** - * Create an [[AlterTableReplaceCol]] command - * - * For example: - * {{{ - * ALTER TABLE tableIdentifier [PARTITION spec] - * REPLACE COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] - * }}} - */ - override def visitReplaceColumns(ctx: ReplaceColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableReplaceCol( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - createStructType(ctx.colTypeList), - // Note that Restrict and Cascade are mutually exclusive. - ctx.RESTRICT != null, - ctx.CASCADE != null)( - command(ctx)) - } - - /** - * Create location string. - */ - override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { - string(ctx.STRING) - } - - /** - * Create a [[BucketSpec]]. - */ - override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { - BucketSpec( - ctx.INTEGER_VALUE.getText.toInt, - visitIdentifierList(ctx.identifierList), - Option(ctx.orderedIdentifierList).toSeq - .flatMap(_.orderedIdentifier.asScala) - .map(_.identifier.getText)) - } - - /** - * Convert a nested constants list into a sequence of string sequences. - */ - override def visitNestedConstantList( - ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { - ctx.constantList.asScala.map(visitConstantList) - } - - /** - * Convert a constants list into a String sequence. - */ - override def visitConstantList(ctx: ConstantListContext): Seq[String] = withOrigin(ctx) { - ctx.constant.asScala.map(visitStringConstant) - } -} From 7c36ef0f54d8506f3c9593fe27824840006c3646 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Mon, 18 Apr 2016 20:14:58 +0800 Subject: [PATCH 08/18] Merge branch 'apache-master' into SPARK-4452-2 --- .../spark/sql/execution/SparkSqlParser.scala | 792 ++++++++++++++++++ 1 file changed, 792 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala new file mode 100644 index 0000000000000..8ed6ed21d0170 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -0,0 +1,792 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources._ + +/** + * Concrete parser for Spark SQL statements. + */ +object SparkSqlParser extends AbstractSqlParser{ + val astBuilder = new SparkSqlAstBuilder +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class SparkSqlAstBuilder extends AstBuilder { + import org.apache.spark.sql.catalyst.parser.ParserUtils._ + + /** + * Create a [[SetCommand]] logical plan. + * + * Note that we assume that everything after the SET keyword is assumed to be a part of the + * key-value pair. The split between key and value is made by searching for the first `=` + * character in the raw string. + */ + override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { + // Construct the command. + val raw = remainder(ctx.SET.getSymbol) + val keyValueSeparatorIndex = raw.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = raw.substring(0, keyValueSeparatorIndex).trim + val value = raw.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (raw.nonEmpty) { + SetCommand(Some(raw.trim -> None)) + } else { + SetCommand(None) + } + } + + /** + * Create a [[SetDatabaseCommand]] logical plan. + */ + override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { + SetDatabaseCommand(ctx.db.getText) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + * Example SQL : + * {{{ + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * }}} + */ + override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + ShowTablesCommand( + Option(ctx.db).map(_.getText), + Option(ctx.pattern).map(string)) + } + + /** + * Create a [[ShowDatabasesCommand]] logical plan. + * Example SQL: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ + override def visitShowDatabases(ctx: ShowDatabasesContext): LogicalPlan = withOrigin(ctx) { + ShowDatabasesCommand(Option(ctx.pattern).map(string)) + } + + /** + * A command for users to list the properties for a table. If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ + override def visitShowTblProperties( + ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { + ShowTablePropertiesCommand( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.key).map(visitTablePropertyKey)) + } + + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a [[CacheTableCommand]] logical plan. + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + val query = Option(ctx.query).map(plan) + CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + } + + /** + * Create an [[UncacheTableCommand]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableCommand(ctx.identifier.getText) + } + + /** + * Create a [[ClearCacheCommand]] logical plan. + */ + override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { + ClearCacheCommand + } + + /** + * Create an [[ExplainCommand]] logical plan. + */ + override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { + val options = ctx.explainOption.asScala + if (options.exists(_.FORMATTED != null)) { + logWarning("Unsupported operation: EXPLAIN FORMATTED option") + } + + // Create the explain comment. + val statement = plan(ctx.statement) + if (isExplainableStatement(statement)) { + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null), + codegen = options.exists(_.CODEGEN != null)) + } else { + ExplainCommand(OneRowRelation) + } + } + + /** + * Determine if a plan should be explained at all. + */ + protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { + case _: datasources.DescribeCommand => false + case _ => true + } + + /** + * Create a [[DescribeCommand]] logical plan. + */ + override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { + // FORMATTED and columns are not supported. Return null and let the parser decide what to do + // with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + null + } else { + datasources.DescribeCommand( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXTENDED != null) + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + assert(!temporary || !ifNotExists, + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", + ctx) + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. + * + * TODO add bucketing and partitioning. + */ + override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + throw new ParseException("Unsupported operation: EXTERNAL option", ctx) + } + val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + + if (ctx.query != null) { + // Get the backing query. + val query = plan(ctx.query) + + // Determine the storage mode. + val mode = if (ifNotExists) { + SaveMode.Ignore + } else if (temp) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) + } else { + val struct = Option(ctx.colTypeList).map(createStructType) + CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + } + } + + /** + * Convert a table property list into a key-value map. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = Option(property.value).map(string).orNull + key -> value + }.toMap + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * Create a [[CreateDatabase]] command. + * + * For example: + * {{{ + * CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] + * [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)] + * }}} + */ + override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) { + CreateDatabase( + ctx.identifier.getText, + ctx.EXISTS != null, + Option(ctx.locationSpec).map(visitLocationSpec), + Option(ctx.comment).map(string), + Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) + } + + /** + * Create an [[AlterDatabaseProperties]] command. + * + * For example: + * {{{ + * ALTER (DATABASE|SCHEMA) database SET DBPROPERTIES (property_name=property_value, ...); + * }}} + */ + override def visitSetDatabaseProperties( + ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterDatabaseProperties( + ctx.identifier.getText, + visitTablePropertyList(ctx.tablePropertyList)) + } + + /** + * Create a [[DropDatabase]] command. + * + * For example: + * {{{ + * DROP (DATABASE|SCHEMA) [IF EXISTS] database [RESTRICT|CASCADE]; + * }}} + */ + override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) { + DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) + } + + /** + * Create a [[DescribeDatabase]] command. + * + * For example: + * {{{ + * DESCRIBE DATABASE [EXTENDED] database; + * }}} + */ + override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) { + DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null) + } + + /** + * Create a [[CreateFunction]] command. + * + * For example: + * {{{ + * CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name + * [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; + * }}} + */ + override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { + val resources = ctx.resource.asScala.map { resource => + val resourceType = resource.identifier.getText.toLowerCase + resourceType match { + case "jar" | "file" | "archive" => + resourceType -> string(resource.STRING) + case other => + throw new ParseException(s"Resource Type '$resourceType' is not supported.", ctx) + } + } + + // Extract database, name & alias. + val (database, function) = visitFunctionName(ctx.qualifiedName) + CreateFunction( + database, + function, + string(ctx.className), + resources, + ctx.TEMPORARY != null) + } + + /** + * Create a [[DropFunction]] command. + * + * For example: + * {{{ + * DROP [TEMPORARY] FUNCTION [IF EXISTS] function; + * }}} + */ + override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) { + val (database, function) = visitFunctionName(ctx.qualifiedName) + DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null) + } + + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: QualifiedNameContext): (Option[String], String) = { + ctx.identifier().asScala.map(_.getText) match { + case Seq(db, fn) => (Option(db), fn) + case Seq(fn) => (None, fn) + case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) + } + } + + /** + * Create a [[DropTable]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.PURGE != null) { + throw new ParseException("Unsupported operation: PURGE option", ctx) + } + if (ctx.REPLICATION != null) { + throw new ParseException("Unsupported operation: REPLICATION clause", ctx) + } + DropTable( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXISTS != null, + ctx.VIEW != null) + } + + /** + * Create a [[AlterTableRename]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ + override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { + AlterTableRename( + visitTableIdentifier(ctx.from), + visitTableIdentifier(ctx.to), + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableSetProperties]] command. + * + * For example: + * {{{ + * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); + * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); + * }}} + */ + override def visitSetTableProperties( + ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterTableSetProperties( + visitTableIdentifier(ctx.tableIdentifier), + visitTablePropertyList(ctx.tablePropertyList), + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableUnsetProperties]] command. + * + * For example: + * {{{ + * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * }}} + */ + override def visitUnsetTableProperties( + ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterTableUnsetProperties( + visitTableIdentifier(ctx.tableIdentifier), + visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, + ctx.EXISTS != null, + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableSerDeProperties]] command. + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; + * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; + * }}} + */ + override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) { + AlterTableSerDeProperties( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.STRING).map(string), + Option(ctx.tablePropertyList).map(visitTablePropertyList), + // TODO a partition spec is allowed to have optional values. This is currently violated. + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) + } + + // TODO: don't even bother parsing alter table commands related to bucketing and skewing + + override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... CLUSTERED BY ... INTO N BUCKETS") + } + + override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT CLUSTERED") + } + + override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SORTED") + } + + override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... SKEWED BY ...") + } + + override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SKEWED") + } + + override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... NOT STORED AS DIRECTORIES") + } + + override def visitSetTableSkewLocations( + ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... SET SKEWED LOCATION ...") + } + + /** + * Create an [[AlterTableAddPartition]] command. + * + * For example: + * {{{ + * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec + * }}} + * + * ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning + * is associated with physical tables + */ + override def visitAddTablePartition( + ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + // Create partition spec to location mapping. + val specsAndLocs = if (ctx.partitionSpec.isEmpty) { + ctx.partitionSpecLocation.asScala.map { + splCtx => + val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) + val location = Option(splCtx.locationSpec).map(visitLocationSpec) + spec -> location + } + } else { + // Alter View: the location clauses are not allowed. + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None) + } + AlterTableAddPartition( + visitTableIdentifier(ctx.tableIdentifier), + specsAndLocs, + ctx.EXISTS != null) + } + + /** + * Create an [[AlterTableExchangePartition]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 EXCHANGE PARTITION spec WITH TABLE table2; + * }}} + */ + override def visitExchangeTablePartition( + ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...") + } + + /** + * Create an [[AlterTableRenamePartition]] command + * + * For example: + * {{{ + * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; + * }}} + */ + override def visitRenameTablePartition( + ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { + AlterTableRenamePartition( + visitTableIdentifier(ctx.tableIdentifier), + visitNonOptionalPartitionSpec(ctx.from), + visitNonOptionalPartitionSpec(ctx.to)) + } + + /** + * Create an [[AlterTableDropPartition]] command + * + * For example: + * {{{ + * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; + * }}} + * + * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning + * is associated with physical tables + */ + override def visitDropTablePartitions( + ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + if (ctx.PURGE != null) { + throw new AnalysisException(s"Operation not allowed: PURGE") + } + AlterTableDropPartition( + visitTableIdentifier(ctx.tableIdentifier), + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), + ctx.EXISTS != null) + } + + /** + * Create an [[AlterTableArchivePartition]] command + * + * For example: + * {{{ + * ALTER TABLE table ARCHIVE PARTITION spec; + * }}} + */ + override def visitArchiveTablePartition( + ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...") + } + + /** + * Create an [[AlterTableUnarchivePartition]] command + * + * For example: + * {{{ + * ALTER TABLE table UNARCHIVE PARTITION spec; + * }}} + */ + override def visitUnarchiveTablePartition( + ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...") + } + + /** + * Create an [[AlterTableSetFileFormat]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] SET FILEFORMAT file_format; + * }}} + */ + override def visitSetTableFileFormat( + ctx: SetTableFileFormatContext): LogicalPlan = withOrigin(ctx) { + // AlterTableSetFileFormat currently takes both a GenericFileFormat and a + // TableFileFormatContext. This is a bit weird because it should only take one. It also should + // use a CatalogFileFormat instead of either a String or a Sequence of Strings. We will address + // this in a follow-up PR. + val (fileFormat, genericFormat) = ctx.fileFormat match { + case s: GenericFileFormatContext => + (Seq.empty[String], Option(s.identifier.getText)) + case s: TableFileFormatContext => + val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq + (elements.map(string), None) + } + AlterTableSetFileFormat( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + fileFormat, + genericFormat)( + command(ctx)) + } + + /** + * Create an [[AlterTableSetLocation]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] SET LOCATION "loc"; + * }}} + */ + override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { + AlterTableSetLocation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + visitLocationSpec(ctx.locationSpec)) + } + + /** + * Create an [[AlterTableTouch]] command + * + * For example: + * {{{ + * ALTER TABLE table TOUCH [PARTITION spec]; + * }}} + */ + override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...") + } + + /** + * Create an [[AlterTableCompact]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] COMPACT 'compaction_type'; + * }}} + */ + override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...") + } + + /** + * Create an [[AlterTableMerge]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] CONCATENATE; + * }}} + */ + override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE") + } + + /** + * Create an [[AlterTableChangeCol]] command + * + * For example: + * {{{ + * ALTER TABLE tableIdentifier [PARTITION spec] + * CHANGE [COLUMN] col_old_name col_new_name column_type [COMMENT col_comment] + * [FIRST|AFTER column_name] [CASCADE|RESTRICT]; + * }}} + */ + override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) { + val col = visitColType(ctx.colType()) + val comment = if (col.metadata.contains("comment")) { + Option(col.metadata.getString("comment")) + } else { + None + } + + AlterTableChangeCol( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + ctx.oldName.getText, + // We could also pass in a struct field - seems easier. + col.name, + col.dataType, + comment, + Option(ctx.after).map(_.getText), + // Note that Restrict and Cascade are mutually exclusive. + ctx.RESTRICT != null, + ctx.CASCADE != null)( + command(ctx)) + } + + /** + * Create an [[AlterTableAddCol]] command + * + * For example: + * {{{ + * ALTER TABLE tableIdentifier [PARTITION spec] + * ADD COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] + * }}} + */ + override def visitAddColumns(ctx: AddColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddCol( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + createStructType(ctx.colTypeList), + // Note that Restrict and Cascade are mutually exclusive. + ctx.RESTRICT != null, + ctx.CASCADE != null)( + command(ctx)) + } + + /** + * Create an [[AlterTableReplaceCol]] command + * + * For example: + * {{{ + * ALTER TABLE tableIdentifier [PARTITION spec] + * REPLACE COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] + * }}} + */ + override def visitReplaceColumns(ctx: ReplaceColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableReplaceCol( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + createStructType(ctx.colTypeList), + // Note that Restrict and Cascade are mutually exclusive. + ctx.RESTRICT != null, + ctx.CASCADE != null)( + command(ctx)) + } + + /** + * Create location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList).toSeq + .flatMap(_.orderedIdentifier.asScala) + .map(_.identifier.getText)) + } + + /** + * Convert a nested constants list into a sequence of string sequences. + */ + override def visitNestedConstantList( + ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { + ctx.constantList.asScala.map(visitConstantList) + } + + /** + * Convert a constants list into a String sequence. + */ + override def visitConstantList(ctx: ConstantListContext): Seq[String] = withOrigin(ctx) { + ctx.constant.asScala.map(visitStringConstant) + } +} From 70bcffa63da42c84e9a63f6be39ed7330662039e Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 19 Apr 2016 22:00:51 +0800 Subject: [PATCH 09/18] fix thread safety & add ut --- .../apache/spark/memory/MemoryConsumer.java | 4 +- .../collection/ExternalAppendOnlyMap.scala | 59 +++++++--- .../util/collection/ExternalSorter.scala | 103 +++++++++++------- .../spark/util/collection/Spillable.scala | 4 +- .../ExternalAppendOnlyMapSuite.scala | 14 +++ .../util/collection/ExternalSorterSuite.scala | 16 +++ 6 files changed, 141 insertions(+), 59 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index bd5ffe2e00f9f..1fd602b38e542 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -134,7 +134,7 @@ protected void freePage(MemoryBlock page) { /** * Allocates a heap memory of `size`. */ - public long allocateHeapExecutionMemory(long size) { + public long acquireOnHeapMemory(long size) { long granted = taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this); used += granted; @@ -144,7 +144,7 @@ public long allocateHeapExecutionMemory(long size) { /** * Release N bytes of heap memory. */ - public void freeHeapExecutionMemory(long size) { + public void freeOnHeapMemory(long size) { taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this); used -= size; } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 2b19031e5861d..92bef0f38ef75 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -115,7 +115,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() - private var inMemoryOrDiskIterator: Iterator[(K, C)] = null + private var readingIterator: SpillableIterator = null /** * Number of files this map has spilled so far. @@ -192,14 +192,12 @@ class ExternalAppendOnlyMap[K, V, C]( * It will be called by TaskMemoryManager when there is not enough memory for the task. */ override protected[this] def forceSpill(): Boolean = { - assert(inMemoryOrDiskIterator != null) - val inMemoryIterator = inMemoryOrDiskIterator - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator) - inMemoryOrDiskIterator = diskMapIterator - currentMap = null - true + assert(readingIterator != null) + val isSpilled = readingIterator.spill() + if (isSpilled) { + currentMap = null + } + isSpilled } /** @@ -270,14 +268,10 @@ class ExternalAppendOnlyMap[K, V, C]( * it returns pairs from an on-disk map. */ def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { - inMemoryOrDiskIterator = inMemoryIterator - new Iterator[(K, C)] { - - override def hasNext = inMemoryOrDiskIterator.hasNext - - override def next() = inMemoryOrDiskIterator.next() - } + readingIterator = new SpillableIterator(inMemoryIterator) + readingIterator } + /** * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. @@ -573,6 +567,39 @@ class ExternalAppendOnlyMap[K, V, C]( context.addTaskCompletionListener(context => cleanup()) } + private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) + extends Iterator[(K, C)] { + + private var nextUpstream: Iterator[(K, C)] = null + + private var cur: (K, C) = null + + def spill(): Boolean = synchronized { + if (upstream == null || nextUpstream != null) { + false + } else { + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + nextUpstream = spillMemoryIteratorToDisk(upstream) + true + } + } + + override def hasNext: Boolean = synchronized { + if (nextUpstream != null) { + upstream = nextUpstream + nextUpstream = null + } + val r = upstream.hasNext + if (r) { + cur = upstream.next() + } + r + } + + override def next(): (K, C) = cur + } + /** Convenience function to hash the given (K, C) pair by the key. */ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b01b1dc638469..14140dc73f3a8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -136,8 +136,8 @@ private[spark] class ExternalSorter[K, V, C]( def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes private var isShuffleSort: Boolean = true - var forceSpillFile: Option[SpilledFile] = None - private var inMemoryOrDiskIterator: Iterator[((Int, K), C)] = null + private val forceSpillFiles = new ArrayBuffer[SpilledFile] + private var readingIterator: SpillableIterator = null // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -163,7 +163,7 @@ private[spark] class ExternalSorter[K, V, C]( // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. - private[collection] case class SpilledFile( + private[this] case class SpilledFile( file: File, blockId: BlockId, serializerBatchSizes: Array[Long], @@ -250,31 +250,13 @@ private[spark] class ExternalSorter[K, V, C]( if (isShuffleSort) { false } else { - assert(inMemoryOrDiskIterator != null) - val it = inMemoryOrDiskIterator - val inMemoryIterator = new WritablePartitionedIterator { - private[this] var cur = if (it.hasNext) it.next() else null - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - writer.write(cur._1._2, cur._2) - cur = if (it.hasNext) it.next() else null - } - - def hasNext(): Boolean = cur != null - - def nextPartition(): Int = cur._1._1 - } - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - forceSpillFile = Some(spillMemoryIteratorToDisk(inMemoryIterator)) - val spillReader = new SpillReader(forceSpillFile.get) - inMemoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p => - val iterator = spillReader.readNextPartition() - iterator.map(cur => ((p, cur._1), cur._2)) + assert(readingIterator != null) + val isSpilled = readingIterator.spill() + if (isSpilled) { + map = null + buffer = null } - map = null - buffer = null - true + isSpilled } } @@ -655,13 +637,8 @@ private[spark] class ExternalSorter[K, V, C]( if (isShuffleSort) { memoryIterator } else { - inMemoryOrDiskIterator = memoryIterator - new Iterator[((Int, K), C)] { - - override def hasNext = inMemoryOrDiskIterator.hasNext - - override def next() = inMemoryOrDiskIterator.next() - } + readingIterator = new SpillableIterator(memoryIterator) + readingIterator } } @@ -762,7 +739,8 @@ private[spark] class ExternalSorter[K, V, C]( def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() - forceSpillFile.foreach(_.file.delete()) + forceSpillFiles.foreach(s => s.file.delete()) + forceSpillFiles.clear() if (map != null || buffer != null) { map = null // So that the memory can be garbage-collected buffer = null // So that the memory can be garbage-collected @@ -770,10 +748,6 @@ private[spark] class ExternalSorter[K, V, C]( } } - override def toString(): String = { - this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) - } - /** * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, * group together the pairs for each partition into a sub-iterator. @@ -805,4 +779,55 @@ private[spark] class ExternalSorter[K, V, C]( (elem._1._2, elem._2) } } + + private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)]) + extends Iterator[((Int, K), C)] { + + private var nextUpstream: Iterator[((Int, K), C)] = null + + private var cur: ((Int, K), C) = null + + def spill(): Boolean = synchronized { + if (upstream == null || nextUpstream != null) { + false + } else { + val inMemoryIterator = new WritablePartitionedIterator { + private[this] var cur = if (upstream.hasNext) upstream.next() else null + + def writeNext(writer: DiskBlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (upstream.hasNext) upstream.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) + forceSpillFiles.append(spillFile) + val spillReader = new SpillReader(spillFile) + nextUpstream = (0 until numPartitions).iterator.flatMap { p => + val iterator = spillReader.readNextPartition() + iterator.map(cur => ((p, cur._1), cur._2)) + } + true + } + } + + override def hasNext: Boolean = synchronized { + if (nextUpstream != null) { + upstream = nextUpstream + nextUpstream = null + } + val r = upstream.hasNext + if (r) { + cur = upstream.next() + } + r + } + + override def next(): ((Int, K), C) = cur + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index be2077d753526..a532e5a27a1cf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -77,7 +77,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = allocateHeapExecutionMemory(amountToRequest) + val granted = acquireOnHeapMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -126,7 +126,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * Release our memory back to the execution pool so that other tasks can grab it. */ def releaseMemory(): Unit = { - freeHeapExecutionMemory(myMemoryThreshold) + freeOnHeapMemory(myMemoryThreshold) myMemoryThreshold = 0L } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index dc3185a6d505a..d122b79934186 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -418,4 +418,18 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } } + test("force to spill for external aggregation") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.01") + .set("spark.memory.useLegacyMode", "true") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + sc = new SparkContext("local", "test", conf) + val N = 2e6.toInt + sc.parallelize(1 to N, 10) + .map { i => (i, i) } + .groupByKey() + .reduceByKey(_ ++ _) + .count() + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index a1a7ac97d924b..c8f26dcb931dd 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -608,4 +608,20 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } } + + test("force to spill for sorting") { + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.memoryFraction", "0.01") + .set("spark.memory.useLegacyMode", "true") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + sc = new SparkContext("local", "test", conf) + val N = 2e6.toInt + val p = new org.apache.spark.HashPartitioner(10) + val p2 = new org.apache.spark.HashPartitioner(5) + sc.parallelize(1 to N, 10) + .map { x => (x % 10000) -> x.toLong } + .repartitionAndSortWithinPartitions(p2) + .repartitionAndSortWithinPartitions(p) + .count() + } } From b84ad967f6d3fc0b62aafb3510acad2d5df695b9 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Wed, 20 Apr 2016 02:16:55 +0800 Subject: [PATCH 10/18] fix Mima & minor bug --- .../collection/ExternalAppendOnlyMap.scala | 25 +++++++++++++------ .../util/collection/ExternalSorter.scala | 25 +++++++++++++------ .../ExternalAppendOnlyMapSuite.scala | 3 ++- .../util/collection/ExternalSorterSuite.scala | 5 ++-- project/MimaExcludes.scala | 5 ++++ 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 92bef0f38ef75..a78bf4241f6f2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -572,32 +572,41 @@ class ExternalAppendOnlyMap[K, V, C]( private var nextUpstream: Iterator[(K, C)] = null - private var cur: (K, C) = null + private var cur: (K, C) = readNext() + + private var hasSpilled: Boolean = false def spill(): Boolean = synchronized { - if (upstream == null || nextUpstream != null) { + if (hasSpilled) { false } else { logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") nextUpstream = spillMemoryIteratorToDisk(upstream) + hasSpilled = true true } } - override def hasNext: Boolean = synchronized { + def readNext(): (K, C) = synchronized { if (nextUpstream != null) { upstream = nextUpstream nextUpstream = null } - val r = upstream.hasNext - if (r) { - cur = upstream.next() + if (upstream.hasNext) { + upstream.next() + } else { + null } - r } - override def next(): (K, C) = cur + override def hasNext(): Boolean = cur != null + + override def next(): (K, C) = { + val r = cur + cur = readNext() + r + } } /** Convenience function to hash the given (K, C) pair by the key. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 14140dc73f3a8..8bc0861ef9fee 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -785,10 +785,12 @@ private[spark] class ExternalSorter[K, V, C]( private var nextUpstream: Iterator[((Int, K), C)] = null - private var cur: ((Int, K), C) = null + private var cur: ((Int, K), C) = readNext() + + private var hasSpilled: Boolean = false def spill(): Boolean = synchronized { - if (upstream == null || nextUpstream != null) { + if (hasSpilled) { false } else { val inMemoryIterator = new WritablePartitionedIterator { @@ -812,22 +814,29 @@ private[spark] class ExternalSorter[K, V, C]( val iterator = spillReader.readNextPartition() iterator.map(cur => ((p, cur._1), cur._2)) } + hasSpilled = true true } } - override def hasNext: Boolean = synchronized { + def readNext(): ((Int, K), C) = synchronized { if (nextUpstream != null) { upstream = nextUpstream nextUpstream = null } - val r = upstream.hasNext - if (r) { - cur = upstream.next() + if (upstream.hasNext) { + upstream.next() + } else { + null } - r } - override def next(): ((Int, K), C) = cur + override def hasNext(): Boolean = cur != null + + override def next(): ((Int, K), C) = { + val r = cur + cur = readNext() + r + } } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index d122b79934186..5dbe9c615b291 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -422,9 +422,10 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.memoryFraction", "0.01") .set("spark.memory.useLegacyMode", "true") + .set("spark.testing.memory", "100000000") .set("spark.shuffle.sort.bypassMergeThreshold", "0") sc = new SparkContext("local", "test", conf) - val N = 2e6.toInt + val N = 2e5.toInt sc.parallelize(1 to N, 10) .map { i => (i, i) } .groupByKey() diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index c8f26dcb931dd..2f35153fbd192 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -609,13 +609,14 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } - test("force to spill for sorting") { + test("force to spill for external sorter") { val conf = createSparkConf(loadDefaults = false, kryo = false) .set("spark.shuffle.memoryFraction", "0.01") .set("spark.memory.useLegacyMode", "true") + .set("spark.testing.memory", "100000000") .set("spark.shuffle.sort.bypassMergeThreshold", "0") sc = new SparkContext("local", "test", conf) - val N = 2e6.toInt + val N = 2e5.toInt val p = new org.apache.spark.HashPartitioner(10) val p2 = new org.apache.spark.HashPartitioner(5) sc.parallelize(1 to N, 10) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7730823f9411b..98d6219552e5c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -634,6 +634,11 @@ object MimaExcludes { // [SPARK-14628] Simplify task metrics by always tracking read/write metrics ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.readMethod"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod") + ) ++ Seq( + // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.ExternalAppendOnlyMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.ExternalSorter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.Spillable") ) case v if v.startsWith("1.6") => Seq( From dc632f5642b6bee690b351efa1855402ef7bc716 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Wed, 20 Apr 2016 22:25:42 +0800 Subject: [PATCH 11/18] fix Mima --- project/MimaExcludes.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 98d6219552e5c..2e26443c9a15a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -636,9 +636,7 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod") ) ++ Seq( // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.ExternalAppendOnlyMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.ExternalSorter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.Spillable") + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable") ) case v if v.startsWith("1.6") => Seq( From d1ed4e4c0f5d15a1e0f030c397c3f8c83482315a Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 03:51:22 +0800 Subject: [PATCH 12/18] fix ut --- .../org/apache/spark/memory/MemoryConsumer.java | 3 +-- .../apache/spark/util/collection/Spillable.scala | 13 +++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 1fd602b38e542..840f13b39464c 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -135,8 +135,7 @@ protected void freePage(MemoryBlock page) { * Allocates a heap memory of `size`. */ public long acquireOnHeapMemory(long size) { - long granted = - taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this); + long granted = taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this); used += granted; return granted; } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index a532e5a27a1cf..4e86475d51772 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -47,13 +47,18 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // It's used for checking spilling frequency protected def addElementsRead(): Unit = { _elementsRead += 1 } + // Initial threshold for the size of a collection before we start tracking its memory usage + // For testing only + private[this] val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + // Force this collection to spill when there are this many elements in memory // For testing only private[this] val numElementsForceSpillThreshold: Long = SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) // Threshold for this collection's size in bytes before we start tracking its memory usage - private[this] var myMemoryThreshold = 0L + private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill private[this] var _elementsRead = 0L @@ -107,7 +112,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) 0L } else { _elementsRead = 0 - val freeMemory = myMemoryThreshold + val freeMemory = myMemoryThreshold - initialMemoryThreshold _memoryBytesSpilled += freeMemory releaseMemory() freeMemory @@ -126,8 +131,8 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * Release our memory back to the execution pool so that other tasks can grab it. */ def releaseMemory(): Unit = { - freeOnHeapMemory(myMemoryThreshold) - myMemoryThreshold = 0L + freeOnHeapMemory(myMemoryThreshold - initialMemoryThreshold) + myMemoryThreshold = initialMemoryThreshold } /** From 743ef16b0274f7d2ce7f435aed96d64316c8c77e Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 03:54:36 +0800 Subject: [PATCH 13/18] update comments --- .../main/scala/org/apache/spark/util/collection/Spillable.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 4e86475d51772..e71689a4eca9a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -58,6 +58,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) // Threshold for this collection's size in bytes before we start tracking its memory usage + // To avoid memory leak for rdd.first(), initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill From 97fd17483fe2efebd64a1a57dfe40aa16a46f625 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 10:11:54 +0800 Subject: [PATCH 14/18] fix thread safe --- .../spark/util/collection/ExternalAppendOnlyMap.scala | 4 ++-- .../org/apache/spark/util/collection/ExternalSorter.scala | 8 ++++---- .../org/apache/spark/util/collection/Spillable.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index a78bf4241f6f2..601a765b014fa 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -81,7 +81,7 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + @volatile private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager @@ -115,7 +115,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() - private var readingIterator: SpillableIterator = null + @volatile private var readingIterator: SpillableIterator = null /** * Number of files this map has spilled so far. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 8bc0861ef9fee..501ff58155c4a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -124,8 +124,8 @@ private[spark] class ExternalSorter[K, V, C]( // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. - private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = new PartitionedPairBuffer[K, C] + @volatile private var map = new PartitionedAppendOnlyMap[K, C] + @volatile private var buffer = new PartitionedPairBuffer[K, C] // Total spilling statistics private var _diskBytesSpilled = 0L @@ -135,9 +135,9 @@ private[spark] class ExternalSorter[K, V, C]( private var _peakMemoryUsedBytes: Long = 0L def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes - private var isShuffleSort: Boolean = true + @volatile private var isShuffleSort: Boolean = true private val forceSpillFiles = new ArrayBuffer[SpilledFile] - private var readingIterator: SpillableIterator = null + @volatile private var readingIterator: SpillableIterator = null // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index e71689a4eca9a..27b40e016da09 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -58,7 +58,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) // Threshold for this collection's size in bytes before we start tracking its memory usage - // To avoid memory leak for rdd.first(), initialize this to a value orders of magnitude > 0 + // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill From e009d95c715879269253da2b47e669ffc2e13683 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 10:22:45 +0800 Subject: [PATCH 15/18] fix thread safe --- .../scala/org/apache/spark/util/collection/Spillable.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 27b40e016da09..aee6399eb0c8c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -41,7 +41,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - protected def elementsRead: Long = _elementsRead + @volatile protected def elementsRead: Long = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -59,13 +59,13 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 - private[this] var myMemoryThreshold = initialMemoryThreshold + @volatile private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill private[this] var _elementsRead = 0L // Number of bytes spilled in total - private[this] var _memoryBytesSpilled = 0L + @volatile private[this] var _memoryBytesSpilled = 0L // Number of spills private[this] var _spillCount = 0 From 7ea727470735cb2a420bd5411af0202d264d9ec7 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 10:52:51 +0800 Subject: [PATCH 16/18] fix SpillableIterator --- .../spark/util/collection/ExternalAppendOnlyMap.scala | 6 ++++-- .../org/apache/spark/util/collection/ExternalSorter.scala | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 601a765b014fa..fc71f8365cd18 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -570,13 +570,15 @@ class ExternalAppendOnlyMap[K, V, C]( private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) extends Iterator[(K, C)] { + private val SPILL_LOCK = new Object() + private var nextUpstream: Iterator[(K, C)] = null private var cur: (K, C) = readNext() private var hasSpilled: Boolean = false - def spill(): Boolean = synchronized { + def spill(): Boolean = SPILL_LOCK.synchronized { if (hasSpilled) { false } else { @@ -588,7 +590,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } - def readNext(): (K, C) = synchronized { + def readNext(): (K, C) = SPILL_LOCK.synchronized { if (nextUpstream != null) { upstream = nextUpstream nextUpstream = null diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 501ff58155c4a..4067acee738ed 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -783,13 +783,15 @@ private[spark] class ExternalSorter[K, V, C]( private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)]) extends Iterator[((Int, K), C)] { + private val SPILL_LOCK = new Object() + private var nextUpstream: Iterator[((Int, K), C)] = null private var cur: ((Int, K), C) = readNext() private var hasSpilled: Boolean = false - def spill(): Boolean = synchronized { + def spill(): Boolean = SPILL_LOCK.synchronized { if (hasSpilled) { false } else { @@ -819,7 +821,7 @@ private[spark] class ExternalSorter[K, V, C]( } } - def readNext(): ((Int, K), C) = synchronized { + def readNext(): ((Int, K), C) = SPILL_LOCK.synchronized { if (nextUpstream != null) { upstream = nextUpstream nextUpstream = null From e7a98d57a31923406c204e15f72c7a43579653bb Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 11:02:21 +0800 Subject: [PATCH 17/18] fix Mima --- project/MimaExcludes.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 65b0a97e4dc12..3c9f1532f9dcd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -660,6 +660,9 @@ object MimaExcludes { // SPARK-14704: Create accumulators in TaskMetrics ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this") + ) ++ Seq( + // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable") ) case v if v.startsWith("1.6") => Seq( From ff3c2b85bc50ac631684e443cb7a19df9359535e Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Thu, 21 Apr 2016 14:38:59 +0800 Subject: [PATCH 18/18] update ut --- .../util/collection/ExternalAppendOnlyMapSuite.scala | 2 +- .../spark/util/collection/ExternalSorterSuite.scala | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index b6273016104b9..5141e36d9e38d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -424,7 +424,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { .set("spark.shuffle.sort.bypassMergeThreshold", "0") sc = new SparkContext("local", "test", conf) val N = 2e5.toInt - sc.parallelize(1 to N, 10) + sc.parallelize(1 to N, 2) .map { i => (i, i) } .groupByKey() .reduceByKey(_ ++ _) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 2f35153fbd192..4dd8e31c27351 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -617,12 +617,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { .set("spark.shuffle.sort.bypassMergeThreshold", "0") sc = new SparkContext("local", "test", conf) val N = 2e5.toInt - val p = new org.apache.spark.HashPartitioner(10) - val p2 = new org.apache.spark.HashPartitioner(5) - sc.parallelize(1 to N, 10) - .map { x => (x % 10000) -> x.toLong } - .repartitionAndSortWithinPartitions(p2) + val p = new org.apache.spark.HashPartitioner(2) + val p2 = new org.apache.spark.HashPartitioner(3) + sc.parallelize(1 to N, 3) + .map { x => (x % 100000) -> x.toLong } .repartitionAndSortWithinPartitions(p) + .repartitionAndSortWithinPartitions(p2) .count() } }