Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.io.IOException;

import org.apache.spark.SparkEnv;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
Expand Down Expand Up @@ -84,7 +84,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
* @param taskContext the current task context.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
Expand All @@ -93,21 +93,28 @@ public UnsafeFixedWidthAggregationMap(
InternalRow emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
TaskMemoryManager taskMemoryManager,
TaskContext taskContext,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map =
new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
this.map = new BytesToBytesMap(
taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;

// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
taskContext.addTaskCompletionListener(context -> {
free();
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ case class HashAggregateExec(
initialBuffer,
bufferSchema,
groupingKeySchema,
TaskContext.get().taskMemoryManager(),
TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class TungstenAggregationIterator(
initialAggregationBuffer,
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
TaskContext.get().taskMemoryManager(),
TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2702,4 +2702,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-25144 'distinct' causes memory leak") {
val ds = List(Foo(Some("bar"))).toDS
val result = ds.flatMap(_.bar).distinct
result.rdd.isEmpty
}
}

case class Foo(bar: Option[String])
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable
import scala.util.{Random, Try}
import scala.util.control.NonFatal

import org.mockito.Mockito._
import org.scalatest.Matchers

import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
Expand Down Expand Up @@ -53,6 +54,8 @@ class UnsafeFixedWidthAggregationMapSuite
private var memoryManager: TestMemoryManager = null
private var taskMemoryManager: TaskMemoryManager = null

private var taskContext: TaskContext = null

def testWithMemoryLeakDetection(name: String)(f: => Unit) {
def cleanup(): Unit = {
if (taskMemoryManager != null) {
Expand All @@ -66,6 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite
val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false")
memoryManager = new TestMemoryManager(conf)
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
taskContext = mock(classOf[TaskContext])
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)

TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
Expand Down Expand Up @@ -110,7 +115,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
1024, // initial capacity,
PAGE_SIZE_BYTES,
false // disable perf metrics
Expand All @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
1024, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
Expand All @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
Expand All @@ -177,7 +182,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
Expand Down Expand Up @@ -225,7 +230,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
Expand Down Expand Up @@ -266,7 +271,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
StructType(Nil),
StructType(Nil),
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
Expand Down Expand Up @@ -311,7 +316,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
pageSize,
false // disable perf metrics
Expand Down Expand Up @@ -349,7 +354,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
pageSize,
false // disable perf metrics
Expand Down