Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.ArrayList;
import java.util.List;

import com.codahale.metrics.Counter;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
Expand Down Expand Up @@ -66,6 +67,8 @@ public class TransportContext {
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;
private final boolean isClientOnly;
// Number of registered connections to the shuffle service
private Counter registeredConnections = new Counter();

/**
* Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
Expand Down Expand Up @@ -221,7 +224,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler, conf.maxChunksBeingTransferred());
return new TransportChannelHandler(client, responseHandler, requestHandler,
conf.connectionTimeoutMs(), closeIdleConnections);
conf.connectionTimeoutMs(), closeIdleConnections, this);
}

/**
Expand All @@ -234,4 +237,8 @@ private ChunkFetchRequestHandler createChunkFetchHandler(TransportChannelHandler
}

public TransportConf getConf() { return conf; }

public Counter getRegisteredConnections() {
return registeredConnections;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.apache.spark.network.TransportContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -57,18 +58,21 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
private final TransportRequestHandler requestHandler;
private final long requestTimeoutNs;
private final boolean closeIdleConnections;
private final TransportContext transportContext;

public TransportChannelHandler(
TransportClient client,
TransportResponseHandler responseHandler,
TransportRequestHandler requestHandler,
long requestTimeoutMs,
boolean closeIdleConnections) {
boolean closeIdleConnections,
TransportContext transportContext) {
this.client = client;
this.responseHandler = responseHandler;
this.requestHandler = requestHandler;
this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
this.closeIdleConnections = closeIdleConnections;
this.transportContext = transportContext;
}

public TransportClient getClient() {
Expand Down Expand Up @@ -176,4 +180,16 @@ public TransportResponseHandler getResponseHandler() {
return responseHandler;
}

@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
transportContext.getRegisteredConnections().inc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to call super.blah in these overrides so that other handlers in the pipeline also get the notification.

super.channelRegistered(ctx);
}

@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
transportContext.getRegisteredConnections().dec();
super.channelUnregistered(ctx);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;

import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricSet;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -159,4 +160,8 @@ public void close() {
}
bootstrap = null;
}

public Counter getRegisteredConnections() {
return context.getRegisteredConnections();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.codahale.metrics.Metric;
import com.codahale.metrics.MetricSet;
import com.codahale.metrics.Timer;
import com.codahale.metrics.Counter;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -173,22 +174,29 @@ private void checkAuth(TransportClient client, String appId) {
/**
* A simple class to wrap all shuffle service wrapper metrics
*/
private class ShuffleMetrics implements MetricSet {
@VisibleForTesting
public class ShuffleMetrics implements MetricSet {
private final Map<String, Metric> allMetrics;
// Time latency for open block request in ms
private final Timer openBlockRequestLatencyMillis = new Timer();
// Time latency for executor registration latency in ms
private final Timer registerExecutorRequestLatencyMillis = new Timer();
// Block transfer rate in byte per second
private final Meter blockTransferRateBytes = new Meter();
// Number of active connections to the shuffle service
private Counter activeConnections = new Counter();
// Number of registered connections to the shuffle service
private Counter registeredConnections = new Counter();

private ShuffleMetrics() {
public ShuffleMetrics() {
allMetrics = new HashMap<>();
allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
allMetrics.put("registeredExecutorsSize",
(Gauge<Integer>) () -> blockManager.getRegisteredExecutorsSize());
allMetrics.put("numActiveConnections", activeConnections);
allMetrics.put("numRegisteredConnections", registeredConnections);
}

@Override
Expand Down Expand Up @@ -244,4 +252,16 @@ public ManagedBuffer next() {
}
}

@Override
public void channelActive(TransportClient client) {
metrics.activeConnections.inc();
super.channelActive(client);
}

@Override
public void channelInactive(TransportClient client) {
metrics.activeConnections.dec();
super.channelInactive(client);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much. Sorry about that.

import java.nio.charset.StandardCharsets;
import java.nio.ByteBuffer;
import java.util.List;
Expand Down Expand Up @@ -170,15 +171,6 @@ protected void serviceInit(Configuration conf) throws Exception {
TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf));
blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile);

// register metrics on the block handler into the Node Manager's metrics system.
YarnShuffleServiceMetrics serviceMetrics =
new YarnShuffleServiceMetrics(blockHandler.getAllMetrics());

MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance();
metricsSystem.register(
"sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics);
logger.info("Registered metrics with Hadoop's DefaultMetricsSystem");

// If authentication is enabled, set up the shuffle server to use a
// special RPC handler that filters out unauthenticated fetch requests
List<TransportServerBootstrap> bootstraps = Lists.newArrayList();
Expand All @@ -199,6 +191,18 @@ protected void serviceInit(Configuration conf) throws Exception {
port = shuffleServer.getPort();
boundPort = port;
String authEnabledString = authEnabled ? "enabled" : "not enabled";

// register metrics on the block handler into the Node Manager's metrics system.
blockHandler.getAllMetrics().getMetrics().put("numRegisteredConnections",
shuffleServer.getRegisteredConnections());
YarnShuffleServiceMetrics serviceMetrics =
new YarnShuffleServiceMetrics(blockHandler.getAllMetrics());

MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance();
metricsSystem.register(
"sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics);
logger.info("Registered metrics with Hadoop's DefaultMetricsSystem");

logger.info("Started YARN shuffle service for Spark on port {}. " +
"Authentication is {}. Registered executor file is {}", port, authEnabledString,
registeredExecutorFile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ public static void collectMetric(
throw new IllegalStateException(
"Not supported class type of metric[" + name + "] for value " + gaugeValue);
}
} else if (metric instanceof Counter) {
Counter c = (Counter) metric;
long counterValue = c.getCount();
metricsRecordBuilder.addGauge(new ShuffleServiceMetricsInfo(name, "Number of " +
"connections to shuffle service " + name), counterValue);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
server = transportContext.createServer(port, bootstraps.asJava)

shuffleServiceSource.registerMetricSet(server.getAllMetrics)
blockHandler.getAllMetrics.getMetrics.put("numRegisteredConnections",
server.getRegisteredConnections)
shuffleServiceSource.registerMetricSet(blockHandler.getAllMetrics)
masterMetricsSystem.registerSource(shuffleServiceSource)
masterMetricsSystem.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers {
test("metrics named as expected") {
val allMetrics = Set(
"openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis",
"blockTransferRateBytes", "registeredExecutorsSize")
"blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections",
"numRegisteredConnections")

metrics.getMetrics.keySet().asScala should be (allMetrics)
}
Expand Down