diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java index 8ff26762f6cd..cb6a085a4b01 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java @@ -5,7 +5,6 @@ import java.io.StringReader; import java.lang.reflect.Type; -import java.lang.reflect.TypeVariable; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.*; @@ -34,39 +33,31 @@ public class HubConnection implements AutoCloseable { private static final List emptyArray = new ArrayList<>(); private static final int MAX_NEGOTIATE_ATTEMPTS = 100; - private String baseUrl; - private Transport transport; - private boolean customTransport = false; - private OnReceiveCallBack callback; private final CallbackMap handlers = new CallbackMap(); - private HubProtocol protocol; - private Boolean handshakeReceived = false; - private HubConnectionState hubConnectionState = HubConnectionState.DISCONNECTED; - private final Lock hubConnectionStateLock = new ReentrantLock(); - private List onClosedCallbackList; + private final HubProtocol protocol; private final boolean skipNegotiate; - private Single accessTokenProvider; - private Single redirectAccessTokenProvider; private final Map headers; - private final Map localHeaders = new HashMap<>(); - private ConnectionState connectionState = null; - private HttpClient httpClient; - private String stopError; - private Timer pingTimer = null; - private final AtomicLong nextServerTimeout = new AtomicLong(); - private final AtomicLong nextPingActivation = new AtomicLong(); - private long keepAliveInterval = 15*1000; - private long serverTimeout = 30*1000; - private long tickRate = 1000; - private CompletableSubject handshakeResponseSubject; - private long handshakeResponseTimeout = 15*1000; - private Map streamMap = new ConcurrentHashMap<>(); - private TransportEnum transportEnum = TransportEnum.ALL; - private String connectionId; private final int negotiateVersion = 1; private final Logger logger = LoggerFactory.getLogger(HubConnection.class); - private ScheduledExecutorService handshakeTimeout = null; - private Completable start; + private final HttpClient httpClient; + private final Transport customTransport; + private final OnReceiveCallBack callback; + private final Single accessTokenProvider; + private final TransportEnum transportEnum; + + // These are all user-settable properties + private String baseUrl; + private List onClosedCallbackList; + private long keepAliveInterval = 15 * 1000; + private long serverTimeout = 30 * 1000; + private long handshakeResponseTimeout = 15 * 1000; + + // Private property, modified for testing + private long tickRate = 1000; + + + // Holds all mutable state other than user-defined handlers and settable properties. + private final ReconnectingConnectionState state; /** * Sets the server timeout interval for the connection. @@ -110,7 +101,11 @@ public long getKeepAliveInterval() { * @return A string representing the the client's connectionId. */ public String getConnectionId() { - return this.connectionId; + ConnectionState state = this.state.getConnectionStateUnsynchronized(true); + if (state != null) { + return state.connectionId; + } + return null; } // For testing purposes @@ -119,16 +114,8 @@ void setTickRate(long tickRateInMilliseconds) { } // For testing purposes - Map getStreamMap() { - return this.streamMap; - } - - TransportEnum getTransportEnum() { - return this.transportEnum; - } - Transport getTransport() { - return transport; + return this.state.getConnectionState().transport; } HubConnection(String url, Transport transport, boolean skipNegotiate, HttpClient httpClient, HubProtocol protocol, @@ -138,6 +125,7 @@ Transport getTransport() { throw new IllegalArgumentException("A valid url is required."); } + this.state = new ReconnectingConnectionState(this.logger); this.baseUrl = url; this.protocol = protocol; @@ -154,10 +142,14 @@ Transport getTransport() { } if (transport != null) { - this.transport = transport; - this.customTransport = true; + this.transportEnum = TransportEnum.ALL; + this.customTransport = transport; } else if (transportEnum != null) { this.transportEnum = transportEnum; + this.customTransport = null; + } else { + this.transportEnum = TransportEnum.ALL; + this.customTransport = null; } if (handshakeResponseTimeout > 0) { @@ -167,124 +159,12 @@ Transport getTransport() { this.headers = headers; this.skipNegotiate = skipNegotiate; - this.callback = (payload) -> { - resetServerTimeout(); - if (!handshakeReceived) { - List handshakeByteList = new ArrayList(); - byte curr = payload.get(); - // Add the handshake to handshakeBytes, but not the record separator - while (curr != RECORD_SEPARATOR) { - handshakeByteList.add(curr); - curr = payload.get(); - } - int handshakeLength = handshakeByteList.size() + 1; - byte[] handshakeBytes = new byte[handshakeLength - 1]; - for (int i = 0; i < handshakeLength - 1; i++) { - handshakeBytes[i] = handshakeByteList.get(i); - } - // The handshake will always be a UTF8 Json string - String handshakeResponseString = new String(handshakeBytes, StandardCharsets.UTF_8); - HandshakeResponseMessage handshakeResponse; - try { - handshakeResponse = HandshakeProtocol.parseHandshakeResponse(handshakeResponseString); - } catch (RuntimeException ex) { - RuntimeException exception = new RuntimeException("An invalid handshake response was received from the server.", ex); - handshakeResponseSubject.onError(exception); - throw exception; - } - if (handshakeResponse.getHandshakeError() != null) { - String errorMessage = "Error in handshake " + handshakeResponse.getHandshakeError(); - logger.error(errorMessage); - RuntimeException exception = new RuntimeException(errorMessage); - handshakeResponseSubject.onError(exception); - throw exception; - } - handshakeReceived = true; - handshakeResponseSubject.onComplete(); - - // The payload only contained the handshake response so we can return. - if (!payload.hasRemaining()) { - return; - } - } - - List messages = protocol.parseMessages(payload, connectionState); - - for (HubMessage message : messages) { - logger.debug("Received message of type {}.", message.getMessageType()); - switch (message.getMessageType()) { - case INVOCATION_BINDING_FAILURE: - InvocationBindingFailureMessage msg = (InvocationBindingFailureMessage)message; - logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException()); - break; - case INVOCATION: - - InvocationMessage invocationMessage = (InvocationMessage) message; - List handlers = this.handlers.get(invocationMessage.getTarget()); - if (handlers != null) { - for (InvocationHandler handler : handlers) { - try { - handler.getAction().invoke(invocationMessage.getArguments()); - } catch (Exception e) { - logger.error("Invoking client side method '{}' failed:", invocationMessage.getTarget(), e); - } - } - } else { - logger.warn("Failed to find handler for '{}' method.", invocationMessage.getTarget()); - } - break; - case CLOSE: - logger.info("Close message received from server."); - CloseMessage closeMessage = (CloseMessage) message; - stop(closeMessage.getError()); - break; - case PING: - // We don't need to do anything in the case of a ping message. - break; - case COMPLETION: - CompletionMessage completionMessage = (CompletionMessage)message; - InvocationRequest irq = connectionState.tryRemoveInvocation(completionMessage.getInvocationId()); - if (irq == null) { - logger.warn("Dropped unsolicited Completion message for invocation '{}'.", completionMessage.getInvocationId()); - continue; - } - irq.complete(completionMessage); - break; - case STREAM_ITEM: - StreamItem streamItem = (StreamItem)message; - InvocationRequest streamInvocationRequest = connectionState.getInvocation(streamItem.getInvocationId()); - if (streamInvocationRequest == null) { - logger.warn("Dropped unsolicited Completion message for invocation '{}'.", streamItem.getInvocationId()); - continue; - } - - streamInvocationRequest.addItem(streamItem); - break; - case STREAM_INVOCATION: - case CANCEL_INVOCATION: - logger.error("This client does not support {} messages.", message.getMessageType()); - - throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", message.getMessageType())); - } - } - }; - } - - private void timeoutHandshakeResponse(long timeout, TimeUnit unit) { - handshakeTimeout = Executors.newSingleThreadScheduledExecutor(); - handshakeTimeout.schedule(() -> { - // If onError is called on a completed subject the global error handler is called - if (!(handshakeResponseSubject.hasComplete() || handshakeResponseSubject.hasThrowable())) - { - handshakeResponseSubject.onError( - new TimeoutException("Timed out waiting for the server to respond to the handshake message.")); - } - }, timeout, unit); + this.callback = (payload) -> ReceiveLoop(payload); } - private Single handleNegotiate(String url) { + private Single handleNegotiate(String url, Map localHeaders) { HttpRequest request = new HttpRequest(); - request.addHeaders(this.localHeaders); + request.addHeaders(localHeaders); return httpClient.post(Negotiate.resolveNegotiateUrl(url, this.negotiateVersion), request).map((response) -> { if (response.getStatusCode() != 200) { @@ -299,11 +179,7 @@ private Single handleNegotiate(String url) { } if (negotiateResponse.getAccessToken() != null) { - this.redirectAccessTokenProvider = Single.just(negotiateResponse.getAccessToken()); - // We know the Single is non blocking in this case - // It's fine to call blockingGet() on it. - String token = this.redirectAccessTokenProvider.blockingGet(); - this.localHeaders.put("Authorization", "Bearer " + token); + localHeaders.put("Authorization", "Bearer " + negotiateResponse.getAccessToken()); } return negotiateResponse; @@ -316,7 +192,7 @@ private Single handleNegotiate(String url) { * @return HubConnection state enum. */ public HubConnectionState getConnectionState() { - return hubConnectionState; + return this.state.getHubConnectionState(); } // For testing only @@ -333,7 +209,7 @@ public void setBaseUrl(String url) { throw new IllegalArgumentException("The HubConnection url must be a valid url."); } - if (hubConnectionState != HubConnectionState.DISCONNECTED) { + if (this.state.getHubConnectionState() != HubConnectionState.DISCONNECTED) { throw new IllegalStateException("The HubConnection must be in the disconnected state to change the url."); } @@ -348,46 +224,47 @@ public void setBaseUrl(String url) { public Completable start() { CompletableSubject localStart = CompletableSubject.create(); - hubConnectionStateLock.lock(); + this.state.lock.lock(); try { - if (hubConnectionState != HubConnectionState.DISCONNECTED) { - logger.debug("The connection is in the '{}' state. Waiting for in-progress start to complete or completing this start immediately.", hubConnectionState); - return start; + if (this.state.getHubConnectionState() != HubConnectionState.DISCONNECTED) { + logger.debug("The connection is in the '{}' state. Waiting for in-progress start to complete or completing this start immediately.", this.state.getHubConnectionState()); + return this.state.getConnectionStateUnsynchronized(false).startTask; } - hubConnectionState = HubConnectionState.CONNECTING; - start = localStart; + this.state.changeState(HubConnectionState.DISCONNECTED, HubConnectionState.CONNECTING); - handshakeResponseSubject = CompletableSubject.create(); - handshakeReceived = false; CompletableSubject tokenCompletable = CompletableSubject.create(); + Map localHeaders = new HashMap<>(); localHeaders.put(UserAgentHelper.getUserAgentName(), UserAgentHelper.createUserAgentString()); if (headers != null) { - this.localHeaders.putAll(headers); + localHeaders.putAll(headers); } + ConnectionState connectionState = new ConnectionState(this); + this.state.setConnectionState(connectionState); + connectionState.startTask = localStart; accessTokenProvider.subscribe(token -> { if (token != null && !token.isEmpty()) { - this.localHeaders.put("Authorization", "Bearer " + token); + localHeaders.put("Authorization", "Bearer " + token); } tokenCompletable.onComplete(); }, error -> { tokenCompletable.onError(error); }); - stopError = null; Single negotiate = null; if (!skipNegotiate) { - negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(baseUrl, 0))); + negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(baseUrl, 0, localHeaders))); } else { negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(new NegotiateResponse(baseUrl)))); } negotiate.flatMapCompletable(negotiateResponse -> { logger.debug("Starting HubConnection."); + Transport transport = customTransport; if (transport == null) { Single tokenProvider = negotiateResponse.getAccessToken() != null ? Single.just(negotiateResponse.getAccessToken()) : accessTokenProvider; - switch (transportEnum) { + switch (negotiateResponse.getChosenTransport()) { case LONG_POLLING: transport = new LongPollingTransport(localHeaders, httpClient, tokenProvider); break; @@ -396,81 +273,88 @@ public Completable start() { } } + connectionState.transport = transport; + transport.setOnReceive(this.callback); transport.setOnClose((message) -> stopConnection(message)); return transport.start(negotiateResponse.getFinalUrl()).andThen(Completable.defer(() -> { ByteBuffer handshake = HandshakeProtocol.createHandshakeRequestMessage( - new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); + new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); - connectionState = new ConnectionState(this); - - return transport.send(handshake).andThen(Completable.defer(() -> { - timeoutHandshakeResponse(handshakeResponseTimeout, TimeUnit.MILLISECONDS); - return handshakeResponseSubject.andThen(Completable.defer(() -> { - hubConnectionStateLock.lock(); + this.state.lock(); + try { + if (this.state.hubConnectionState != HubConnectionState.CONNECTING) { + return Completable.error(new RuntimeException("Connection closed while trying to connect.")); + } + return connectionState.transport.send(handshake).andThen(Completable.defer(() -> { + this.state.lock(); try { - hubConnectionState = HubConnectionState.CONNECTED; - logger.info("HubConnection started."); - resetServerTimeout(); - //Don't send pings if we're using long polling. - if (transportEnum != TransportEnum.LONG_POLLING) { - activatePingTimer(); + ConnectionState activeState = this.state.getConnectionStateUnsynchronized(true); + if (activeState != null && activeState == connectionState) { + connectionState.timeoutHandshakeResponse(handshakeResponseTimeout, TimeUnit.MILLISECONDS); + } else { + return Completable.error(new RuntimeException("Connection closed while sending handshake.")); } } finally { - hubConnectionStateLock.unlock(); + this.state.unlock(); } + return connectionState.handshakeResponseSubject.andThen(Completable.defer(() -> { + this.state.lock(); + try { + ConnectionState activeState = this.state.getConnectionStateUnsynchronized(true); + if (activeState == null || activeState != connectionState) { + return Completable.error(new RuntimeException("Connection closed while waiting for handshake.")); + } + this.state.changeState(HubConnectionState.CONNECTING, HubConnectionState.CONNECTED); + logger.info("HubConnection started."); + connectionState.resetServerTimeout(); + // Don't send pings if we're using long polling. + if (negotiateResponse.getChosenTransport() != TransportEnum.LONG_POLLING) { + connectionState.activatePingTimer(); + } + } finally { + this.state.unlock(); + } - return Completable.complete(); + return Completable.complete(); + })); })); - })); + } finally { + this.state.unlock(); + } })); // subscribe makes this a "hot" completable so this runs immediately }).subscribe(() -> { localStart.onComplete(); }, error -> { - hubConnectionStateLock.lock(); - hubConnectionState = HubConnectionState.DISCONNECTED; - hubConnectionStateLock.unlock(); + this.state.lock(); + try { + ConnectionState activeState = this.state.getConnectionStateUnsynchronized(true); + if (activeState == connectionState) { + this.state.changeState(HubConnectionState.CONNECTING, HubConnectionState.DISCONNECTED); + } + // this error is already logged and we want the user to see the original error + } catch (Exception ex) { + } finally { + this.state.unlock(); + } + localStart.onError(error); }); } finally { - hubConnectionStateLock.unlock(); + this.state.lock.unlock(); } return localStart; } - private void activatePingTimer() { - this.pingTimer = new Timer(); - this.pingTimer.schedule(new TimerTask() { - @Override - public void run() { - try { - if (System.currentTimeMillis() > nextServerTimeout.get()) { - stop("Server timeout elapsed without receiving a message from the server."); - return; - } - - if (System.currentTimeMillis() > nextPingActivation.get()) { - sendHubMessage(PingMessage.getInstance()); - } - } catch (Exception e) { - logger.warn("Error sending ping: {}.", e.getMessage()); - // The connection is probably in a bad or closed state now, cleanup the timer so - // it stops triggering - pingTimer.cancel(); - } - } - }, new Date(0), tickRate); - } - - private Single startNegotiate(String url, int negotiateAttempts) { - if (hubConnectionState != HubConnectionState.CONNECTING) { + private Single startNegotiate(String url, int negotiateAttempts, Map localHeaders) { + if (this.state.getHubConnectionState() != HubConnectionState.CONNECTING) { throw new RuntimeException("HubConnection trying to negotiate when not in the CONNECTING state."); } - return handleNegotiate(url).flatMap(response -> { + return handleNegotiate(url, localHeaders).flatMap(response -> { if (response.getRedirectUrl() != null && negotiateAttempts >= MAX_NEGOTIATE_ATTEMPTS) { throw new RuntimeException("Negotiate redirection limit exceeded."); } @@ -479,23 +363,26 @@ private Single startNegotiate(String url, int negotiateAttemp Set transports = response.getAvailableTransports(); if (this.transportEnum == TransportEnum.ALL) { if (transports.contains("WebSockets")) { - this.transportEnum = TransportEnum.WEBSOCKETS; + response.setChosenTransport(TransportEnum.WEBSOCKETS); } else if (transports.contains("LongPolling")) { - this.transportEnum = TransportEnum.LONG_POLLING; + response.setChosenTransport(TransportEnum.LONG_POLLING); } else { throw new RuntimeException("There were no compatible transports on the server."); } } else if (this.transportEnum == TransportEnum.WEBSOCKETS && !transports.contains("WebSockets") || (this.transportEnum == TransportEnum.LONG_POLLING && !transports.contains("LongPolling"))) { throw new RuntimeException("There were no compatible transports on the server."); + } else { + response.setChosenTransport(this.transportEnum); } String connectionToken = ""; if (response.getVersion() > 0) { - this.connectionId = response.getConnectionId(); + this.state.getConnectionState().connectionId = response.getConnectionId(); connectionToken = response.getConnectionToken(); } else { - connectionToken = this.connectionId = response.getConnectionId(); + connectionToken = response.getConnectionId(); + this.state.getConnectionState().connectionId = connectionToken; } String finalUrl = Utils.appendQueryString(url, "id=" + connectionToken); @@ -504,7 +391,7 @@ private Single startNegotiate(String url, int negotiateAttemp return Single.just(response); } - return startNegotiate(response.getRedirectUrl(), negotiateAttempts + 1); + return startNegotiate(response.getRedirectUrl(), negotiateAttempts + 1, localHeaders); }); } @@ -515,20 +402,23 @@ private Single startNegotiate(String url, int negotiateAttemp * @return A Completable that completes when the connection has been stopped. */ private Completable stop(String errorMessage) { - hubConnectionStateLock.lock(); + Transport transport; + this.state.lock(); try { - if (hubConnectionState == HubConnectionState.DISCONNECTED) { + if (this.state.getHubConnectionState() == HubConnectionState.DISCONNECTED) { return Completable.complete(); } if (errorMessage != null) { - stopError = errorMessage; + this.state.getConnectionStateUnsynchronized(false).stopError = errorMessage; logger.error("HubConnection disconnected with an error: {}.", errorMessage); } else { logger.debug("Stopping HubConnection."); } + + transport = this.state.getConnectionStateUnsynchronized(false).transport; } finally { - hubConnectionStateLock.unlock(); + this.state.unlock(); } Completable stop = transport.stop(); @@ -536,6 +426,84 @@ private Completable stop(String errorMessage) { return stop; } + private void ReceiveLoop(ByteBuffer payload) + { + List messages; + ConnectionState connectionState; + this.state.lock(); + try { + connectionState = this.state.getConnectionState(); + connectionState.resetServerTimeout(); + connectionState.handleHandshake(payload); + // The payload only contained the handshake response so we can return. + if (!payload.hasRemaining()) { + return; + } + + messages = protocol.parseMessages(payload, connectionState); + } finally { + this.state.unlock(); + } + + for (HubMessage message : messages) { + logger.debug("Received message of type {}.", message.getMessageType()); + switch (message.getMessageType()) { + case INVOCATION_BINDING_FAILURE: + InvocationBindingFailureMessage msg = (InvocationBindingFailureMessage)message; + logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException()); + break; + case INVOCATION: + + InvocationMessage invocationMessage = (InvocationMessage) message; + List handlers = this.handlers.get(invocationMessage.getTarget()); + if (handlers != null) { + for (InvocationHandler handler : handlers) { + try { + handler.getAction().invoke(invocationMessage.getArguments()); + } catch (Exception e) { + logger.error("Invoking client side method '{}' failed:", invocationMessage.getTarget(), e); + } + } + } else { + logger.warn("Failed to find handler for '{}' method.", invocationMessage.getTarget()); + } + break; + case CLOSE: + logger.info("Close message received from server."); + CloseMessage closeMessage = (CloseMessage) message; + stop(closeMessage.getError()); + break; + case PING: + // We don't need to do anything in the case of a ping message. + break; + case COMPLETION: + CompletionMessage completionMessage = (CompletionMessage)message; + InvocationRequest irq = connectionState.tryRemoveInvocation(completionMessage.getInvocationId()); + if (irq == null) { + logger.warn("Dropped unsolicited Completion message for invocation '{}'.", completionMessage.getInvocationId()); + continue; + } + irq.complete(completionMessage); + break; + case STREAM_ITEM: + StreamItem streamItem = (StreamItem)message; + InvocationRequest streamInvocationRequest = connectionState.getInvocation(streamItem.getInvocationId()); + if (streamInvocationRequest == null) { + logger.warn("Dropped unsolicited Completion message for invocation '{}'.", streamItem.getInvocationId()); + continue; + } + + streamInvocationRequest.addItem(streamItem); + break; + case STREAM_INVOCATION: + case CANCEL_INVOCATION: + logger.error("This client does not support {} messages.", message.getMessageType()); + + throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", message.getMessageType())); + } + } + } + /** * Stops a connection to the server. * @@ -547,46 +515,29 @@ public Completable stop() { private void stopConnection(String errorMessage) { RuntimeException exception = null; - hubConnectionStateLock.lock(); + this.state.lock(); try { // errorMessage gets passed in from the transport. An already existing stopError value // should take precedence. - if (stopError != null) { - errorMessage = stopError; + if (this.state.getConnectionStateUnsynchronized(false).stopError != null) { + errorMessage = this.state.getConnectionStateUnsynchronized(false).stopError; } if (errorMessage != null) { exception = new RuntimeException(errorMessage); logger.error("HubConnection disconnected with an error {}.", errorMessage); } + + ConnectionState connectionState = this.state.getConnectionStateUnsynchronized(true); if (connectionState != null) { connectionState.cancelOutstandingInvocations(exception); - connectionState = null; - } - - if (pingTimer != null) { - pingTimer.cancel(); - pingTimer = null; + connectionState.close(); + this.state.setConnectionState(null); } logger.info("HubConnection stopped."); - hubConnectionState = HubConnectionState.DISCONNECTED; - handshakeResponseSubject.onComplete(); - redirectAccessTokenProvider = null; - connectionId = null; - transportEnum = TransportEnum.ALL; - this.localHeaders.clear(); - this.streamMap.clear(); - - if (this.handshakeTimeout != null) { - this.handshakeTimeout.shutdownNow(); - this.handshakeTimeout = null; - } - - if (this.customTransport == false) { - this.transport = null; - } + this.state.changeState(HubConnectionState.CONNECTED, HubConnectionState.DISCONNECTED); } finally { - hubConnectionStateLock.unlock(); + this.state.unlock(); } // Do not run these callbacks inside the hubConnectionStateLock @@ -605,14 +556,14 @@ private void stopConnection(String errorMessage) { * @param args The arguments to be passed to the method. */ public void send(String method, Object... args) { - hubConnectionStateLock.lock(); + this.state.lock(); try { - if (hubConnectionState != HubConnectionState.CONNECTED) { + if (this.state.getHubConnectionState() != HubConnectionState.CONNECTED) { throw new RuntimeException("The 'send' method cannot be called if the connection is not active."); } sendInvocationMessage(method, args); } finally { - hubConnectionStateLock.unlock(); + this.state.unlock(); } } @@ -622,7 +573,8 @@ private void sendInvocationMessage(String method, Object[] args) { private void sendInvocationMessage(String method, Object[] args, String id, Boolean isStreamInvocation) { List streamIds = new ArrayList<>(); - args = checkUploadStream(args, streamIds); + List streams = new ArrayList<>(); + args = checkUploadStream(args, streamIds, streams); InvocationMessage invocationMessage; if (isStreamInvocation) { invocationMessage = new StreamInvocationMessage(null, id, method, args, streamIds); @@ -630,35 +582,35 @@ private void sendInvocationMessage(String method, Object[] args, String id, Bool invocationMessage = new InvocationMessage(null, id, method, args, streamIds); } - sendHubMessage(invocationMessage); - launchStreams(streamIds); + sendHubMessageWithLock(invocationMessage); + launchStreams(streamIds, streams); } - void launchStreams(List streamIds) { - if (streamMap.isEmpty()) { + void launchStreams(List streamIds, List streams) { + if (streams.isEmpty()) { return; } - for (String streamId: streamIds) { - Observable observable = this.streamMap.get(streamId); - observable.subscribe( - (item) -> sendHubMessage(new StreamItem(null, streamId, item)), + for (int i = 0; i < streamIds.size(); i++) { + String streamId = streamIds.get(i); + Observable stream = streams.get(i); + stream.subscribe( + (item) -> sendHubMessageWithLock(new StreamItem(null, streamId, item)), (error) -> { - sendHubMessage(new CompletionMessage(null, streamId, null, error.toString())); - this.streamMap.remove(streamId); + sendHubMessageWithLock(new CompletionMessage(null, streamId, null, error.toString())); }, () -> { - sendHubMessage(new CompletionMessage(null, streamId, null, null)); - this.streamMap.remove(streamId); + sendHubMessageWithLock(new CompletionMessage(null, streamId, null, null)); }); } } - Object[] checkUploadStream(Object[] args, List streamIds) { + Object[] checkUploadStream(Object[] args, List streamIds, List streams) { if (args == null) { return new Object[] { null }; } + ConnectionState connectionState = this.state.getConnectionState(); List params = new ArrayList<>(Arrays.asList(args)); for (Object arg: args) { if (arg instanceof Observable) { @@ -666,7 +618,7 @@ Object[] checkUploadStream(Object[] args, List streamIds) { Observable stream = (Observable)arg; String streamId = connectionState.getNextInvocationId(); streamIds.add(streamId); - this.streamMap.put(streamId, stream); + streams.add(stream); } } @@ -680,14 +632,14 @@ Object[] checkUploadStream(Object[] args, List streamIds) { * @param args The arguments used to invoke the server method. * @return A Completable that indicates when the invocation has completed. */ - @SuppressWarnings("unchecked") public Completable invoke(String method, Object... args) { - hubConnectionStateLock.lock(); + this.state.lock(); try { - if (hubConnectionState != HubConnectionState.CONNECTED) { + if (this.state.getHubConnectionState() != HubConnectionState.CONNECTED) { throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active."); } + ConnectionState connectionState = this.state.getConnectionStateUnsynchronized(false); String id = connectionState.getNextInvocationId(); CompletableSubject subject = CompletableSubject.create(); @@ -705,7 +657,7 @@ public Completable invoke(String method, Object... args) { sendInvocationMessage(method, args, id, false); return subject; } finally { - hubConnectionStateLock.unlock(); + this.state.unlock(); } } @@ -739,12 +691,13 @@ public Single invoke(Type returnType, String method, Object... args) { @SuppressWarnings("unchecked") private Single invoke(Type returnType, Class returnClass, String method, Object... args) { - hubConnectionStateLock.lock(); + this.state.lock(); try { - if (hubConnectionState != HubConnectionState.CONNECTED) { + if (this.state.getHubConnectionState() != HubConnectionState.CONNECTED) { throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active."); } + ConnectionState connectionState = this.state.getConnectionStateUnsynchronized(false); String id = connectionState.getNextInvocationId(); InvocationRequest irq = new InvocationRequest(returnType, id); connectionState.addInvocation(irq); @@ -763,7 +716,7 @@ private Single invoke(Type returnType, Class returnClass, String metho sendInvocationMessage(method, args, id, false); return subject; } finally { - hubConnectionStateLock.unlock(); + this.state.unlock(); } } @@ -798,12 +751,13 @@ public Observable stream(Type returnType, String method, Object ... args) private Observable stream(Type returnType, Class returnClass, String method, Object ... args) { String invocationId; InvocationRequest irq; - hubConnectionStateLock.lock(); + this.state.lock(); try { - if (hubConnectionState != HubConnectionState.CONNECTED) { + if (this.state.getHubConnectionState() != HubConnectionState.CONNECTED) { throw new RuntimeException("The 'stream' method cannot be called if the connection is not active."); } + ConnectionState connectionState = this.state.getConnectionStateUnsynchronized(false); invocationId = connectionState.getNextInvocationId(); irq = new InvocationRequest(returnType, invocationId); connectionState.addInvocation(irq); @@ -821,38 +775,37 @@ private Observable stream(Type returnType, Class returnClass, String m return observable.doOnDispose(() -> { if (subscriptionCount.decrementAndGet() == 0) { CancelInvocationMessage cancelInvocationMessage = new CancelInvocationMessage(null, invocationId); - sendHubMessage(cancelInvocationMessage); - if (connectionState != null) { - connectionState.tryRemoveInvocation(invocationId); - } + sendHubMessageWithLock(cancelInvocationMessage); + connectionState.tryRemoveInvocation(invocationId); subject.onComplete(); } }); } finally { - hubConnectionStateLock.unlock(); - } - } - - private void sendHubMessage(HubMessage message) { - ByteBuffer serializedMessage = protocol.writeMessage(message); - if (message.getMessageType() == HubMessageType.INVOCATION ) { - logger.debug("Sending {} message '{}'.", message.getMessageType().name(), ((InvocationMessage)message).getInvocationId()); - } else if (message.getMessageType() == HubMessageType.STREAM_INVOCATION) { - logger.debug("Sending {} message '{}'.", message.getMessageType().name(), ((StreamInvocationMessage)message).getInvocationId()); - } else { - logger.debug("Sending {} message.", message.getMessageType().name()); + this.state.unlock(); } - - transport.send(serializedMessage).subscribeWith(CompletableSubject.create()); - resetKeepAlive(); } - private void resetServerTimeout() { - this.nextServerTimeout.set(System.currentTimeMillis() + serverTimeout); - } + private void sendHubMessageWithLock(HubMessage message) { + this.state.lock(); + try { + if (this.state.getHubConnectionState() != HubConnectionState.CONNECTED) { + throw new RuntimeException("Trying to send and message while the connection is not active."); + } + ByteBuffer serializedMessage = protocol.writeMessage(message); + if (message.getMessageType() == HubMessageType.INVOCATION) { + logger.debug("Sending {} message '{}'.", message.getMessageType().name(), ((InvocationMessage)message).getInvocationId()); + } else if (message.getMessageType() == HubMessageType.STREAM_INVOCATION) { + logger.debug("Sending {} message '{}'.", message.getMessageType().name(), ((StreamInvocationMessage)message).getInvocationId()); + } else { + logger.debug("Sending {} message.", message.getMessageType().name()); + } - private void resetKeepAlive() { - this.nextPingActivation.set(System.currentTimeMillis() + keepAliveInterval); + ConnectionState connectionState = this.state.getConnectionStateUnsynchronized(false); + connectionState.transport.send(serializedMessage).subscribeWith(CompletableSubject.create()); + connectionState.resetKeepAlive(); + } finally { + this.state.unlock(); + } } /** @@ -1316,7 +1269,18 @@ private final class ConnectionState implements InvocationBinder { private final HubConnection connection; private final AtomicInteger nextId = new AtomicInteger(0); private final HashMap pendingInvocations = new HashMap<>(); - private final Lock lock = new ReentrantLock(); + private final AtomicLong nextServerTimeout = new AtomicLong(); + private final AtomicLong nextPingActivation = new AtomicLong(); + private Timer pingTimer = null; + private Boolean handshakeReceived = false; + private ScheduledExecutorService handshakeTimeout = null; + + public final Lock lock = new ReentrantLock(); + public final CompletableSubject handshakeResponseSubject = CompletableSubject.create(); + public Transport transport; + public String connectionId; + public String stopError; + public Completable startTask; public ConnectionState(HubConnection connection) { this.connection = connection; @@ -1376,6 +1340,98 @@ public InvocationRequest tryRemoveInvocation(String id) { } } + public void resetServerTimeout() { + this.nextServerTimeout.set(System.currentTimeMillis() + serverTimeout); + } + + public void resetKeepAlive() { + this.nextPingActivation.set(System.currentTimeMillis() + keepAliveInterval); + } + + public void activatePingTimer() { + this.pingTimer = new Timer(); + this.pingTimer.schedule(new TimerTask() { + @Override + public void run() { + try { + if (System.currentTimeMillis() > nextServerTimeout.get()) { + stop("Server timeout elapsed without receiving a message from the server."); + return; + } + + if (System.currentTimeMillis() > nextPingActivation.get()) { + sendHubMessageWithLock(PingMessage.getInstance()); + } + } catch (Exception e) { + logger.warn("Error sending ping: {}.", e.getMessage()); + // The connection is probably in a bad or closed state now, cleanup the timer so + // it stops triggering + pingTimer.cancel(); + } + } + }, new Date(0), tickRate); + } + + public void handleHandshake(ByteBuffer payload) { + if (!handshakeReceived) { + List handshakeByteList = new ArrayList(); + byte curr = payload.get(); + // Add the handshake to handshakeBytes, but not the record separator + while (curr != RECORD_SEPARATOR) { + handshakeByteList.add(curr); + curr = payload.get(); + } + int handshakeLength = handshakeByteList.size() + 1; + byte[] handshakeBytes = new byte[handshakeLength - 1]; + for (int i = 0; i < handshakeLength - 1; i++) { + handshakeBytes[i] = handshakeByteList.get(i); + } + // The handshake will always be a UTF8 Json string + String handshakeResponseString = new String(handshakeBytes, StandardCharsets.UTF_8); + HandshakeResponseMessage handshakeResponse; + try { + handshakeResponse = HandshakeProtocol.parseHandshakeResponse(handshakeResponseString); + } catch (RuntimeException ex) { + RuntimeException exception = new RuntimeException("An invalid handshake response was received from the server.", ex); + handshakeResponseSubject.onError(exception); + throw exception; + } + if (handshakeResponse.getHandshakeError() != null) { + String errorMessage = "Error in handshake " + handshakeResponse.getHandshakeError(); + logger.error(errorMessage); + RuntimeException exception = new RuntimeException(errorMessage); + handshakeResponseSubject.onError(exception); + throw exception; + } + handshakeReceived = true; + handshakeResponseSubject.onComplete(); + } + } + + public void timeoutHandshakeResponse(long timeout, TimeUnit unit) { + handshakeTimeout = Executors.newSingleThreadScheduledExecutor(); + handshakeTimeout.schedule(() -> { + // If onError is called on a completed subject the global error handler is called + if (!(handshakeResponseSubject.hasComplete() || handshakeResponseSubject.hasThrowable())) + { + handshakeResponseSubject.onError( + new TimeoutException("Timed out waiting for the server to respond to the handshake message.")); + } + }, timeout, unit); + } + + public void close() { + handshakeResponseSubject.onComplete(); + + if (pingTimer != null) { + pingTimer.cancel(); + } + + if (this.handshakeTimeout != null) { + this.handshakeTimeout.shutdownNow(); + } + } + @Override public Type getReturnType(String invocationId) { InvocationRequest irq = getInvocation(invocationId); @@ -1402,6 +1458,76 @@ public List getParameterTypes(String methodName) { } } + // We don't have reconnect yet, but this helps align the Java client with the .NET client + // and hopefully make it easier to implement reconnect in the future + private final class ReconnectingConnectionState { + private final Logger logger; + private final Lock lock = new ReentrantLock(); + private ConnectionState state; + private HubConnectionState hubConnectionState = HubConnectionState.DISCONNECTED; + + public ReconnectingConnectionState(Logger logger) { + this.logger = logger; + } + + public void setConnectionState(ConnectionState state) { + this.lock.lock(); + try { + this.state = state; + } finally { + this.lock.unlock(); + } + } + + public ConnectionState getConnectionStateUnsynchronized(Boolean allowNull) { + if (allowNull != true && this.state == null) { + throw new RuntimeException("Connection is not active."); + } + return this.state; + } + + public ConnectionState getConnectionState() { + this.lock.lock(); + try { + if (this.state == null) { + throw new RuntimeException("Connection is not active."); + } + return this.state; + } finally { + this.lock.unlock(); + } + } + + public HubConnectionState getHubConnectionState() { + return this.hubConnectionState; + } + + public void changeState(HubConnectionState from, HubConnectionState to) { + this.lock.lock(); + try { + logger.debug("The HubConnection is attempting to transition from the {} state to the {} state.", from, to); + if (this.hubConnectionState != from) { + logger.debug("The HubConnection failed to transition from the {} state to the {} state because it was actually in the {} state.", + from, to, this.hubConnectionState); + throw new RuntimeException(String.format("The HubConnection failed to transition from the '%s' state to the '%s' state because it was actually in the '%s' state.", + from, to, this.hubConnectionState)); + } + + this.hubConnectionState = to; + } finally { + this.lock.unlock(); + } + } + + public void lock() { + this.lock.lock(); + } + + public void unlock() { + this.lock.unlock(); + } + } + @Override public void close() { try { diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/NegotiateResponse.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/NegotiateResponse.java index bf09b37578af..a593c542fd93 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/NegotiateResponse.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/NegotiateResponse.java @@ -18,6 +18,7 @@ class NegotiateResponse { private String error; private String finalUrl; private int version; + private TransportEnum chosenTransport; public NegotiateResponse(JsonReader reader) { try { @@ -125,4 +126,12 @@ public String getConnectionToken() { public void setFinalUrl(String url) { this.finalUrl = url; } + + public TransportEnum getChosenTransport() { + return chosenTransport; + } + + public void setChosenTransport(TransportEnum chosenTransport) { + this.chosenTransport = chosenTransport; + } } diff --git a/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/HubConnectionTest.java index 5839dd2f6633..54314f818fd2 100644 --- a/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/HubConnectionTest.java @@ -524,74 +524,6 @@ public void checkStreamUploadThroughSendWithArgs() { assertEquals("{\"type\":3,\"invocationId\":\"1\"}\u001E", TestUtils.byteBufferToString(messages[3])); } - @Test - public void streamMapIsClearedOnClose() { - MockTransport mockTransport = new MockTransport(); - HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); - - hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait(); - - ReplaySubject stream = ReplaySubject.create(); - hubConnection.send("UploadStream", stream, 12); - - stream.onNext("FirstItem"); - ByteBuffer[] messages = mockTransport.getSentMessages(); - assertEquals("{\"type\":1,\"target\":\"UploadStream\",\"arguments\":[12],\"streamIds\":[\"1\"]}\u001E", TestUtils.byteBufferToString(messages[1])); - assertEquals("{\"type\":2,\"invocationId\":\"1\",\"item\":\"FirstItem\"}\u001E", TestUtils.byteBufferToString(messages[2])); - - stream.onComplete(); - messages = mockTransport.getSentMessages(); - assertEquals("{\"type\":3,\"invocationId\":\"1\"}\u001E", TestUtils.byteBufferToString(messages[3])); - - hubConnection.stop().timeout(30, TimeUnit.SECONDS).blockingAwait(); - - assertTrue(hubConnection.getStreamMap().isEmpty()); - } - - @Test - public void streamMapEntriesRemovedOnStreamClose() { - MockTransport mockTransport = new MockTransport(); - HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); - - hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait(); - - ReplaySubject stream = ReplaySubject.create(); - hubConnection.send("UploadStream", stream, 12); - - ReplaySubject secondStream = ReplaySubject.create(); - hubConnection.send("SecondUploadStream", secondStream, 13); - - - stream.onNext("FirstItem"); - secondStream.onNext("SecondItem"); - ByteBuffer[] messages = mockTransport.getSentMessages(); - assertEquals("{\"type\":1,\"target\":\"UploadStream\",\"arguments\":[12],\"streamIds\":[\"1\"]}\u001E", TestUtils.byteBufferToString(messages[1])); - assertEquals("{\"type\":1,\"target\":\"SecondUploadStream\",\"arguments\":[13],\"streamIds\":[\"2\"]}\u001E", TestUtils.byteBufferToString(messages[2])); - assertEquals("{\"type\":2,\"invocationId\":\"1\",\"item\":\"FirstItem\"}\u001E", TestUtils.byteBufferToString(messages[3])); - assertEquals("{\"type\":2,\"invocationId\":\"2\",\"item\":\"SecondItem\"}\u001E", TestUtils.byteBufferToString(messages[4])); - - - assertEquals(2, hubConnection.getStreamMap().size()); - assertTrue(hubConnection.getStreamMap().keySet().contains("1")); - assertTrue(hubConnection.getStreamMap().keySet().contains("2")); - - // Verify that we clear the entry from the stream map after we clear the first stream. - stream.onComplete(); - assertEquals(1, hubConnection.getStreamMap().size()); - assertTrue(hubConnection.getStreamMap().keySet().contains("2")); - - secondStream.onError(new Exception("Exception")); - assertEquals(0, hubConnection.getStreamMap().size()); - assertTrue(hubConnection.getStreamMap().isEmpty()); - - messages = mockTransport.getSentMessages(); - assertEquals("{\"type\":3,\"invocationId\":\"1\"}\u001E", TestUtils.byteBufferToString(messages[5])); - assertEquals("{\"type\":3,\"invocationId\":\"2\",\"error\":\"java.lang.Exception: Exception\"}\u001E", TestUtils.byteBufferToString(messages[6])); - - hubConnection.stop().timeout(30, TimeUnit.SECONDS).blockingAwait(); - assertTrue(hubConnection.getStreamMap().isEmpty()); - } - @Test public void useSameSubjectMultipleTimes() { MockTransport mockTransport = new MockTransport(); @@ -3004,7 +2936,6 @@ public void ClientThatSelectsWebsocketsThrowsWhenWebsocketsAreNotAvailable() { .withHttpClient(client) .build(); - assertEquals(TransportEnum.WEBSOCKETS, hubConnection.getTransportEnum()); RuntimeException exception = assertThrows(RuntimeException.class, () -> hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait()); @@ -3023,13 +2954,52 @@ public void ClientThatSelectsLongPollingThrowsWhenLongPollingIsNotAvailable() { .withHttpClient(client) .build(); - assertEquals(TransportEnum.LONG_POLLING, hubConnection.getTransportEnum()); RuntimeException exception = assertThrows(RuntimeException.class, () -> hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait()); assertEquals(exception.getMessage(), "There were no compatible transports on the server."); } + @Test + public void ConnectionRestartDoesNotResetUserTransportEnum() { + AtomicInteger requestCount = new AtomicInteger(0); + AtomicReference blockGet = new AtomicReference(CompletableSubject.create()); + TestHttpClient client = new TestHttpClient() + .on("POST", (req) -> { + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))); + }) + .on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}," + + "{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")))) + .on("GET", (req) -> { + if (requestCount.incrementAndGet() >= 3) { + blockGet.get().timeout(30, TimeUnit.SECONDS).blockingAwait(); + } + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR))); + }) + .on("DELETE", (req) -> { + blockGet.get().onComplete(); + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))); + }); + + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransport(TransportEnum.LONG_POLLING) + .withHttpClient(client) + .build(); + + hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait(); + assertTrue(hubConnection.getTransport() instanceof LongPollingTransport); + hubConnection.stop().timeout(30, TimeUnit.SECONDS).blockingAwait(); + + requestCount.set(0); + blockGet.set(CompletableSubject.create()); + hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait(); + assertTrue(hubConnection.getTransport() instanceof LongPollingTransport); + hubConnection.stop().timeout(30, TimeUnit.SECONDS).blockingAwait(); + } + @Test public void LongPollingTransportAccessTokenProviderThrowsOnInitialPoll() { TestHttpClient client = new TestHttpClient() @@ -3189,10 +3159,10 @@ public void stopWithoutObservingWithLongPollingTransportStops() { closed.onComplete(); }); - hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait(); hubConnection.stop(); - closed.timeout(1, TimeUnit.SECONDS).blockingAwait(); + closed.timeout(30, TimeUnit.SECONDS).blockingAwait(); blockGet.onComplete(); assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); } @@ -3227,12 +3197,12 @@ public void hubConnectionClosesAndRunsOnClosedCallbackAfterCloseMessageWithLongP hubConnection.onClosed((ex) -> { closed.onComplete(); }); - hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); blockGet.onComplete(); - closed.timeout(1, TimeUnit.SECONDS).blockingAwait(); + closed.timeout(30, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); }