diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java index d62f47c0e318f..97e56bef5d1aa 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java @@ -34,6 +34,7 @@ import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; @@ -63,7 +64,7 @@ final class Request { - private static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON; + static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON; final String method; final String endpoint; @@ -338,6 +339,11 @@ static Request search(SearchRequest searchRequest) throws IOException { return new Request(HttpGet.METHOD_NAME, endpoint, params.getParams(), entity); } + static Request searchScroll(SearchScrollRequest searchScrollRequest) throws IOException { + HttpEntity entity = createEntity(searchScrollRequest, REQUEST_BODY_CONTENT_TYPE); + return new Request("GET", "/_search/scroll", Collections.emptyMap(), entity); + } + private static HttpEntity createEntity(ToXContent toXContent, XContentType xContentType) throws IOException { BytesRef source = XContentHelper.toXContent(toXContent, xContentType, false).toBytesRef(); return new ByteArrayEntity(source.bytes, source.offset, source.length, ContentType.create(xContentType.mediaType())); @@ -483,7 +489,7 @@ Params withWaitForActiveShards(ActiveShardCount activeShardCount) { return this; } - Params withIndicesOptions (IndicesOptions indicesOptions) { + Params withIndicesOptions(IndicesOptions indicesOptions) { putParam("ignore_unavailable", Boolean.toString(indicesOptions.ignoreUnavailable())); putParam("allow_no_indices", Boolean.toString(indicesOptions.allowNoIndices())); String expandWildcards; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java index 47645817c8491..ff4101be7c0de 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java @@ -38,6 +38,7 @@ import org.elasticsearch.action.main.MainResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.common.CheckedFunction; @@ -325,6 +326,27 @@ public void searchAsync(SearchRequest searchRequest, ActionListenerSearch Scroll + * API on elastic.co + */ + public SearchResponse searchScroll(SearchScrollRequest searchScrollRequest, Header... headers) throws IOException { + return performRequestAndParseEntity(searchScrollRequest, Request::searchScroll, SearchResponse::fromXContent, emptySet(), headers); + } + + /** + * Asynchronously executes a search using the Search Scroll api + * + * See Search Scroll + * API on elastic.co + */ + public void searchScrollAsync(SearchScrollRequest searchScrollRequest, ActionListener listener, Header... headers) { + performRequestAsyncAndParseEntity(searchScrollRequest, Request::searchScroll, SearchResponse::fromXContent, + listener, emptySet(), headers); + } + private Resp performRequestAndParseEntity(Req request, CheckedFunction requestConverter, CheckedFunction entityParser, @@ -354,6 +376,7 @@ Resp performRequest(Req request, } throw parseResponseException(e); } + try { return responseConverter.apply(response); } catch(Exception e) { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java index 12f0d991e7f59..d06f4deda725f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; @@ -40,6 +41,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.lucene.uid.Versions; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; @@ -714,10 +716,27 @@ public void testSearch() throws Exception { if (searchSourceBuilder == null) { assertNull(request.entity); } else { - BytesReference expectedBytes = XContentHelper.toXContent(searchSourceBuilder, XContentType.JSON, false); - assertEquals(XContentType.JSON.mediaType(), request.entity.getContentType().getValue()); - assertEquals(expectedBytes, new BytesArray(EntityUtils.toByteArray(request.entity))); + assertToXContentBody(searchSourceBuilder, request.entity); + } + } + + public void testSearchScroll() throws IOException { + SearchScrollRequest searchScrollRequest = new SearchScrollRequest(); + searchScrollRequest.scrollId(randomAlphaOfLengthBetween(5, 10)); + if (randomBoolean()) { + searchScrollRequest.scroll(randomPositiveTimeValue()); } + Request request = Request.searchScroll(searchScrollRequest); + assertEquals("GET", request.method); + assertEquals("/_search/scroll", request.endpoint); + assertEquals(0, request.params.size()); + assertToXContentBody(searchScrollRequest, request.entity); + } + + private static void assertToXContentBody(ToXContent expectedBody, HttpEntity actualEntity) throws IOException { + BytesReference expectedBytes = XContentHelper.toXContent(expectedBody, Request.REQUEST_BODY_CONTENT_TYPE, false); + assertEquals(XContentType.JSON.mediaType(), actualEntity.getContentType().getValue()); + assertEquals(expectedBytes, new BytesArray(EntityUtils.toByteArray(actualEntity))); } public void testParams() { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 05883a066a5da..8c5cdc6d68933 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -33,6 +33,7 @@ import org.apache.http.message.BasicHttpResponse; import org.apache.http.message.BasicRequestLine; import org.apache.http.message.BasicStatusLine; +import org.apache.http.nio.entity.NStringEntity; import org.elasticsearch.Build; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; @@ -41,21 +42,26 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.main.MainRequest; import org.elasticsearch.action.main.MainResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchResponseSections; +import org.elasticsearch.action.search.SearchScrollRequest; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.cbor.CborXContent; import org.elasticsearch.common.xcontent.smile.SmileXContent; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.ArgumentMatcher; -import org.mockito.Matchers; import org.mockito.internal.matchers.ArrayEquals; import org.mockito.internal.matchers.VarargMatcher; @@ -68,6 +74,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.client.RestClientTestUtil.randomHeaders; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.hamcrest.CoreMatchers.instanceOf; import static org.mockito.Matchers.anyMapOf; @@ -76,6 +83,8 @@ import static org.mockito.Matchers.anyVararg; import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isNotNull; +import static org.mockito.Matchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -95,49 +104,70 @@ public void initClient() { } public void testPingSuccessful() throws IOException { - Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header"); + Header[] headers = randomHeaders(random(), "Header"); Response response = mock(Response.class); when(response.getStatusLine()).thenReturn(newStatusLine(RestStatus.OK)); when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class), anyObject(), anyVararg())).thenReturn(response); assertTrue(restHighLevelClient.ping(headers)); verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()), - Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); + isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); } public void testPing404NotFound() throws IOException { - Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header"); + Header[] headers = randomHeaders(random(), "Header"); Response response = mock(Response.class); when(response.getStatusLine()).thenReturn(newStatusLine(RestStatus.NOT_FOUND)); when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class), anyObject(), anyVararg())).thenReturn(response); assertFalse(restHighLevelClient.ping(headers)); verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()), - Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); + isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); } public void testPingSocketTimeout() throws IOException { - Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header"); + Header[] headers = randomHeaders(random(), "Header"); when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class), anyObject(), anyVararg())).thenThrow(new SocketTimeoutException()); expectThrows(SocketTimeoutException.class, () -> restHighLevelClient.ping(headers)); verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()), - Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); + isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); } public void testInfo() throws IOException { - Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header"); - Response response = mock(Response.class); + Header[] headers = randomHeaders(random(), "Header"); MainResponse testInfo = new MainResponse("nodeName", Version.CURRENT, new ClusterName("clusterName"), "clusterUuid", Build.CURRENT, true); - when(response.getEntity()).thenReturn( - new StringEntity(toXContent(testInfo, XContentType.JSON, false).utf8ToString(), ContentType.APPLICATION_JSON)); - when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class), - anyObject(), anyVararg())).thenReturn(response); + mockResponse(testInfo); MainResponse receivedInfo = restHighLevelClient.info(headers); assertEquals(testInfo, receivedInfo); verify(restClient).performRequest(eq("GET"), eq("/"), eq(Collections.emptyMap()), - Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); + isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); + } + + public void testSearchScroll() throws IOException { + Header[] headers = randomHeaders(random(), "Header"); + SearchResponse mockSearchResponse = new SearchResponse(new SearchResponseSections(SearchHits.empty(), InternalAggregations.EMPTY, + null, false, false, null, 1), randomAlphaOfLengthBetween(5, 10), 5, 5, 100, new ShardSearchFailure[0]); + mockResponse(mockSearchResponse); + SearchResponse searchResponse = restHighLevelClient.searchScroll(new SearchScrollRequest(randomAlphaOfLengthBetween(5, 10)), + headers); + assertEquals(mockSearchResponse.getScrollId(), searchResponse.getScrollId()); + assertEquals(0, searchResponse.getHits().totalHits); + assertEquals(5, searchResponse.getTotalShards()); + assertEquals(5, searchResponse.getSuccessfulShards()); + assertEquals(100, searchResponse.getTook().getMillis()); + verify(restClient).performRequest(eq("GET"), eq("/_search/scroll"), eq(Collections.emptyMap()), + isNotNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers))); + } + + private void mockResponse(ToXContent toXContent) throws IOException { + Response response = mock(Response.class); + ContentType contentType = ContentType.parse(Request.REQUEST_BODY_CONTENT_TYPE.mediaType()); + String requestBody = toXContent(toXContent, Request.REQUEST_BODY_CONTENT_TYPE, false).utf8ToString(); + when(response.getEntity()).thenReturn(new NStringEntity(requestBody, contentType)); + when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class), + anyObject(), anyVararg())).thenReturn(response); } public void testRequestValidation() {