Skip to content

Commit b1f814f

Browse files
committed
Allow providing params via a closure
1 parent edd2bf3 commit b1f814f

File tree

4 files changed

+147
-57
lines changed

4 files changed

+147
-57
lines changed

src/main/kotlin/org/phoenixframework/Defaults.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ package org.phoenixframework
2525
import com.google.gson.FieldNamingPolicy
2626
import com.google.gson.Gson
2727
import com.google.gson.GsonBuilder
28+
import okhttp3.HttpUrl
29+
import java.net.URL
2830

2931
object Defaults {
3032

@@ -51,4 +53,38 @@ object Defaults {
5153
.setLenient()
5254
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
5355
.create()
56+
57+
/**
58+
* Takes an endpoint and a params closure given by the User and constructs a URL that
59+
* is ready to be sent to the Socket connection.
60+
*
61+
* Will convert "ws://" and "wss://" to http/s which is what OkHttp expects.
62+
*
63+
* @throws IllegalArgumentException if [endpoint] is not a valid URL endpoint.
64+
*/
65+
internal fun buildEndpointUrl(endpoint: String, paramsClosure: PayloadClosure?): URL {
66+
var mutableUrl = endpoint
67+
// Silently replace web socket URLs with HTTP URLs.
68+
if (endpoint.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
69+
mutableUrl = "http:" + endpoint.substring(3)
70+
} else if (endpoint.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
71+
mutableUrl = "https:" + endpoint.substring(4)
72+
}
73+
74+
// If there are query params, append them now
75+
var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $endpoint")
76+
paramsClosure?.invoke()?.let {
77+
val httpBuilder = httpUrl.newBuilder()
78+
it.forEach { (key, value) ->
79+
httpBuilder.addQueryParameter(key, value.toString())
80+
}
81+
82+
httpUrl = httpBuilder.build()
83+
}
84+
85+
// Store the URL that will be used to establish a connection
86+
return httpUrl.url()
87+
}
88+
89+
5490
}

src/main/kotlin/org/phoenixframework/Socket.kt

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,32 @@ const val WS_CLOSE_NORMAL = 1000
8787
/** RFC 6455: indicates that the connection was closed abnormally */
8888
const val WS_CLOSE_ABNORMAL = 1006
8989

90+
/**
91+
* A closure that will return an optional Payload
92+
*/
93+
typealias PayloadClosure = () -> Payload?
94+
9095
/**
9196
* Connects to a Phoenix Server
9297
*/
98+
99+
100+
/**
101+
* A [Socket] which connects to a Phoenix Server. Takes a closure to allow for changing parameters
102+
* to be sent to the server when connecting.
103+
*
104+
* ## Example
105+
* ```
106+
* val socket = Socket("https://example.com/socket", { mapOf("token" to mAuthToken) })
107+
* ```
108+
* @param url Url to connect to such as https://example.com/socket
109+
* @param paramsClosure Closure which allows to change parameters sent during connection.
110+
* @param gson Default GSON Client to parse JSON. You can provide your own if needed.
111+
* @param client Default OkHttpClient to connect with. You can provide your own if needed.
112+
*/
93113
class Socket(
94114
url: String,
95-
params: Payload? = null,
115+
paramsClosure: PayloadClosure?,
96116
private val gson: Gson = Defaults.gson,
97117
private val client: OkHttpClient = OkHttpClient.Builder().build()
98118
) {
@@ -109,13 +129,14 @@ class Socket(
109129
val endpoint: String
110130

111131
/** The fully qualified socket URL */
112-
val endpointUrl: URL
132+
var endpointUrl: URL
133+
private set
113134

114135
/**
115-
* The optional params to pass when connecting. Must be set when
116-
* initializing the Socket. These will be appended to the URL.
136+
* A closure that returns the optional params to pass when connecting. Must
137+
* be set when initializing the Socket. These will be appended to the URL.
117138
*/
118-
val params: Payload? = params
139+
val paramsClosure: PayloadClosure? = paramsClosure
119140

120141
/** Timeout to use when opening a connection */
121142
var timeout: Long = Defaults.TIMEOUT
@@ -189,6 +210,27 @@ class Socket(
189210
//------------------------------------------------------------------------------
190211
// Initialization
191212
//------------------------------------------------------------------------------
213+
/**
214+
* A [Socket] which connects to a Phoenix Server. Takes a constant parameter to be sent to the
215+
* server when connecting. Defaults to null if excluded.
216+
*
217+
* ## Example
218+
* ```
219+
* val socket = Socket("https://example.com/socket", mapOf("token" to mAuthToken))
220+
* ```
221+
*
222+
* @param url Url to connect to such as https://example.com/socket
223+
* @param params Constant parameters to send when connecting. Defaults to null
224+
* @param gson Default GSON Client to parse JSON. You can provide your own if needed.
225+
* @param client Default OkHttpClient to connect with. You can provide your own if needed.
226+
*/
227+
constructor(
228+
url: String,
229+
params: Payload? = null,
230+
gson: Gson = Defaults.gson,
231+
client: OkHttpClient = OkHttpClient.Builder().build()
232+
): this(url, params?.let { { it } }, gson, client)
233+
192234
init {
193235
var mutableUrl = url
194236

@@ -203,29 +245,13 @@ class Socket(
203245
mutableUrl += "websocket"
204246
}
205247

248+
206249
// Store the endpoint before changing the protocol
207250
this.endpoint = mutableUrl
208251

209-
// Silently replace web socket URLs with HTTP URLs.
210-
if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
211-
mutableUrl = "http:" + url.substring(3)
212-
} else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
213-
mutableUrl = "https:" + url.substring(4)
214-
}
215-
216-
// If there are query params, append them now
217-
var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url")
218-
params?.let {
219-
val httpBuilder = httpUrl.newBuilder()
220-
it.forEach { (key, value) ->
221-
httpBuilder.addQueryParameter(key, value.toString())
222-
}
223-
224-
httpUrl = httpBuilder.build()
225-
}
226-
227-
// Store the URL that will be used to establish a connection
228-
this.endpointUrl = httpUrl.url()
252+
// Store the URL that will be used to establish a connection. Could potentially be
253+
// different at the time connect() is called based on a changing params closure.
254+
this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure)
229255

230256
// Create reconnect timer
231257
this.reconnectTimer = TimeoutTimer(
@@ -262,6 +288,12 @@ class Socket(
262288
// Reset the clean close flag when attempting to connect
263289
this.closeWasClean = false
264290

291+
// Build the new endpointUrl with the params closure. The payload returned
292+
// from the closure could have changed after the socket attempts to reconnect,
293+
// i.e. and authToken.
294+
this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure)
295+
296+
// Now create the connection transport and attempt to connect
265297
this.connection = this.transport(endpointUrl)
266298
this.connection?.onOpen = { onConnectionOpened() }
267299
this.connection?.onClose = { code -> onConnectionClosed(code) }

src/test/kotlin/org/phoenixframework/ChannelTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class ChannelTest {
155155

156156
@BeforeEach
157157
internal fun setUp() {
158-
socket = spy(Socket(url ="https://localhost:4000/socket", client = okHttpClient))
158+
socket = spy(Socket(url = "https://localhost:4000/socket", client = okHttpClient))
159159
socket.dispatchQueue = fakeClock
160160
channel = Channel("topic", kDefaultPayload, socket)
161161
}

src/test/kotlin/org/phoenixframework/SocketTest.kt

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class SocketTest {
4848
internal fun `sets defaults`() {
4949
val socket = Socket("wss://localhost:4000/socket")
5050

51-
assertThat(socket.params).isNull()
51+
assertThat(socket.paramsClosure).isNull()
5252
assertThat(socket.channels).isEmpty()
5353
assertThat(socket.sendBuffer).isEmpty()
5454
assertThat(socket.ref).isEqualTo(0)
@@ -81,7 +81,7 @@ class SocketTest {
8181
socket.logger = { }
8282
socket.reconnectAfterMs = { 10 }
8383

84-
assertThat(socket.params).isEqualTo(mapOf("one" to 2))
84+
assertThat(socket.paramsClosure?.invoke()).isEqualTo(mapOf("one" to 2))
8585
assertThat(socket.endpoint).isEqualTo("wss://localhost:4000/socket/websocket")
8686
assertThat(socket.timeout).isEqualTo(40_000)
8787
assertThat(socket.heartbeatIntervalMs).isEqualTo(60_000)
@@ -94,32 +94,34 @@ class SocketTest {
9494
internal fun `constructs with a valid URL`() {
9595
// Test different schemes
9696
assertThat(Socket("http://localhost:4000/socket/websocket").endpointUrl.toString())
97-
.isEqualTo("http://localhost:4000/socket/websocket")
97+
.isEqualTo("http://localhost:4000/socket/websocket")
9898

9999
assertThat(Socket("https://localhost:4000/socket/websocket").endpointUrl.toString())
100-
.isEqualTo("https://localhost:4000/socket/websocket")
100+
.isEqualTo("https://localhost:4000/socket/websocket")
101101

102102
assertThat(Socket("ws://localhost:4000/socket/websocket").endpointUrl.toString())
103-
.isEqualTo("http://localhost:4000/socket/websocket")
103+
.isEqualTo("http://localhost:4000/socket/websocket")
104104

105105
assertThat(Socket("wss://localhost:4000/socket/websocket").endpointUrl.toString())
106-
.isEqualTo("https://localhost:4000/socket/websocket")
106+
.isEqualTo("https://localhost:4000/socket/websocket")
107107

108108
// test params
109109
val singleParam = hashMapOf("token" to "abc123")
110110
assertThat(Socket("ws://localhost:4000/socket/websocket", singleParam).endpointUrl.toString())
111-
.isEqualTo("http://localhost:4000/socket/websocket?token=abc123")
111+
.isEqualTo("http://localhost:4000/socket/websocket?token=abc123")
112112

113113
val multipleParams = hashMapOf("token" to "abc123", "user_id" to 1)
114114
assertThat(
115-
Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString())
116-
.isEqualTo("http://localhost:4000/socket/websocket?user_id=1&token=abc123")
115+
Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString()
116+
)
117+
.isEqualTo("http://localhost:4000/socket/websocket?user_id=1&token=abc123")
117118

118119
// test params with spaces
119120
val spacesParams = hashMapOf("token" to "abc 123", "user_id" to 1)
120121
assertThat(
121-
Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString())
122-
.isEqualTo("https://localhost:4000/socket/websocket?user_id=1&token=abc%20123")
122+
Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString()
123+
)
124+
.isEqualTo("https://localhost:4000/socket/websocket?user_id=1&token=abc%20123")
123125
}
124126

125127
/* End Constructor */
@@ -185,6 +187,28 @@ class SocketTest {
185187
assertThat(socket.connection).isNotNull()
186188
}
187189

190+
@Test
191+
internal fun `accounts for changing parameters`() {
192+
val transport = mock<(URL) -> Transport>()
193+
whenever(transport.invoke(any())).thenReturn(connection)
194+
195+
var token = "a"
196+
val socket = Socket("wss://localhost:4000/socket", { mapOf("token" to token) })
197+
socket.transport = transport
198+
199+
socket.connect()
200+
argumentCaptor<URL> {
201+
verify(transport).invoke(capture())
202+
assertThat(firstValue.query).isEqualTo("token=a")
203+
204+
token = "b"
205+
socket.disconnect()
206+
socket.connect()
207+
verify(transport, times(2)).invoke(capture())
208+
assertThat(lastValue.query).isEqualTo("token=b")
209+
}
210+
}
211+
188212
@Test
189213
internal fun `sets callbacks for connection`() {
190214
var open = 0
@@ -216,10 +240,10 @@ class SocketTest {
216240
assertThat(lastResponse).isNull()
217241

218242
val data = mapOf(
219-
"topic" to "topic",
220-
"event" to "event",
221-
"payload" to mapOf("go" to true),
222-
"status" to "status"
243+
"topic" to "topic",
244+
"event" to "event",
245+
"payload" to mapOf("go" to true),
246+
"status" to "status"
223247
)
224248

225249
val json = Defaults.gson.toJson(data)
@@ -259,10 +283,10 @@ class SocketTest {
259283
assertThat(lastResponse).isNull()
260284

261285
val data = mapOf(
262-
"topic" to "topic",
263-
"event" to "event",
264-
"payload" to mapOf("go" to true),
265-
"status" to "status"
286+
"topic" to "topic",
287+
"event" to "event",
288+
"payload" to mapOf("go" to true),
289+
"status" to "status"
266290
)
267291

268292
val json = Defaults.gson.toJson(data)
@@ -457,7 +481,7 @@ class SocketTest {
457481
socket.push("topic", "event", mapOf("one" to "two"), "ref", "join-ref")
458482

459483
val expect =
460-
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"ref\",\"join_ref\":\"join-ref\"}"
484+
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"ref\",\"join_ref\":\"join-ref\"}"
461485
verify(connection).send(expect)
462486
}
463487

@@ -624,7 +648,6 @@ class SocketTest {
624648
}
625649
}
626650

627-
628651
@Nested
629652
@DisplayName("resetHeartbeat")
630653
inner class ResetHeartbeat {
@@ -658,14 +681,16 @@ class SocketTest {
658681

659682
assertThat(socket.heartbeatTask).isNotNull()
660683
argumentCaptor<() -> Unit> {
661-
verify(mockDispatchQueue).queueAtFixedRate(eq(5_000L), eq(5_000L),
662-
eq(TimeUnit.MILLISECONDS), capture())
684+
verify(mockDispatchQueue).queueAtFixedRate(
685+
eq(5_000L), eq(5_000L),
686+
eq(TimeUnit.MILLISECONDS), capture()
687+
)
663688

664689
// fire the task
665690
allValues.first().invoke()
666691

667692
val expected =
668-
"{\"topic\":\"phoenix\",\"event\":\"heartbeat\",\"payload\":{},\"ref\":\"1\"}"
693+
"{\"topic\":\"phoenix\",\"event\":\"heartbeat\",\"payload\":{},\"ref\":\"1\"}"
669694
verify(connection).send(expected)
670695
}
671696
}
@@ -937,7 +962,7 @@ class SocketTest {
937962
socket.channels = socket.channels.minus(otherChannel)
938963

939964
val rawMessage =
940-
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}"
965+
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}"
941966
socket.onConnectionMessage(rawMessage)
942967

943968
verify(targetChannel).trigger(message = any())
@@ -950,7 +975,7 @@ class SocketTest {
950975
socket.onMessage { message = it }
951976

952977
val rawMessage =
953-
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}"
978+
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}"
954979
socket.onConnectionMessage(rawMessage)
955980

956981
assertThat(message?.topic).isEqualTo("topic")
@@ -962,15 +987,14 @@ class SocketTest {
962987
socket.pendingHeartbeatRef = "5"
963988

964989
val rawMessage =
965-
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"5\"}"
990+
"{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"5\"}"
966991
socket.onConnectionMessage(rawMessage)
967992
assertThat(socket.pendingHeartbeatRef).isNull()
968993
}
969994

970995
/* End OnConnectionMessage */
971996
}
972997

973-
974998
@Nested
975999
@DisplayName("ConcurrentModificationException")
9761000
inner class ConcurrentModificationExceptionTests {
@@ -1015,7 +1039,7 @@ class SocketTest {
10151039
internal fun `onError does not throw`() {
10161040
var oneCalled = 0
10171041
var twoCalled = 0
1018-
socket.onError { _, _->
1042+
socket.onError { _, _ ->
10191043
socket.onError { _, _ -> twoCalled += 1 }
10201044
oneCalled += 1
10211045
}
@@ -1061,6 +1085,4 @@ class SocketTest {
10611085

10621086
/* End ConcurrentModificationExceptionTests */
10631087
}
1064-
1065-
10661088
}

0 commit comments

Comments
 (0)