diff --git a/docs/changelog/99219.yaml b/docs/changelog/99219.yaml new file mode 100644 index 0000000000000..811e2df5f83d0 --- /dev/null +++ b/docs/changelog/99219.yaml @@ -0,0 +1,5 @@ +pr: 99219 +summary: Reduce copying when creating scroll/PIT ids +area: Search +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java index 2b276fc827b14..2b7105cffe2bb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java @@ -8,15 +8,18 @@ package org.elasticsearch.action.search; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.ByteBufferStreamInput; import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.InputStreamStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; @@ -24,11 +27,11 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.transport.RemoteClusterAware; +import java.io.ByteArrayInputStream; import java.io.IOException; -import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -63,46 +66,54 @@ public static String encode( Map aliasFilter, TransportVersion version ) { - final Map shards = new HashMap<>(); - for (SearchPhaseResult searchPhaseResult : searchPhaseResults) { - final SearchShardTarget target = searchPhaseResult.getSearchShardTarget(); - shards.put( - target.getShardId(), - new SearchContextIdForNode(target.getClusterAlias(), target.getNodeId(), searchPhaseResult.getContextId()) - ); - } - try (BytesStreamOutput out = new BytesStreamOutput()) { - out.setTransportVersion(version); - TransportVersion.writeVersion(version, out); - out.writeMap(shards); - out.writeMap(aliasFilter, StreamOutput::writeWriteable); - return Base64.getUrlEncoder().encodeToString(BytesReference.toBytes(out.bytes())); + final BytesReference bytesReference; + try (var encodedStreamOutput = new BytesStreamOutput()) { + try (var out = new OutputStreamStreamOutput(Base64.getUrlEncoder().wrap(encodedStreamOutput))) { + out.setTransportVersion(version); + TransportVersion.writeVersion(version, out); + out.writeCollection(searchPhaseResults, SearchContextId::writeSearchPhaseResult); + out.writeMap(aliasFilter, StreamOutput::writeWriteable); + } + bytesReference = encodedStreamOutput.bytes(); } catch (IOException e) { + assert false : e; throw new IllegalArgumentException(e); } + final BytesRef bytesRef = bytesReference.toBytesRef(); + return new String(bytesRef.bytes, bytesRef.offset, bytesRef.length, StandardCharsets.ISO_8859_1); + } + + private static void writeSearchPhaseResult(StreamOutput out, SearchPhaseResult searchPhaseResult) throws IOException { + final SearchShardTarget target = searchPhaseResult.getSearchShardTarget(); + target.getShardId().writeTo(out); + new SearchContextIdForNode(target.getClusterAlias(), target.getNodeId(), searchPhaseResult.getContextId()).writeTo(out); } public static SearchContextId decode(NamedWriteableRegistry namedWriteableRegistry, String id) { - final ByteBuffer byteBuffer; - try { - byteBuffer = ByteBuffer.wrap(Base64.getUrlDecoder().decode(id)); - } catch (Exception e) { - throw new IllegalArgumentException("invalid id: [" + id + "]", e); - } - try (StreamInput in = new NamedWriteableAwareStreamInput(new ByteBufferStreamInput(byteBuffer), namedWriteableRegistry)) { + try ( + var decodedInputStream = Base64.getUrlDecoder().wrap(new ByteArrayInputStream(id.getBytes(StandardCharsets.ISO_8859_1))); + var in = new NamedWriteableAwareStreamInput(new InputStreamStreamInput(decodedInputStream), namedWriteableRegistry) + ) { final TransportVersion version = TransportVersion.readVersion(in); in.setTransportVersion(version); - final Map shards = in.readMap(ShardId::new, SearchContextIdForNode::new); - final Map aliasFilters = in.readMap(AliasFilter::readFrom); + final Map shards = Collections.unmodifiableMap( + in.readCollection(Maps::newHashMapWithExpectedSize, SearchContextId::readShardsMapEntry) + ); + final Map aliasFilters = in.readImmutableMap(AliasFilter::readFrom); if (in.available() > 0) { throw new IllegalArgumentException("Not all bytes were read"); } - return new SearchContextId(Collections.unmodifiableMap(shards), Collections.unmodifiableMap(aliasFilters)); + return new SearchContextId(shards, aliasFilters); } catch (IOException e) { + assert false : e; throw new IllegalArgumentException(e); } } + private static void readShardsMapEntry(StreamInput in, Map shards) throws IOException { + shards.put(new ShardId(in), new SearchContextIdForNode(in)); + } + public String[] getActualIndices() { final Set indices = new HashSet<>(); for (Map.Entry entry : shards().entrySet()) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java index 5680a4525b468..8702fdb16ea89 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java @@ -8,9 +8,13 @@ package org.elasticsearch.action.search; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.InputStreamStreamInput; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.VersionCheckingStreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.AtomicArray; @@ -21,8 +25,10 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.transport.RemoteClusterAware; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; import java.util.Base64; public final class TransportSearchHelper { @@ -34,32 +40,40 @@ static InternalScrollSearchRequest internalScrollSearchRequest(ShardSearchContex } static String buildScrollId(AtomicArray searchPhaseResults) { - try { - BytesStreamOutput out = new BytesStreamOutput(); - out.writeString(INCLUDE_CONTEXT_UUID); - out.writeString(searchPhaseResults.length() == 1 ? ParsedScrollId.QUERY_AND_FETCH_TYPE : ParsedScrollId.QUERY_THEN_FETCH_TYPE); - out.writeCollection(searchPhaseResults.asList(), (o, searchPhaseResult) -> { - o.writeString(searchPhaseResult.getContextId().getSessionId()); - o.writeLong(searchPhaseResult.getContextId().getId()); - SearchShardTarget searchShardTarget = searchPhaseResult.getSearchShardTarget(); - if (searchShardTarget.getClusterAlias() != null) { - o.writeString( - RemoteClusterAware.buildRemoteIndexName(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()) - ); - } else { - o.writeString(searchShardTarget.getNodeId()); - } - }); - return Base64.getUrlEncoder().encodeToString(out.copyBytes().array()); + final BytesReference bytesReference; + try (var encodedStreamOutput = new BytesStreamOutput()) { + try (var out = new OutputStreamStreamOutput(Base64.getUrlEncoder().wrap(encodedStreamOutput))) { + out.writeString(INCLUDE_CONTEXT_UUID); + out.writeString( + searchPhaseResults.length() == 1 ? ParsedScrollId.QUERY_AND_FETCH_TYPE : ParsedScrollId.QUERY_THEN_FETCH_TYPE + ); + out.writeCollection(searchPhaseResults.asList(), (o, searchPhaseResult) -> { + o.writeString(searchPhaseResult.getContextId().getSessionId()); + o.writeLong(searchPhaseResult.getContextId().getId()); + SearchShardTarget searchShardTarget = searchPhaseResult.getSearchShardTarget(); + if (searchShardTarget.getClusterAlias() != null) { + o.writeString( + RemoteClusterAware.buildRemoteIndexName(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()) + ); + } else { + o.writeString(searchShardTarget.getNodeId()); + } + }); + } + bytesReference = encodedStreamOutput.bytes(); } catch (IOException e) { + assert false : e; throw new UncheckedIOException(e); } + final BytesRef bytesRef = bytesReference.toBytesRef(); + return new String(bytesRef.bytes, bytesRef.offset, bytesRef.length, StandardCharsets.ISO_8859_1); } static ParsedScrollId parseScrollId(String scrollId) { - try { - byte[] bytes = Base64.getUrlDecoder().decode(scrollId); - ByteArrayStreamInput in = new ByteArrayStreamInput(bytes); + try ( + var decodedInputStream = Base64.getUrlDecoder().wrap(new ByteArrayInputStream(scrollId.getBytes(StandardCharsets.ISO_8859_1))); + var in = new InputStreamStreamInput(decodedInputStream) + ) { final boolean includeContextUUID; final String type; final String firstChunk = in.readString(); @@ -70,22 +84,13 @@ static ParsedScrollId parseScrollId(String scrollId) { includeContextUUID = false; type = firstChunk; } - SearchContextIdForNode[] context = new SearchContextIdForNode[in.readVInt()]; - for (int i = 0; i < context.length; ++i) { - final String contextUUID = includeContextUUID ? in.readString() : ""; - long id = in.readLong(); - String target = in.readString(); - String clusterAlias; - final int index = target.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR); - if (index == -1) { - clusterAlias = null; - } else { - clusterAlias = target.substring(0, index); - target = target.substring(index + 1); - } - context[i] = new SearchContextIdForNode(clusterAlias, target, new ShardSearchContextId(contextUUID, id)); - } - if (in.getPosition() != bytes.length) { + final SearchContextIdForNode[] context = in.readArray( + includeContextUUID + ? TransportSearchHelper::readSearchContextIdForNodeIncludingContextUUID + : TransportSearchHelper::readSearchContextIdForNodeExcludingContextUUID, + SearchContextIdForNode[]::new + ); + if (in.available() > 0) { throw new IllegalArgumentException("Not all bytes were read"); } return new ParsedScrollId(scrollId, type, context); @@ -94,6 +99,28 @@ static ParsedScrollId parseScrollId(String scrollId) { } } + private static SearchContextIdForNode readSearchContextIdForNodeIncludingContextUUID(StreamInput in) throws IOException { + return innerReadSearchContextIdForNode(in.readString(), in); + } + + private static SearchContextIdForNode readSearchContextIdForNodeExcludingContextUUID(StreamInput in) throws IOException { + return innerReadSearchContextIdForNode("", in); + } + + private static SearchContextIdForNode innerReadSearchContextIdForNode(String contextUUID, StreamInput in) throws IOException { + long id = in.readLong(); + String target = in.readString(); + String clusterAlias; + final int index = target.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR); + if (index == -1) { + clusterAlias = null; + } else { + clusterAlias = target.substring(0, index); + target = target.substring(index + 1); + } + return new SearchContextIdForNode(clusterAlias, target, new ShardSearchContextId(contextUUID, id)); + } + /** * Using the 'search.check_ccs_compatibility' setting, clients can ask for an early * check that inspects the incoming request and tries to verify that it can be handled by @@ -121,7 +148,5 @@ public static void checkCCSVersionCompatibility(Writeable writeableRequest) { } } - private TransportSearchHelper() { - - } + private TransportSearchHelper() {} } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java index af624a28cd60c..b2d9190685542 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java @@ -13,6 +13,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -1130,7 +1131,7 @@ public List readOptionalStringCollectionAsList() throws IOException { } /** - * Reads a set of objects which was written using {@link StreamOutput#writeCollection}}. If the returned set contains any entries it + * Reads a set of objects which was written using {@link StreamOutput#writeCollection}. If the returned set contains any entries it * will a (mutable) {@link HashSet}. If it is empty it might be immutable. The collection that was originally written should also have * been a set. */ @@ -1169,6 +1170,20 @@ public List readNamedWriteableCollectionAsList(Cla throw new UnsupportedOperationException("can't read named writeable from StreamInput"); } + /** + * Reads a collection which was written using {@link StreamOutput#writeCollection}, accumulating the results using the provided + * consumer. + */ + public C readCollection(IntFunction constructor, CheckedBiConsumer itemConsumer) + throws IOException { + int count = readArraySize(); + var result = constructor.apply(count); + for (int i = 0; i < count; i++) { + itemConsumer.accept(this, result); + } + return result; + } + /** * Reads a collection, comprising a call to {@link #readVInt} for the size, followed by that many invocations of {@code reader}. *