diff --git a/build.gradle b/build.gradle index 5088b85..7989f35 100644 --- a/build.gradle +++ b/build.gradle @@ -19,12 +19,12 @@ repositories { dependencies { compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8" compile "com.google.code.gson:gson:2.8.5" - compile "com.squareup.okhttp3:okhttp:3.10.0" + compile "com.squareup.okhttp3:okhttp:3.14.1" testCompile group: 'junit', name: 'junit', version: '4.12' - testCompile group: 'com.google.truth', name: 'truth', version: '0.42' - testCompile group: 'org.mockito', name: 'mockito-core', version: '2.19.1' + testCompile group: 'com.google.truth', name: 'truth', version: '0.44' + testCompile group: 'org.mockito', name: 'mockito-core', version: '2.27.0' testCompile group: 'com.nhaarman.mockitokotlin2', name: 'mockito-kotlin', version: '2.1.0' } diff --git a/src/main/kotlin/org/phoenixframework/Channel.kt b/src/main/kotlin/org/phoenixframework/Channel.kt new file mode 100644 index 0000000..a90ec1d --- /dev/null +++ b/src/main/kotlin/org/phoenixframework/Channel.kt @@ -0,0 +1,359 @@ +package org.phoenixframework + +import java.util.concurrent.ConcurrentLinkedQueue + +/** + * Represents a binding to a Channel event + */ +data class Binding( + val event: String, + val ref: Int, + val callback: (Message) -> Unit +) + +/** + * Represents a Channel bound to a given topic + */ +class Channel( + val topic: String, + var params: Payload, + internal val socket: Socket +) { + + //------------------------------------------------------------------------------ + // Channel Nested Enums + //------------------------------------------------------------------------------ + /** States of a Channel */ + enum class State() { + CLOSED, + ERRORED, + JOINED, + JOINING, + LEAVING + } + + /** Channel specific events */ + enum class Event(val value: String) { + HEARTBEAT("heartbeat"), + JOIN("phx_join"), + LEAVE("phx_leave"), + REPLY("phx_reply"), + ERROR("phx_error"), + CLOSE("phx_close"); + + companion object { + /** True if the event is one of Phoenix's channel lifecycle events */ + fun isLifecycleEvent(event: String): Boolean { + return when (event) { + JOIN.value, + LEAVE.value, + REPLY.value, + ERROR.value, + CLOSE.value -> true + else -> false + } + } + } + } + + //------------------------------------------------------------------------------ + // Channel Attributes + //------------------------------------------------------------------------------ + /** Current state of the Channel */ + internal var state: Channel.State + + /** Collection of event bindings. */ + internal val bindings: ConcurrentLinkedQueue + + /** Tracks event binding ref counters */ + internal var bindingRef: Int + + /** Timeout when attempting to join a Channel */ + internal var timeout: Long + + /** Set to true once the channel has attempted to join */ + var joinedOnce: Boolean + + /** Push to send then attempting to join */ + var joinPush: Push + + /** Buffer of Pushes that will be sent once the Channel's socket connects */ + var pushBuffer: MutableList + + /** Timer to attempt rejoins */ + var rejoinTimer: TimeoutTimer + + /** + * Optional onMessage hook that can be provided. Receives all event messages for specialized + * handling before dispatching to the Channel event callbacks. + */ + var onMessage: (Message) -> Message = { it } + + init { + this.state = State.CLOSED + this.bindings = ConcurrentLinkedQueue() + this.bindingRef = 0 + this.timeout = socket.timeout + this.joinedOnce = false + this.pushBuffer = mutableListOf() + this.rejoinTimer = TimeoutTimer( + scheduledExecutorService = socket.timerPool, + callback = { rejoinUntilConnected() }, + timerCalculation = Defaults.steppedBackOff) + + // Setup Push to be sent when joining + this.joinPush = Push( + channel = this, + event = Channel.Event.JOIN.value, + payload = params, + timeout = timeout) + + // Perform once the Channel has joined + this.joinPush.receive("ok") { + // Mark the Channel as joined + this.state = State.JOINED + + // Reset the timer, preventing it from attempting to join again + this.rejoinTimer.reset() + + // Send any buffered messages and clear the buffer + this.pushBuffer.forEach { it.send() } + this.pushBuffer.clear() + } + + // Perform if Channel timed out while attempting to join + this.joinPush.receive("timeout") { message -> + + // Only handle a timeout if the Channel is in the 'joining' state + if (!this.isJoining) return@receive + + this.socket.logItems("Channel: timeouts $topic, $joinRef after $timeout ms") + + // Send a Push to the server to leave the Channel + val leavePush = Push( + channel = this, + event = Channel.Event.LEAVE.value) + leavePush.send() + + // Mark the Channel as in an error and attempt to rejoin + this.state = State.ERRORED + this.joinPush.reset() + this.rejoinTimer.scheduleTimeout() + } + + // Clean up when the channel closes + this.onClose { + // Reset any timer that may be on-going + this.rejoinTimer.reset() + + // Log that the channel was left + this.socket.logItems("Channel: close $topic") + + // Mark the channel as closed and remove it from the socket + this.state = State.CLOSED + this.socket.remove(this) + } + + // Handles an error, attempts to rejoin + this.onError { + // Do not emit error if the channel is in the process of leaving + // or if it has already closed + if (this.isLeaving || this.isClosed) return@onError + + // Log that the channel received an error + this.socket.logItems("Channel: error $topic") + + // Mark the channel as errored and attempt to rejoin + this.state = State.ERRORED + this.rejoinTimer.scheduleTimeout() + } + + // Perform when the join reply is received + this.on(Event.REPLY) { message -> + this.trigger(replyEventName(message.ref), message.payload, message.ref, message.joinRef) + } + } + + //------------------------------------------------------------------------------ + // Public Properties + //------------------------------------------------------------------------------ + /** The ref sent during the join message. */ + val joinRef: String? get() = joinPush.ref + + /** @return True if the Channel can push messages */ + val canPush: Boolean + get() = this.socket.isConnected && this.isJoined + + /** @return: True if the Channel has been closed */ + val isClosed: Boolean + get() = state == State.CLOSED + + /** @return: True if the Channel experienced an error */ + val isErrored: Boolean + get() = state == State.ERRORED + + /** @return: True if the channel has joined */ + val isJoined: Boolean + get() = state == State.JOINED + + /** @return: True if the channel has requested to join */ + val isJoining: Boolean + get() = state == State.JOINING + + /** @return: True if the channel has requested to leave */ + val isLeaving: Boolean + get() = state == State.LEAVING + + //------------------------------------------------------------------------------ + // Public + //------------------------------------------------------------------------------ + fun join(timeout: Long = Defaults.TIMEOUT): Push { + // Ensure that `.join()` is called only once per Channel instance + if (joinedOnce) { + throw IllegalStateException( + "Tried to join channel multiple times. `join()` can only be called once per channel") + } + + // Join the channel + this.joinedOnce = true + this.rejoin(timeout) + return joinPush + } + + fun onClose(callback: (Message) -> Unit): Int { + return this.on(Event.CLOSE, callback) + } + + fun onError(callback: (Message) -> Unit): Int { + return this.on(Event.ERROR, callback) + } + + fun onMessage(callback: (Message) -> Message) { + this.onMessage = callback + } + + fun on(event: Channel.Event, callback: (Message) -> Unit): Int { + return this.on(event.value, callback) + } + + fun on(event: String, callback: (Message) -> Unit): Int { + val ref = bindingRef + this.bindingRef = ref + 1 + + this.bindings.add(Binding(event, ref, callback)) + return ref + } + + fun off(event: String, ref: Int? = null) { + this.bindings.removeAll { bind -> + bind.event == event && (ref == null || ref == bind.ref) + } + } + + fun push(event: String, payload: Payload, timeout: Long = Defaults.TIMEOUT): Push { + if (!joinedOnce) { + // If the Channel has not been joined, throw an exception + throw RuntimeException( + "Tried to push $event to $topic before joining. Use channel.join() before pushing events") + } + + val pushEvent = Push(this, event, payload, timeout) + + if (canPush) { + pushEvent.send() + } else { + pushEvent.startTimeout() + pushBuffer.add(pushEvent) + } + + return pushEvent + } + + fun leave(timeout: Long = Defaults.TIMEOUT): Push { + this.state = State.LEAVING + + // Perform the same behavior if the channel leaves successfully or not + val onClose: ((Message) -> Unit) = { + this.socket.logItems("Channel: leave $topic") + this.trigger(it) + } + + // Push event to send to the server + val leavePush = Push( + channel = this, + event = Event.LEAVE.value, + timeout = timeout) + + leavePush + .receive("ok", onClose) + .receive("timeout", onClose) + leavePush.send() + + // If the Channel cannot send push events, trigger a success locally + if (!canPush) leavePush.trigger("ok", hashMapOf()) + + return leavePush + } + + //------------------------------------------------------------------------------ + // Internal + //------------------------------------------------------------------------------ + /** Checks if a Message's event belongs to this Channel instance */ + internal fun isMember(message: Message): Boolean { + if (message.topic != this.topic) return false + + val isLifecycleEvent = Event.isLifecycleEvent(message.event) + + // If the message is a lifecycle event and it is not a join for this channel, drop the outdated message + if (message.joinRef != null && isLifecycleEvent && message.joinRef != this.joinRef) { + this.socket.logItems("Channel: Dropping outdated message. ${message.topic}") + return false + } + + return true + } + + internal fun trigger( + event: String, + payload: Payload = hashMapOf(), + ref: String = "", + joinRef: String? = null + ) { + this.trigger(Message(ref, topic, event, payload, joinRef)) + } + + internal fun trigger(message: Message) { + // Inform the onMessage hook of the message + val handledMessage = this.onMessage(message) + + // Inform all matching event bindings of the message + this.bindings + .filter { it.event == message.event } + .forEach { it.callback(handledMessage) } + } + + /** Create an event with a given ref */ + internal fun replyEventName(ref: String): String { + return "chan_reply_$ref" + } + + //------------------------------------------------------------------------------ + // Private + //------------------------------------------------------------------------------ + /** Will continually attempt to rejoin the Channel on a timer. */ + private fun rejoinUntilConnected() { + this.rejoinTimer.scheduleTimeout() + if (this.socket.isConnected) this.rejoin() + } + + /** Sends the Channel's joinPush to the Server */ + private fun sendJoin(timeout: Long) { + this.state = State.JOINING + this.joinPush.resend(timeout) + } + + /** Rejoins the Channel e.g. after a disconnect */ + private fun rejoin(timeout: Long = Defaults.TIMEOUT) { + this.sendJoin(timeout) + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Defaults.kt b/src/main/kotlin/org/phoenixframework/Defaults.kt new file mode 100644 index 0000000..6157833 --- /dev/null +++ b/src/main/kotlin/org/phoenixframework/Defaults.kt @@ -0,0 +1,26 @@ +package org.phoenixframework + +import com.google.gson.FieldNamingPolicy +import com.google.gson.Gson +import com.google.gson.GsonBuilder + +object Defaults { + + /** Default timeout of 10s */ + const val TIMEOUT: Long = 10_000 + + /** Default heartbeat interval of 30s */ + const val HEARTBEAT: Long = 30_000 + + /** Default reconnect algorithm. Reconnects after 1s, 2s, 5s and then 10s thereafter */ + val steppedBackOff: (Int) -> Long = { tries -> + if (tries > 3) 10000 else listOf(1000L, 2000L, 5000L)[tries - 1] + } + + /** The default Gson configuration to use when parsing messages */ + val gson: Gson + get() = GsonBuilder() + .setLenient() + .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + .create() +} \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Message.kt b/src/main/kotlin/org/phoenixframework/Message.kt new file mode 100644 index 0000000..053808f --- /dev/null +++ b/src/main/kotlin/org/phoenixframework/Message.kt @@ -0,0 +1,33 @@ +package org.phoenixframework + +import com.google.gson.annotations.SerializedName + +class Message( + /** The unique string ref. Empty if not present */ + @SerializedName("ref") + val ref: String = "", + + /** The message topic */ + @SerializedName("topic") + val topic: String = "", + + /** The message event name, for example "phx_join" or any other custom name */ + @SerializedName("event") + val event: String = "", + + /** The payload of the message */ + @SerializedName("payload") + val payload: Payload = HashMap(), + + /** The ref sent during a join event. Empty if not present. */ + @SerializedName("join_ref") + val joinRef: String? = null) { + + + /** + * Convenience var to access the message's payload's status. Equivalent + * to checking message.payload["status"] yourself + */ + val status: String? + get() = payload["status"] as? String +} diff --git a/src/main/kotlin/org/phoenixframework/PhxSocket.kt b/src/main/kotlin/org/phoenixframework/PhxSocket.kt index 1e6856a..d781542 100644 --- a/src/main/kotlin/org/phoenixframework/PhxSocket.kt +++ b/src/main/kotlin/org/phoenixframework/PhxSocket.kt @@ -11,11 +11,8 @@ import okhttp3.WebSocket import okhttp3.WebSocketListener import java.net.URL import java.util.Timer -import kotlin.collections.ArrayList -import kotlin.collections.HashMap import kotlin.concurrent.schedule -typealias Payload = Map /** Default timeout set to 10s */ const val DEFAULT_TIMEOUT: Long = 10000 @@ -23,8 +20,6 @@ const val DEFAULT_TIMEOUT: Long = 10000 /** Default heartbeat interval set to 30s */ const val DEFAULT_HEARTBEAT: Long = 30000 -/** The code used when the socket was closed without error */ -const val WS_CLOSE_NORMAL = 1000 /** The code used when the socket was closed after the heartbeat timer timed out */ const val WS_CLOSE_HEARTBEAT_ERROR = 5000 diff --git a/src/main/kotlin/org/phoenixframework/Presence.kt b/src/main/kotlin/org/phoenixframework/Presence.kt new file mode 100644 index 0000000..0b83b56 --- /dev/null +++ b/src/main/kotlin/org/phoenixframework/Presence.kt @@ -0,0 +1,285 @@ +package org.phoenixframework + +//------------------------------------------------------------------------------ +// Type Aliases +//------------------------------------------------------------------------------ +/** Meta details of a Presence. Just a dictionary of properties */ +typealias PresenceMeta = MutableMap + +/** A mapping of a String to an array of Metas. e.g. {"metas": [{id: 1}]} */ +typealias PresenceMap = MutableMap> + +/** A mapping of a Presence state to a mapping of Metas */ +typealias PresenceState = MutableMap + +/** + * Diff has keys "joins" and "leaves", pointing to a Presence.State each containing the users + * that joined and left. + */ +typealias PresenceDiff = MutableMap + +/** Closure signature of OnJoin callbacks */ +typealias OnJoin = (key: String, current: PresenceMap?, new: PresenceMap) -> Unit + +/** Closure signature for OnLeave callbacks */ +typealias OnLeave = (key: String, current: PresenceMap, left: PresenceMap) -> Unit + +/** Closure signature for OnSync callbacks */ +typealias OnSync = () -> Unit + +class Presence(channel: Channel, opts: Options = Options.defaults) { + + //------------------------------------------------------------------------------ + // Enums and Data classes + //------------------------------------------------------------------------------ + /** + * Custom options that can be provided when creating Presence + */ + data class Options(val events: Map) { + companion object { + + /** + * Default set of Options used when creating Presence. Uses the + * phoenix events "presence_state" and "presence_diff" + */ + val defaults: Options + get() = Options( + mapOf( + Events.STATE to "presence_state", + Events.DIFF to "presence_diff")) + } + } + + /** Collection of callbacks with default values */ + data class Caller( + var onJoin: OnJoin = { _, _, _ -> }, + var onLeave: OnLeave = { _, _, _ -> }, + var onSync: OnSync = {} + ) + + /** Presence Events of "state" and "diff" */ + enum class Events { + STATE, + DIFF + } + + //------------------------------------------------------------------------------ + // Properties + //------------------------------------------------------------------------------ + /** The channel the Presence belongs to */ + private val channel: Channel + + /** Caller to callback hooks */ + private val caller: Caller + + /** The state of the Presence */ + var state: PresenceState + private set + + /** Pending `join` and `leave` diffs that need to be synced */ + var pendingDiffs: MutableList + private set + + /** The channel's joinRef, set when state events occur */ + var joinRef: String? + private set + + /** True if the Presence has not yet initially synced */ + val isPendingSyncState: Boolean + get() = this.joinRef == null || (this.joinRef !== this.channel.joinRef) + + //------------------------------------------------------------------------------ + // Initialization + //------------------------------------------------------------------------------ + init { + this.state = mutableMapOf() + this.pendingDiffs = mutableListOf() + this.channel = channel + this.joinRef = null + this.caller = Presence.Caller() + + val stateEvent = opts.events[Events.STATE] + val diffEvent = opts.events[Events.DIFF] + + if (stateEvent != null && diffEvent != null) { + + this.channel.on(stateEvent) { message -> + val newState = message.payload.toMutableMap() as PresenceState + + this.joinRef = this.channel.joinRef + this.state = + Presence.syncState(state, newState, caller.onJoin, caller.onLeave) + + + this.pendingDiffs.forEach { diff -> + this.state = Presence.syncDiff(state, diff, caller.onJoin, caller.onLeave) + } + + this.pendingDiffs.clear() + this.caller.onSync() + } + + this.channel.on(diffEvent) { message -> + val diff = message.payload.toMutableMap() as PresenceDiff + if (isPendingSyncState) { + this.pendingDiffs.add(diff) + } else { + this.state = Presence.syncDiff(state, diff, caller.onJoin, caller.onLeave) + this.caller.onSync() + } + } + } + } + + //------------------------------------------------------------------------------ + // Callbacks + //------------------------------------------------------------------------------ + fun onJoin(callback: OnJoin) { + this.caller.onJoin = callback + } + + fun onLeave(callback: OnLeave) { + this.caller.onLeave = callback + } + + fun onSync(callback: OnSync) { + this.caller.onSync = callback + } + + //------------------------------------------------------------------------------ + // Listing + //------------------------------------------------------------------------------ + fun list(): List { + return this.listBy { it.value } + } + + fun listBy(transform: (Map.Entry) -> T): List { + return Presence.listBy(state, transform) + } + + fun filterBy(predicate: ((Map.Entry) -> Boolean)?): PresenceState { + return Presence.filter(state, predicate) + } + + //------------------------------------------------------------------------------ + // Syncing + //------------------------------------------------------------------------------ + companion object { + + /** + * Used to sync the list of presences on the server with the client's state. An optional + * `onJoin` and `onLeave` callback can be provided to react to changes in the client's local + * presences across disconnects and reconnects with the server. + * + */ + fun syncState( + currentState: PresenceState, + newState: PresenceState, + onJoin: OnJoin = { _, _, _ -> }, + onLeave: OnLeave = { _, _, _ -> } + ): PresenceState { + val state = currentState + val leaves: PresenceState = mutableMapOf() + val joins: PresenceState = mutableMapOf() + + state.forEach { key, presence -> + if (!newState.containsKey(key)) { + leaves[key] = presence + } + } + + newState.forEach { key, newPresence -> + state[key]?.let { currentPresence -> + val newRefs = newPresence["metas"]!!.map { meta -> meta["phx"] as String } + val curRefs = currentPresence["metas"]!!.map { meta -> meta["phx"] as String } + + val joinedMetas = newPresence["metas"]!!.filter { meta -> + curRefs.indexOf(meta["phx_ref"]) < 0 + } + val leftMetas = currentPresence["metas"]!!.filter { meta -> + newRefs.indexOf(meta["phx_ref"]) < 0 + } + + if (joinedMetas.isNotEmpty()) { + joins[key] = newPresence + joins[key]!!["metas"] = joinedMetas.toMutableList() + } + + if (leftMetas.isNotEmpty()) { + leaves[key] = currentPresence + leaves[key]!!["metas"] = leftMetas.toMutableList() + } + } ?: run { + joins[key] = newPresence + } + } + + val diff: PresenceDiff = mutableMapOf("joins" to joins, "leaves" to leaves) + return Presence.syncDiff(state, diff, onJoin, onLeave) + + } + + /** + * Used to sync a diff of presence join and leave events from the server, as they happen. + * Like `syncState`, `syncDiff` accepts optional `onJoin` and `onLeave` callbacks to react + * to a user joining or leaving from a device. + */ + fun syncDiff( + currentState: PresenceState, + diff: PresenceDiff, + onJoin: OnJoin = { _, _, _ -> }, + onLeave: OnLeave = { _, _, _ -> } + ): PresenceState { + val state = currentState + + // Sync the joined states and inform onJoin of new presence + diff["joins"]?.forEach { key, newPresence -> + val currentPresence = state[key] + state[key] = newPresence + + currentPresence?.let { curPresence -> + val joinedRefs = state[key]!!["metas"]!!.map { m -> m["phx_ref"] as String } + val curMetas = curPresence["metas"]!!.filter { m -> joinedRefs.indexOf(m["phx_ref"]) < 0 } + + state[key]!!["metas"]!!.addAll(0, curMetas) + } + + onJoin.invoke(key, currentPresence, newPresence) + } + + // Sync the left diff and inform onLeave of left presence + diff["leaves"]?.forEach { key, leftPresence -> + val curPresence = state[key] ?: return@forEach + + val refsToRemove = leftPresence["metas"]!!.map { it["phx_ref"] as String } + val keepMetas = + curPresence["metas"]!!.filter { m -> refsToRemove.indexOf(m["phx_ref"]) < 0 } + + curPresence["metas"] = keepMetas.toMutableList() + onLeave.invoke(key, curPresence, leftPresence) + + if (keepMetas.isNotEmpty()) { + state[key]!!["metas"] = keepMetas.toMutableList() + } else { + state.remove(key) + } + } + + return state + } + + fun filter( + presence: PresenceState, + predicate: ((Map.Entry) -> Boolean)? + ): PresenceState { + return presence.filter(predicate ?: { true }).toMutableMap() + } + + fun listBy( + presence: PresenceState, + transform: (Map.Entry) -> T + ): List { + return presence.map(transform) + } + } +} diff --git a/src/main/kotlin/org/phoenixframework/Push.kt b/src/main/kotlin/org/phoenixframework/Push.kt new file mode 100644 index 0000000..d0c11b6 --- /dev/null +++ b/src/main/kotlin/org/phoenixframework/Push.kt @@ -0,0 +1,175 @@ +package org.phoenixframework + +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit + +/** + * A Push represents an attempt to send a payload through a Channel for a specific event. + */ +class Push( + /** The channel the Push is being sent through */ + val channel: Channel, + /** The event the Push is targeting */ + val event: String, + /** The message to be sent */ + var payload: Payload = mapOf(), + /** Duration before the message is considered timed out and failed to send */ + var timeout: Long = Defaults.TIMEOUT +) { + + /** The server's response to the Push */ + var receivedMessage: Message? = null + + /** The task to be triggered if the Push times out */ + var timeoutTask: ScheduledFuture<*>? = null + + /** Hooks into a Push. Where .receive("ok", callback(Payload)) are stored */ + var receiveHooks: MutableMap Unit)>> = HashMap() + + /** True if the Push has been sent */ + var sent: Boolean = false + + /** The reference ID of the Push */ + var ref: String? = null + + /** The event that is associated with the reference ID of the Push */ + var refEvent: String? = null + + //------------------------------------------------------------------------------ + // Public + //------------------------------------------------------------------------------ + /** + * Resets and sends the Push + * @param timeout Optional. The push timeout. Default is 10_000ms = 10s + */ + fun resend(timeout: Long = Defaults.TIMEOUT) { + this.timeout = timeout + this.reset() + this.send() + } + + /** + * Sends the Push. If it has already timed out then the call will be ignored. use + * `resend(timeout:)` in this case. + */ + fun send() { + if (hasReceived("timeout")) return + + this.startTimeout() + this.sent = true + // TODO: this.channel.socket.push + // TODO: weak reference? + } + + /** + * Receive a specific event when sending an Outbound message + * + * Example: + * channel + * .send("event", myPayload) + * .receive("error") { } + */ + fun receive(status: String, callback: (Message) -> Unit): Push { + // If the message has already be received, pass it to the callback + receivedMessage?.let { if (hasReceived(status)) callback(it) } + + if (receiveHooks[status] == null) { + // Create a new array of hooks if no previous hook is associated with status + receiveHooks[status] = arrayListOf(callback) + } else { + // A previous hook for this status already exists. Just append the new hook + receiveHooks[status]?.add(callback) + } + + return this + } + + //------------------------------------------------------------------------------ + // Internal + //------------------------------------------------------------------------------ + /** Resets the Push as it was after it was first initialized. */ + internal fun reset() { + this.cancelRefEvent() + this.ref = null + this.refEvent = null + this.receivedMessage = null + this.sent = false + } + + /** + * Triggers an event to be sent through the Push's parent Channel + */ + internal fun trigger(status: String, payload: Payload) { + this.refEvent?.let { refEvent -> + val mutPayload = payload.toMutableMap() + mutPayload["status"] = status + + this.channel.trigger(refEvent, mutPayload) + } + } + + /** + * Schedules a timeout task which will be triggered after a specific timeout is reached + */ + internal fun startTimeout() { + // Cancel any existing timeout before starting a new one + this.timeoutTask?.let { if (!it.isCancelled) this.cancelTimeout() } + + // Get the ref of the Push + val ref = this.channel.socket.makeRef() + val refEvent = this.channel.replyEventName(ref) + + this.ref = ref + this.refEvent = refEvent + + // Subscribe to a reply from the server when the Push is received + this.channel.on(refEvent) { message -> + this.cancelRefEvent() + this.cancelTimeout() + this.receivedMessage = message + + // Check if there is an event receive hook to be informed + message.status?.let { status -> matchReceive(status, message) } + } + + // Setup and start the Timer + this.timeoutTask = channel.socket.timerPool.schedule({ + this.trigger("timeout", hashMapOf()) + }, timeout, TimeUnit.MILLISECONDS) + } + + + //------------------------------------------------------------------------------ + // Private + //------------------------------------------------------------------------------ + /** + * Finds the receiveHook which needs to be informed of a status response and passes it the message + * + * @param status Status which was received. e.g. "ok", "error", etc. + * @param message Message to pass to receive hook + */ + private fun matchReceive(status: String, message: Message) { + receiveHooks[status]?.forEach { it(message) } + } + + /** Removes receive hook from Channel regarding this Push */ + private fun cancelRefEvent() { + this.refEvent?.let { /* TODO: this.channel.off(it) */ } + } + + /** Cancels any ongoing timeout task */ + private fun cancelTimeout() { + this.timeoutTask?.cancel(true) + this.timeoutTask = null + } + + + + /** + * @param status Status to check if it has been received + * @return True if the status has already been received by the Push + */ + private fun hasReceived(status: String): Boolean { + return receivedMessage?.status == status + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt index 8832201..ee19905 100644 --- a/src/main/kotlin/org/phoenixframework/Socket.kt +++ b/src/main/kotlin/org/phoenixframework/Socket.kt @@ -1,6 +1,13 @@ package org.phoenixframework +import com.google.gson.Gson +import okhttp3.HttpUrl +import okhttp3.OkHttpClient +import okhttp3.Response +import java.net.URL +import java.util.concurrent.ScheduledFuture import java.util.concurrent.ScheduledThreadPoolExecutor +import java.util.concurrent.TimeUnit // Copyright (c) 2019 Daniel Rees // @@ -22,14 +29,424 @@ import java.util.concurrent.ScheduledThreadPoolExecutor // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -class Socket { +/** Alias for a JSON mapping */ +typealias Payload = Map - /** - * All timers associated with a socket will share the same pool. Used for every Channel or - * Push that is sent through or created by a Socket instance. Different Socket instances will - * create individual thread pools. - */ - private val timerPool = ScheduledThreadPoolExecutor(8) +/** Data class that holds callbacks assigned to the socket */ +internal data class StateChangeCallbacks( + val open: MutableList<() -> Unit> = ArrayList(), + val close: MutableList<() -> Unit> = ArrayList(), + val error: MutableList<(Throwable, Response?) -> Unit> = ArrayList(), + val message: MutableList<(Message) -> Unit> = ArrayList() +) { + /** Clears all stored callbacks */ + fun release() { + open.clear() + close.clear() + error.clear() + message.clear() + } +} +/** The code used when the socket was closed without error */ +const val WS_CLOSE_NORMAL = 1000 + +/** + * Connects to a Phoenix Server + */ +class Socket( + url: String, + params: Payload? = null, + private val gson: Gson = Defaults.gson, + private val client: OkHttpClient = OkHttpClient.Builder().build() +) { + + //------------------------------------------------------------------------------ + // Public Attributes + //------------------------------------------------------------------------------ + /** + * The string WebSocket endpoint (ie `"ws://example.com/socket"`, + * `"wss://example.com"`, etc.) that was passed to the Socket during + * initialization. The URL endpoint will be modified by the Socket to + * include `"/websocket"` if missing. + */ + val endpoint: String + + /** The fully qualified socket URL */ + val endpointUrl: URL + + /** + * The optional params to pass when connecting. Must be set when + * initializing the Socket. These will be appended to the URL. + */ + val params: Payload? = params + + /** Timeout to use when opening a connection */ + var timeout: Long = Defaults.TIMEOUT + + /** Interval between sending a heartbeat */ + var heartbeatInterval: Long = Defaults.HEARTBEAT + + /** Internval between socket reconnect attempts */ + var reconnectAfterMs: ((Int) -> Long) = Defaults.steppedBackOff + + /** The optional function to receive logs */ + var logger: ((String) -> Unit)? = null + + /** Disables heartbeats from being sent. Default is false. */ + var skipHeartbeat: Boolean = false + + //------------------------------------------------------------------------------ + // Internal Attributes + //------------------------------------------------------------------------------ + /** + * All timers associated with a socket will share the same pool. Used for every Channel or + * Push that is sent through or created by a Socket instance. Different Socket instances will + * create individual thread pools. + */ + internal val timerPool = ScheduledThreadPoolExecutor(8) + + //------------------------------------------------------------------------------ + // Private Attributes + //------------------------------------------------------------------------------ + /** Returns the type of transport to use. Potentially expose for custom transports */ + private val transport: (URL) -> Transport = { WebSocketTransport(it, client) } + + /** Collection of callbacks for socket state changes */ + private val stateChangeCallbacks: StateChangeCallbacks = StateChangeCallbacks() + + /** Collection of unclosed channels created by the Socket */ + private var channels: MutableList = ArrayList() + + /** Buffers messages that need to be sent once the socket has connected */ + private var sendBuffer: MutableList<() -> Unit> = ArrayList() + + /** Ref counter for messages */ + private var ref: Int = 0 + + /** Task to be triggered in the future to send a heartbeat message */ + private var heartbeatTask: ScheduledFuture<*>? = null + + /** Ref counter for the last heartbeat that was sent */ + private var pendingHeartbeatRef: String? = null + + /** Timer to use when attempting to reconnect */ + private var reconnectTimer: TimeoutTimer + + //------------------------------------------------------------------------------ + // Connection Attributes + //------------------------------------------------------------------------------ + /** The underlying WebSocket connection */ + private var connection: Transport? = null + + //------------------------------------------------------------------------------ + // Initialization + //------------------------------------------------------------------------------ + init { + // Silently replace web socket URLs with HTTP URLs. + var mutableUrl = url + if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) { + mutableUrl = "http:" + url.substring(3) + } else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) { + mutableUrl = "https:" + url.substring(4) + } + + // Ensure that the URL ends with "/websocket" + if (!mutableUrl.contains("/websocket")) { + // Do not duplicate '/' in path + if (mutableUrl.last() != '/') { + mutableUrl += "/" + } + + // append "websocket" to the path + mutableUrl += "websocket" + } + + // If there are query params, append them now + var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url") + params?.let { + val httpBuilder = httpUrl.newBuilder() + it.forEach { (key, value) -> + httpBuilder.addQueryParameter(key, value.toString()) + } + + httpUrl = httpBuilder.build() + } + + this.endpoint = mutableUrl + this.endpointUrl = httpUrl.url() + + // Create reconnect timer + this.reconnectTimer = TimeoutTimer( + scheduledExecutorService = timerPool, + timerCalculation = reconnectAfterMs, + callback = { + // log(socket attempting to reconnect) + // this.teardown() { this.connect() } + }) + } + + //------------------------------------------------------------------------------ + // Public Properties + //------------------------------------------------------------------------------ + /** @return The socket protocol being used. e.g. "wss", "ws" */ + val protocol: String + get() = when (endpointUrl.protocol) { + "https" -> "wss" + "http" -> "ws" + else -> endpointUrl.protocol + } + + /** @return True if the connection exists and is open */ + val isConnected: Boolean + get() = this.connection?.readyState == ReadyState.OPEN + + //------------------------------------------------------------------------------ + // Public + //------------------------------------------------------------------------------ + fun connect() { + // Do not attempt to connect if already connected + if (isConnected) return + + this.connection = this.transport(endpointUrl) + this.connection?.onOpen = { onConnectionOpened() } + this.connection?.onClose = { code -> onConnectionClosed(code) } + this.connection?.onError = { t, r -> onConnectionError(t, r) } + this.connection?.onMessage = { m -> onConnectionMessage(m) } + this.connection?.connect() + } + + fun disconnect( + code: Int = WS_CLOSE_NORMAL, + reason: String? = null, + callback: (() -> Unit)? = null + ) { + this.reconnectTimer.reset() + this.teardown(code, reason, callback) + + } + + fun onOpen(callback: (() -> Unit)) { + this.stateChangeCallbacks.open.add(callback) + } + + fun onClose(callback: () -> Unit) { + this.stateChangeCallbacks.close.add(callback) + } + + fun onError(callback: (Throwable, Response?) -> Unit) { + this.stateChangeCallbacks.error.add(callback) + } + + fun onMessage(callback: (Message) -> Unit) { + this.stateChangeCallbacks.message.add(callback) + } + + fun removeAllCallbacks() { + this.stateChangeCallbacks.release() + } + + fun channel(topic: String, params: Payload = mapOf()): Channel { + val channel = Channel(topic, params, this) + this.channels.add(channel) + + return channel + } + + fun remove(channel: Channel) { + this.channels.removeAll { it.joinRef == channel.joinRef } + } + + //------------------------------------------------------------------------------ + // Internal + //------------------------------------------------------------------------------ + internal fun push( + topic: String, + event: String, + payload: Payload, + ref: String? = null, + joinRef: String? = null + ) { + + val callback: (() -> Unit) = { + val body = mutableMapOf() + body["topic"] = topic + body["event"] = event + body["payload"] = payload + + ref?.let { body["ref"] = it } + joinRef?.let { body["join_ref"] = it } + + val data = gson.toJson(body) + connection?.let { transport -> + this.logItems("Push: Sending $data") + transport.send(data) + } + } + + if (isConnected) { + // If the socket is connected, then execute the callback immediately. + callback.invoke() + } else { + // If the socket is not connected, add the push to a buffer which will + // be sent immediately upon connection. + sendBuffer.add(callback) + } + } + + /** @return the next message ref, accounting for overflows */ + internal fun makeRef(): String { + this.ref = if (ref == Int.MAX_VALUE) 0 else ref + 1 + return ref.toString() + } + + fun logItems(body: String) { + logger?.let { + it(body) + } + } + + //------------------------------------------------------------------------------ + // Private + //------------------------------------------------------------------------------ + private fun teardown( + code: Int = WS_CLOSE_NORMAL, + reason: String? = null, + callback: (() -> Unit)? = null + ) { + // Disconnect the transport + this.connection?.onClose = null + this.connection?.disconnect(code, reason) + this.connection = null + + // Heartbeats are no longer needed + this.heartbeatTask?.cancel(true) + this.heartbeatTask = null + + // Since the connections onClose was null'd out, inform all state callbacks + // that the Socket has closed + this.stateChangeCallbacks.close.forEach { it.invoke() } + callback?.invoke() + } + + /** Triggers an error event to all connected Channels */ + private fun triggerChannelError() { + this.channels.forEach { it.trigger(Channel.Event.ERROR.value) } + } + + /** Send all messages that were buffered before the socket opened */ + private fun flushSendBuffer() { + if (isConnected && sendBuffer.isNotEmpty()) { + this.sendBuffer.forEach { it.invoke() } + this.sendBuffer.clear() + } + } + + //------------------------------------------------------------------------------ + // Heartbeat + //------------------------------------------------------------------------------ + private fun resetHeartbeat() { + // Clear anything related to the previous heartbeat + this.pendingHeartbeatRef = null + this.heartbeatTask?.cancel(true) + this.heartbeatTask = null + + // Do not start up the heartbeat timer if skipHeartbeat is true + if (skipHeartbeat) return + heartbeatTask = timerPool.schedule({ + + }, heartbeatInterval, TimeUnit.MILLISECONDS) + } + + private fun sendHeartbeat() { + // Do not send if the connection is closed + if (!isConnected) return + + // If there is a pending heartbeat ref, then the last heartbeat was + // never acknowledged by the server. Close the connection and attempt + // to reconnect. + pendingHeartbeatRef?.let { + pendingHeartbeatRef = null + logItems("Transport: Heartbeat timeout. Attempt to re-establish connection") + + // Disconnect the socket manually. Do not use `teardown` or + // `disconnect` as they will nil out the websocket delegate + this.connection?.disconnect(WS_CLOSE_NORMAL, "Heartbeat timed out") + return + } + + // The last heartbeat was acknowledged by the server. Send another one + this.pendingHeartbeatRef = this.makeRef() + this.push( + topic = "phoenix", + event = Channel.Event.HEARTBEAT.value, + payload = mapOf(), + ref = pendingHeartbeatRef) + } + + //------------------------------------------------------------------------------ + // Connection Transport Hooks + //------------------------------------------------------------------------------ + private fun onConnectionOpened() { + this.logItems("Transport: Connected to $endpoint") + + // Send any messages that were waiting for a connection + this.flushSendBuffer() + + // Reset how the socket tried to reconnect + this.reconnectTimer.reset() + + // Restart the heartbeat timer + this.resetHeartbeat() + + // Inform all onOpen callbacks that the Socket has opened + this.stateChangeCallbacks.open.forEach { it.invoke() } + } + + private fun onConnectionClosed(code: Int) { + this.logItems("Transport: close") + this.triggerChannelError() + + // Prevent the heartbeat from triggering if the socket closed + this.heartbeatTask?.cancel(true) + this.heartbeatTask = null + + // Inform callbacks the socket closed + this.stateChangeCallbacks.close.forEach { it.invoke() } + + // If there was a non-normal event when the connection closed, attempt + // to schedule a reconnect attempt + if (code != WS_CLOSE_NORMAL) { + reconnectTimer.scheduleTimeout() + } + } + + private fun onConnectionMessage(rawMessage: String) { + this.logItems("Receive: $rawMessage") + + // Parse the message as JSON + val message = gson.fromJson(rawMessage, Message::class.java) + + // Clear heartbeat ref, preventing a heartbeat timeout disconnect + if (message.ref == pendingHeartbeatRef) pendingHeartbeatRef = null + + // Dispatch the message to all channels that belong to the topic + this.channels + .filter { it.isMember(message) } + .forEach { it.trigger(message) } + + // Inform all onMessage callbacks of the message + this.stateChangeCallbacks.message.forEach { it.invoke(message) } + } + + private fun onConnectionError(t: Throwable, response: Response?) { + this.logItems("Transport: error $t") + + // Send an error to all channels + this.triggerChannelError() + + // Inform any state callbacks of the error + this.stateChangeCallbacks.error.forEach { it.invoke(t, response) } + } } \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Transport.kt b/src/main/kotlin/org/phoenixframework/Transport.kt new file mode 100644 index 0000000..556b195 --- /dev/null +++ b/src/main/kotlin/org/phoenixframework/Transport.kt @@ -0,0 +1,86 @@ +package org.phoenixframework + +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import java.net.URL + +enum class ReadyState { + CONNECTING, + OPEN, + CLOSING, + CLOSED +} + +interface Transport { + + val readyState: ReadyState + + var onOpen: (() -> Unit)? + var onError: ((Throwable, Response?) -> Unit)? + var onMessage: ((String) -> Unit)? + var onClose: ((Int) -> Unit)? + + fun connect() + fun disconnect(code: Int, reason: String? = null) + fun send(data: String) +} + +class WebSocketTransport( + private val url: URL, + private val okHttpClient: OkHttpClient +) : + WebSocketListener(), + Transport { + + private var connection: WebSocket? = null + + override var readyState: ReadyState = ReadyState.CLOSED + override var onOpen: (() -> Unit)? = null + override var onError: ((Throwable, Response?) -> Unit)? = null + override var onMessage: ((String) -> Unit)? = null + override var onClose: ((Int) -> Unit)? = null + + override fun connect() { + this.readyState = ReadyState.CONNECTING + val request = Request.Builder().url(url).build() + connection = okHttpClient.newWebSocket(request, this) + } + + override fun disconnect(code: Int, reason: String?) { + connection?.close(code, reason) + connection = null + } + + override fun send(data: String) { + connection?.send(data) + } + + //------------------------------------------------------------------------------ + // WebSocket Listener + //------------------------------------------------------------------------------ + override fun onOpen(webSocket: WebSocket, response: Response) { + this.readyState = ReadyState.OPEN + this.onOpen?.invoke() + } + + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + this.readyState = ReadyState.CLOSED + this.onError?.invoke(t, response) + } + + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + this.readyState = ReadyState.CLOSING + } + + override fun onMessage(webSocket: WebSocket, text: String) { + this.onMessage?.invoke(text) + } + + override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { + this.readyState = ReadyState.CLOSED + this.onClose?.invoke(code) + } +} \ No newline at end of file