Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
language: java
dist: trusty
after_success:
- bash <(curl -s https://codecov.io/bash)
jdk:
Expand Down
12 changes: 4 additions & 8 deletions src/main/kotlin/org/phoenixframework/Push.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Push(
var timeoutTask: DispatchWorkItem? = null

/** Hooks into a Push. Where .receive("ok", callback(Payload)) are stored */
var receiveHooks: MutableMap<String, MutableList<((message: Message) -> Unit)>> = HashMap()
var receiveHooks: MutableMap<String, List<((message: Message) -> Unit)>> = HashMap()

/** True if the Push has been sent */
var sent: Boolean = false
Expand Down Expand Up @@ -93,13 +93,9 @@ class Push(
// If the message has already be received, pass it to the callback
receivedMessage?.let { if (hasReceived(status)) callback(it) }

if (receiveHooks[status] == null) {
// Create a new array of hooks if no previous hook is associated with status
receiveHooks[status] = arrayListOf(callback)
} else {
// A previous hook for this status already exists. Just append the new hook
receiveHooks[status]?.add(callback)
}
// If a previous hook for this status already exists. Just append the new hook. If not, then
// create a new array of hooks if no previous hook is associated with status
receiveHooks[status] = receiveHooks[status]?.copyAndAdd(callback) ?: arrayListOf(callback)

return this
}
Expand Down
68 changes: 50 additions & 18 deletions src/main/kotlin/org/phoenixframework/Socket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,61 @@ import java.util.concurrent.TimeUnit
typealias Payload = Map<String, Any>

/** Data class that holds callbacks assigned to the socket */
internal data class StateChangeCallbacks(
val open: MutableList<() -> Unit> = ArrayList(),
val close: MutableList<() -> Unit> = ArrayList(),
val error: MutableList<(Throwable, Response?) -> Unit> = ArrayList(),
val message: MutableList<(Message) -> Unit> = ArrayList()
) {
internal class StateChangeCallbacks {

var open: List<() -> Unit> = ArrayList()
private set
var close: List<() -> Unit> = ArrayList()
private set
var error: List<(Throwable, Response?) -> Unit> = ArrayList()
private set
var message: List<(Message) -> Unit> = ArrayList()
private set

/** Safely adds an onOpen callback */
fun onOpen(callback: () -> Unit) {
this.open = this.open.copyAndAdd(callback)
}

/** Safely adds an onClose callback */
fun onClose(callback: () -> Unit) {
this.close = this.close.copyAndAdd(callback)
}

/** Safely adds an onError callback */
fun onError(callback: (Throwable, Response?) -> Unit) {
this.error = this.error.copyAndAdd(callback)
}

/** Safely adds an onMessage callback */
fun onMessage(callback: (Message) -> Unit) {
this.message = this.message.copyAndAdd(callback)
}

/** Clears all stored callbacks */
fun release() {
open.clear()
close.clear()
error.clear()
message.clear()
open = emptyList()
close = emptyList()
error = emptyList()
message = emptyList()
}
}

/** Converts the List to a MutableList, adds the value, and then returns as a read-only List */
fun <T> List<T>.copyAndAdd(value: T): List<T> {
val temp = this.toMutableList()
temp.add(value)

return temp
}


/** RFC 6455: indicates a normal closure */
const val WS_CLOSE_NORMAL = 1000

/** RFC 6455: indicates that the connection was closed abnormally */
const val WS_CLOSE_ABNORMAL = 1006


/**
* Connects to a Phoenix Server
*/
Expand Down Expand Up @@ -125,7 +158,7 @@ class Socket(
internal val stateChangeCallbacks: StateChangeCallbacks = StateChangeCallbacks()

/** Collection of unclosed channels created by the Socket */
internal var channels: MutableList<Channel> = ArrayList()
internal var channels: List<Channel> = ArrayList()

/** Buffers messages that need to be sent once the socket has connected */
internal var sendBuffer: MutableList<() -> Unit> = ArrayList()
Expand Down Expand Up @@ -250,19 +283,19 @@ class Socket(
}

fun onOpen(callback: (() -> Unit)) {
this.stateChangeCallbacks.open.add(callback)
this.stateChangeCallbacks.onOpen(callback)
}

fun onClose(callback: () -> Unit) {
this.stateChangeCallbacks.close.add(callback)
this.stateChangeCallbacks.onClose(callback)
}

fun onError(callback: (Throwable, Response?) -> Unit) {
this.stateChangeCallbacks.error.add(callback)
this.stateChangeCallbacks.onError(callback)
}

fun onMessage(callback: (Message) -> Unit) {
this.stateChangeCallbacks.message.add(callback)
this.stateChangeCallbacks.onMessage(callback)
}

fun removeAllCallbacks() {
Expand All @@ -271,7 +304,7 @@ class Socket(

fun channel(topic: String, params: Payload = mapOf()): Channel {
val channel = Channel(topic, params, this)
this.channels.add(channel)
this.channels = this.channels.copyAndAdd(channel)

return channel
}
Expand All @@ -282,7 +315,6 @@ class Socket(
// that does not contain the channel that was removed.
this.channels = channels
.filter { it.joinRef != channel.joinRef }
.toMutableList()
}

//------------------------------------------------------------------------------
Expand Down
123 changes: 108 additions & 15 deletions src/test/kotlin/org/phoenixframework/SocketTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.mockito.Mock
import org.mockito.MockitoAnnotations
import org.phoenixframework.utilities.copyAndRemove
import java.net.URL
import java.util.concurrent.TimeUnit

Expand Down Expand Up @@ -397,7 +398,6 @@ class SocketTest {
channel1.join().trigger("ok", emptyMap())
channel2.join().trigger("ok", emptyMap())


var chan1Called = false
channel1.onError { chan1Called = true }

Expand Down Expand Up @@ -756,8 +756,8 @@ class SocketTest {
val spy = spy(channel)

// Use the spy instance instead of the Channel instance
socket.channels.remove(channel)
socket.channels.add(spy)
socket.channels = socket.channels.copyAndRemove(channel)
socket.channels = socket.channels.copyAndAdd(spy)

spy.join()
assertThat(spy.state).isEqualTo(Channel.State.JOINING)
Expand All @@ -772,8 +772,8 @@ class SocketTest {
val spy = spy(channel)

// Use the spy instance instead of the Channel instance
socket.channels.remove(channel)
socket.channels.add(spy)
socket.channels = socket.channels.copyAndRemove(channel)
socket.channels = socket.channels.copyAndAdd(spy)

spy.join().trigger("ok", emptyMap())

Expand All @@ -789,8 +789,8 @@ class SocketTest {
val spy = spy(channel)

// Use the spy instance instead of the Channel instance
socket.channels.remove(channel)
socket.channels.add(spy)
socket.channels = socket.channels.copyAndRemove(channel)
socket.channels = socket.channels.copyAndAdd(spy)

spy.join().trigger("ok", emptyMap())
spy.leave()
Expand Down Expand Up @@ -828,8 +828,8 @@ class SocketTest {
val spy = spy(channel)

// Use the spy instance instead of the Channel instance
socket.channels.remove(channel)
socket.channels.add(spy)
socket.channels = socket.channels.copyAndRemove(channel)
socket.channels = socket.channels.copyAndAdd(spy)

spy.join()
assertThat(spy.state).isEqualTo(Channel.State.JOINING)
Expand All @@ -844,8 +844,8 @@ class SocketTest {
val spy = spy(channel)

// Use the spy instance instead of the Channel instance
socket.channels.remove(channel)
socket.channels.add(spy)
socket.channels = socket.channels.copyAndRemove(channel)
socket.channels = socket.channels.copyAndAdd(spy)

spy.join().trigger("ok", emptyMap())

Expand All @@ -861,8 +861,8 @@ class SocketTest {
val spy = spy(channel)

// Use the spy instance instead of the Channel instance
socket.channels.remove(channel)
socket.channels.add(spy)
socket.channels = socket.channels.copyAndRemove(channel)
socket.channels = socket.channels.copyAndAdd(spy)

spy.join().trigger("ok", emptyMap())
spy.leave()
Expand All @@ -886,8 +886,8 @@ class SocketTest {
val otherChannel = mock<Channel>()
whenever(otherChannel.isMember(any())).thenReturn(false)

socket.channels.add(targetChannel)
socket.channels.add(otherChannel)
socket.channels = socket.channels.copyAndAdd(targetChannel)
socket.channels = socket.channels.copyAndRemove(otherChannel)

val rawMessage =
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}"
Expand Down Expand Up @@ -923,4 +923,97 @@ class SocketTest {
/* End OnConnectionMessage */
}


@Nested
@DisplayName("ConcurrentModificationException")
inner class ConcurrentModificationExceptionTests {

@Test
internal fun `onOpen does not throw`() {
var oneCalled = 0
var twoCalled = 0
socket.onOpen {
socket.onOpen { twoCalled += 1 }
oneCalled += 1
}

socket.onConnectionOpened()
assertThat(oneCalled).isEqualTo(1)
assertThat(twoCalled).isEqualTo(0)

socket.onConnectionOpened()
assertThat(oneCalled).isEqualTo(2)
assertThat(twoCalled).isEqualTo(1)
}

@Test
internal fun `onClose does not throw`() {
var oneCalled = 0
var twoCalled = 0
socket.onClose {
socket.onClose { twoCalled += 1 }
oneCalled += 1
}

socket.onConnectionClosed(1000)
assertThat(oneCalled).isEqualTo(1)
assertThat(twoCalled).isEqualTo(0)

socket.onConnectionClosed(1001)
assertThat(oneCalled).isEqualTo(2)
assertThat(twoCalled).isEqualTo(1)
}

@Test
internal fun `onError does not throw`() {
var oneCalled = 0
var twoCalled = 0
socket.onError { _, _->
socket.onError { _, _ -> twoCalled += 1 }
oneCalled += 1
}

socket.onConnectionError(Throwable(), null)
assertThat(oneCalled).isEqualTo(1)
assertThat(twoCalled).isEqualTo(0)

socket.onConnectionError(Throwable(), null)
assertThat(oneCalled).isEqualTo(2)
assertThat(twoCalled).isEqualTo(1)
}

@Test
internal fun `onMessage does not throw`() {
var oneCalled = 0
var twoCalled = 0
socket.onMessage {
socket.onMessage { twoCalled += 1 }
oneCalled += 1
}

socket.onConnectionMessage("{\"status\":\"ok\"}")
assertThat(oneCalled).isEqualTo(1)
assertThat(twoCalled).isEqualTo(0)

socket.onConnectionMessage("{\"status\":\"ok\"}")
assertThat(oneCalled).isEqualTo(2)
assertThat(twoCalled).isEqualTo(1)
}

@Test
internal fun `does not throw when adding channel`() {
var oneCalled = 0
socket.onOpen {
val channel = socket.channel("foo")
oneCalled += 1
}

socket.onConnectionOpened()
assertThat(oneCalled).isEqualTo(1)
}

/* End ConcurrentModificationExceptionTests */
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,12 @@ import org.phoenixframework.Channel

fun Channel.getBindings(event: String): List<Binding> {
return bindings.toList().filter { it.event == event }
}

/** Converts the List to a MutableList, removes the value, and then returns as a read-only List */
fun <T> List<T>.copyAndRemove(value: T): List<T> {
val temp = this.toMutableList()
temp.remove(value)

return temp
}