Skip to content
128 changes: 128 additions & 0 deletions core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.memory;


import java.io.IOException;

import org.apache.spark.unsafe.memory.MemoryBlock;


/**
* An memory consumer of TaskMemoryManager, which support spilling.
*/
public abstract class MemoryConsumer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about naming this class SpillableMemoryConsumer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it too long?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length is about the same as TaskMemoryManager - so not too long.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm neutral on the name change. At first I thought that the name MemoryConsumer might not make sense if it was used by places that can't spill, but I suppose that those places could just have spill() return 0. So I'm fine sticking with the current name.


private final TaskMemoryManager taskMemoryManager;
private final long pageSize;
private long used;

protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
this.taskMemoryManager = taskMemoryManager;
this.pageSize = pageSize;
this.used = 0;
}

protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
this(taskMemoryManager, taskMemoryManager.pageSizeBytes());
}

/**
* Returns the size of used memory in bytes.
*/
long getUsed() {
return used;
}

/**
* Force spill during building.
*
* For testing.
*/
public void spill() throws IOException {
spill(Long.MAX_VALUE, this);
}

/**
* Spill some data to disk to release memory, which will be called by TaskMemoryManager
* when there is not enough memory for the task.
*
* This should be implemented by subclass.
*
* Note: In order to avoid possible deadlock, should not call acquireMemory() from spill().
*
* @param size the amount of memory should be released
* @param trigger the MemoryConsumer that trigger this spilling
* @return the amount of released memory in bytes
* @throws IOException
*/
public abstract long spill(long size, MemoryConsumer trigger) throws IOException;

/**
* Acquire `size` bytes memory.
*
* If there is not enough memory, throws OutOfMemoryError.
*/
protected void acquireMemory(long size) {
long got = taskMemoryManager.acquireExecutionMemory(size, this);
if (got < size) {
taskMemoryManager.releaseExecutionMemory(got, this);
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
}
used += got;
}

/**
* Release `size` bytes memory.
*/
protected void releaseMemory(long size) {
used -= size;
taskMemoryManager.releaseExecutionMemory(size, this);
}

/**
* Allocate a memory block with at least `required` bytes.
*
* Throws IOException if there is not enough memory.
*
* @throws OutOfMemoryError
*/
protected MemoryBlock allocatePage(long required) {
MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
if (page == null || page.size() < required) {
long got = 0;
if (page != null) {
got = page.size();
freePage(page);
}
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
}
used += page.size();
return page;
}

/**
* Free a memory block.
*/
protected void freePage(MemoryBlock page) {
used -= page.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, maybe an invalid concern, but is it safe to call page.size() on a freed page?

taskMemoryManager.freePage(page, this);
}
}
138 changes: 108 additions & 30 deletions core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

package org.apache.spark.memory;

import java.util.*;
import javax.annotation.concurrent.GuardedBy;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;

/**
* Manages the memory allocated by an individual task.
Expand Down Expand Up @@ -100,30 +105,105 @@ public class TaskMemoryManager {
*/
private final boolean inHeap;

/**
* The size of memory granted to each consumer.
*/
@GuardedBy("this")
private final HashSet<MemoryConsumer> consumers;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain that this field is guarded by synchronizing on this (or use @GuardedBy("this")).


/**
* Construct a new TaskMemoryManager.
*/
public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
this.memoryManager = memoryManager;
this.taskAttemptId = taskAttemptId;
this.consumers = new HashSet<>();
}

