Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ protected void handleMessage(
int numRemovedBlocks = blockManager.removeBlocks(msg.appId, msg.execId, msg.blockIds);
callback.onSuccess(new BlocksRemoved(numRemovedBlocks).toByteBuffer());

} else if (msgObj instanceof GetLocalDirsForExecutors) {
GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
checkAuth(client, msg.appId);
Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());

} else {
throw new UnsupportedOperationException("Unexpected message: " + msgObj);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;

import com.codahale.metrics.MetricSet;
import com.google.common.collect.Lists;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.shuffle.protocol.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
Expand Down Expand Up @@ -182,14 +183,54 @@ public void onSuccess(ByteBuffer response) {
@Override
public void onFailure(Throwable e) {
logger.warn("Error trying to remove RDD blocks " + Arrays.toString(blockIds) +
" via external shuffle service from executor: " + execId, e);
" via external shuffle service from executor: " + execId, e);
numRemovedBlocksFuture.complete(0);
client.close();
}
});
return numRemovedBlocksFuture;
}

public void getHostLocalDirs(
String host,
int port,
String[] execIds,
CompletableFuture<Map<String, String[]>> hostLocalDirsCompletable) {
checkInit();
GetLocalDirsForExecutors getLocalDirsMessage = new GetLocalDirsForExecutors(appId, execIds);
try {
TransportClient client = clientFactory.createClient(host, port);
client.sendRpc(getLocalDirsMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response);
hostLocalDirsCompletable.complete(
((LocalDirsForExecutors) msgObj).getLocalDirsByExec());
} catch (Throwable t) {
logger.warn("Error trying to get the host local dirs for " +
Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service",
t.getCause());
hostLocalDirsCompletable.completeExceptionally(t);
} finally {
client.close();
}
}

@Override
public void onFailure(Throwable t) {
logger.warn("Error trying to get the host local dirs for " +
Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service",
t.getCause());
hostLocalDirsCompletable.completeExceptionally(t);
client.close();
}
});
} catch (IOException | InterruptedException e) {
hostLocalDirsCompletable.completeExceptionally(e);
}
}

@Override
public void close() {
checkInit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -369,6 +371,19 @@ public int removeBlocks(String appId, String execId, String[] blockIds) {
return numRemovedBlocks;
}

public Map<String, String[]> getLocalDirs(String appId, String[] execIds) {
return Arrays.stream(execIds)
.map(exec -> {
ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec));
if (info == null) {
throw new RuntimeException(
String.format("Executor is not registered (appId=%s, execId=%s)", appId, exec));
}
return Pair.of(exec, info.localDirs);
})
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
}

