diff --git a/src/main/kotlin/org/phoenixframework/Channel.kt b/src/main/kotlin/org/phoenixframework/Channel.kt index 5d357b9..a287702 100644 --- a/src/main/kotlin/org/phoenixframework/Channel.kt +++ b/src/main/kotlin/org/phoenixframework/Channel.kt @@ -127,8 +127,17 @@ class Channel( this.pushBuffer = mutableListOf() this.rejoinTimer = TimeoutTimer( dispatchQueue = socket.dispatchQueue, - callback = { rejoinUntilConnected() }, - timerCalculation = socket.reconnectAfterMs) + timerCalculation = socket.rejoinAfterMs, + callback = { if (socket.isConnected) rejoin() } + ) + + // Respond to socket events + this.socket.onError { _, _-> this.rejoinTimer.reset() } + this.socket.onOpen { + this.rejoinTimer.reset() + if (this.isErrored) { this.rejoin() } + } + // Setup Push to be sent when joining this.joinPush = Push( @@ -150,12 +159,15 @@ class Channel( this.pushBuffer.clear() } + // Perform if Channel errors while attempting to join + this.joinPush.receive("error") { + this.state = State.ERRORED + if (this.socket.isConnected) { this.rejoinTimer.scheduleTimeout() } + } + // Perform if Channel timed out while attempting to join this.joinPush.receive("timeout") { - - // Only handle a timeout if the Channel is in the 'joining' state - if (!this.isJoining) return@receive - + // Log the timeout this.socket.logItems("Channel: timeouts $topic, $joinRef after $timeout ms") // Send a Push to the server to leave the Channel @@ -165,10 +177,11 @@ class Channel( timeout = this.timeout) leavePush.send() - // Mark the Channel as in an error and attempt to rejoin + // Mark the Channel as in an error and attempt to rejoin if socket is connected this.state = State.ERRORED this.joinPush.reset() - this.rejoinTimer.scheduleTimeout() + + if (this.socket.isConnected) { this.rejoinTimer.scheduleTimeout() } } // Clean up when the channel closes @@ -177,7 +190,7 @@ class Channel( this.rejoinTimer.reset() // Log that the channel was left - this.socket.logItems("Channel: close $topic") + this.socket.logItems("Channel: close $topic $joinRef") // Mark the channel as closed and remove it from the socket this.state = State.CLOSED @@ -186,16 +199,15 @@ class Channel( // Handles an error, attempts to rejoin this.onError { - // Do not emit error if the channel is in the process of leaving - // or if it has already closed - if (this.isLeaving || this.isClosed) return@onError - // Log that the channel received an error - this.socket.logItems("Channel: error $topic") + this.socket.logItems("Channel: error $topic ${it.payload}") + + // If error was received while joining, then reset the Push + if (isJoining) { this.joinPush.reset() } - // Mark the channel as errored and attempt to rejoin + // Mark the channel as errored and attempt to rejoin if socket is currently connected this.state = State.ERRORED - this.rejoinTimer.scheduleTimeout() + if (socket.isConnected) { this.rejoinTimer.scheduleTimeout() } } // Perform when the join reply is received @@ -245,8 +257,9 @@ class Channel( } // Join the channel + this.timeout = timeout this.joinedOnce = true - this.rejoin(timeout) + this.rejoin() return joinPush } @@ -304,6 +317,9 @@ class Channel( // will return false, so instead store it _before_ starting the leave val canPush = this.canPush + // If attempting a rejoin during a leave, then reset, cancelling the rejoin + this.rejoinTimer.reset() + // Now set the state to leaving this.state = State.LEAVING @@ -384,12 +400,6 @@ class Channel( //------------------------------------------------------------------------------ // Private //------------------------------------------------------------------------------ - /** Will continually attempt to rejoin the Channel on a timer. */ - private fun rejoinUntilConnected() { - this.rejoinTimer.scheduleTimeout() - if (this.socket.isConnected) this.rejoin() - } - /** Sends the Channel's joinPush to the Server */ private fun sendJoin(timeout: Long) { this.state = State.JOINING @@ -398,6 +408,10 @@ class Channel( /** Rejoins the Channel e.g. after a disconnect */ private fun rejoin(timeout: Long = this.timeout) { + // Do not attempt to rejoin if the channel is in the process of leaving + if (isLeaving) return + + // Send the joinPush this.sendJoin(timeout) } } \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Defaults.kt b/src/main/kotlin/org/phoenixframework/Defaults.kt index b57de50..b6c12be 100644 --- a/src/main/kotlin/org/phoenixframework/Defaults.kt +++ b/src/main/kotlin/org/phoenixframework/Defaults.kt @@ -34,11 +34,6 @@ object Defaults { /** Default heartbeat interval of 30s */ const val HEARTBEAT: Long = 30_000 - /** Default reconnect algorithm. Reconnects after 1s, 2s, 5s and then 10s thereafter */ - val steppedBackOff: (Int) -> Long = { tries -> - if (tries > 3) 10_000 else listOf(1_000L, 2_000L, 5_000L)[tries - 1] - } - /** Default reconnect algorithm for the socket */ val reconnectSteppedBackOff: (Int) -> Long = { tries -> if (tries > 9) 5_000 else listOf(10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L)[tries - 1] diff --git a/src/test/kotlin/org/phoenixframework/ChannelTest.kt b/src/test/kotlin/org/phoenixframework/ChannelTest.kt index 2ff5b31..fa812ef 100644 --- a/src/test/kotlin/org/phoenixframework/ChannelTest.kt +++ b/src/test/kotlin/org/phoenixframework/ChannelTest.kt @@ -5,23 +5,27 @@ import com.nhaarman.mockitokotlin2.any import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.never +import com.nhaarman.mockitokotlin2.spy import com.nhaarman.mockitokotlin2.times import com.nhaarman.mockitokotlin2.verify -import com.nhaarman.mockitokotlin2.verifyZeroInteractions import com.nhaarman.mockitokotlin2.whenever +import okhttp3.OkHttpClient import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.mockito.Mock +import org.mockito.Mockito.verifyZeroInteractions import org.mockito.MockitoAnnotations -import org.mockito.stubbing.Answer import org.phoenixframework.queue.ManualDispatchQueue import org.phoenixframework.utilities.getBindings class ChannelTest { + + @Mock lateinit var okHttpClient: OkHttpClient + @Mock lateinit var socket: Socket @Mock lateinit var mockCallback: ((Message) -> Unit) @@ -29,28 +33,21 @@ class ChannelTest { private val kDefaultTimeout = 10_000L private val kDefaultPayload: Payload = mapOf("one" to "two") private val kEmptyPayload: Payload = mapOf() - private val reconnectAfterMs: (Int) -> Long = Defaults.steppedBackOff lateinit var fakeClock: ManualDispatchQueue lateinit var channel: Channel - var mutableRef = 0 - var mutableRefAnswer: Answer = Answer { - mutableRef += 1 - mutableRef.toString() - } - @BeforeEach internal fun setUp() { MockitoAnnotations.initMocks(this) - mutableRef = 0 fakeClock = ManualDispatchQueue() whenever(socket.dispatchQueue).thenReturn(fakeClock) whenever(socket.makeRef()).thenReturn(kDefaultRef) whenever(socket.timeout).thenReturn(kDefaultTimeout) - whenever(socket.reconnectAfterMs).thenReturn(reconnectAfterMs) + whenever(socket.reconnectAfterMs).thenReturn(Defaults.reconnectSteppedBackOff) + whenever(socket.rejoinAfterMs).thenReturn(Defaults.rejoinSteppedBackOff) channel = Channel("topic", kDefaultPayload, socket) } @@ -155,6 +152,14 @@ class ChannelTest { @Nested @DisplayName("join") inner class Join { + + @BeforeEach + internal fun setUp() { + socket = spy(Socket(url ="https://localhost:4000/socket", client = okHttpClient)) + socket.dispatchQueue = fakeClock + channel = Channel("topic", kDefaultPayload, socket) + } + @Test internal fun `sets state to joining`() { channel.join() @@ -201,143 +206,151 @@ class ChannelTest { assertThat(joinPush.timeout).isEqualTo(newTimeout) } - /* End Join */ - } - - @Nested - @DisplayName("timeout behavior") - inner class TimeoutBehavior { - @Test - internal fun `succeeds before timeout`() { - val joinPush = channel.joinPush - val timeout = channel.timeout + @Nested + @DisplayName("timeout behavior") + inner class TimeoutBehavior { - channel.join() - verify(socket).push(any(), any(), any(), any(), any()) + private lateinit var joinPush: Push - fakeClock.tick(timeout / 2) + private fun receiveSocketOpen() { + whenever(socket.isConnected).thenReturn(true) + socket.onConnectionOpened() + } - joinPush.trigger("ok", kEmptyPayload) - assertThat(channel.state).isEqualTo(Channel.State.JOINED) + @BeforeEach + internal fun setUp() { + joinPush = channel.joinPush + } - fakeClock.tick(timeout) - verify(socket, times(1)).push(any(), any(), any(), any(), any()) - } + @Test + internal fun `succeeds before timeout`() { + val timeout = channel.timeout - @Test - internal fun `retries with backoff after timeout`() { - var ref = 0 - whenever(socket.isConnected).thenReturn(true) - whenever(socket.makeRef()).thenAnswer { - ref += 1 - ref.toString() - } + socket.connect() + this.receiveSocketOpen() - val joinPush = channel.joinPush - val timeout = channel.timeout + channel.join() + verify(socket).push(any(), any(), any(), any(), any()) + assertThat(channel.timeout).isEqualTo(10_000) - channel.join() - verify(socket, times(1)).push(any(), any(), any(), any(), any()) + fakeClock.tick(100) - fakeClock.tick(timeout) // leave push sent to the server - verify(socket, times(2)).push(any(), any(), any(), any(), any()) + joinPush.trigger("ok", kEmptyPayload) + assertThat(channel.state).isEqualTo(Channel.State.JOINED) - fakeClock.tick(1_000) // begin stepped backoff - verify(socket, times(3)).push(any(), any(), any(), any(), any()) + fakeClock.tick(timeout) + verify(socket, times(1)).push(any(), any(), any(), any(), any()) + } - fakeClock.tick(2_000) - verify(socket, times(4)).push(any(), any(), any(), any(), any()) + @Test + internal fun `retries with backoff after timeout`() { + val timeout = channel.timeout - fakeClock.tick(5_000) - verify(socket, times(5)).push(any(), any(), any(), any(), any()) + socket.connect() + this.receiveSocketOpen() - fakeClock.tick(10_000) - verify(socket, times(6)).push(any(), any(), any(), any(), any()) + channel.join().receive("timeout", mockCallback) - joinPush.trigger("ok", kEmptyPayload) - assertThat(channel.state).isEqualTo(Channel.State.JOINED) + verify(socket, times(1)).push(any(), eq("phx_join"), any(), any(), any()) + verify(mockCallback, never()).invoke(any()) - fakeClock.tick(10_000) - verify(socket, times(6)).push(any(), any(), any(), any(), any()) - assertThat(channel.state).isEqualTo(Channel.State.JOINED) - } + fakeClock.tick(timeout) // leave pushed to server + verify(socket, times(1)).push(any(), eq("phx_leave"), any(), any(), any()) + verify(mockCallback, times(1)).invoke(any()) - @Test - internal fun `with socket and join delay`() { - whenever(socket.isConnected).thenReturn(false) - val joinPush = channel.joinPush + fakeClock.tick(timeout + 1000) // rejoin + verify(socket, times(2)).push(any(), eq("phx_join"), any(), any(), any()) + verify(socket, times(2)).push(any(), eq("phx_leave"), any(), any(), any()) + verify(mockCallback, times(2)).invoke(any()) - channel.join() - verify(socket, times(1)).push(any(), any(), any(), any(), any()) + fakeClock.tick(10_000) + joinPush.trigger("ok", kEmptyPayload) + verify(socket, times(3)).push(any(), eq("phx_join"), any(), any(), any()) + assertThat(channel.state).isEqualTo(Channel.State.JOINED) + } - // Open the socket after a delay - fakeClock.tick(9_000) - verify(socket, times(1)).push(any(), any(), any(), any(), any()) + @Test + internal fun `with socket and join delay`() { + val joinPush = channel.joinPush - // join request returns between timeouts - fakeClock.tick(1_000) + channel.join() + verify(socket, times(1)).push(any(), any(), any(), any(), any()) - whenever(socket.isConnected).thenReturn(true) - joinPush.trigger("ok", kEmptyPayload) + // Open the socket after a delay + fakeClock.tick(9_000) + verify(socket, times(1)).push(any(), any(), any(), any(), any()) - assertThat(channel.state).isEqualTo(Channel.State.ERRORED) + // join request returns between timeouts + fakeClock.tick(1_000) + socket.connect() - fakeClock.tick(1_000) - assertThat(channel.state).isEqualTo(Channel.State.JOINING) + assertThat(channel.state).isEqualTo(Channel.State.ERRORED) + this.receiveSocketOpen() + joinPush.trigger("ok", kEmptyPayload) - joinPush.trigger("ok", kEmptyPayload) - assertThat(channel.state).isEqualTo(Channel.State.JOINED) + fakeClock.tick(1_000) + assertThat(channel.state).isEqualTo(Channel.State.JOINED) - verify(socket, times(3)).push(any(), any(), any(), any(), any()) - } + verify(socket, times(3)).push(any(), any(), any(), any(), any()) + } - @Test - internal fun `with socket delay only`() { - whenever(socket.isConnected).thenReturn(false) - val joinPush = channel.joinPush + @Test + internal fun `with socket delay only`() { + val joinPush = channel.joinPush - channel.join() + channel.join() + assertThat(channel.state).isEqualTo(Channel.State.JOINING) - // connect socket after a delay - fakeClock.tick(6_000) - whenever(socket.isConnected).thenReturn(true) + // connect socket after a delay + fakeClock.tick(6_000) + socket.connect() - fakeClock.tick(4_000) - joinPush.trigger("ok", kEmptyPayload) + // open socket after delay + fakeClock.tick(5_000) + this.receiveSocketOpen() + joinPush.trigger("ok", kEmptyPayload) - fakeClock.tick(2_000) - assertThat(channel.state).isEqualTo(Channel.State.JOINING) + joinPush.trigger("ok", kEmptyPayload) + assertThat(channel.state).isEqualTo(Channel.State.JOINED) + } - joinPush.trigger("ok", kEmptyPayload) - assertThat(channel.state).isEqualTo(Channel.State.JOINED) + /* End TimeoutBehavior */ } - /* End TimeoutBehavior */ + /* End Join */ } @Nested @DisplayName("joinPush") inner class JoinPush { + private lateinit var joinPush: Push + /* setup */ @BeforeEach internal fun setUp() { + socket = spy(Socket("https://localhost:4000/socket")) + socket.dispatchQueue = fakeClock + whenever(socket.isConnected).thenReturn(true) - whenever(socket.makeRef()).thenAnswer(mutableRefAnswer) + + channel = Channel("topic", kDefaultPayload, socket) + joinPush = channel.joinPush + channel.join() } /* helper methods */ - private fun receivesOk(joinPush: Push) { + private fun receivesOk() { fakeClock.tick(joinPush.timeout / 2) joinPush.trigger("ok", mapOf("a" to "b")) } - private fun receivesTimeout(joinPush: Push) { - fakeClock.tick(joinPush.timeout) + private fun receivesTimeout() { + fakeClock.tick(joinPush.timeout * 2) } - private fun receivesError(joinPush: Push) { + private fun receivesError() { fakeClock.tick(joinPush.timeout / 2) joinPush.trigger("error", mapOf("a" to "b")) } @@ -347,32 +360,23 @@ class ChannelTest { inner class ReceivesOk { @Test internal fun `sets channel state to joined`() { - val joinPush = channel.joinPush - assertThat(channel.state).isNotEqualTo(Channel.State.JOINED) - receivesOk(joinPush) + receivesOk() assertThat(channel.state).isEqualTo(Channel.State.JOINED) } @Test internal fun `triggers receive(ok) callback after ok response`() { - val joinPush = channel.joinPush - - val mockCallback = mock<(Message) -> Unit>() joinPush.receive("ok", mockCallback) - receivesOk(joinPush) + receivesOk() verify(mockCallback, times(1)).invoke(any()) } @Test internal fun `triggers receive('ok') callback if ok response already received`() { - val joinPush = channel.joinPush - - receivesOk(joinPush) - - val mockCallback = mock<(Message) -> Unit>() + receivesOk() joinPush.receive("ok", mockCallback) verify(mockCallback, times(1)).invoke(any()) @@ -380,74 +384,61 @@ class ChannelTest { @Test internal fun `does not trigger other receive callbacks after ok response`() { - val joinPush = channel.joinPush - - val mockCallback = mock<(Message) -> Unit>() joinPush .receive("error", mockCallback) .receive("timeout", mockCallback) - receivesOk(joinPush) - receivesTimeout(joinPush) + receivesOk() + receivesTimeout() verify(mockCallback, times(0)).invoke(any()) } @Test internal fun `clears timeoutTimer workItem`() { - val joinPush = channel.joinPush - assertThat(joinPush.timeoutTask).isNotNull() val mockTimeoutTask = mock() joinPush.timeoutTask = mockTimeoutTask - receivesOk(joinPush) + receivesOk() verify(mockTimeoutTask).cancel() assertThat(joinPush.timeoutTask).isNull() } @Test internal fun `sets receivedMessage`() { - val joinPush = channel.joinPush - assertThat(joinPush.receivedMessage).isNull() - receivesOk(joinPush) + receivesOk() assertThat(joinPush.receivedMessage?.payload).isEqualTo(mapOf("status" to "ok", "a" to "b")) assertThat(joinPush.receivedMessage?.status).isEqualTo("ok") } @Test internal fun `removes channel binding`() { - val joinPush = channel.joinPush - var bindings = channel.getBindings("chan_reply_1") assertThat(bindings).hasSize(1) - receivesOk(joinPush) + receivesOk() bindings = channel.getBindings("chan_reply_1") assertThat(bindings).isEmpty() } @Test internal fun `resets channel rejoinTimer`() { - val joinPush = channel.joinPush - val mockRejoinTimer = mock() channel.rejoinTimer = mockRejoinTimer - receivesOk(joinPush) + receivesOk() verify(mockRejoinTimer, times(1)).reset() } @Test internal fun `sends and empties channel's buffered pushEvents`() { - val joinPush = channel.joinPush - val mockPush = mock() channel.pushBuffer.add(mockPush) - receivesOk(joinPush) + receivesOk() verify(mockPush).send() assertThat(channel.pushBuffer).isEmpty() } @@ -460,51 +451,52 @@ class ChannelTest { inner class ReceivesTimeout { @Test internal fun `sets channel state to errored`() { - val joinPush = channel.joinPush - - receivesTimeout(joinPush) - assertThat(channel.state).isEqualTo(Channel.State.ERRORED) + var timeoutReceived = false + joinPush.receive("timeout") { + timeoutReceived = true + assertThat(channel.state).isEqualTo(Channel.State.ERRORED) + } + + receivesTimeout() + assertThat(timeoutReceived).isTrue() } @Test internal fun `triggers receive('timeout') callback after ok response`() { - val joinPush = channel.joinPush - val mockCallback = mock<(Message) -> Unit>() joinPush.receive("timeout", mockCallback) - receivesTimeout(joinPush) + receivesTimeout() verify(mockCallback).invoke(any()) } @Test internal fun `does not trigger other receive callbacks after timeout response`() { - val joinPush = channel.joinPush - val mockOk = mock<(Message) -> Unit>() val mockError = mock<(Message) -> Unit>() - val mockTimeout = mock<(Message) -> Unit>() + var timeoutReceived = false + joinPush .receive("ok", mockOk) .receive("error", mockError) - .receive("timeout", mockTimeout) + .receive("timeout") { + verifyZeroInteractions(mockOk) + verifyZeroInteractions(mockError) + timeoutReceived = true + } - receivesTimeout(joinPush) - joinPush.trigger("ok", emptyMap()) + receivesTimeout() + receivesOk() - verifyZeroInteractions(mockOk) - verifyZeroInteractions(mockError) - verify(mockTimeout).invoke(any()) + assertThat(timeoutReceived).isTrue() } @Test internal fun `schedules rejoinTimer timeout`() { - val joinPush = channel.joinPush - val mockTimer = mock() channel.rejoinTimer = mockTimer - receivesTimeout(joinPush) + receivesTimeout() verify(mockTimer).scheduleTimeout() } @@ -516,22 +508,18 @@ class ChannelTest { inner class ReceivesError { @Test internal fun `triggers receive('error') callback after error response`() { - val joinPush = channel.joinPush - - val mockCallback = mock<(Message) -> Unit>() + assertThat(channel.state).isEqualTo(Channel.State.JOINING) joinPush.receive("error", mockCallback) - receivesError(joinPush) - verify(mockCallback).invoke(any()) + receivesError() + joinPush.trigger("error", kEmptyPayload) + verify(mockCallback, times(1)).invoke(any()) } @Test internal fun `triggers receive('error') callback if error response already received`() { - val joinPush = channel.joinPush - - receivesError(joinPush) + receivesError() - val mockCallback = mock<(Message) -> Unit>() joinPush.receive("error", mockCallback) verify(mockCallback).invoke(any()) @@ -539,27 +527,32 @@ class ChannelTest { @Test internal fun `does not trigger other receive callbacks after ok response`() { - val joinPush = channel.joinPush - - val mockCallback = mock<(Message) -> Unit>() + val mockOk = mock<(Message) -> Unit>() + val mockError = mock<(Message) -> Unit>() + val mockTimeout = mock<(Message) -> Unit>() joinPush - .receive("ok", mockCallback) - .receive("timeout", mockCallback) + .receive("ok", mockOk) + .receive("error") { + mockError.invoke(it) + channel.leave() + } + .receive("timeout", mockTimeout) + + receivesError() + receivesTimeout() - receivesError(joinPush) - receivesTimeout(joinPush) - verifyZeroInteractions(mockCallback) + verify(mockError, times(1)).invoke(any()) + verifyZeroInteractions(mockOk) + verifyZeroInteractions(mockTimeout) } @Test internal fun `clears timeoutTimer workItem`() { - val joinPush = channel.joinPush - val mockTask = mock() assertThat(joinPush.timeoutTask).isNotNull() joinPush.timeoutTask = mockTask - receivesError(joinPush) + receivesError() verify(mockTask).cancel() assertThat(joinPush.timeoutTask).isNull() @@ -567,11 +560,9 @@ class ChannelTest { @Test internal fun `sets receivedMessage`() { - val joinPush = channel.joinPush - assertThat(joinPush.receivedMessage).isNull() - receivesError(joinPush) + receivesError() assertThat(joinPush.receivedMessage).isNotNull() assertThat(joinPush.receivedMessage?.status).isEqualTo("error") assertThat(joinPush.receivedMessage?.payload?.get("a")).isEqualTo("b") @@ -579,32 +570,26 @@ class ChannelTest { @Test internal fun `removes channel binding`() { - val joinPush = channel.joinPush - var bindings = channel.getBindings("chan_reply_1") assertThat(bindings).hasSize(1) - receivesError(joinPush) + receivesError() bindings = channel.getBindings("chan_reply_1") assertThat(bindings).isEmpty() } @Test internal fun `does not sets channel state to joined`() { - val joinPush = channel.joinPush - - receivesError(joinPush) + receivesError() assertThat(channel.state).isNotEqualTo(Channel.State.JOINED) } @Test internal fun `does not trigger channel's buffered pushEvents`() { - val joinPush = channel.joinPush - val mockPush = mock() channel.pushBuffer.add(mockPush) - receivesError(joinPush) + receivesError() verifyZeroInteractions(mockPush) assertThat(channel.pushBuffer).hasSize(1) } @@ -618,12 +603,25 @@ class ChannelTest { @Nested @DisplayName("onError") inner class OnError { + + private lateinit var joinPush: Push + + /* setup */ @BeforeEach internal fun setUp() { + socket = spy(Socket("https://localhost:4000/socket")) + socket.dispatchQueue = fakeClock + whenever(socket.isConnected).thenReturn(true) + + channel = Channel("topic", kDefaultPayload, socket) + joinPush = channel.joinPush + channel.join() + joinPush.trigger("ok", kEmptyPayload) } + @Test internal fun `sets channel state to errored`() { assertThat(channel.state).isNotEqualTo(Channel.State.ERRORED) @@ -632,6 +630,25 @@ class ChannelTest { assertThat(channel.state).isEqualTo(Channel.State.ERRORED) } + @Test + internal fun `does not trigger redundant errors during backoff`() { + // Spy the channel's join push + joinPush = spy(channel.joinPush) + channel.joinPush = joinPush + + verify(joinPush, times(0)).send() + + channel.trigger(Channel.Event.ERROR) + + fakeClock.tick(1000) + verify(joinPush, times(1)).send() + + channel.trigger("error") + + fakeClock.tick(1000) + verify(joinPush, times(1)).send() + } + @Test internal fun `tries to rejoin with backoff`() { val mockTimer = mock() @@ -645,45 +662,50 @@ class ChannelTest { internal fun `does not rejoin if leaving channel`() { channel.state = Channel.State.LEAVING - val mockPush = mock() - channel.joinPush = mockPush + // Spy the joinPush + joinPush = spy(channel.joinPush) + channel.joinPush = joinPush - channel.trigger(Channel.Event.ERROR) + socket.onConnectionError(Throwable(), null) fakeClock.tick(1_000) - verify(mockPush, never()).send() + verify(joinPush, never()).send() fakeClock.tick(2_000) - verify(mockPush, never()).send() + verify(joinPush, never()).send() assertThat(channel.state).isEqualTo(Channel.State.LEAVING) } @Test - internal fun `does nothing if channel is closed`() { + internal fun `does not rejoin if channel is closed`() { channel.state = Channel.State.CLOSED - val mockPush = mock() - channel.joinPush = mockPush + // Spy the joinPush + joinPush = spy(channel.joinPush) + channel.joinPush = joinPush - channel.trigger(Channel.Event.ERROR) + socket.onConnectionError(Throwable(), null) fakeClock.tick(1_000) - verify(mockPush, never()).send() + verify(joinPush, never()).send() fakeClock.tick(2_000) - verify(mockPush, never()).send() + verify(joinPush, never()).send() assertThat(channel.state).isEqualTo(Channel.State.CLOSED) } @Test - internal fun `triggers additional callbacks`() { - val mockCallback = mock<(Message) -> Unit>() + internal fun `triggers additional callbacks after join`() { channel.onError(mockCallback) + joinPush.trigger("ok", kEmptyPayload) + + assertThat(channel.state).isEqualTo(Channel.State.JOINED) + verifyZeroInteractions(mockCallback) channel.trigger(Channel.Event.ERROR) - verify(mockCallback).invoke(any()) + verify(mockCallback, times(1)).invoke(any()) } /* End OnError */ @@ -692,12 +714,24 @@ class ChannelTest { @Nested @DisplayName("onClose") inner class OnClose { + + private lateinit var joinPush: Push + + /* setup */ @BeforeEach internal fun setUp() { + socket = spy(Socket("https://localhost:4000/socket")) + socket.dispatchQueue = fakeClock + whenever(socket.isConnected).thenReturn(true) + + channel = Channel("topic", kDefaultPayload, socket) + joinPush = channel.joinPush + channel.join() } + @Test internal fun `sets state to closed`() { assertThat(channel.state).isNotEqualTo(Channel.State.CLOSED) @@ -708,11 +742,17 @@ class ChannelTest { @Test internal fun `does not rejoin`() { - val mockPush = mock() - channel.joinPush = mockPush + // Spy the channel's join push + joinPush = spy(channel.joinPush) + channel.joinPush = joinPush channel.trigger(Channel.Event.CLOSE) - verify(mockPush, never()).send() + + fakeClock.tick(1_000) + verify(joinPush, never()).send() + + fakeClock.tick(2_000) + verify(joinPush, never()).send() } @Test @@ -725,18 +765,18 @@ class ChannelTest { } @Test - internal fun `removes self from socket`() { + internal fun `removes channel from socket`() { channel.trigger(Channel.Event.CLOSE) verify(socket).remove(channel) } @Test internal fun `triggers additional callbacks`() { - val mockCallback = mock<(Message) -> Unit>() channel.onClose(mockCallback) + verifyZeroInteractions(mockCallback) channel.trigger(Channel.Event.CLOSE) - verify(mockCallback).invoke(any()) + verify(mockCallback, times(1)).invoke(any()) } /* End OnClose */ diff --git a/src/test/kotlin/org/phoenixframework/DefaultsTest.kt b/src/test/kotlin/org/phoenixframework/DefaultsTest.kt new file mode 100644 index 0000000..eca6d0b --- /dev/null +++ b/src/test/kotlin/org/phoenixframework/DefaultsTest.kt @@ -0,0 +1,45 @@ +package org.phoenixframework + +import com.google.common.truth.Truth.assertThat +import org.junit.jupiter.api.Test + +internal class DefaultsTest { + + @Test + internal fun `default timeout is 10_000`() { + assertThat(Defaults.TIMEOUT).isEqualTo(10_000) + } + + @Test + internal fun `default heartbeat is 30_000`() { + assertThat(Defaults.HEARTBEAT).isEqualTo(30_000) + } + + @Test + internal fun `default reconnectAfterMs returns all values`() { + val reconnect = Defaults.reconnectSteppedBackOff + + assertThat(reconnect(1)).isEqualTo(10) + assertThat(reconnect(2)).isEqualTo(50) + assertThat(reconnect(3)).isEqualTo(100) + assertThat(reconnect(4)).isEqualTo(150) + assertThat(reconnect(5)).isEqualTo(200) + assertThat(reconnect(6)).isEqualTo(250) + assertThat(reconnect(7)).isEqualTo(500) + assertThat(reconnect(8)).isEqualTo(1_000) + assertThat(reconnect(9)).isEqualTo(2_000) + assertThat(reconnect(10)).isEqualTo(5_000) + assertThat(reconnect(11)).isEqualTo(5_000) + } + + @Test + internal fun `default rejoinAfterMs returns all values`() { + val reconnect = Defaults.rejoinSteppedBackOff + + assertThat(reconnect(1)).isEqualTo(1_000) + assertThat(reconnect(2)).isEqualTo(2_000) + assertThat(reconnect(3)).isEqualTo(5_000) + assertThat(reconnect(4)).isEqualTo(10_000) + assertThat(reconnect(5)).isEqualTo(10_000) + } +} \ No newline at end of file diff --git a/src/test/kotlin/org/phoenixframework/PresenceTest.kt b/src/test/kotlin/org/phoenixframework/PresenceTest.kt index 4c41780..780f897 100644 --- a/src/test/kotlin/org/phoenixframework/PresenceTest.kt +++ b/src/test/kotlin/org/phoenixframework/PresenceTest.kt @@ -38,6 +38,7 @@ class PresenceTest { whenever(socket.timeout).thenReturn(Defaults.TIMEOUT) whenever(socket.makeRef()).thenReturn("1") whenever(socket.reconnectAfterMs).thenReturn { 1_000 } + whenever(socket.rejoinAfterMs).thenReturn(Defaults.rejoinSteppedBackOff) whenever(socket.dispatchQueue).thenReturn(mock()) channel = Channel("topic", mapOf(), socket) diff --git a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt index 9d474a8..c1ff3a6 100644 --- a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt +++ b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt @@ -23,20 +23,18 @@ class ManualDispatchQueue : DispatchQueue { // Filter all work items that are due to be fired and have not been // cancelled. Return early if there are no items to fire - var pastDueWorkItems = workItems.filter { it.isPastDue(advanceTo) && !it.isCancelled } + var pastDueWorkItems = workItems.filter { it.isPastDue(advanceTo) && !it.isCancelled }.sorted() // Keep looping until there are no more work items that are passed the advance to time while (pastDueWorkItems.isNotEmpty()) { - // Perform all work items that are due - pastDueWorkItems.forEach { - tickTime = it.deadline - it.perform() - } + val firstItem = pastDueWorkItems.first() + tickTime = firstItem.deadline + firstItem.perform() // Remove all work items that are past due or canceled workItems.removeAll { it.isPastDue(tickTime) || it.isCancelled } - pastDueWorkItems = workItems.filter { it.isPastDue(advanceTo) && !it.isCancelled } + pastDueWorkItems = workItems.filter { it.isPastDue(advanceTo) && !it.isCancelled }.sorted() } // Now that all work has been performed, advance the clock diff --git a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt index 4a5705e..02c18cc 100644 --- a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt +++ b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt @@ -122,6 +122,53 @@ internal class ManualDispatchQueueTest { assertThat(queue.tickTime).isEqualTo(250) } + @Test + internal fun `triggers work in order of deadline`() { + var task200Called = false + var task100Called = false + + + val task200 = queue.queue(200, TimeUnit.MILLISECONDS) { + task200Called = true + } + + queue.queue(100, TimeUnit.MILLISECONDS) { + task100Called = true + task200.cancel() + } + + + queue.tick(300) + assertThat(task100Called).isTrue() + assertThat(task200Called).isFalse() + } + + @Test + internal fun `triggers inserted work in order of deadline`() { + var task500Called = false + var task200Called = false + var task100Called = false + + val task500 = queue.queue(500, TimeUnit.MILLISECONDS) { + task500Called = true + } + + queue.queue(200, TimeUnit.MILLISECONDS) { + task200Called = true + + queue.queue(100, TimeUnit.MILLISECONDS) { + task100Called = true + task500.cancel() + } + } + + queue.tick(600) + assertThat(task100Called).isTrue() + assertThat(task200Called).isTrue() + assertThat(task500Called).isFalse() + } + + @Test internal fun `does not triggers nested work that is scheduled outside of the tick`() { diff --git a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt index 0d0a03f..42d47eb 100644 --- a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt +++ b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt @@ -6,7 +6,7 @@ class ManualDispatchWorkItem( private val runnable: () -> Unit, var deadline: Long, private val period: Long = 0 -) : DispatchWorkItem { +) : DispatchWorkItem, Comparable { private var performCount = 0 @@ -30,4 +30,8 @@ class ManualDispatchWorkItem( override fun cancel() { this.isCancelled = true } + + override fun compareTo(other: ManualDispatchWorkItem): Int { + return deadline.compareTo(other.deadline) + } }