Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions client/rest/src/main/java/org/elasticsearch/client/RestClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.AuthCache;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpHead;
Expand All @@ -34,8 +35,11 @@
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.methods.HttpTrace;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.client.BasicAuthCache;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.nio.client.methods.HttpAsyncMethods;
import org.apache.http.nio.protocol.HttpAsyncRequestProducer;
Expand Down Expand Up @@ -92,7 +96,7 @@ public class RestClient implements Closeable {
private final long maxRetryTimeoutMillis;
private final String pathPrefix;
private final AtomicInteger lastHostIndex = new AtomicInteger(0);
private volatile Set<HttpHost> hosts;
private volatile HostTuple<Set<HttpHost>> hostTuple;
private final ConcurrentMap<HttpHost, DeadHostState> blacklist = new ConcurrentHashMap<>();
private final FailureListener failureListener;

Expand Down Expand Up @@ -122,11 +126,13 @@ public synchronized void setHosts(HttpHost... hosts) {
throw new IllegalArgumentException("hosts must not be null nor empty");
}
Set<HttpHost> httpHosts = new HashSet<>();
AuthCache authCache = new BasicAuthCache();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think this might not be desired behavior. On every sniff the auth cache is invalidated completely. I think it would be better to add the hosts to the set. Then we should do a set difference to find the hosts that have been added and a set difference to find the hosts that have been removed. The added hosts get a new entry in the authcache and the removed ones get removed from the cache

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we update the AuthCache rather than replace it, then it will potentially impact an ongoing request. I checked out the BasicAuthCache and BasicScheme code and it looks pretty lightweight?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. We do not need to do this

for (HttpHost host : hosts) {
Objects.requireNonNull(host, "host cannot be null");
httpHosts.add(host);
authCache.put(host, new BasicScheme());
}
this.hosts = Collections.unmodifiableSet(httpHosts);
this.hostTuple = new HostTuple<>(Collections.unmodifiableSet(httpHosts), authCache);
this.blacklist.clear();
}

