diff --git a/ChatExample/app/libs/JavaPhoenixClient-0.3.0.jar b/ChatExample/app/libs/JavaPhoenixClient-0.3.0.jar deleted file mode 100644 index 4aa7d32..0000000 Binary files a/ChatExample/app/libs/JavaPhoenixClient-0.3.0.jar and /dev/null differ diff --git a/ChatExample/app/libs/JavaPhoenixClient-0.3.4.jar b/ChatExample/app/libs/JavaPhoenixClient-0.3.4.jar new file mode 100644 index 0000000..34f7511 Binary files /dev/null and b/ChatExample/app/libs/JavaPhoenixClient-0.3.4.jar differ diff --git a/ChatExample/app/src/main/AndroidManifest.xml b/ChatExample/app/src/main/AndroidManifest.xml index 5dbb22a..15f50d6 100644 --- a/ChatExample/app/src/main/AndroidManifest.xml +++ b/ChatExample/app/src/main/AndroidManifest.xml @@ -1,24 +1,26 @@ + xmlns:tools="http://schemas.android.com/tools" + package="com.github.dsrees.chatexample"> - + - - - - + + + + - - - - + + + + \ No newline at end of file diff --git a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt index 3acdcb5..6e225dd 100644 --- a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt +++ b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt @@ -22,12 +22,12 @@ class MainActivity : AppCompatActivity() { // Use when connecting to https://github.com/dwyl/phoenix-chat-example - private val socket = Socket("https://phxchat.herokuapp.com/socket/websocket") - private val topic = "room:lobby" + // private val socket = Socket("https://phxchat.herokuapp.com/socket/websocket") + // private val topic = "room:lobby" // Use when connecting to local server -// private val socket = Socket("ws://10.0.2.2:4000/socket/websocket") -// private val topic = "rooms:lobby" + private val socket = Socket("ws://10.0.2.2:4000/socket/websocket") + private val topic = "rooms:lobby" private var lobbyChannel: Channel? = null diff --git a/src/main/kotlin/org/phoenixframework/Channel.kt b/src/main/kotlin/org/phoenixframework/Channel.kt index 4c97999..1d840f3 100644 --- a/src/main/kotlin/org/phoenixframework/Channel.kt +++ b/src/main/kotlin/org/phoenixframework/Channel.kt @@ -112,6 +112,9 @@ class Channel( /** Timer to attempt rejoins */ internal var rejoinTimer: TimeoutTimer + /** Refs if stateChange hooks */ + internal var stateChangeRefs: MutableList + /** * Optional onMessage hook that can be provided. Receives all event messages for specialized * handling before dispatching to the Channel event callbacks. @@ -125,6 +128,7 @@ class Channel( this.timeout = socket.timeout this.joinedOnce = false this.pushBuffer = mutableListOf() + this.stateChangeRefs = mutableListOf() this.rejoinTimer = TimeoutTimer( dispatchQueue = socket.dispatchQueue, timerCalculation = socket.rejoinAfterMs, @@ -133,10 +137,11 @@ class Channel( // Respond to socket events this.socket.onError { _, _-> this.rejoinTimer.reset() } + .apply { stateChangeRefs.add(this) } this.socket.onOpen { this.rejoinTimer.reset() if (this.isErrored) { this.rejoin() } - } + }.apply { stateChangeRefs.add(this) } // Setup Push to be sent when joining @@ -203,7 +208,14 @@ class Channel( this.socket.logItems("Channel: error $topic ${it.payload}") // If error was received while joining, then reset the Push - if (isJoining) { this.joinPush.reset() } + if (isJoining) { + // Make sure that the "phx_join" isn't buffered to send once the socket + // reconnects. The channel will send a new join event when the socket connects. + this.joinRef?.let { this.socket.removeFromSendBuffer(it) } + + // Reset the push to be used again later + this.joinPush.reset() + } // Mark the channel as errored and attempt to rejoin if socket is currently connected this.state = State.ERRORED @@ -414,6 +426,9 @@ class Channel( // Do not attempt to rejoin if the channel is in the process of leaving if (isLeaving) return + // Leave potentially duplicated channels + this.socket.leaveOpenTopic(this.topic) + // Send the joinPush this.sendJoin(timeout) } diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt index 66bb7fc..305c93d 100644 --- a/src/main/kotlin/org/phoenixframework/Socket.kt +++ b/src/main/kotlin/org/phoenixframework/Socket.kt @@ -35,33 +35,42 @@ typealias Payload = Map /** Data class that holds callbacks assigned to the socket */ internal class StateChangeCallbacks { - var open: List<() -> Unit> = ArrayList() + var open: List Unit>> = ArrayList() private set - var close: List<() -> Unit> = ArrayList() + var close: List Unit>> = ArrayList() private set - var error: List<(Throwable, Response?) -> Unit> = ArrayList() + var error: List Unit>> = ArrayList() private set - var message: List<(Message) -> Unit> = ArrayList() + var message: List Unit>> = ArrayList() private set /** Safely adds an onOpen callback */ - fun onOpen(callback: () -> Unit) { - this.open = this.open + callback + fun onOpen(ref: String, callback: () -> Unit) { + this.open = this.open + Pair(ref, callback) } /** Safely adds an onClose callback */ - fun onClose(callback: () -> Unit) { - this.close = this.close + callback + fun onClose(ref: String, callback: () -> Unit) { + this.close = this.close + Pair(ref, callback) } /** Safely adds an onError callback */ - fun onError(callback: (Throwable, Response?) -> Unit) { - this.error = this.error + callback + fun onError(ref: String, callback: (Throwable, Response?) -> Unit) { + this.error = this.error + Pair(ref, callback) } /** Safely adds an onMessage callback */ - fun onMessage(callback: (Message) -> Unit) { - this.message = this.message + callback + fun onMessage(ref: String, callback: (Message) -> Unit) { + this.message = this.message + Pair(ref, callback) + } + + /** Clears any callbacks with the matching refs */ + fun release(refs: List) { + open = open.filter { refs.contains(it.first) } + close = close.filter { refs.contains(it.first) } + error = error.filter { refs.contains(it.first) } + message = message.filter { refs.contains(it.first) } + } /** Clears all stored callbacks */ @@ -151,8 +160,11 @@ class Socket( /** Collection of unclosed channels created by the Socket */ internal var channels: List = ArrayList() - /** Buffers messages that need to be sent once the socket has connected */ - internal var sendBuffer: MutableList<() -> Unit> = ArrayList() + /** + * Buffers messages that need to be sent once the socket has connected. It is an array of Pairs + * that contain the ref of the message to send and the callback that will send the message. + */ + internal var sendBuffer: MutableList Unit>> = ArrayList() /** Ref counter for messages */ internal var ref: Int = 0 @@ -273,20 +285,20 @@ class Socket( } - fun onOpen(callback: (() -> Unit)) { - this.stateChangeCallbacks.onOpen(callback) + fun onOpen(callback: (() -> Unit)): String { + return makeRef().apply { stateChangeCallbacks.onOpen(this, callback) } } - fun onClose(callback: () -> Unit) { - this.stateChangeCallbacks.onClose(callback) + fun onClose(callback: () -> Unit): String { + return makeRef().apply { stateChangeCallbacks.onClose(this, callback) } } - fun onError(callback: (Throwable, Response?) -> Unit) { - this.stateChangeCallbacks.onError(callback) + fun onError(callback: (Throwable, Response?) -> Unit): String { + return makeRef().apply { stateChangeCallbacks.onError(this, callback) } } - fun onMessage(callback: (Message) -> Unit) { - this.stateChangeCallbacks.onMessage(callback) + fun onMessage(callback: (Message) -> Unit): String { + return makeRef().apply { stateChangeCallbacks.onMessage(this, callback) } } fun removeAllCallbacks() { @@ -301,6 +313,8 @@ class Socket( } fun remove(channel: Channel) { + this.off(channel.stateChangeRefs) + // To avoid a ConcurrentModificationException, filter out the channels to be // removed instead of calling .remove() on the list, thus returning a new list // that does not contain the channel that was removed. @@ -308,6 +322,15 @@ class Socket( .filter { it.joinRef != channel.joinRef } } + /** + * Removes [onOpen], [onClose], [onError], and [onMessage] registrations by their [ref] value. + * + * @param refs List of refs to remove + */ + fun off(refs: List) { + this.stateChangeCallbacks.release(refs) + } + //------------------------------------------------------------------------------ // Internal //------------------------------------------------------------------------------ @@ -341,7 +364,7 @@ class Socket( } else { // If the socket is not connected, add the push to a buffer which will // be sent immediately upon connection. - sendBuffer.add(callback) + sendBuffer.add(Pair(ref, callback)) } } @@ -374,7 +397,7 @@ class Socket( // Since the connections onClose was null'd out, inform all state callbacks // that the Socket has closed - this.stateChangeCallbacks.close.forEach { it.invoke() } + this.stateChangeCallbacks.close.forEach { it.second.invoke() } callback?.invoke() } @@ -391,11 +414,27 @@ class Socket( /** Send all messages that were buffered before the socket opened */ internal fun flushSendBuffer() { if (isConnected && sendBuffer.isNotEmpty()) { - this.sendBuffer.forEach { it.invoke() } + this.sendBuffer.forEach { it.second.invoke() } this.sendBuffer.clear() } } + /** Removes an item from the send buffer with the matching ref */ + internal fun removeFromSendBuffer(ref: String) { + this.sendBuffer = this.sendBuffer + .filter { it.first != ref } + .toMutableList() + } + + internal fun leaveOpenTopic(topic: String) { + this.channels + .firstOrNull { it.topic == topic && (it.isJoined || it.isJoining) } + ?.let { + logItems("Transport: Leaving duplicate topic: [$topic]") + it.leave() + } + } + //------------------------------------------------------------------------------ // Heartbeat //------------------------------------------------------------------------------ @@ -469,7 +508,7 @@ class Socket( this.resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - this.stateChangeCallbacks.open.forEach { it.invoke() } + this.stateChangeCallbacks.open.forEach { it.second.invoke() } } internal fun onConnectionClosed(code: Int) { @@ -486,7 +525,7 @@ class Socket( } // Inform callbacks the socket closed - this.stateChangeCallbacks.close.forEach { it.invoke() } + this.stateChangeCallbacks.close.forEach { it.second.invoke() } } internal fun onConnectionMessage(rawMessage: String) { @@ -504,7 +543,7 @@ class Socket( .forEach { it.trigger(message) } // Inform all onMessage callbacks of the message - this.stateChangeCallbacks.message.forEach { it.invoke(message) } + this.stateChangeCallbacks.message.forEach { it.second.invoke(message) } } internal fun onConnectionError(t: Throwable, response: Response?) { @@ -514,7 +553,7 @@ class Socket( this.triggerChannelError() // Inform any state callbacks of the error - this.stateChangeCallbacks.error.forEach { it.invoke(t, response) } + this.stateChangeCallbacks.error.forEach { it.second.invoke(t, response) } } } \ No newline at end of file diff --git a/src/test/kotlin/org/phoenixframework/ChannelTest.kt b/src/test/kotlin/org/phoenixframework/ChannelTest.kt index 4963bfd..943fb05 100644 --- a/src/test/kotlin/org/phoenixframework/ChannelTest.kt +++ b/src/test/kotlin/org/phoenixframework/ChannelTest.kt @@ -193,7 +193,7 @@ class ChannelTest { @Test internal fun `triggers socket push with channel params`() { channel.join() - verify(socket).push("topic", "phx_join", kDefaultPayload, kDefaultRef, channel.joinRef) + verify(socket).push("topic", "phx_join", kDefaultPayload, "3", channel.joinRef) } @Test @@ -206,6 +206,22 @@ class ChannelTest { assertThat(joinPush.timeout).isEqualTo(newTimeout) } + @Test + internal fun `leaves existing duplicate topic on new join`() { + val socket = spy(Socket("wss://localhost:4000/socket")) + val channel = socket.channel("topic") + + channel.join().receive("ok") { + val newChannel = socket.channel("topic") + assertThat(channel.isJoined).isTrue() + newChannel.join() + + assertThat(channel.isJoined).isFalse() + } + + channel.joinPush.trigger("ok", kEmptyPayload) + } + @Nested @DisplayName("timeout behavior") inner class TimeoutBehavior { @@ -429,11 +445,11 @@ class ChannelTest { @Test internal fun `removes channel binding`() { - var bindings = channel.getBindings("chan_reply_1") + var bindings = channel.getBindings("chan_reply_3") assertThat(bindings).hasSize(1) receivesOk() - bindings = channel.getBindings("chan_reply_1") + bindings = channel.getBindings("chan_reply_3") assertThat(bindings).isEmpty() } @@ -583,7 +599,7 @@ class ChannelTest { @Test internal fun `removes channel binding`() { - var bindings = channel.getBindings("chan_reply_1") + var bindings = channel.getBindings("chan_reply_3") assertThat(bindings).hasSize(1) receivesError() @@ -656,12 +672,25 @@ class ChannelTest { fakeClock.tick(1000) verify(joinPush, times(1)).send() - channel.trigger("error") + channel.trigger(Channel.Event.ERROR) fakeClock.tick(1000) verify(joinPush, times(1)).send() } + @Test + internal fun `removes the joinPush message from sendBuffer`() { + val channel = Channel("topic", kDefaultPayload, socket) + val push = mock() + whenever(push.ref).thenReturn("10") + channel.joinPush = push + channel.state = Channel.State.JOINING + + channel.trigger(Channel.Event.ERROR) + verify(socket).removeFromSendBuffer("10") + verify(push).reset() + } + @Test internal fun `tries to rejoin with backoff`() { val mockTimer = mock() diff --git a/src/test/kotlin/org/phoenixframework/SocketTest.kt b/src/test/kotlin/org/phoenixframework/SocketTest.kt index 3150211..bd1397a 100644 --- a/src/test/kotlin/org/phoenixframework/SocketTest.kt +++ b/src/test/kotlin/org/phoenixframework/SocketTest.kt @@ -462,7 +462,7 @@ class SocketTest { verify(connection, never()).send(any()) assertThat(socket.sendBuffer).hasSize(2) - socket.sendBuffer.forEach { it.invoke() } + socket.sendBuffer.forEach { it.second.invoke() } verify(connection, times(2)).send(any()) } @@ -540,9 +540,9 @@ class SocketTest { @Test internal fun `invokes callbacks in buffer when connected`() { var oneCalled = 0 - socket.sendBuffer.add { oneCalled += 1 } + socket.sendBuffer.add(Pair("0", { oneCalled += 1 })) var twoCalled = 0 - socket.sendBuffer.add { twoCalled += 1 } + socket.sendBuffer.add(Pair("1", { twoCalled += 1 })) val threeCalled = 0 whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN) @@ -563,7 +563,7 @@ class SocketTest { @Test internal fun `empties send buffer`() { - socket.sendBuffer.add { } + socket.sendBuffer.add(Pair(null, {})) whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN) socket.connect() @@ -577,6 +577,31 @@ class SocketTest { /* End FlushSendBuffer */ } + @Nested + @DisplayName("removeFromSendBuffer") + inner class RemoveFromSendBuffer { + @Test + internal fun `removes a callback with matching ref`() { + var oneCalled = 0 + socket.sendBuffer.add(Pair("0", { oneCalled += 1 })) + var twoCalled = 0 + socket.sendBuffer.add(Pair("1", { twoCalled += 1 })) + + whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN) + + // connect + socket.connect() + + socket.removeFromSendBuffer("0") + + // sends once connected + socket.flushSendBuffer() + assertThat(oneCalled).isEqualTo(0) + assertThat(twoCalled).isEqualTo(1) + } + } + + @Nested @DisplayName("resetHeartbeat") inner class ResetHeartbeat { @@ -638,7 +663,7 @@ class SocketTest { @Test internal fun `flushes the send buffer`() { var oneCalled = 0 - socket.sendBuffer.add { oneCalled += 1 } + socket.sendBuffer.add(Pair("1", { oneCalled += 1 })) socket.onConnectionOpened() assertThat(oneCalled).isEqualTo(1)