/** Simply encodes an executor's full ID, which is appId + execId. */
public static class AppExecId {
public final String appId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public abstract class BlockTransferMessage implements Encodable {
public enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
FETCH_SHUFFLE_BLOCKS(9);
FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11);

private final byte id;

Expand Down Expand Up @@ -76,6 +76,8 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
case 7: return RemoveBlocks.decode(buf);
case 8: return BlocksRemoved.decode(buf);
case 9: return FetchShuffleBlocks.decode(buf);
case 10: return GetLocalDirsForExecutors.decode(buf);
case 11: return LocalDirsForExecutors.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public String toString() {
public boolean equals(Object other) {
if (other != null && other instanceof BlocksRemoved) {
BlocksRemoved o = (BlocksRemoved) other;
return Objects.equal(numRemovedBlocks, o.numRemovedBlocks);
return numRemovedBlocks == o.numRemovedBlocks;
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public boolean equals(Object other) {
if (other != null && other instanceof ExecutorShuffleInfo) {
ExecutorShuffleInfo o = (ExecutorShuffleInfo) other;
return Arrays.equals(localDirs, o.localDirs)
&& Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir)
&& subDirsPerLocalDir == o.subDirsPerLocalDir
&& Objects.equal(shuffleManager, o.shuffleManager);
}
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle.protocol;

import java.util.Arrays;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

import org.apache.spark.network.protocol.Encoders;

// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;

/** Request to get the local dirs for the given executors. */
public class GetLocalDirsForExecutors extends BlockTransferMessage {
public final String appId;
public final String[] execIds;

public GetLocalDirsForExecutors(String appId, String[] execIds) {
this.appId = appId;
this.execIds = execIds;
}

@Override
protected Type type() { return Type.GET_LOCAL_DIRS_FOR_EXECUTORS; }

@Override
public int hashCode() {
return Objects.hashCode(appId) * 41 + Arrays.hashCode(execIds);
}

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("appId", appId)
.add("execIds", Arrays.toString(execIds))
.toString();
}

@Override
public boolean equals(Object other) {
if (other instanceof GetLocalDirsForExecutors) {
GetLocalDirsForExecutors o = (GetLocalDirsForExecutors) other;
return appId.equals(o.appId) && Arrays.equals(execIds, o.execIds);
}
return false;
}

@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(appId) + Encoders.StringArrays.encodedLength(execIds);
}

@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
Encoders.StringArrays.encode(buf, execIds);
}

public static GetLocalDirsForExecutors decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
String[] execIds = Encoders.StringArrays.decode(buf);
return new GetLocalDirsForExecutors(appId, execIds);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle.protocol;

import java.util.*;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

import org.apache.spark.network.protocol.Encoders;

// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;

/** The reply to get local dirs giving back the dirs for each of the requested executors. */
public class LocalDirsForExecutors extends BlockTransferMessage {
private final String[] execIds;
private final int[] numLocalDirsByExec;
private final String[] allLocalDirs;

public LocalDirsForExecutors(Map<String, String[]> localDirsByExec) {
this.execIds = new String[localDirsByExec.size()];
this.numLocalDirsByExec = new int[localDirsByExec.size()];
ArrayList<String> localDirs = new ArrayList<>();
int index = 0;
for (Map.Entry<String, String[]> e: localDirsByExec.entrySet()) {
execIds[index] = e.getKey();
numLocalDirsByExec[index] = e.getValue().length;
Collections.addAll(localDirs, e.getValue());
index++;
}
this.allLocalDirs = localDirs.toArray(new String[0]);
}

private LocalDirsForExecutors(String[] execIds, int[] numLocalDirsByExec, String[] allLocalDirs) {
this.execIds = execIds;
this.numLocalDirsByExec = numLocalDirsByExec;
this.allLocalDirs = allLocalDirs;
}

@Override
protected Type type() { return Type.LOCAL_DIRS_FOR_EXECUTORS; }

@Override
public int hashCode() {
return Arrays.hashCode(execIds);
}

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("execIds", Arrays.toString(execIds))
.add("numLocalDirsByExec", Arrays.toString(numLocalDirsByExec))
.add("allLocalDirs", Arrays.toString(allLocalDirs))
.toString();
}

@Override
public boolean equals(Object other) {
if (other instanceof LocalDirsForExecutors) {
LocalDirsForExecutors o = (LocalDirsForExecutors) other;
return Arrays.equals(execIds, o.execIds)
&& Arrays.equals(numLocalDirsByExec, o.numLocalDirsByExec)
&& Arrays.equals(allLocalDirs, o.allLocalDirs);
}
return false;
}

@Override
public int encodedLength() {
return Encoders.StringArrays.encodedLength(execIds)
+ Encoders.IntArrays.encodedLength(numLocalDirsByExec)
+ Encoders.StringArrays.encodedLength(allLocalDirs);
}

@Override
public void encode(ByteBuf buf) {
Encoders.StringArrays.encode(buf, execIds);
Encoders.IntArrays.encode(buf, numLocalDirsByExec);
Encoders.StringArrays.encode(buf, allLocalDirs);
}

public static LocalDirsForExecutors decode(ByteBuf buf) {
String[] execIds = Encoders.StringArrays.decode(buf);
int[] numLocalDirsByExec = Encoders.IntArrays.decode(buf);
String[] allLocalDirs = Encoders.StringArrays.decode(buf);
return new LocalDirsForExecutors(execIds, numLocalDirsByExec, allLocalDirs);
}

public Map<String, String[]> getLocalDirsByExec() {
Map<String, String[]> localDirsByExec = new HashMap<>();
int index = 0;
int localDirsIndex = 0;
for (int length: numLocalDirsByExec) {
localDirsByExec.put(execIds[index],
Arrays.copyOfRange(allLocalDirs, localDirsIndex, localDirsIndex + length));
localDirsIndex += length;
index++;
}
return localDirsByExec;
}
}
Loading