Expand Down Expand Up @@ -315,19 +321,22 @@ public void performRequestAsync(String method, String endpoint, Map<String, Stri
setHeaders(request, headers);
FailureTrackingResponseListener failureTrackingResponseListener = new FailureTrackingResponseListener(responseListener);
long startTime = System.nanoTime();
performRequestAsync(startTime, nextHost().iterator(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
failureTrackingResponseListener);
performRequestAsync(startTime, nextHost(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
failureTrackingResponseListener);
}

private void performRequestAsync(final long startTime, final Iterator<HttpHost> hosts, final HttpRequestBase request,
private void performRequestAsync(final long startTime, final HostTuple<Iterator<HttpHost>> hostTuple, final HttpRequestBase request,
final Set<Integer> ignoreErrorCodes,
final HttpAsyncResponseConsumerFactory httpAsyncResponseConsumerFactory,
final FailureTrackingResponseListener listener) {
final HttpHost host = hosts.next();
final HttpHost host = hostTuple.hosts.next();
//we stream the request body if the entity allows for it
HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer = httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
client.execute(requestProducer, asyncResponseConsumer, new FutureCallback<HttpResponse>() {
final HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
final HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer =
httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
final HttpClientContext context = HttpClientContext.create();
context.setAuthCache(hostTuple.authCache);
client.execute(requestProducer, asyncResponseConsumer, context, new FutureCallback<HttpResponse>() {
@Override
public void completed(HttpResponse httpResponse) {
try {
Expand Down Expand Up @@ -366,7 +375,7 @@ public void failed(Exception failure) {
}

private void retryIfPossible(Exception exception) {
if (hosts.hasNext()) {
if (hostTuple.hosts.hasNext()) {
//in case we are retrying, check whether maxRetryTimeout has been reached
long timeElapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
long timeout = maxRetryTimeoutMillis - timeElapsedMillis;
Expand All @@ -377,7 +386,7 @@ private void retryIfPossible(Exception exception) {
} else {
listener.trackFailure(exception);
request.reset();
performRequestAsync(startTime, hosts, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
performRequestAsync(startTime, hostTuple, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
}
} else {
listener.onDefinitiveFailure(exception);
Expand Down Expand Up @@ -415,17 +424,18 @@ private void setHeaders(HttpRequest httpRequest, Header[] requestHeaders) {
* The iterator returned will never be empty. In case there are no healthy hosts available, or dead ones to be be retried,
* one dead host gets returned so that it can be retried.
*/
private Iterable<HttpHost> nextHost() {
private HostTuple<Iterator<HttpHost>> nextHost() {
final HostTuple<Set<HttpHost>> hostTuple = this.hostTuple;
Collection<HttpHost> nextHosts = Collections.emptySet();
do {
Set<HttpHost> filteredHosts = new HashSet<>(hosts);
Set<HttpHost> filteredHosts = new HashSet<>(hostTuple.hosts);
for (Map.Entry<HttpHost, DeadHostState> entry : blacklist.entrySet()) {
if (System.nanoTime() - entry.getValue().getDeadUntilNanos() < 0) {
filteredHosts.remove(entry.getKey());
}
}
if (filteredHosts.isEmpty()) {
//last resort: if there are no good hosts to use, return a single dead one, the one that's closest to being retried
//last resort: if there are no good host to use, return a single dead one, the one that's closest to being retried
List<Map.Entry<HttpHost, DeadHostState>> sortedHosts = new ArrayList<>(blacklist.entrySet());
if (sortedHosts.size() > 0) {
Collections.sort(sortedHosts, new Comparator<Map.Entry<HttpHost, DeadHostState>>() {
Expand All @@ -444,7 +454,7 @@ public int compare(Map.Entry<HttpHost, DeadHostState> o1, Map.Entry<HttpHost, De
nextHosts = rotatedHosts;
}
} while(nextHosts.isEmpty());
return nextHosts;
return new HostTuple<>(nextHosts.iterator(), hostTuple.authCache);
}

/**
Expand Down Expand Up @@ -686,4 +696,18 @@ public void onFailure(HttpHost host) {

}
}

/**
* {@code HostTuple} enables the {@linkplain HttpHost}s and {@linkplain AuthCache} to be set together in a thread
* safe, volatile way.
*/
private static class HostTuple<T> {
public final T hosts;
public final AuthCache authCache;

public HostTuple(final T hosts, final AuthCache authCache) {
this.hosts = hosts;
this.authCache = authCache;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine;
Expand Down Expand Up @@ -73,13 +75,15 @@ public class RestClientMultipleHostsTests extends RestClientTestCase {
public void createRestClient() throws IOException {
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
@Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
HttpHost httpHost = requestProducer.getTarget();
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
//return the desired status code or exception depending on the path
if (request.getURI().getPath().equals("/soe")) {
futureCallback.failed(new SocketTimeoutException(httpHost.toString()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import org.apache.http.Consts;
import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.util.EntityUtils;
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
import org.elasticsearch.mocksocket.MockHttpServer;
Expand All @@ -48,7 +52,10 @@
import static org.elasticsearch.client.RestClientTestUtil.getAllStatusCodes;
import static org.elasticsearch.client.RestClientTestUtil.getHttpMethods;
import static org.elasticsearch.client.RestClientTestUtil.randomStatusCode;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

/**
Expand All @@ -66,22 +73,10 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {

@BeforeClass
public static void startHttpServer() throws Exception {
String pathPrefixWithoutLeadingSlash;
if (randomBoolean()) {
pathPrefixWithoutLeadingSlash = "testPathPrefix/" + randomAsciiOfLengthBetween(1, 5);
pathPrefix = "/" + pathPrefixWithoutLeadingSlash;
} else {
pathPrefix = pathPrefixWithoutLeadingSlash = "";
}

pathPrefix = randomBoolean() ? "/testPathPrefix/" + randomAsciiOfLengthBetween(1, 5) : "";
httpServer = createHttpServer();
defaultHeaders = RestClientTestUtil.randomHeaders(getRandom(), "Header-default");
RestClientBuilder restClientBuilder = RestClient.builder(
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
if (pathPrefix.length() > 0) {
restClientBuilder.setPathPrefix((randomBoolean() ? "/" : "") + pathPrefixWithoutLeadingSlash);
}
restClient = restClientBuilder.build();
restClient = createRestClient(false, true);
}

private static HttpServer createHttpServer() throws Exception {
Expand Down Expand Up @@ -129,6 +124,35 @@ public void handle(HttpExchange httpExchange) throws IOException {
}
}

private static RestClient createRestClient(final boolean useAuth, final boolean usePreemptiveAuth) {
// provide the username/password for every request
final BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("user", "pass"));

final RestClientBuilder restClientBuilder = RestClient.builder(
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
if (pathPrefix.length() > 0) {
// sometimes cut off the leading slash
restClientBuilder.setPathPrefix(randomBoolean() ? pathPrefix.substring(1) : pathPrefix);
}

if (useAuth) {
restClientBuilder.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(final HttpAsyncClientBuilder httpClientBuilder) {
if (usePreemptiveAuth == false) {
// disable preemptive auth by ignoring any authcache
httpClientBuilder.disableAuthCaching();
}

return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
});
}

return restClientBuilder.build();
}

@AfterClass
public static void stopHttpServers() throws IOException {
restClient.close();
Expand Down Expand Up @@ -159,7 +183,7 @@ public void testHeaders() throws IOException {

assertEquals(method, esResponse.getRequestLine().getMethod());
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
assertHeaders(defaultHeaders, requestHeaders, esResponse.getHeaders(), standardHeaders);
for (final Header responseHeader : esResponse.getHeaders()) {
String name = responseHeader.getName();
Expand Down Expand Up @@ -189,7 +213,41 @@ public void testGetWithBody() throws IOException {
bodyTest("GET");
}

private void bodyTest(String method) throws IOException {
/**
* Verify that credentials are sent on the first request with preemptive auth enabled (default when provided with credentials).
*/
public void testPreemptiveAuthEnabled() throws IOException {
final String[] methods = { "POST", "PUT", "GET", "DELETE" };

try (final RestClient restClient = createRestClient(true, true)) {
for (final String method : methods) {
final Response response = bodyTest(restClient, method);

assertThat(response.getHeader("Authorization"), startsWith("Basic"));
}
}
}

/**
* Verify that credentials are <em>not</em> sent on the first request with preemptive auth disabled.
*/
public void testPreemptiveAuthDisabled() throws IOException {
final String[] methods = { "POST", "PUT", "GET", "DELETE" };

try (final RestClient restClient = createRestClient(true, false)) {
for (final String method : methods) {
final Response response = bodyTest(restClient, method);

assertThat(response.getHeader("Authorization"), nullValue());
}
}
}

private Response bodyTest(final String method) throws IOException {
return bodyTest(restClient, method);
}

private Response bodyTest(final RestClient restClient, final String method) throws IOException {
String requestBody = "{ \"field\": \"value\" }";
StringEntity entity = new StringEntity(requestBody);
int statusCode = randomStatusCode(getRandom());
Expand All @@ -201,7 +259,9 @@ private void bodyTest(String method) throws IOException {
}
assertEquals(method, esResponse.getRequestLine().getMethod());
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(requestBody, EntityUtils.toString(esResponse.getEntity()));

return esResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpTrace;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine;
Expand Down Expand Up @@ -96,11 +98,13 @@ public class RestClientSingleHostTests extends RestClientTestCase {
public void createRestClient() throws IOException {
httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
@Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
//return the desired status code or exception depending on the path
if (request.getURI().getPath().equals("/soe")) {
Expand Down Expand Up @@ -156,7 +160,7 @@ public void testInternalHttpRequest() throws Exception {
for (String httpMethod : getHttpMethods()) {
HttpUriRequest expectedRequest = performRandomRequest(httpMethod);
verify(httpClient, times(++times)).<HttpResponse>execute(requestArgumentCaptor.capture(),
any(HttpAsyncResponseConsumer.class), any(FutureCallback.class));
any(HttpAsyncResponseConsumer.class), any(HttpClientContext.class), any(FutureCallback.class));
HttpUriRequest actualRequest = (HttpUriRequest)requestArgumentCaptor.getValue().generateRequest();
assertEquals(expectedRequest.getURI(), actualRequest.getURI());
assertEquals(expectedRequest.getClass(), actualRequest.getClass());
Expand Down
Loading