Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions src/main/kotlin/org/phoenixframework/Defaults.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ package org.phoenixframework
import com.google.gson.FieldNamingPolicy
import com.google.gson.Gson
import com.google.gson.GsonBuilder
import okhttp3.HttpUrl
import java.net.URL

object Defaults {

Expand All @@ -36,19 +38,56 @@ object Defaults {

/** Default reconnect algorithm for the socket */
val reconnectSteppedBackOff: (Int) -> Long = { tries ->
if (tries > 9) 5_000 else listOf(10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L)[tries - 1]
if (tries > 9) 5_000 else listOf(
10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L
)[tries - 1]
}

/** Default rejoin algorithm for individual channels */
val rejoinSteppedBackOff: (Int) -> Long = { tries ->
if (tries > 3) 10_000 else listOf(1_000L, 2_000L, 5_000L)[tries - 1]
}


/** The default Gson configuration to use when parsing messages */
val gson: Gson
get() = GsonBuilder()
.setLenient()
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
.create()
.setLenient()
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
.create()

/**
* Takes an endpoint and a params closure given by the User and constructs a URL that
* is ready to be sent to the Socket connection.
*
* Will convert "ws://" and "wss://" to http/s which is what OkHttp expects.
*
* @throws IllegalArgumentException if [endpoint] is not a valid URL endpoint.
*/
internal fun buildEndpointUrl(
endpoint: String,
paramsClosure: PayloadClosure
): URL {
var mutableUrl = endpoint
// Silently replace web socket URLs with HTTP URLs.
if (endpoint.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
mutableUrl = "http:" + endpoint.substring(3)
} else if (endpoint.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
mutableUrl = "https:" + endpoint.substring(4)
}

// If there are query params, append them now
var httpUrl =
HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $endpoint")
paramsClosure.invoke()?.let {
val httpBuilder = httpUrl.newBuilder()
it.forEach { (key, value) ->
httpBuilder.addQueryParameter(key, value.toString())
}

httpUrl = httpBuilder.build()
}

// Store the URL that will be used to establish a connection
return httpUrl.url()
}
}
140 changes: 90 additions & 50 deletions src/main/kotlin/org/phoenixframework/Socket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,34 @@ internal class StateChangeCallbacks {
private set

/** Safely adds an onOpen callback */
fun onOpen(ref: String, callback: () -> Unit) {
fun onOpen(
ref: String,
callback: () -> Unit
) {
this.open = this.open + Pair(ref, callback)
}

/** Safely adds an onClose callback */
fun onClose(ref: String, callback: () -> Unit) {
fun onClose(
ref: String,
callback: () -> Unit
) {
this.close = this.close + Pair(ref, callback)
}

/** Safely adds an onError callback */
fun onError(ref: String, callback: (Throwable, Response?) -> Unit) {
fun onError(
ref: String,
callback: (Throwable, Response?) -> Unit
) {
this.error = this.error + Pair(ref, callback)
}

/** Safely adds an onMessage callback */
fun onMessage(ref: String, callback: (Message) -> Unit) {
fun onMessage(
ref: String,
callback: (Message) -> Unit
) {
this.message = this.message + Pair(ref, callback)
}

Expand All @@ -87,12 +99,31 @@ const val WS_CLOSE_NORMAL = 1000
/** RFC 6455: indicates that the connection was closed abnormally */
const val WS_CLOSE_ABNORMAL = 1006

/**
* A closure that will return an optional Payload
*/
typealias PayloadClosure = () -> Payload?

/**
* Connects to a Phoenix Server
*/

/**
* A [Socket] which connects to a Phoenix Server. Takes a closure to allow for changing parameters
* to be sent to the server when connecting.
*
* ## Example
* ```
* val socket = Socket("https://example.com/socket", { mapOf("token" to mAuthToken) })
* ```
* @param url Url to connect to such as https://example.com/socket
* @param paramsClosure Closure which allows to change parameters sent during connection.
* @param gson Default GSON Client to parse JSON. You can provide your own if needed.
* @param client Default OkHttpClient to connect with. You can provide your own if needed.
*/
class Socket(
url: String,
params: Payload? = null,
val paramsClosure: PayloadClosure,
private val gson: Gson = Defaults.gson,
private val client: OkHttpClient = OkHttpClient.Builder().build()
) {
Expand All @@ -109,13 +140,8 @@ class Socket(
val endpoint: String

/** The fully qualified socket URL */
val endpointUrl: URL

/**
* The optional params to pass when connecting. Must be set when
* initializing the Socket. These will be appended to the URL.
*/
val params: Payload? = params
var endpointUrl: URL
private set

/** Timeout to use when opening a connection */
var timeout: Long = Defaults.TIMEOUT
Expand Down Expand Up @@ -189,6 +215,27 @@ class Socket(
//------------------------------------------------------------------------------
// Initialization
//------------------------------------------------------------------------------
/**
* A [Socket] which connects to a Phoenix Server. Takes a constant parameter to be sent to the
* server when connecting. Defaults to null if excluded.
*
* ## Example
* ```
* val socket = Socket("https://example.com/socket", mapOf("token" to mAuthToken))
* ```
*
* @param url Url to connect to such as https://example.com/socket
* @param params Constant parameters to send when connecting. Defaults to null
* @param gson Default GSON Client to parse JSON. You can provide your own if needed.
* @param client Default OkHttpClient to connect with. You can provide your own if needed.
*/
constructor(
url: String,
params: Payload? = null,
gson: Gson = Defaults.gson,
client: OkHttpClient = OkHttpClient.Builder().build()
) : this(url, { params }, gson, client)

init {
var mutableUrl = url

Expand All @@ -206,35 +253,18 @@ class Socket(
// Store the endpoint before changing the protocol
this.endpoint = mutableUrl

// Silently replace web socket URLs with HTTP URLs.
if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
mutableUrl = "http:" + url.substring(3)
} else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
mutableUrl = "https:" + url.substring(4)
}

// If there are query params, append them now
var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url")
params?.let {
val httpBuilder = httpUrl.newBuilder()
it.forEach { (key, value) ->
httpBuilder.addQueryParameter(key, value.toString())
}

httpUrl = httpBuilder.build()
}

// Store the URL that will be used to establish a connection
this.endpointUrl = httpUrl.url()
// Store the URL that will be used to establish a connection. Could potentially be
// different at the time connect() is called based on a changing params closure.
this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure)

// Create reconnect timer
this.reconnectTimer = TimeoutTimer(
dispatchQueue = dispatchQueue,
timerCalculation = reconnectAfterMs,
callback = {
this.logItems("Socket attempting to reconnect")
this.teardown { this.connect() }
})
dispatchQueue = dispatchQueue,
timerCalculation = reconnectAfterMs,
callback = {
this.logItems("Socket attempting to reconnect")
this.teardown { this.connect() }
})
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -262,6 +292,11 @@ class Socket(
// Reset the clean close flag when attempting to connect
this.closeWasClean = false

// Build the new endpointUrl with the params closure. The payload returned
// from the closure could be different such as a changing authToken.
this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure)

// Now create the connection transport and attempt to connect
this.connection = this.transport(endpointUrl)
this.connection?.onOpen = { onConnectionOpened() }
this.connection?.onClose = { code -> onConnectionClosed(code) }
Expand All @@ -281,7 +316,6 @@ class Socket(
// Reset any reconnects and teardown the socket connection
this.reconnectTimer.reset()
this.teardown(code, reason, callback)

}

fun onOpen(callback: (() -> Unit)): String {
Expand All @@ -304,7 +338,10 @@ class Socket(
this.stateChangeCallbacks.release()
}

fun channel(topic: String, params: Payload = mapOf()): Channel {
fun channel(
topic: String,
params: Payload = mapOf()
): Channel {
val channel = Channel(topic, params, this)
this.channels = this.channels + channel

Expand All @@ -318,7 +355,7 @@ class Socket(
// removed instead of calling .remove() on the list, thus returning a new list
// that does not contain the channel that was removed.
this.channels = channels
.filter { it.joinRef != channel.joinRef }
.filter { it.joinRef != channel.joinRef }
}

/**
Expand Down Expand Up @@ -449,7 +486,7 @@ class Socket(
val period = heartbeatIntervalMs

heartbeatTask =
dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() }
dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() }
}

internal fun sendHeartbeat() {
Expand All @@ -471,10 +508,11 @@ class Socket(
// The last heartbeat was acknowledged by the server. Send another one
this.pendingHeartbeatRef = this.makeRef()
this.push(
topic = "phoenix",
event = Channel.Event.HEARTBEAT.value,
payload = mapOf(),
ref = pendingHeartbeatRef)
topic = "phoenix",
event = Channel.Event.HEARTBEAT.value,
payload = mapOf(),
ref = pendingHeartbeatRef
)
}

private fun abnormalClose(reason: String) {
Expand Down Expand Up @@ -538,14 +576,17 @@ class Socket(

// Dispatch the message to all channels that belong to the topic
this.channels
.filter { it.isMember(message) }
.forEach { it.trigger(message) }
.filter { it.isMember(message) }
.forEach { it.trigger(message) }

// Inform all onMessage callbacks of the message
this.stateChangeCallbacks.message.forEach { it.second.invoke(message) }
}

internal fun onConnectionError(t: Throwable, response: Response?) {
internal fun onConnectionError(
t: Throwable,
response: Response?
) {
this.logItems("Transport: error $t")

// Send an error to all channels
Expand All @@ -554,5 +595,4 @@ class Socket(
// Inform any state callbacks of the error
this.stateChangeCallbacks.error.forEach { it.second.invoke(t, response) }
}

}
2 changes: 1 addition & 1 deletion src/test/kotlin/org/phoenixframework/ChannelTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class ChannelTest {

@BeforeEach
internal fun setUp() {
socket = spy(Socket(url ="https://localhost:4000/socket", client = okHttpClient))
socket = spy(Socket(url = "https://localhost:4000/socket", client = okHttpClient))
socket.dispatchQueue = fakeClock
channel = Channel("topic", kDefaultPayload, socket)
}
Expand Down
Loading