diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java index 995be24dd3fcc..d18970f2bfa5c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java @@ -24,10 +24,17 @@ import org.elasticsearch.index.shard.ShardId; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; /** * A {@link RoutingNode} represents a cluster node associated with a single {@link DiscoveryNode} including all shards @@ -41,6 +48,8 @@ public class RoutingNode implements Iterable { private final LinkedHashMap shards; // LinkedHashMap to preserve order + private final Map> statesToShards; + public RoutingNode(String nodeId, DiscoveryNode node, ShardRouting... shards) { this(nodeId, node, buildShardRoutingMap(shards)); } @@ -49,6 +58,13 @@ public RoutingNode(String nodeId, DiscoveryNode node, ShardRouting... shards) { this.nodeId = nodeId; this.node = node; this.shards = shards; + statesToShards = new HashMap<>(ShardRoutingState.values().length); + for (ShardRoutingState state : ShardRoutingState.values()) { + statesToShards.put(state, new HashSet<>()); + } + for (ShardRouting shardRouting : shards.values()) { + statesToShards.get(shardRouting.state()).add(shardRouting); + } } private static LinkedHashMap buildShardRoutingMap(ShardRouting... shardRoutings) { @@ -104,6 +120,7 @@ void add(ShardRouting shard) { + "] where it already exists. current [" + shards.get(shard.shardId()) + "]. new [" + shard + "]"); } shards.put(shard.shardId(), shard); + statesToShards.get(shard.state()).add(shard); } void update(ShardRouting oldShard, ShardRouting newShard) { @@ -112,11 +129,14 @@ void update(ShardRouting oldShard, ShardRouting newShard) { // TODO: change caller logic in RoutingNodes so that this check can go away return; } + statesToShards.get(oldShard.state()).remove(oldShard); + statesToShards.get(newShard.state()).add(newShard); ShardRouting previousValue = shards.put(newShard.shardId(), newShard); assert previousValue == oldShard : "expected shard " + previousValue + " but was " + oldShard; } void remove(ShardRouting shard) { + statesToShards.get(shard.state()).remove(shard); ShardRouting previousValue = shards.remove(shard.shardId()); assert previousValue == shard : "expected shard " + previousValue + " but was " + shard; } @@ -127,15 +147,7 @@ void remove(ShardRouting shard) { * @return number of shards */ public int numberOfShardsWithState(ShardRoutingState... states) { - int count = 0; - for (ShardRouting shardEntry : this) { - for (ShardRoutingState state : states) { - if (shardEntry.state() == state) { - count++; - } - } - } - return count; + return Arrays.stream(states).mapToInt(s -> statesToShards.get(s).size()).sum(); } /** @@ -144,51 +156,34 @@ public int numberOfShardsWithState(ShardRoutingState... states) { * @return List of shards */ public List shardsWithState(ShardRoutingState... states) { - List shards = new ArrayList<>(); - for (ShardRouting shardEntry : this) { - for (ShardRoutingState state : states) { - if (shardEntry.state() == state) { - shards.add(shardEntry); - } - } - } - return shards; + return Arrays.stream(states) + .map(state -> statesToShards.get(state)) + .flatMap(Collection::stream) + .collect(Collectors.toList()); } /** * Determine the shards of an index with a specific state - * @param index id of the index + * @param index id of the index * @param states set of states which should be listed * @return a list of shards */ public List shardsWithState(String index, ShardRoutingState... states) { - List shards = new ArrayList<>(); - - for (ShardRouting shardEntry : this) { - if (!shardEntry.getIndexName().equals(index)) { - continue; - } - for (ShardRoutingState state : states) { - if (shardEntry.state() == state) { - shards.add(shardEntry); - } - } - } - return shards; + return Arrays.stream(states) + .map(state -> statesToShards.get(state)) + .flatMap(Collection::stream) + .filter(shard -> shard.getIndexName().equals(index)) + .collect(Collectors.toList()); } /** * The number of shards on this node that will not be eventually relocated. */ public int numberOfOwningShards() { - int count = 0; - for (ShardRouting shardEntry : this) { - if (shardEntry.state() != ShardRoutingState.RELOCATING) { - count++; - } - } - - return count; + return Arrays.stream(ShardRoutingState.values()) + .filter(s -> s != ShardRoutingState.RELOCATING) + .mapToInt(s -> statesToShards.get(s).size()) + .sum(); } public String prettyPrint() { @@ -200,6 +195,7 @@ public String prettyPrint() { return sb.toString(); } + @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("routingNode (["); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/ThrottlingAllocationDecider.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/ThrottlingAllocationDecider.java index 47ecc1b894b32..ad37415f46cce 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/ThrottlingAllocationDecider.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/ThrottlingAllocationDecider.java @@ -24,6 +24,7 @@ import org.elasticsearch.cluster.routing.RecoverySource; import org.elasticsearch.cluster.routing.RoutingNode; import org.elasticsearch.cluster.routing.ShardRouting; +import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.routing.allocation.RoutingAllocation; import org.elasticsearch.common.settings.ClusterSettings; @@ -122,10 +123,10 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingNode node, Routing // count *just the primaries* currently doing recovery on the node and check against primariesInitialRecoveries int primariesInRecovery = 0; - for (ShardRouting shard : node) { + for (ShardRouting shard : node.shardsWithState(ShardRoutingState.INITIALIZING)) { // when a primary shard is INITIALIZING, it can be because of *initial recovery* or *relocation from another node* // we only count initial recoveries here, so we need to make sure that relocating node is null - if (shard.initializing() && shard.primary() && shard.relocatingNodeId() == null) { + if (shard.primary() && shard.relocatingNodeId() == null) { primariesInRecovery++; } }