From 174c4e55b897beaf51e395376bbb3d651d394d94 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Jul 2018 00:18:31 +0800 Subject: [PATCH 1/2] free aggregate map when task ends --- .../spark/internal/config/package.scala | 8 +++++++ .../apache/spark/memory/MemoryManager.scala | 2 +- .../UnsafeFixedWidthAggregationMap.java | 17 ++++++++++----- .../spark/sql/execution/SparkStrategies.scala | 7 +------ .../aggregate/HashAggregateExec.scala | 2 +- .../TungstenAggregationIterator.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 21 ++++++++++++------- 7 files changed, 37 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ba892bf7f60d6..e75bdcc00cec6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -563,4 +563,12 @@ package object config { .intConf .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + + private[spark] val BUFFER_PAGE_SIZE = + ConfigBuilder("spark.buffer.pageSize") + .internal() + .doc("The page size(in bytes) of the data buffers used in Sorter, HashMap, etc.") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ > 0, "page size must be positive.") + .createOptional } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 0641adc2ab699..0e937ab317e65 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -225,7 +225,7 @@ private[spark] abstract class MemoryManager( } val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) - conf.getSizeAsBytes("spark.buffer.pageSize", default) + conf.get(BUFFER_PAGE_SIZE).getOrElse(default) } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..7f44642afc684 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 07a6fcae83b70..cfbcb9aad65c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -73,12 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case Limit(IntegerLiteral(limit), Sort(order, true, child)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c7b2c187cccd..2cac0cfce28de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -328,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..c1911235f8df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) From 7f2b653bb2fdeab7f52014e39eba3be7219bb961 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Jul 2018 15:45:26 +0800 Subject: [PATCH 2/2] address comments --- .../scala/org/apache/spark/internal/config/package.scala | 8 -------- .../scala/org/apache/spark/memory/MemoryManager.scala | 2 +- .../sql/execution/UnsafeFixedWidthAggregationMap.java | 2 +- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e75bdcc00cec6..ba892bf7f60d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -563,12 +563,4 @@ package object config { .intConf .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) - - private[spark] val BUFFER_PAGE_SIZE = - ConfigBuilder("spark.buffer.pageSize") - .internal() - .doc("The page size(in bytes) of the data buffers used in Sorter, HashMap, etc.") - .bytesConf(ByteUnit.BYTE) - .checkValue(_ > 0, "page size must be positive.") - .createOptional } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 0e937ab317e65..0641adc2ab699 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -225,7 +225,7 @@ private[spark] abstract class MemoryManager( } val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) - conf.get(BUFFER_PAGE_SIZE).getOrElse(default) + conf.getSizeAsBytes("spark.buffer.pageSize", default) } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 7f44642afc684..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -106,7 +106,7 @@ public UnsafeFixedWidthAggregationMap( // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator - // does not fully consume the sorter's output (e.g. sort followed by limit). + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). taskContext.addTaskCompletionListener(context -> { free(); });