/**
* Acquire N bytes of memory for execution, evicting cached blocks if necessary.
* Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
* spill() of consumers to release more memory.
*
* @return number of bytes successfully granted (<= N).
*/
public long acquireExecutionMemory(long size) {
return memoryManager.acquireExecutionMemory(size, taskAttemptId);
public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
assert(required >= 0);
synchronized (this) {
long got = memoryManager.acquireExecutionMemory(required, taskAttemptId);

// try to release memory from other consumers first, then we can reduce the frequency of
// spilling, avoid to have too many spilled files.
if (got < required) {
// Call spill() on other consumers to release memory
for (MemoryConsumer c: consumers) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this approach still have the same concern about concurrent modification of consumers while iterating over it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we never remove it, and it will not add more under this lock.

if (c != null && c != consumer && c.getUsed() > 0) {
try {
long released = c.spill(required - got, consumer);
if (released > 0) {
logger.info("Task {} released {} from {} for {}", taskAttemptId,
Utils.bytesToString(released), c, consumer);
got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
if (got >= required) {
break;
}
}
} catch (IOException e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this catch clause be moved to wrap c.spill() at line 142 ?

logger.error("error while calling spill() on " + c, e);
throw new OutOfMemoryError("error while calling spill() on " + c + " : "
+ e.getMessage());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using string concatenation to pass the IOException's message, why not use regular exception chaining here? Does OutOfMemoryError not support that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not support that.

}
}
}
}

// call spill() on itself
if (got < required && consumer != null) {
try {
long released = consumer.spill(required - got, consumer);
if (released > 0) {
logger.info("Task {} released {} from itself ({})", taskAttemptId,
Utils.bytesToString(released), consumer);
got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
}
} catch (IOException e) {
logger.error("error while calling spill() on " + consumer, e);
throw new OutOfMemoryError("error while calling spill() on " + consumer + " : "
+ e.getMessage());
}
}

consumers.add(consumer);
logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
return got;
}
}

/**
* Release N bytes of execution memory.
* Release N bytes of execution memory for a MemoryConsumer.
*/
public void releaseExecutionMemory(long size) {
public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert to make sure size >= 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
memoryManager.releaseExecutionMemory(size, taskAttemptId);
}

/**
* Dump the memory usage of all consumers.
*/
public void showMemoryUsage() {
logger.info("Memory used in task " + taskAttemptId);
synchronized (this) {
for (MemoryConsumer c: consumers) {
if (c.getUsed() > 0) {
logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed()));
}
}
}
}

/**
* Return the page size in bytes.
*/
public long pageSizeBytes() {
return memoryManager.pageSizeBytes();
}
Expand All @@ -134,42 +214,40 @@ public long pageSizeBytes() {
*
* Returns `null` if there was not enough memory to allocate the page.
*/
public MemoryBlock allocatePage(long size) {
public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
if (size > MAXIMUM_PAGE_SIZE_BYTES) {
throw new IllegalArgumentException(
"Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
}

long acquired = acquireExecutionMemory(size, consumer);
if (acquired <= 0) {
return null;
}

final int pageNumber;
synchronized (this) {
pageNumber = allocatedPages.nextClearBit(0);
if (pageNumber >= PAGE_TABLE_SIZE) {
releaseExecutionMemory(acquired, consumer);
throw new IllegalStateException(
"Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
}
allocatedPages.set(pageNumber);
}
final long acquiredExecutionMemory = acquireExecutionMemory(size);
if (acquiredExecutionMemory != size) {
releaseExecutionMemory(acquiredExecutionMemory);
synchronized (this) {
allocatedPages.clear(pageNumber);
}
return null;
}
final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
if (logger.isTraceEnabled()) {
logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
}
return page;
}

/**
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
*/
public void freePage(MemoryBlock page) {
public void freePage(MemoryBlock page, MemoryConsumer consumer) {
assert (page.pageNumber != -1) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
assert(allocatedPages.get(page.pageNumber));
Expand All @@ -182,14 +260,14 @@ public void freePage(MemoryBlock page) {
}
long pageSize = page.size();
memoryManager.tungstenMemoryAllocator().free(page);
releaseExecutionMemory(pageSize);
releaseExecutionMemory(pageSize, consumer);
}

/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
*
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
* @param offsetInPage an offset in this page which incorporates the base offset. In other words,
* this should be the value that you would pass as the base offset into an
* UNSAFE call (e.g. page.baseOffset() + something).
Expand Down Expand Up @@ -261,17 +339,17 @@ public long getOffsetInPage(long pagePlusOffsetAddress) {
* value can be used to detect memory leaks.
*/
public long cleanUpAllAllocatedMemory() {
long freedBytes = 0;
for (MemoryBlock page : pageTable) {
if (page != null) {
freedBytes += page.size();
freePage(page);
synchronized (this) {
Arrays.fill(pageTable, null);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

for (MemoryConsumer c: consumers) {
if (c != null && c.getUsed() > 0) {
// In case of failed task, it's normal to see leaked memory
logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
}
}
consumers.clear();
}

freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);

return freedBytes;
return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
}

/**
Expand Down
Loading