diff --git a/core/src/main/java/org/elasticsearch/action/search/SearchScrollRequest.java b/core/src/main/java/org/elasticsearch/action/search/SearchScrollRequest.java index 03a40dc8b3e9a..fbe648cceaa80 100644 --- a/core/src/main/java/org/elasticsearch/action/search/SearchScrollRequest.java +++ b/core/src/main/java/org/elasticsearch/action/search/SearchScrollRequest.java @@ -24,6 +24,9 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.Scroll; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; @@ -33,7 +36,7 @@ import static org.elasticsearch.action.ValidateActions.addValidationError; -public class SearchScrollRequest extends ActionRequest { +public class SearchScrollRequest extends ActionRequest implements ToXContentObject { private String scrollId; private Scroll scroll; @@ -145,4 +148,39 @@ public String getDescription() { return "scrollId[" + scrollId + "], scroll[" + scroll + "]"; } + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("scroll_id", scrollId); + if (scroll != null) { + builder.field("scroll", scroll.keepAlive().getStringRep()); + } + builder.endObject(); + return builder; + } + + /** + * Parse a search scroll request from a request body provided through the REST layer. + * Values that are already be set and are also found while parsing will be overridden. + */ + public void fromXContent(XContentParser parser) throws IOException { + if (parser.nextToken() != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException("Malformed content, must start with an object"); + } else { + XContentParser.Token token; + String currentFieldName = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if ("scroll_id".equals(currentFieldName) && token == XContentParser.Token.VALUE_STRING) { + scrollId(parser.text()); + } else if ("scroll".equals(currentFieldName) && token == XContentParser.Token.VALUE_STRING) { + scroll(new Scroll(TimeValue.parseTimeValue(parser.text(), null, "scroll"))); + } else { + throw new IllegalArgumentException("Unknown parameter [" + currentFieldName + + "] in request body or parameter is of the wrong type[" + token + "] "); + } + } + } + } } diff --git a/core/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java b/core/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java index feba6640b65a1..59b7c660fa163 100644 --- a/core/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java +++ b/core/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java @@ -22,8 +22,6 @@ import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; @@ -58,34 +56,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC request.withContentOrSourceParamParserOrNull(xContentParser -> { if (xContentParser != null) { - // NOTE: if rest request with xcontent body has request parameters, these parameters override xcontent values + // NOTE: if rest request with xcontent body has request parameters, values parsed from request body have the precedence try { - buildFromContent(xContentParser, searchScrollRequest); + searchScrollRequest.fromXContent(xContentParser); } catch (IOException e) { throw new IllegalArgumentException("Failed to parse request body", e); } }}); return channel -> client.searchScroll(searchScrollRequest, new RestStatusToXContentListener<>(channel)); } - - public static void buildFromContent(XContentParser parser, SearchScrollRequest searchScrollRequest) throws IOException { - if (parser.nextToken() != XContentParser.Token.START_OBJECT) { - throw new IllegalArgumentException("Malformed content, must start with an object"); - } else { - XContentParser.Token token; - String currentFieldName = null; - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if ("scroll_id".equals(currentFieldName) && token == XContentParser.Token.VALUE_STRING) { - searchScrollRequest.scrollId(parser.text()); - } else if ("scroll".equals(currentFieldName) && token == XContentParser.Token.VALUE_STRING) { - searchScrollRequest.scroll(new Scroll(TimeValue.parseTimeValue(parser.text(), null, "scroll"))); - } else { - throw new IllegalArgumentException("Unknown parameter [" + currentFieldName - + "] in request body or parameter is of the wrong type[" + token + "] "); - } - } - } - } } diff --git a/core/src/test/java/org/elasticsearch/action/search/SearchScrollRequestTests.java b/core/src/test/java/org/elasticsearch/action/search/SearchScrollRequestTests.java index 9773d7320d0f2..6ec9f95f489de 100644 --- a/core/src/test/java/org/elasticsearch/action/search/SearchScrollRequestTests.java +++ b/core/src/test/java/org/elasticsearch/action/search/SearchScrollRequestTests.java @@ -19,15 +19,25 @@ package org.elasticsearch.action.search; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.search.internal.InternalScrollSearchRequest; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.hamcrest.Matchers.startsWith; public class SearchScrollRequestTests extends ESTestCase { @@ -60,6 +70,60 @@ public void testInternalScrollSearchRequestSerialization() throws IOException { } } + public void testFromXContent() throws Exception { + SearchScrollRequest searchScrollRequest = new SearchScrollRequest(); + if (randomBoolean()) { + //test that existing values get overridden + searchScrollRequest = createSearchScrollRequest(); + } + try (XContentParser parser = createParser(XContentFactory.jsonBuilder() + .startObject() + .field("scroll_id", "SCROLL_ID") + .field("scroll", "1m") + .endObject())) { + searchScrollRequest.fromXContent(parser); + } + assertEquals("SCROLL_ID", searchScrollRequest.scrollId()); + assertEquals(TimeValue.parseTimeValue("1m", null, "scroll"), searchScrollRequest.scroll().keepAlive()); + } + + public void testFromXContentWithUnknownParamThrowsException() throws Exception { + SearchScrollRequest searchScrollRequest = new SearchScrollRequest(); + XContentParser invalidContent = createParser(XContentFactory.jsonBuilder() + .startObject() + .field("scroll_id", "value_2") + .field("unknown", "keyword") + .endObject()); + + Exception e = expectThrows(IllegalArgumentException.class, + () -> searchScrollRequest.fromXContent(invalidContent)); + assertThat(e.getMessage(), startsWith("Unknown parameter [unknown]")); + } + + public void testToXContent() throws IOException { + SearchScrollRequest searchScrollRequest = new SearchScrollRequest(); + searchScrollRequest.scrollId("SCROLL_ID"); + searchScrollRequest.scroll("1m"); + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + searchScrollRequest.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertEquals("{\"scroll_id\":\"SCROLL_ID\",\"scroll\":\"1m\"}", builder.string()); + } + } + + public void testToAndFromXContent() throws IOException { + XContentType xContentType = randomFrom(XContentType.values()); + boolean humanReadable = randomBoolean(); + SearchScrollRequest originalRequest = createSearchScrollRequest(); + BytesReference originalBytes = toShuffledXContent(originalRequest, xContentType, ToXContent.EMPTY_PARAMS, humanReadable); + SearchScrollRequest parsedRequest = new SearchScrollRequest(); + try (XContentParser parser = createParser(xContentType.xContent(), originalBytes)) { + parsedRequest.fromXContent(parser); + } + assertEquals(originalRequest, parsedRequest); + BytesReference parsedBytes = XContentHelper.toXContent(parsedRequest, xContentType, humanReadable); + assertToXContentEquivalent(originalBytes, parsedBytes, xContentType); + } + public void testEqualsAndHashcode() { checkEqualsAndHashCode(createSearchScrollRequest(), SearchScrollRequestTests::copyRequest, SearchScrollRequestTests::mutate); } diff --git a/core/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java b/core/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java index 662bc07f90d7e..078eab68d04f5 100644 --- a/core/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java +++ b/core/src/test/java/org/elasticsearch/search/scroll/RestSearchScrollActionTests.java @@ -20,36 +20,29 @@ package org.elasticsearch.search.scroll; import org.elasticsearch.action.search.SearchScrollRequest; +import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.search.RestSearchScrollAction; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestChannel; import org.elasticsearch.test.rest.FakeRestRequest; +import org.mockito.ArgumentCaptor; + +import java.util.HashMap; +import java.util.Map; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.startsWith; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; public class RestSearchScrollActionTests extends ESTestCase { - public void testParseSearchScrollRequest() throws Exception { - XContentParser content = createParser(XContentFactory.jsonBuilder() - .startObject() - .field("scroll_id", "SCROLL_ID") - .field("scroll", "1m") - .endObject()); - - SearchScrollRequest searchScrollRequest = new SearchScrollRequest(); - RestSearchScrollAction.buildFromContent(content, searchScrollRequest); - - assertThat(searchScrollRequest.scrollId(), equalTo("SCROLL_ID")); - assertThat(searchScrollRequest.scroll().keepAlive(), equalTo(TimeValue.parseTimeValue("1m", null, "scroll"))); - } public void testParseSearchScrollRequestWithInvalidJsonThrowsException() throws Exception { RestSearchScrollAction action = new RestSearchScrollAction(Settings.EMPTY, mock(RestController.class)); @@ -59,16 +52,24 @@ public void testParseSearchScrollRequestWithInvalidJsonThrowsException() throws assertThat(e.getMessage(), equalTo("Failed to parse request body")); } - public void testParseSearchScrollRequestWithUnknownParamThrowsException() throws Exception { - SearchScrollRequest searchScrollRequest = new SearchScrollRequest(); - XContentParser invalidContent = createParser(XContentFactory.jsonBuilder() - .startObject() - .field("scroll_id", "value_2") - .field("unknown", "keyword") - .endObject()); + public void testBodyParamsOverrideQueryStringParams() throws Exception { + NodeClient nodeClient = mock(NodeClient.class); + doNothing().when(nodeClient).searchScroll(any(), any()); + + RestSearchScrollAction action = new RestSearchScrollAction(Settings.EMPTY, mock(RestController.class)); + Map params = new HashMap<>(); + params.put("scroll_id", "QUERY_STRING"); + params.put("scroll", "1000m"); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()) + .withParams(params) + .withContent(new BytesArray("{\"scroll_id\":\"BODY\", \"scroll\":\"1m\"}"), XContentType.JSON).build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 0); + action.handleRequest(request, channel, nodeClient); - Exception e = expectThrows(IllegalArgumentException.class, - () -> RestSearchScrollAction.buildFromContent(invalidContent, searchScrollRequest)); - assertThat(e.getMessage(), startsWith("Unknown parameter [unknown]")); + ArgumentCaptor argument = ArgumentCaptor.forClass(SearchScrollRequest.class); + verify(nodeClient).searchScroll(argument.capture(), anyObject()); + SearchScrollRequest searchScrollRequest = argument.getValue(); + assertEquals("BODY", searchScrollRequest.scrollId()); + assertEquals("1m", searchScrollRequest.scroll().keepAlive().getStringRep()); } }