diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt index d10d5ee..92ece3a 100644 --- a/src/main/kotlin/org/phoenixframework/Socket.kt +++ b/src/main/kotlin/org/phoenixframework/Socket.kt @@ -277,7 +277,12 @@ class Socket( } fun remove(channel: Channel) { - this.channels.removeAll { it.joinRef == channel.joinRef } + // To avoid a ConcurrentModificationException, filter out the channels to be + // removed instead of calling .remove() on the list, thus returning a new list + // that does not contain the channel that was removed. + this.channels = channels + .filter { it.joinRef != channel.joinRef } + .toMutableList() } //------------------------------------------------------------------------------ diff --git a/src/test/kotlin/org/phoenixframework/SocketTest.kt b/src/test/kotlin/org/phoenixframework/SocketTest.kt index 3b74612..4ca29c1 100644 --- a/src/test/kotlin/org/phoenixframework/SocketTest.kt +++ b/src/test/kotlin/org/phoenixframework/SocketTest.kt @@ -386,6 +386,41 @@ class SocketTest { assertThat(socket.channels).contains(channel2) } + @Test + internal fun `does not throw exception when iterating over channels`() { + val channel1 = socket.channel("topic-1") + val channel2 = socket.channel("topic-2") + + channel1.joinPush.ref = "1" + channel2.joinPush.ref = "2" + + channel1.join().trigger("ok", emptyMap()) + channel2.join().trigger("ok", emptyMap()) + + + var chan1Called = false + channel1.onError { chan1Called = true } + + var chan2Called = false + channel2.onError { + chan2Called = true + socket.remove(channel2) + } + + // This will trigger an iteration over the socket.channels list which will trigger + // channel2.onError. That callback will attempt to remove channel2 during iteration + // which would throw a ConcurrentModificationException if the socket.remove method + // is implemented incorrectly. + socket.onConnectionError(IllegalStateException(), null) + + // Assert that both on all error's got called even when a channel was removed + assertThat(chan1Called).isTrue() + assertThat(chan2Called).isTrue() + + assertThat(socket.channels).doesNotContain(channel2) + assertThat(socket.channels).contains(channel1) + } + /* End Remove */ }