diff --git a/src/main/kotlin/org/phoenixframework/Defaults.kt b/src/main/kotlin/org/phoenixframework/Defaults.kt index b6c12be..f410354 100644 --- a/src/main/kotlin/org/phoenixframework/Defaults.kt +++ b/src/main/kotlin/org/phoenixframework/Defaults.kt @@ -25,6 +25,8 @@ package org.phoenixframework import com.google.gson.FieldNamingPolicy import com.google.gson.Gson import com.google.gson.GsonBuilder +import okhttp3.HttpUrl +import java.net.URL object Defaults { @@ -36,7 +38,9 @@ object Defaults { /** Default reconnect algorithm for the socket */ val reconnectSteppedBackOff: (Int) -> Long = { tries -> - if (tries > 9) 5_000 else listOf(10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L)[tries - 1] + if (tries > 9) 5_000 else listOf( + 10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L + )[tries - 1] } /** Default rejoin algorithm for individual channels */ @@ -44,11 +48,46 @@ object Defaults { if (tries > 3) 10_000 else listOf(1_000L, 2_000L, 5_000L)[tries - 1] } - /** The default Gson configuration to use when parsing messages */ val gson: Gson get() = GsonBuilder() - .setLenient() - .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) - .create() + .setLenient() + .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + .create() + + /** + * Takes an endpoint and a params closure given by the User and constructs a URL that + * is ready to be sent to the Socket connection. + * + * Will convert "ws://" and "wss://" to http/s which is what OkHttp expects. + * + * @throws IllegalArgumentException if [endpoint] is not a valid URL endpoint. + */ + internal fun buildEndpointUrl( + endpoint: String, + paramsClosure: PayloadClosure + ): URL { + var mutableUrl = endpoint + // Silently replace web socket URLs with HTTP URLs. + if (endpoint.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) { + mutableUrl = "http:" + endpoint.substring(3) + } else if (endpoint.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) { + mutableUrl = "https:" + endpoint.substring(4) + } + + // If there are query params, append them now + var httpUrl = + HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $endpoint") + paramsClosure.invoke()?.let { + val httpBuilder = httpUrl.newBuilder() + it.forEach { (key, value) -> + httpBuilder.addQueryParameter(key, value.toString()) + } + + httpUrl = httpBuilder.build() + } + + // Store the URL that will be used to establish a connection + return httpUrl.url() + } } \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt index 6e208ad..be32a22 100644 --- a/src/main/kotlin/org/phoenixframework/Socket.kt +++ b/src/main/kotlin/org/phoenixframework/Socket.kt @@ -45,22 +45,34 @@ internal class StateChangeCallbacks { private set /** Safely adds an onOpen callback */ - fun onOpen(ref: String, callback: () -> Unit) { + fun onOpen( + ref: String, + callback: () -> Unit + ) { this.open = this.open + Pair(ref, callback) } /** Safely adds an onClose callback */ - fun onClose(ref: String, callback: () -> Unit) { + fun onClose( + ref: String, + callback: () -> Unit + ) { this.close = this.close + Pair(ref, callback) } /** Safely adds an onError callback */ - fun onError(ref: String, callback: (Throwable, Response?) -> Unit) { + fun onError( + ref: String, + callback: (Throwable, Response?) -> Unit + ) { this.error = this.error + Pair(ref, callback) } /** Safely adds an onMessage callback */ - fun onMessage(ref: String, callback: (Message) -> Unit) { + fun onMessage( + ref: String, + callback: (Message) -> Unit + ) { this.message = this.message + Pair(ref, callback) } @@ -87,12 +99,31 @@ const val WS_CLOSE_NORMAL = 1000 /** RFC 6455: indicates that the connection was closed abnormally */ const val WS_CLOSE_ABNORMAL = 1006 +/** + * A closure that will return an optional Payload + */ +typealias PayloadClosure = () -> Payload? + /** * Connects to a Phoenix Server */ + +/** + * A [Socket] which connects to a Phoenix Server. Takes a closure to allow for changing parameters + * to be sent to the server when connecting. + * + * ## Example + * ``` + * val socket = Socket("https://example.com/socket", { mapOf("token" to mAuthToken) }) + * ``` + * @param url Url to connect to such as https://example.com/socket + * @param paramsClosure Closure which allows to change parameters sent during connection. + * @param gson Default GSON Client to parse JSON. You can provide your own if needed. + * @param client Default OkHttpClient to connect with. You can provide your own if needed. + */ class Socket( url: String, - params: Payload? = null, + val paramsClosure: PayloadClosure, private val gson: Gson = Defaults.gson, private val client: OkHttpClient = OkHttpClient.Builder().build() ) { @@ -109,13 +140,8 @@ class Socket( val endpoint: String /** The fully qualified socket URL */ - val endpointUrl: URL - - /** - * The optional params to pass when connecting. Must be set when - * initializing the Socket. These will be appended to the URL. - */ - val params: Payload? = params + var endpointUrl: URL + private set /** Timeout to use when opening a connection */ var timeout: Long = Defaults.TIMEOUT @@ -189,6 +215,27 @@ class Socket( //------------------------------------------------------------------------------ // Initialization //------------------------------------------------------------------------------ + /** + * A [Socket] which connects to a Phoenix Server. Takes a constant parameter to be sent to the + * server when connecting. Defaults to null if excluded. + * + * ## Example + * ``` + * val socket = Socket("https://example.com/socket", mapOf("token" to mAuthToken)) + * ``` + * + * @param url Url to connect to such as https://example.com/socket + * @param params Constant parameters to send when connecting. Defaults to null + * @param gson Default GSON Client to parse JSON. You can provide your own if needed. + * @param client Default OkHttpClient to connect with. You can provide your own if needed. + */ + constructor( + url: String, + params: Payload? = null, + gson: Gson = Defaults.gson, + client: OkHttpClient = OkHttpClient.Builder().build() + ) : this(url, { params }, gson, client) + init { var mutableUrl = url @@ -206,35 +253,18 @@ class Socket( // Store the endpoint before changing the protocol this.endpoint = mutableUrl - // Silently replace web socket URLs with HTTP URLs. - if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) { - mutableUrl = "http:" + url.substring(3) - } else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) { - mutableUrl = "https:" + url.substring(4) - } - - // If there are query params, append them now - var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url") - params?.let { - val httpBuilder = httpUrl.newBuilder() - it.forEach { (key, value) -> - httpBuilder.addQueryParameter(key, value.toString()) - } - - httpUrl = httpBuilder.build() - } - - // Store the URL that will be used to establish a connection - this.endpointUrl = httpUrl.url() + // Store the URL that will be used to establish a connection. Could potentially be + // different at the time connect() is called based on a changing params closure. + this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure) // Create reconnect timer this.reconnectTimer = TimeoutTimer( - dispatchQueue = dispatchQueue, - timerCalculation = reconnectAfterMs, - callback = { - this.logItems("Socket attempting to reconnect") - this.teardown { this.connect() } - }) + dispatchQueue = dispatchQueue, + timerCalculation = reconnectAfterMs, + callback = { + this.logItems("Socket attempting to reconnect") + this.teardown { this.connect() } + }) } //------------------------------------------------------------------------------ @@ -262,6 +292,11 @@ class Socket( // Reset the clean close flag when attempting to connect this.closeWasClean = false + // Build the new endpointUrl with the params closure. The payload returned + // from the closure could be different such as a changing authToken. + this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure) + + // Now create the connection transport and attempt to connect this.connection = this.transport(endpointUrl) this.connection?.onOpen = { onConnectionOpened() } this.connection?.onClose = { code -> onConnectionClosed(code) } @@ -281,7 +316,6 @@ class Socket( // Reset any reconnects and teardown the socket connection this.reconnectTimer.reset() this.teardown(code, reason, callback) - } fun onOpen(callback: (() -> Unit)): String { @@ -304,7 +338,10 @@ class Socket( this.stateChangeCallbacks.release() } - fun channel(topic: String, params: Payload = mapOf()): Channel { + fun channel( + topic: String, + params: Payload = mapOf() + ): Channel { val channel = Channel(topic, params, this) this.channels = this.channels + channel @@ -318,7 +355,7 @@ class Socket( // removed instead of calling .remove() on the list, thus returning a new list // that does not contain the channel that was removed. this.channels = channels - .filter { it.joinRef != channel.joinRef } + .filter { it.joinRef != channel.joinRef } } /** @@ -449,7 +486,7 @@ class Socket( val period = heartbeatIntervalMs heartbeatTask = - dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() } + dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() } } internal fun sendHeartbeat() { @@ -471,10 +508,11 @@ class Socket( // The last heartbeat was acknowledged by the server. Send another one this.pendingHeartbeatRef = this.makeRef() this.push( - topic = "phoenix", - event = Channel.Event.HEARTBEAT.value, - payload = mapOf(), - ref = pendingHeartbeatRef) + topic = "phoenix", + event = Channel.Event.HEARTBEAT.value, + payload = mapOf(), + ref = pendingHeartbeatRef + ) } private fun abnormalClose(reason: String) { @@ -538,14 +576,17 @@ class Socket( // Dispatch the message to all channels that belong to the topic this.channels - .filter { it.isMember(message) } - .forEach { it.trigger(message) } + .filter { it.isMember(message) } + .forEach { it.trigger(message) } // Inform all onMessage callbacks of the message this.stateChangeCallbacks.message.forEach { it.second.invoke(message) } } - internal fun onConnectionError(t: Throwable, response: Response?) { + internal fun onConnectionError( + t: Throwable, + response: Response? + ) { this.logItems("Transport: error $t") // Send an error to all channels @@ -554,5 +595,4 @@ class Socket( // Inform any state callbacks of the error this.stateChangeCallbacks.error.forEach { it.second.invoke(t, response) } } - } diff --git a/src/test/kotlin/org/phoenixframework/ChannelTest.kt b/src/test/kotlin/org/phoenixframework/ChannelTest.kt index 943fb05..805bcee 100644 --- a/src/test/kotlin/org/phoenixframework/ChannelTest.kt +++ b/src/test/kotlin/org/phoenixframework/ChannelTest.kt @@ -155,7 +155,7 @@ class ChannelTest { @BeforeEach internal fun setUp() { - socket = spy(Socket(url ="https://localhost:4000/socket", client = okHttpClient)) + socket = spy(Socket(url = "https://localhost:4000/socket", client = okHttpClient)) socket.dispatchQueue = fakeClock channel = Channel("topic", kDefaultPayload, socket) } diff --git a/src/test/kotlin/org/phoenixframework/SocketTest.kt b/src/test/kotlin/org/phoenixframework/SocketTest.kt index 062ad1c..0aa43c7 100644 --- a/src/test/kotlin/org/phoenixframework/SocketTest.kt +++ b/src/test/kotlin/org/phoenixframework/SocketTest.kt @@ -48,7 +48,7 @@ class SocketTest { internal fun `sets defaults`() { val socket = Socket("wss://localhost:4000/socket") - assertThat(socket.params).isNull() + assertThat(socket.paramsClosure.invoke()).isNull() assertThat(socket.channels).isEmpty() assertThat(socket.sendBuffer).isEmpty() assertThat(socket.ref).isEqualTo(0) @@ -81,7 +81,7 @@ class SocketTest { socket.logger = { } socket.reconnectAfterMs = { 10 } - assertThat(socket.params).isEqualTo(mapOf("one" to 2)) + assertThat(socket.paramsClosure?.invoke()).isEqualTo(mapOf("one" to 2)) assertThat(socket.endpoint).isEqualTo("wss://localhost:4000/socket/websocket") assertThat(socket.timeout).isEqualTo(40_000) assertThat(socket.heartbeatIntervalMs).isEqualTo(60_000) @@ -94,32 +94,34 @@ class SocketTest { internal fun `constructs with a valid URL`() { // Test different schemes assertThat(Socket("http://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket") + .isEqualTo("http://localhost:4000/socket/websocket") assertThat(Socket("https://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("https://localhost:4000/socket/websocket") + .isEqualTo("https://localhost:4000/socket/websocket") assertThat(Socket("ws://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket") + .isEqualTo("http://localhost:4000/socket/websocket") assertThat(Socket("wss://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("https://localhost:4000/socket/websocket") + .isEqualTo("https://localhost:4000/socket/websocket") // test params val singleParam = hashMapOf("token" to "abc123") assertThat(Socket("ws://localhost:4000/socket/websocket", singleParam).endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket?token=abc123") + .isEqualTo("http://localhost:4000/socket/websocket?token=abc123") val multipleParams = hashMapOf("token" to "abc123", "user_id" to 1) assertThat( - Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket?user_id=1&token=abc123") + Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString() + ) + .isEqualTo("http://localhost:4000/socket/websocket?user_id=1&token=abc123") // test params with spaces val spacesParams = hashMapOf("token" to "abc 123", "user_id" to 1) assertThat( - Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString()) - .isEqualTo("https://localhost:4000/socket/websocket?user_id=1&token=abc%20123") + Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString() + ) + .isEqualTo("https://localhost:4000/socket/websocket?user_id=1&token=abc%20123") } /* End Constructor */ @@ -185,6 +187,28 @@ class SocketTest { assertThat(socket.connection).isNotNull() } + @Test + internal fun `accounts for changing parameters`() { + val transport = mock<(URL) -> Transport>() + whenever(transport.invoke(any())).thenReturn(connection) + + var token = "a" + val socket = Socket("wss://localhost:4000/socket", { mapOf("token" to token) }) + socket.transport = transport + + socket.connect() + argumentCaptor { + verify(transport).invoke(capture()) + assertThat(firstValue.query).isEqualTo("token=a") + + token = "b" + socket.disconnect() + socket.connect() + verify(transport, times(2)).invoke(capture()) + assertThat(lastValue.query).isEqualTo("token=b") + } + } + @Test internal fun `sets callbacks for connection`() { var open = 0 @@ -216,10 +240,10 @@ class SocketTest { assertThat(lastResponse).isNull() val data = mapOf( - "topic" to "topic", - "event" to "event", - "payload" to mapOf("go" to true), - "status" to "status" + "topic" to "topic", + "event" to "event", + "payload" to mapOf("go" to true), + "status" to "status" ) val json = Defaults.gson.toJson(data) @@ -259,10 +283,10 @@ class SocketTest { assertThat(lastResponse).isNull() val data = mapOf( - "topic" to "topic", - "event" to "event", - "payload" to mapOf("go" to true), - "status" to "status" + "topic" to "topic", + "event" to "event", + "payload" to mapOf("go" to true), + "status" to "status" ) val json = Defaults.gson.toJson(data) @@ -457,7 +481,7 @@ class SocketTest { socket.push("topic", "event", mapOf("one" to "two"), "ref", "join-ref") val expect = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"ref\",\"join_ref\":\"join-ref\"}" + "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"ref\",\"join_ref\":\"join-ref\"}" verify(connection).send(expect) } @@ -624,7 +648,6 @@ class SocketTest { } } - @Nested @DisplayName("resetHeartbeat") inner class ResetHeartbeat { @@ -658,14 +681,16 @@ class SocketTest { assertThat(socket.heartbeatTask).isNotNull() argumentCaptor<() -> Unit> { - verify(mockDispatchQueue).queueAtFixedRate(eq(5_000L), eq(5_000L), - eq(TimeUnit.MILLISECONDS), capture()) + verify(mockDispatchQueue).queueAtFixedRate( + eq(5_000L), eq(5_000L), + eq(TimeUnit.MILLISECONDS), capture() + ) // fire the task allValues.first().invoke() val expected = - "{\"topic\":\"phoenix\",\"event\":\"heartbeat\",\"payload\":{},\"ref\":\"1\"}" + "{\"topic\":\"phoenix\",\"event\":\"heartbeat\",\"payload\":{},\"ref\":\"1\"}" verify(connection).send(expected) } } @@ -937,7 +962,7 @@ class SocketTest { socket.channels = socket.channels.minus(otherChannel) val rawMessage = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}" + "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}" socket.onConnectionMessage(rawMessage) verify(targetChannel).trigger(message = any()) @@ -950,7 +975,7 @@ class SocketTest { socket.onMessage { message = it } val rawMessage = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}" + "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}" socket.onConnectionMessage(rawMessage) assertThat(message?.topic).isEqualTo("topic") @@ -962,7 +987,7 @@ class SocketTest { socket.pendingHeartbeatRef = "5" val rawMessage = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"5\"}" + "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"5\"}" socket.onConnectionMessage(rawMessage) assertThat(socket.pendingHeartbeatRef).isNull() } @@ -970,7 +995,6 @@ class SocketTest { /* End OnConnectionMessage */ } - @Nested @DisplayName("ConcurrentModificationException") inner class ConcurrentModificationExceptionTests { @@ -1015,7 +1039,7 @@ class SocketTest { internal fun `onError does not throw`() { var oneCalled = 0 var twoCalled = 0 - socket.onError { _, _-> + socket.onError { _, _ -> socket.onError { _, _ -> twoCalled += 1 } oneCalled += 1 } @@ -1061,6 +1085,4 @@ class SocketTest { /* End ConcurrentModificationExceptionTests */ } - - }