Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import java.io.IOException;

import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
Expand Down Expand Up @@ -82,27 +82,34 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
* @param taskContext the current task context.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
*/
public UnsafeFixedWidthAggregationMap(
InternalRow emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
TaskMemoryManager taskMemoryManager,
TaskContext taskContext,
int initialCapacity,
long pageSizeBytes) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map =
new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true);
this.map = new BytesToBytesMap(
taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true);

// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
taskContext.addTaskCompletionListener(context -> {
free();
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ case class HashAggregateExec(
initialBuffer,
bufferSchema,
groupingKeySchema,
TaskContext.get().taskMemoryManager(),
TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class TungstenAggregationIterator(
initialAggregationBuffer,
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
TaskContext.get().taskMemoryManager(),
TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2818,4 +2818,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count)
}
}

test("SPARK-25144 'distinct' causes memory leak") {
val ds = List(Foo(Some("bar"))).toDS
val result = ds.flatMap(_.bar).distinct
result.rdd.isEmpty
}
}

case class Foo(bar: Option[String])
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable
import scala.util.{Random, Try}
import scala.util.control.NonFatal

import org.mockito.Mockito._
import org.scalatest.Matchers

import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
Expand Down Expand Up @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite
private var memoryManager: TestMemoryManager = null
private var taskMemoryManager: TaskMemoryManager = null

private var taskContext: TaskContext = null

def testWithMemoryLeakDetection(name: String)(f: => Unit) {
def cleanup(): Unit = {
if (taskMemoryManager != null) {
Expand All @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite
val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false")
memoryManager = new TestMemoryManager(conf)
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
taskContext = mock(classOf[TaskContext])
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)

TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
Expand Down Expand Up @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
1024, // initial capacity,
PAGE_SIZE_BYTES
)
Expand All @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
1024, // initial capacity
PAGE_SIZE_BYTES
)
Expand All @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES
)
Expand All @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES
)
Expand Down Expand Up @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES
)
Expand Down Expand Up @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
StructType(Nil),
StructType(Nil),
taskMemoryManager,
taskContext,
128, // initial capacity
PAGE_SIZE_BYTES
)
Expand Down Expand Up @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
pageSize
)
Expand Down Expand Up @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
taskContext,
128, // initial capacity
pageSize
)
Expand Down