diff --git a/.travis.yml b/.travis.yml
index 22da5ee..1800eea 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,4 +1,7 @@
language: java
+dist: trusty
+after_success:
+ - bash <(curl -s https://codecov.io/bash)
jdk:
- oraclejdk8
sudo: false
diff --git a/ChatExample/.gitignore b/ChatExample/.gitignore
new file mode 100644
index 0000000..2b75303
--- /dev/null
+++ b/ChatExample/.gitignore
@@ -0,0 +1,13 @@
+*.iml
+.gradle
+/local.properties
+/.idea/caches
+/.idea/libraries
+/.idea/modules.xml
+/.idea/workspace.xml
+/.idea/navEditor.xml
+/.idea/assetWizardSettings.xml
+.DS_Store
+/build
+/captures
+.externalNativeBuild
diff --git a/ChatExample/app/.gitignore b/ChatExample/app/.gitignore
new file mode 100644
index 0000000..796b96d
--- /dev/null
+++ b/ChatExample/app/.gitignore
@@ -0,0 +1 @@
+/build
diff --git a/ChatExample/app/build.gradle b/ChatExample/app/build.gradle
new file mode 100644
index 0000000..18dd1a7
--- /dev/null
+++ b/ChatExample/app/build.gradle
@@ -0,0 +1,53 @@
+apply plugin: 'com.android.application'
+
+apply plugin: 'kotlin-android'
+
+apply plugin: 'kotlin-android-extensions'
+
+android {
+ compileSdkVersion 30
+ defaultConfig {
+ applicationId "com.github.dsrees.chatexample"
+ minSdkVersion 19
+ targetSdkVersion 30
+ versionCode 1
+ versionName "1.0"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+ }
+ }
+
+ compileOptions {
+ targetCompatibility = "8"
+ sourceCompatibility = "8"
+ }
+}
+
+dependencies {
+ /*
+ To update the JavaPhoenixClient, either use the latest dependency from mavenCentral()
+ OR run
+ `./gradlew jar`
+ and copy
+ `/build/lib/*.jar` to `/ChatExample/app/libs`
+ and comment out the mavenCentral() dependency
+ */
+ implementation fileTree(dir: 'libs', include: ['*.jar'])
+// implementation 'com.github.dsrees:JavaPhoenixClient:0.3.4'
+
+
+ implementation "com.google.code.gson:gson:2.8.5"
+ implementation "com.squareup.okhttp3:okhttp:3.12.2"
+
+
+ implementation"org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
+
+ implementation 'androidx.appcompat:appcompat:1.0.2'
+ implementation 'androidx.recyclerview:recyclerview:1.0.0'
+ implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
+
+}
diff --git a/ChatExample/app/libs/JavaPhoenixClient-0.7.0.jar b/ChatExample/app/libs/JavaPhoenixClient-0.7.0.jar
new file mode 100644
index 0000000..6dbc437
Binary files /dev/null and b/ChatExample/app/libs/JavaPhoenixClient-0.7.0.jar differ
diff --git a/ChatExample/app/proguard-rules.pro b/ChatExample/app/proguard-rules.pro
new file mode 100644
index 0000000..f1b4245
--- /dev/null
+++ b/ChatExample/app/proguard-rules.pro
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
diff --git a/ChatExample/app/src/main/AndroidManifest.xml b/ChatExample/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000..15f50d6
--- /dev/null
+++ b/ChatExample/app/src/main/AndroidManifest.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt
new file mode 100644
index 0000000..6e225dd
--- /dev/null
+++ b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt
@@ -0,0 +1,140 @@
+package com.github.dsrees.chatexample
+
+import androidx.appcompat.app.AppCompatActivity
+import android.os.Bundle
+import android.util.Log
+import android.widget.ArrayAdapter
+import android.widget.Button
+import android.widget.EditText
+import androidx.recyclerview.widget.LinearLayoutManager
+import kotlinx.android.synthetic.main.activity_main.*
+import org.phoenixframework.Channel
+import org.phoenixframework.Socket
+
+class MainActivity : AppCompatActivity() {
+
+ companion object {
+ const val TAG = "MainActivity"
+ }
+
+ private val messagesAdapter = MessagesAdapter()
+ private val layoutManager = LinearLayoutManager(this)
+
+
+ // Use when connecting to https://github.com/dwyl/phoenix-chat-example
+ // private val socket = Socket("https://phxchat.herokuapp.com/socket/websocket")
+ // private val topic = "room:lobby"
+
+ // Use when connecting to local server
+ private val socket = Socket("ws://10.0.2.2:4000/socket/websocket")
+ private val topic = "rooms:lobby"
+
+ private var lobbyChannel: Channel? = null
+
+ private val username: String
+ get() = username_input.text.toString()
+
+ private val message: String
+ get() = message_input.text.toString()
+
+ override fun onCreate(savedInstanceState: Bundle?) {
+ super.onCreate(savedInstanceState)
+ setContentView(R.layout.activity_main)
+
+
+ layoutManager.stackFromEnd = true
+
+ messages_recycler_view.layoutManager = layoutManager
+ messages_recycler_view.adapter = messagesAdapter
+
+ socket.onOpen {
+ this.addText("Socket Opened")
+ runOnUiThread { connect_button.text = "Disconnect" }
+ }
+
+ socket.onClose {
+ this.addText("Socket Closed")
+ runOnUiThread { connect_button.text = "Connect" }
+ }
+
+ socket.onError { throwable, response ->
+ Log.e(TAG, "Socket Errored $response", throwable)
+ this.addText("Socket Error")
+ }
+
+ socket.logger = {
+ Log.d(TAG, "SOCKET $it")
+ }
+
+
+ connect_button.setOnClickListener {
+ if (socket.isConnected) {
+ this.disconnectAndLeave()
+ } else {
+ this.disconnectAndLeave()
+ this.connectAndJoin()
+ }
+ }
+
+ send_button.setOnClickListener { sendMessage() }
+ }
+
+ private fun sendMessage() {
+ val payload = mapOf("user" to username, "body" to message)
+ this.lobbyChannel?.push("new:msg", payload)
+ ?.receive("ok") { Log.d(TAG, "success $it") }
+ ?.receive("error") { Log.d(TAG, "error $it") }
+
+ message_input.text.clear()
+ }
+
+ private fun disconnectAndLeave() {
+ // Be sure the leave the channel or call socket.remove(lobbyChannel)
+ lobbyChannel?.leave()
+ socket.disconnect { this.addText("Socket Disconnected") }
+ }
+
+ private fun connectAndJoin() {
+ val channel = socket.channel(topic, mapOf("status" to "joining"))
+ channel.on("join") {
+ this.addText("You joined the room")
+ }
+
+ channel.on("new:msg") { message ->
+ val payload = message.payload
+ val username = payload["user"] as? String
+ val body = payload["body"]
+
+
+ if (username != null && body != null) {
+ this.addText("[$username] $body")
+ }
+ }
+
+ channel.on("user:entered") {
+ this.addText("[anonymous entered]")
+ }
+
+ this.lobbyChannel = channel
+ channel
+ .join()
+ .receive("ok") {
+ this.addText("Joined Channel")
+ }
+ .receive("error") {
+ this.addText("Failed to join channel: ${it.payload}")
+ }
+
+
+ this.socket.connect()
+ }
+
+ private fun addText(message: String) {
+ runOnUiThread {
+ this.messagesAdapter.add(message)
+ layoutManager.smoothScrollToPosition(messages_recycler_view, null, messagesAdapter.itemCount)
+ }
+
+ }
+
+}
diff --git a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt
new file mode 100644
index 0000000..e99b294
--- /dev/null
+++ b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt
@@ -0,0 +1,34 @@
+package com.github.dsrees.chatexample
+
+import android.view.LayoutInflater
+import android.view.View
+import android.view.ViewGroup
+import android.widget.TextView
+import androidx.recyclerview.widget.RecyclerView
+
+class MessagesAdapter : RecyclerView.Adapter() {
+
+ private var messages: MutableList = mutableListOf()
+
+ fun add(message: String) {
+ messages.add(message)
+ notifyItemInserted(messages.size)
+ }
+
+
+ override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
+ val view = LayoutInflater.from(parent.context).inflate(R.layout.item_message, parent, false)
+ return ViewHolder(view)
+ }
+
+ override fun getItemCount(): Int = messages.size
+
+ override fun onBindViewHolder(holder: ViewHolder, position: Int) {
+ holder.label.text = messages[position]
+ }
+
+ inner class ViewHolder(itemView: View) : RecyclerView.ViewHolder(itemView) {
+ val label: TextView = itemView.findViewById(R.id.item_message_label)
+ }
+
+}
\ No newline at end of file
diff --git a/ChatExample/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/ChatExample/app/src/main/res/drawable-v24/ic_launcher_foreground.xml
new file mode 100644
index 0000000..6348baa
--- /dev/null
+++ b/ChatExample/app/src/main/res/drawable-v24/ic_launcher_foreground.xml
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ChatExample/app/src/main/res/drawable/ic_launcher_background.xml b/ChatExample/app/src/main/res/drawable/ic_launcher_background.xml
new file mode 100644
index 0000000..a0ad202
--- /dev/null
+++ b/ChatExample/app/src/main/res/drawable/ic_launcher_background.xml
@@ -0,0 +1,74 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ChatExample/app/src/main/res/layout/activity_main.xml b/ChatExample/app/src/main/res/layout/activity_main.xml
new file mode 100644
index 0000000..bc78d34
--- /dev/null
+++ b/ChatExample/app/src/main/res/layout/activity_main.xml
@@ -0,0 +1,65 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/ChatExample/app/src/main/res/layout/item_message.xml b/ChatExample/app/src/main/res/layout/item_message.xml
new file mode 100644
index 0000000..da57d9a
--- /dev/null
+++ b/ChatExample/app/src/main/res/layout/item_message.xml
@@ -0,0 +1,6 @@
+
+
diff --git a/ChatExample/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/ChatExample/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml
new file mode 100644
index 0000000..bbd3e02
--- /dev/null
+++ b/ChatExample/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/ChatExample/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/ChatExample/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml
new file mode 100644
index 0000000..bbd3e02
--- /dev/null
+++ b/ChatExample/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/ChatExample/app/src/main/res/mipmap-hdpi/ic_launcher.png b/ChatExample/app/src/main/res/mipmap-hdpi/ic_launcher.png
new file mode 100644
index 0000000..898f3ed
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-hdpi/ic_launcher.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-hdpi/ic_launcher_round.png b/ChatExample/app/src/main/res/mipmap-hdpi/ic_launcher_round.png
new file mode 100644
index 0000000..dffca36
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-hdpi/ic_launcher_round.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-mdpi/ic_launcher.png b/ChatExample/app/src/main/res/mipmap-mdpi/ic_launcher.png
new file mode 100644
index 0000000..64ba76f
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-mdpi/ic_launcher.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/ChatExample/app/src/main/res/mipmap-mdpi/ic_launcher_round.png
new file mode 100644
index 0000000..dae5e08
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-xhdpi/ic_launcher.png b/ChatExample/app/src/main/res/mipmap-xhdpi/ic_launcher.png
new file mode 100644
index 0000000..e5ed465
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-xhdpi/ic_launcher.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png b/ChatExample/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png
new file mode 100644
index 0000000..14ed0af
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-xxhdpi/ic_launcher.png b/ChatExample/app/src/main/res/mipmap-xxhdpi/ic_launcher.png
new file mode 100644
index 0000000..b0907ca
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-xxhdpi/ic_launcher.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png b/ChatExample/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png
new file mode 100644
index 0000000..d8ae031
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png b/ChatExample/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png
new file mode 100644
index 0000000..2c18de9
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png differ
diff --git a/ChatExample/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png b/ChatExample/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png
new file mode 100644
index 0000000..beed3cd
Binary files /dev/null and b/ChatExample/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png differ
diff --git a/ChatExample/app/src/main/res/values/colors.xml b/ChatExample/app/src/main/res/values/colors.xml
new file mode 100644
index 0000000..69b2233
--- /dev/null
+++ b/ChatExample/app/src/main/res/values/colors.xml
@@ -0,0 +1,6 @@
+
+
+ #008577
+ #00574B
+ #D81B60
+
diff --git a/ChatExample/app/src/main/res/values/strings.xml b/ChatExample/app/src/main/res/values/strings.xml
new file mode 100644
index 0000000..dcd2a67
--- /dev/null
+++ b/ChatExample/app/src/main/res/values/strings.xml
@@ -0,0 +1,3 @@
+
+ ChatExample
+
diff --git a/ChatExample/app/src/main/res/values/styles.xml b/ChatExample/app/src/main/res/values/styles.xml
new file mode 100644
index 0000000..5885930
--- /dev/null
+++ b/ChatExample/app/src/main/res/values/styles.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
diff --git a/ChatExample/app/src/main/res/xml/network_security_config.xml b/ChatExample/app/src/main/res/xml/network_security_config.xml
new file mode 100644
index 0000000..2c06eaa
--- /dev/null
+++ b/ChatExample/app/src/main/res/xml/network_security_config.xml
@@ -0,0 +1,6 @@
+
+
+
+ localhost
+
+
\ No newline at end of file
diff --git a/ChatExample/build.gradle b/ChatExample/build.gradle
new file mode 100644
index 0000000..5ff4b70
--- /dev/null
+++ b/ChatExample/build.gradle
@@ -0,0 +1,28 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+
+buildscript {
+ ext.kotlin_version = '1.3.31'
+ repositories {
+ google()
+ mavenCentral()
+
+ }
+ dependencies {
+ classpath 'com.android.tools.build:gradle:3.4.1'
+ classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
+ }
+}
+
+allprojects {
+ repositories {
+ google()
+ mavenCentral()
+
+ }
+}
+
+task clean(type: Delete) {
+ delete rootProject.buildDir
+}
diff --git a/ChatExample/gradle.properties b/ChatExample/gradle.properties
new file mode 100644
index 0000000..23339e0
--- /dev/null
+++ b/ChatExample/gradle.properties
@@ -0,0 +1,21 @@
+# Project-wide Gradle settings.
+# IDE (e.g. Android Studio) users:
+# Gradle settings configured through the IDE *will override*
+# any settings specified in this file.
+# For more details on how to configure your build environment visit
+# http://www.gradle.org/docs/current/userguide/build_environment.html
+# Specifies the JVM arguments used for the daemon process.
+# The setting is particularly useful for tweaking memory settings.
+org.gradle.jvmargs=-Xmx1536m
+# When configured, Gradle will run in incubating parallel mode.
+# This option should only be used with decoupled projects. More details, visit
+# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
+# org.gradle.parallel=true
+# AndroidX package structure to make it clearer which packages are bundled with the
+# Android operating system, and which are packaged with your app's APK
+# https://developer.android.com/topic/libraries/support-library/androidx-rn
+android.useAndroidX=true
+# Automatically convert third-party libraries to use AndroidX
+android.enableJetifier=true
+# Kotlin code style for this project: "official" or "obsolete":
+kotlin.code.style=official
diff --git a/ChatExample/gradle/wrapper/gradle-wrapper.jar b/ChatExample/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000..f6b961f
Binary files /dev/null and b/ChatExample/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/ChatExample/gradle/wrapper/gradle-wrapper.properties b/ChatExample/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000..3c4d55d
--- /dev/null
+++ b/ChatExample/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Tue May 14 11:28:07 EDT 2019
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-5.1.1-all.zip
diff --git a/ChatExample/gradlew b/ChatExample/gradlew
new file mode 100755
index 0000000..cccdd3d
--- /dev/null
+++ b/ChatExample/gradlew
@@ -0,0 +1,172 @@
+#!/usr/bin/env sh
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+ echo "$*"
+}
+
+die () {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+ NONSTOP* )
+ nonstop=true
+ ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=$((i+1))
+ done
+ case $i in
+ (0) set -- ;;
+ (1) set -- "$args0" ;;
+ (2) set -- "$args0" "$args1" ;;
+ (3) set -- "$args0" "$args1" "$args2" ;;
+ (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Escape application args
+save () {
+ for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+ echo " "
+}
+APP_ARGS=$(save "$@")
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
+if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
+ cd "$(dirname "$0")"
+fi
+
+exec "$JAVACMD" "$@"
diff --git a/ChatExample/gradlew.bat b/ChatExample/gradlew.bat
new file mode 100644
index 0000000..e95643d
--- /dev/null
+++ b/ChatExample/gradlew.bat
@@ -0,0 +1,84 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windows variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/ChatExample/settings.gradle b/ChatExample/settings.gradle
new file mode 100644
index 0000000..e7b4def
--- /dev/null
+++ b/ChatExample/settings.gradle
@@ -0,0 +1 @@
+include ':app'
diff --git a/README.md b/README.md
index 61bcdf5..03d8d55 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
# JavaPhoenixClient
-[  ](https://bintray.com/drees/java-phoenix-client/JavaPhoenixClient/_latestVersion)
+[](https://search.maven.org/search?q=g:%22com.github.dsrees%22%20AND%20a:%22JavaPhoenixClient%22)
[](https://travis-ci.com/dsrees/JavaPhoenixClient)
+[](https://codecov.io/gh/dsrees/JavaPhoenixClient)
JavaPhoenixClient is a Kotlin implementation of the [phoenix.js](https://hexdocs.pm/phoenix/js/) client used to manage Phoenix channels.
@@ -15,7 +16,7 @@ fun connectToChatRoom() {
// Create the Socket
val params = hashMapOf("token" to "abc123")
- val socket = PhxSocket("http://localhost:4000/socket/websocket", multipleParams)
+ val socket = Socket("http://localhost:4000/socket/websocket", params)
// Listen to events on the Socket
socket.logger = { Log.d("TAG", it) }
@@ -27,9 +28,9 @@ fun connectToChatRoom() {
// Join channels and listen to events
val chatroom = socket.channel("chatroom:general")
- chatroom.on("new_message") {
- // `it` is a PhxMessage object
- val payload = it.payload
+ chatroom.on("new_message") { message ->
+ val payload = message.payload
+ ...
}
chatroom.join()
@@ -38,40 +39,86 @@ fun connectToChatRoom() {
}
```
+
+If you need to provide dynamic parameters that can change between calls to `connect()`, then you can pass a closure to the constructor
+
+```kotlin
+
+// Create the Socket
+var authToken = "abc"
+val socket = Socket("http://localhost:4000/socket/websocket", { mapOf("token" to authToken) })
+
+// Connect with query parameters "?token=abc"
+socket.connect()
+
+
+// later in time, connect with query parameters "?token=xyz"
+authToken = "xyz"
+socket.connect() // or internal reconnect logic kicks in
+```
+
+
You can also inject your own OkHttp Client into the Socket to provide your own configuration
```kotlin
-// Create the Socket with a pre-configured OkHttp Client
+// Configure your own OkHttp Client
val client = OkHttpClient.Builder()
.connectTimeout(1000, TimeUnit.MILLISECONDS)
.build()
+// Create Socket with your custom instances
+val params = hashMapOf("token" to "abc123")
+val socket = Socket("http://localhost:4000/socket/websocket",
+ params,
+ client)
+```
+
+By default, the client use GSON to encode and decode JSON. If you prefer to manage this yourself, you
+can provide custom encode/decode functions in the `Socket` constructor.
+
+```kotlin
+
+// Configure your own GSON instance
+val gson = Gson.Builder().create()
+val encoder: EncodeClosure = {
+ // Encode a Map into JSON using your custom GSON instance or another JSON library
+ // of your choice (Moshi, etc)
+}
+val decoder: DecodeClosure = {
+ // Decode a JSON String into a `Message` object using your custom JSON library
+}
+
+// Create Socket with your custom instances
val params = hashMapOf("token" to "abc123")
-val socket = PhxSocket("http://localhost:4000/socket/websocket",
- multipleParams,
- client)
+val socket = Socket("http://localhost:4000/socket/websocket",
+ params,
+ encoder,
+ decoder)
```
+
+
+
### Installation
-JavaPhoenixClient is hosted on JCenter. You'll need to make sure you declare `jcenter()` as one of your repositories
+JavaPhoenixClient is hosted on MavenCentral. You'll need to make sure you declare `mavenCentral()` as one of your repositories
```
repositories {
- jcenter()
+ mavenCentral()
}
```
and then add the library. See [releases](https://github.com/dsrees/JavaPhoenixClient/releases) for the latest version
```$xslt
dependencies {
- implementation 'com.github.dsrees:JavaPhoenixClient:0.1.6'
+ implementation 'com.github.dsrees:JavaPhoenixClient:1.1.3'
}
```
### Feedback
-Please submit in issue if you have any problems!
+Please submit in issue if you have any problems or questions! PRs are also welcome.
This library is built to mirror the [phoenix.js](https://hexdocs.pm/phoenix/js/) and [SwiftPhoenixClient](https://github.com/davidstump/SwiftPhoenixClient) libraries.
diff --git a/RELEASING.md b/RELEASING.md
index 75a44b4..891816d 100644
--- a/RELEASING.md
+++ b/RELEASING.md
@@ -7,3 +7,6 @@ Release Process
4. Tag: `git tag -a X.Y.Z -m "Version X.Y.Z"`
5. Push: `git push && git push --tags`
6. Add the new release with notes (https://github.com/dsrees/JavaPhoenixClient/releases).
+ 7. Publish to Maven Central by running `./gradlew clean publish`. Can only be done by dsrees until CI setup
+ 8. Close the staging repo here: https://s01.oss.sonatype.org/#stagingRepositories
+ 9. Release the closed repo
diff --git a/build.gradle b/build.gradle
index e2907db..82d3fca 100644
--- a/build.gradle
+++ b/build.gradle
@@ -1,57 +1,75 @@
-buildscript { repositories { jcenter() } }
+buildscript {
+ repositories {
+ jcenter()
+ mavenCentral()
+ }
+
+ dependencies {
+ classpath 'com.vanniktech:gradle-maven-publish-plugin:0.14.2'
+ classpath 'org.jetbrains.dokka:dokka-gradle-plugin:1.4.30'
+ }
+}
+
plugins {
- id 'java'
- id 'org.jetbrains.kotlin.jvm' version '1.2.51'
- id 'nebula.project' version '4.0.1'
- id "nebula.maven-publish" version "7.2.4"
- id 'nebula.nebula-bintray' version '3.5.5'
+ id 'java'
+ id 'jacoco'
+ id 'org.jetbrains.kotlin.jvm' version '1.3.31'
}
+ext {
+ RELEASE_REPOSITORY_URL = "https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/"
+ SNAPSHOT_REPOSITORY_URL = "https://s01.oss.sonatype.org/content/repositories/snapshots/"
+}
+
+apply plugin: "org.jetbrains.dokka"
+apply plugin: "com.vanniktech.maven.publish"
+
group 'com.github.dsrees'
-version '0.1.6'
+version '1.1.3'
sourceCompatibility = 1.8
repositories {
- mavenCentral()
+ jcenter()
+ mavenCentral()
+}
+
+test {
+ // JUnit 5 Support
+ useJUnitPlatform()
+
+ // This allows us see tests execution progress in the output on the CI.
+ testLogging {
+ events 'passed', 'skipped', 'failed', 'standardOut', 'standardError'
+ exceptionFormat 'full'
+ }
+
}
dependencies {
- compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
- compile "com.google.code.gson:gson:2.8.5"
- compile "com.squareup.okhttp3:okhttp:3.10.0"
+ compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
+ compile "com.google.code.gson:gson:2.8.5"
+ compile "com.squareup.okhttp3:okhttp:3.12.2"
+ testImplementation 'org.junit.jupiter:junit-jupiter-api:5.3.1'
+ testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.3.1'
- testCompile group: 'junit', name: 'junit', version: '4.12'
- testCompile group: 'com.google.truth', name: 'truth', version: '0.42'
- testCompile group: 'org.mockito', name: 'mockito-core', version: '2.19.1'
+ testCompile group: 'com.google.truth', name: 'truth', version: '1.1.3'
+ testCompile group: 'org.mockito', name: 'mockito-core', version: '4.0.0'
+ testCompile group: 'org.mockito.kotlin', name: 'mockito-kotlin', version: '4.0.0'
+}
+jacocoTestReport {
+ reports {
+ xml.enabled true
+ html.enabled false
+ }
}
compileKotlin {
- kotlinOptions.jvmTarget = "1.8"
+ kotlinOptions.jvmTarget = "1.8"
}
compileTestKotlin {
- kotlinOptions.jvmTarget = "1.8"
-}
-
-bintray {
- user = System.getenv('bintrayUser')
- key = System.getenv('bintrayApiKey')
- dryRun = false
- publish = true
- pkg {
- repo = 'java-phoenix-client'
- name = 'JavaPhoenixClient'
- userOrg = user
- websiteUrl = 'https://github.com/dsrees/JavaPhoenixClient'
- issueTrackerUrl = 'https://github.com/dsrees/JavaPhoenixClient/issues'
- vcsUrl = 'https://github.com/dsrees/JavaPhoenixClient.git'
- licenses = ['MIT']
- version {
- name = project.version
- vcsTag = project.version
- }
- }
- publications = ['nebula']
+ kotlinOptions.jvmTarget = "1.8"
}
+
diff --git a/gradle.properties b/gradle.properties
new file mode 100644
index 0000000..65bd7af
--- /dev/null
+++ b/gradle.properties
@@ -0,0 +1,20 @@
+GROUP=com.github.dsrees
+POM_ARTIFACT_ID=JavaPhoenixClient
+VERSION_NAME=0.3.4
+
+POM_NAME=JavaPhoenixClient
+POM_DESCRIPTION=A phoenix channels client built for the JVM
+POM_INCEPTION_YEAR=2018
+
+POM_URL=https://github.com/dsrees/JavaPhoenixClient
+POM_SCM_URL=https://github.com/dsrees/JavaPhoenixClient.git
+POM_SCM_CONNECTION=scm:git:git://github.com/dsrees/JavaPhoenixClient.git
+POM_SCM_DEV_CONNECTION=scm:git:ssh://git@github.com/dsrees/JavaPhoenixClient.git
+
+POM_LICENCE_NAME=MIT License
+POM_LICENCE_URL=https://github.com/dsrees/JavaPhoenixClient/blob/master/LICENSE.md
+POM_LICENCE_DIST=repo
+
+POM_DEVELOPER_ID=dsrees
+POM_DEVELOPER_NAME=Daniel Rees
+POM_DEVELOPER_URL=https://github.com/dsrees/
\ No newline at end of file
diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties
index 347e866..d663d77 100644
--- a/gradle/wrapper/gradle-wrapper.properties
+++ b/gradle/wrapper/gradle-wrapper.properties
@@ -1,6 +1,6 @@
-#Fri Jul 13 09:35:17 EDT 2018
+#Thu Mar 18 20:34:38 EDT 2021
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-4.4-all.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.2-all.zip
diff --git a/script/deploy.sh b/script/deploy.sh
index 7acc031..8b5988e 100755
--- a/script/deploy.sh
+++ b/script/deploy.sh
@@ -5,5 +5,5 @@ if [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ]; then
./gradlew bintrayUpload --stacktrace
else
echo -e '#### Build for Test => Branch ['$TRAVIS_BRANCH'] Pull Request ['$TRAVIS_PULL_REQUEST'] ####'
- ./gradlew build
+ ./gradlew clean build jacocoTestReport
fi
\ No newline at end of file
diff --git a/src/main/kotlin/org/phoenixframework/Channel.kt b/src/main/kotlin/org/phoenixframework/Channel.kt
new file mode 100644
index 0000000..0794cfb
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Channel.kt
@@ -0,0 +1,437 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+/**
+ * Represents a binding to a Channel event
+ */
+data class Binding(
+ val event: String,
+ val ref: Int,
+ val callback: (Message) -> Unit
+)
+
+/**
+ * Represents a Channel bound to a given topic
+ */
+class Channel(
+ val topic: String,
+ params: Payload,
+ internal val socket: Socket
+) {
+
+ //------------------------------------------------------------------------------
+ // Channel Nested Enums
+ //------------------------------------------------------------------------------
+ /** States of a Channel */
+ enum class State() {
+ CLOSED,
+ ERRORED,
+ JOINED,
+ JOINING,
+ LEAVING
+ }
+
+ /** Channel specific events */
+ enum class Event(val value: String) {
+ HEARTBEAT("heartbeat"),
+ JOIN("phx_join"),
+ LEAVE("phx_leave"),
+ REPLY("phx_reply"),
+ ERROR("phx_error"),
+ CLOSE("phx_close");
+
+ companion object {
+ /** True if the event is one of Phoenix's channel lifecycle events */
+ fun isLifecycleEvent(event: String): Boolean {
+ return when (event) {
+ JOIN.value,
+ LEAVE.value,
+ REPLY.value,
+ ERROR.value,
+ CLOSE.value -> true
+ else -> false
+ }
+ }
+ }
+ }
+
+ //------------------------------------------------------------------------------
+ // Channel Attributes
+ //------------------------------------------------------------------------------
+ /** Current state of the Channel */
+ internal var state: State
+
+ /** Collection of event bindings. */
+ internal val bindings: ConcurrentLinkedQueue
+
+ /** Tracks event binding ref counters */
+ internal var bindingRef: Int
+
+ /** Timeout when attempting to join a Channel */
+ internal var timeout: Long
+
+ /** Params passed in through constructions and provided to the JoinPush */
+ var params: Payload = params
+ set(value) {
+ joinPush.payload = value
+ field = value
+ }
+
+ /** Set to true once the channel has attempted to join */
+ internal var joinedOnce: Boolean
+
+ /** Push to send then attempting to join */
+ internal var joinPush: Push
+
+ /** Buffer of Pushes that will be sent once the Channel's socket connects */
+ internal var pushBuffer: MutableList
+
+ /** Timer to attempt rejoins */
+ internal var rejoinTimer: TimeoutTimer
+
+ /** Refs if stateChange hooks */
+ internal var stateChangeRefs: MutableList
+
+ /**
+ * Optional onMessage hook that can be provided. Receives all event messages for specialized
+ * handling before dispatching to the Channel event callbacks.
+ */
+ internal var onMessage: (Message) -> Message = { it }
+
+ init {
+ this.state = State.CLOSED
+ this.bindings = ConcurrentLinkedQueue()
+ this.bindingRef = 0
+ this.timeout = socket.timeout
+ this.joinedOnce = false
+ this.pushBuffer = mutableListOf()
+ this.stateChangeRefs = mutableListOf()
+ this.rejoinTimer = TimeoutTimer(
+ dispatchQueue = socket.dispatchQueue,
+ timerCalculation = socket.rejoinAfterMs,
+ callback = { if (socket.isConnected) rejoin() }
+ )
+
+ // Respond to socket events
+ this.socket.onError { _, _-> this.rejoinTimer.reset() }
+ .apply { stateChangeRefs.add(this) }
+ this.socket.onOpen {
+ this.rejoinTimer.reset()
+ if (this.isErrored) { this.rejoin() }
+ }.apply { stateChangeRefs.add(this) }
+
+
+ // Setup Push to be sent when joining
+ this.joinPush = Push(
+ channel = this,
+ event = Event.JOIN.value,
+ payload = params,
+ timeout = timeout)
+
+ // Perform once the Channel has joined
+ this.joinPush.receive("ok") {
+ // Mark the Channel as joined
+ this.state = State.JOINED
+
+ // Reset the timer, preventing it from attempting to join again
+ this.rejoinTimer.reset()
+
+ // Send any buffered messages and clear the buffer
+ this.pushBuffer.forEach { it.send() }
+ this.pushBuffer.clear()
+ }
+
+ // Perform if Channel errors while attempting to join
+ this.joinPush.receive("error") {
+ this.state = State.ERRORED
+ if (this.socket.isConnected) { this.rejoinTimer.scheduleTimeout() }
+ }
+
+ // Perform if Channel timed out while attempting to join
+ this.joinPush.receive("timeout") {
+ // Log the timeout
+ this.socket.logItems("Channel: timeouts $topic, $joinRef after $timeout ms")
+
+ // Send a Push to the server to leave the Channel
+ val leavePush = Push(
+ channel = this,
+ event = Event.LEAVE.value,
+ timeout = this.timeout)
+ leavePush.send()
+
+ // Mark the Channel as in an error and attempt to rejoin if socket is connected
+ this.state = State.ERRORED
+ this.joinPush.reset()
+
+ if (this.socket.isConnected) { this.rejoinTimer.scheduleTimeout() }
+ }
+
+ // Clean up when the channel closes
+ this.onClose {
+ // Reset any timer that may be on-going
+ this.rejoinTimer.reset()
+
+ // Log that the channel was left
+ this.socket.logItems("Channel: close $topic $joinRef")
+
+ // Mark the channel as closed and remove it from the socket
+ this.state = State.CLOSED
+ this.socket.remove(this)
+ }
+
+ // Handles an error, attempts to rejoin
+ this.onError {
+ // Log that the channel received an error
+ this.socket.logItems("Channel: error $topic ${it.payload}")
+
+ // If error was received while joining, then reset the Push
+ if (isJoining) {
+ // Make sure that the "phx_join" isn't buffered to send once the socket
+ // reconnects. The channel will send a new join event when the socket connects.
+ this.joinRef?.let { this.socket.removeFromSendBuffer(it) }
+
+ // Reset the push to be used again later
+ this.joinPush.reset()
+ }
+
+ // Mark the channel as errored and attempt to rejoin if socket is currently connected
+ this.state = State.ERRORED
+ if (socket.isConnected) { this.rejoinTimer.scheduleTimeout() }
+ }
+
+ // Perform when the join reply is received
+ this.on(Event.REPLY) { message ->
+ this.trigger(replyEventName(message.ref), message.rawPayload, message.ref, message.joinRef, message.payloadJson)
+ }
+ }
+
+ //------------------------------------------------------------------------------
+ // Public Properties
+ //------------------------------------------------------------------------------
+ /** The ref sent during the join message. */
+ val joinRef: String? get() = joinPush.ref
+
+ /** @return True if the Channel can push messages */
+ val canPush: Boolean
+ get() = this.socket.isConnected && this.isJoined
+
+ /** @return: True if the Channel has been closed */
+ val isClosed: Boolean
+ get() = state == State.CLOSED
+
+ /** @return: True if the Channel experienced an error */
+ val isErrored: Boolean
+ get() = state == State.ERRORED
+
+ /** @return: True if the channel has joined */
+ val isJoined: Boolean
+ get() = state == State.JOINED
+
+ /** @return: True if the channel has requested to join */
+ val isJoining: Boolean
+ get() = state == State.JOINING
+
+ /** @return: True if the channel has requested to leave */
+ val isLeaving: Boolean
+ get() = state == State.LEAVING
+
+ //------------------------------------------------------------------------------
+ // Public
+ //------------------------------------------------------------------------------
+ fun join(timeout: Long = this.timeout): Push {
+ // Ensure that `.join()` is called only once per Channel instance
+ if (joinedOnce) {
+ throw IllegalStateException(
+ "Tried to join channel multiple times. `join()` can only be called once per channel")
+ }
+
+ // Join the channel
+ this.timeout = timeout
+ this.joinedOnce = true
+ this.rejoin()
+ return joinPush
+ }
+
+ fun onClose(callback: (Message) -> Unit): Int {
+ return this.on(Event.CLOSE, callback)
+ }
+
+ fun onError(callback: (Message) -> Unit): Int {
+ return this.on(Event.ERROR, callback)
+ }
+
+ fun onMessage(callback: (Message) -> Message) {
+ this.onMessage = callback
+ }
+
+ fun on(event: Event, callback: (Message) -> Unit): Int {
+ return this.on(event.value, callback)
+ }
+
+ fun on(event: String, callback: (Message) -> Unit): Int {
+ val ref = bindingRef
+ this.bindingRef = ref + 1
+
+ this.bindings.add(Binding(event, ref, callback))
+ return ref
+ }
+
+ fun off(event: String, ref: Int? = null) {
+ this.bindings.removeAll { bind ->
+ bind.event == event && (ref == null || ref == bind.ref)
+ }
+ }
+
+ fun push(event: String, payload: Payload, timeout: Long = this.timeout): Push {
+ if (!joinedOnce) {
+ // If the Channel has not been joined, throw an exception
+ throw RuntimeException(
+ "Tried to push $event to $topic before joining. Use channel.join() before pushing events")
+ }
+
+ val pushEvent = Push(this, event, payload, timeout)
+
+ if (canPush) {
+ pushEvent.send()
+ } else {
+ pushEvent.startTimeout()
+ pushBuffer.add(pushEvent)
+ }
+
+ return pushEvent
+ }
+
+ fun leave(timeout: Long = this.timeout): Push {
+ // Can push is dependent upon state == JOINED. Once we set it to LEAVING, then canPush
+ // will return false, so instead store it _before_ starting the leave
+ val canPush = this.canPush
+
+ // If attempting a rejoin during a leave, then reset, cancelling the rejoin
+ this.rejoinTimer.reset()
+
+ // Prevent entering a rejoin loop if leaving a channel before joined
+ this.joinPush.cancelTimeout()
+
+ // Now set the state to leaving
+ this.state = State.LEAVING
+
+ // Perform the same behavior if the channel leaves successfully or not
+ val onClose: ((Message) -> Unit) = {
+ this.socket.logItems("Channel: leave $topic")
+ this.trigger(Event.CLOSE, mapOf("reason" to "leave"))
+ }
+
+ // Push event to send to the server
+ val leavePush = Push(
+ channel = this,
+ event = Event.LEAVE.value,
+ timeout = timeout)
+
+ leavePush
+ .receive("ok", onClose)
+ .receive("timeout", onClose)
+ leavePush.send()
+
+ // If the Channel cannot send push events, trigger a success locally
+ if (!canPush) leavePush.trigger("ok", hashMapOf())
+
+ return leavePush
+ }
+
+ //------------------------------------------------------------------------------
+ // Internal
+ //------------------------------------------------------------------------------
+ /** Checks if a Message's event belongs to this Channel instance */
+ internal fun isMember(message: Message): Boolean {
+ if (message.topic != this.topic) return false
+
+ val isLifecycleEvent = Event.isLifecycleEvent(message.event)
+
+ // If the message is a lifecycle event and it is not a join for this channel, drop the outdated message
+ if (message.joinRef != null && isLifecycleEvent && message.joinRef != this.joinRef) {
+ this.socket.logItems("Channel: Dropping outdated message. ${message.topic}")
+ return false
+ }
+
+ return true
+ }
+
+ internal fun trigger(
+ event: Event,
+ payload: Payload = hashMapOf(),
+ ref: String = "",
+ joinRef: String? = null,
+ payloadJson: String = ""
+ ) {
+ this.trigger(event.value, payload, ref, joinRef, payloadJson)
+ }
+
+ internal fun trigger(
+ event: String,
+ payload: Payload = hashMapOf(),
+ ref: String = "",
+ joinRef: String? = null,
+ payloadJson: String = ""
+ ) {
+ this.trigger(Message(joinRef, ref, topic, event, payload, payloadJson))
+ }
+
+ internal fun trigger(message: Message) {
+ // Inform the onMessage hook of the message
+ val handledMessage = this.onMessage(message)
+
+ // Inform all matching event bindings of the message
+ this.bindings
+ .filter { it.event == message.event }
+ .forEach { it.callback(handledMessage) }
+ }
+
+ /** Create an event with a given ref */
+ internal fun replyEventName(ref: String): String {
+ return "chan_reply_$ref"
+ }
+
+ //------------------------------------------------------------------------------
+ // Private
+ //------------------------------------------------------------------------------
+ /** Sends the Channel's joinPush to the Server */
+ private fun sendJoin(timeout: Long) {
+ this.state = State.JOINING
+ this.joinPush.resend(timeout)
+ }
+
+ /** Rejoins the Channel e.g. after a disconnect */
+ private fun rejoin(timeout: Long = this.timeout) {
+ // Do not attempt to rejoin if the channel is in the process of leaving
+ if (isLeaving) return
+
+ // Leave potentially duplicated channels
+ this.socket.leaveOpenTopic(this.topic)
+
+ // Send the joinPush
+ this.sendJoin(timeout)
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/org/phoenixframework/Defaults.kt b/src/main/kotlin/org/phoenixframework/Defaults.kt
new file mode 100644
index 0000000..e1b1aa8
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Defaults.kt
@@ -0,0 +1,167 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import com.google.gson.FieldNamingPolicy
+import com.google.gson.Gson
+import com.google.gson.GsonBuilder
+import com.google.gson.JsonObject
+import com.google.gson.JsonParser
+import com.google.gson.reflect.TypeToken
+import okhttp3.HttpUrl
+import org.phoenixframework.Defaults.gson
+import java.net.URL
+import javax.swing.text.html.HTML.Tag.P
+
+object Defaults {
+
+ /** Default timeout of 10s */
+ const val TIMEOUT: Long = 10_000
+
+ /** Default heartbeat interval of 30s */
+ const val HEARTBEAT: Long = 30_000
+
+ /** Default JSON Serializer Version set to 2.0.0 */
+ const val VSN: String = "2.0.0"
+
+ /** Default reconnect algorithm for the socket */
+ val reconnectSteppedBackOff: (Int) -> Long = { tries ->
+ if (tries > 9) 5_000 else listOf(
+ 10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L
+ )[tries - 1]
+ }
+
+ /** Default rejoin algorithm for individual channels */
+ val rejoinSteppedBackOff: (Int) -> Long = { tries ->
+ if (tries > 3) 10_000 else listOf(1_000L, 2_000L, 5_000L)[tries - 1]
+ }
+
+ /** The default Gson configuration to use when parsing messages */
+ val gson: Gson
+ get() = GsonBuilder()
+ .setLenient()
+ .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
+ .create()
+
+ /**
+ * Default JSON decoder, backed by GSON, that takes JSON and converts it
+ * into a Message object.
+ */
+ @Suppress("UNCHECKED_CAST")
+ val decode: DecodeClosure = { rawMessage ->
+
+ val parseValue: (String) -> String? = { value ->
+ when(value) {
+ "null" -> null
+ else -> value.replace("\"", "")
+ }
+ }
+
+ var message = rawMessage
+ message = message.removeRange(0, 1) // remove '['
+
+ val joinRef = message.takeWhile { it != ',' } // take "join ref", "null" or "\"5\""
+ message = message.removeRange(0, joinRef.length) // remove join ref
+ message = message.removeRange(0, 1) // remove ','
+
+ val ref = message.takeWhile { it != ',' } // take ref, "null" or "\"5\""
+ message = message.removeRange(0, ref.length) // remove ref
+ message = message.removeRange(0, 1) // remove ','
+
+ val topic = message.takeWhile { it != ',' } // take topic, "\"topic\""
+ message = message.removeRange(0, topic.length)
+ message = message.removeRange(0, 1) // remove ','
+
+ val event = message.takeWhile { it != ',' } // take event, "\"phx_reply\""
+ message = message.removeRange(0, event.length)
+ message = message.removeRange(0, 1) // remove ','
+
+ var remaining = message.removeRange(message.length - 1, message.length) // remove ']'
+
+ // Payload should now just be "{"message":"hello","from":"user_1"}" or
+ // "{"response": {"message":"hello","from":"user_1"}},"status":"ok"}", flatten.
+ val jsonObj = gson.fromJson(remaining, JsonObject::class.java)
+ val response = jsonObj.get("response")
+ val payload = response?.let { gson.toJson(response) } ?: remaining
+
+ val anyType = object : TypeToken>() {}.type
+ val result = gson.fromJson>(remaining, anyType)
+
+ // vsn=2.0.0 message structure
+ // [join_ref, ref, topic, event, payload]
+ Message(
+ joinRef = parseValue(joinRef),
+ ref = parseValue(ref) ?: "",
+ topic = parseValue(topic) ?: "",
+ event = parseValue(event) ?: "",
+ rawPayload = result,
+ payloadJson = payload
+ )
+ }
+
+ /**
+ * Default JSON encoder, backed by GSON, that takes a Map and
+ * converts it into a JSON String.
+ */
+ val encode: EncodeClosure = { payload ->
+ gson.toJson(payload)
+ }
+
+ /**
+ * Takes an endpoint and a params closure given by the User and constructs a URL that
+ * is ready to be sent to the Socket connection.
+ *
+ * Will convert "ws://" and "wss://" to http/s which is what OkHttp expects.
+ *
+ * @throws IllegalArgumentException if [endpoint] is not a valid URL endpoint.
+ */
+ internal fun buildEndpointUrl(
+ endpoint: String,
+ paramsClosure: PayloadClosure,
+ vsn: String
+ ): URL {
+ var mutableUrl = endpoint
+ // Silently replace web socket URLs with HTTP URLs.
+ if (endpoint.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
+ mutableUrl = "http:" + endpoint.substring(3)
+ } else if (endpoint.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
+ mutableUrl = "https:" + endpoint.substring(4)
+ }
+
+ // Add the VSN query parameter
+ var httpUrl = HttpUrl.parse(mutableUrl)
+ ?: throw IllegalArgumentException("invalid url: $endpoint")
+ val httpBuilder = httpUrl.newBuilder()
+ httpBuilder.addQueryParameter("vsn", vsn)
+
+ // Append any additional query params
+ paramsClosure.invoke()?.let {
+ it.forEach { (key, value) ->
+ httpBuilder.addQueryParameter(key, value.toString())
+ }
+ }
+
+ // Return the [URL] that will be used to establish a connection
+ return httpBuilder.build().url()
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/org/phoenixframework/DispatchQueue.kt b/src/main/kotlin/org/phoenixframework/DispatchQueue.kt
new file mode 100644
index 0000000..bd996f4
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/DispatchQueue.kt
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import java.util.concurrent.ScheduledFuture
+import java.util.concurrent.ScheduledThreadPoolExecutor
+import java.util.concurrent.TimeUnit
+
+//------------------------------------------------------------------------------
+// Dispatch Queue Interfaces
+//------------------------------------------------------------------------------
+/**
+ * Interface which abstracts away scheduling future tasks, allowing fake instances
+ * to be injected and manipulated during tests
+ */
+interface DispatchQueue {
+ /** Queue a Runnable to be executed after a given time unit delay */
+ fun queue(delay: Long, unit: TimeUnit, runnable: () -> Unit): DispatchWorkItem
+
+ /**
+ * Creates and executes a periodic action that becomes enabled first after the given initial
+ * delay, and subsequently with the given period; that is, executions will commence after
+ * initialDelay, then initialDelay + period, then initialDelay + 2 * period, and so on.
+ */
+ fun queueAtFixedRate(delay: Long, period: Long, unit: TimeUnit, runnable: () -> Unit): DispatchWorkItem
+}
+
+/** Abstracts away a future task */
+interface DispatchWorkItem {
+ /** True if the work item has been cancelled */
+ val isCancelled: Boolean
+
+ /** Cancels the item from executing */
+ fun cancel()
+}
+
+//------------------------------------------------------------------------------
+// Scheduled Dispatch Queue
+//------------------------------------------------------------------------------
+/**
+ * A DispatchQueue that uses a ScheduledThreadPoolExecutor to schedule tasks to be executed
+ * in the future.
+ *
+ * Uses a default pool size of 8. Custom values can be provided during construction
+ */
+class ScheduledDispatchQueue(poolSize: Int = 8) : DispatchQueue {
+
+ private var scheduledThreadPoolExecutor = ScheduledThreadPoolExecutor(poolSize)
+
+ override fun queue(delay: Long, unit: TimeUnit, runnable: () -> Unit): DispatchWorkItem {
+ val scheduledFuture = scheduledThreadPoolExecutor.schedule(runnable, delay, unit)
+ return ScheduledDispatchWorkItem(scheduledFuture)
+ }
+
+ override fun queueAtFixedRate(
+ delay: Long,
+ period: Long,
+ unit: TimeUnit,
+ runnable: () -> Unit
+ ): DispatchWorkItem {
+ val scheduledFuture = scheduledThreadPoolExecutor.scheduleAtFixedRate(runnable, delay, period, unit)
+ return ScheduledDispatchWorkItem(scheduledFuture)
+ }
+}
+
+/**
+ * A DispatchWorkItem that wraps a ScheduledFuture<*> created by a ScheduledDispatchQueue
+ */
+class ScheduledDispatchWorkItem(private val scheduledFuture: ScheduledFuture<*>) : DispatchWorkItem {
+
+ override val isCancelled: Boolean
+ get() = this.scheduledFuture.isCancelled
+
+ override fun cancel() {
+ this.scheduledFuture.cancel(true)
+ }
+}
diff --git a/src/main/kotlin/org/phoenixframework/Message.kt b/src/main/kotlin/org/phoenixframework/Message.kt
new file mode 100644
index 0000000..ed177b0
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Message.kt
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+
+data class Message(
+ /** The ref sent during a join event. Empty if not present. */
+ val joinRef: String? = null,
+
+ /** The unique string ref. Empty if not present */
+ val ref: String = "",
+
+ /** The message topic */
+ val topic: String = "",
+
+ /** The message event name, for example "phx_join" or any other custom name */
+ val event: String = "",
+
+ /** The raw payload of the message. It is recommended that you use `payload` instead. */
+ internal val rawPayload: Payload = HashMap(),
+
+ /** The payload, as a json string */
+ val payloadJson: String = ""
+) {
+
+ /** The payload of the message */
+ @Suppress("UNCHECKED_CAST")
+ val payload: Payload
+ get() = rawPayload["response"] as? Payload ?: rawPayload
+
+ /**
+ * Convenience var to access the message's payload's status. Equivalent
+ * to checking message.payload["status"] yourself
+ */
+ val status: String?
+ get() = rawPayload["status"] as? String
+}
diff --git a/src/main/kotlin/org/phoenixframework/PhxChannel.kt b/src/main/kotlin/org/phoenixframework/PhxChannel.kt
deleted file mode 100644
index 118ff24..0000000
--- a/src/main/kotlin/org/phoenixframework/PhxChannel.kt
+++ /dev/null
@@ -1,383 +0,0 @@
-package org.phoenixframework
-
-import java.lang.IllegalStateException
-import java.util.concurrent.ConcurrentHashMap
-import java.util.concurrent.ConcurrentLinkedQueue
-
-
-class PhxChannel(
- val topic: String,
- val params: Payload,
- val socket: PhxSocket
-) {
-
-
- /** Enumeration of the different states a channel can exist in */
- enum class PhxState(value: String) {
- CLOSED("closed"),
- ERRORED("errored"),
- JOINED("joined"),
- JOINING("joining"),
- LEAVING("leaving")
- }
-
- /** Enumeration of a variety of Channel specific events */
- enum class PhxEvent(val value: String) {
- HEARTBEAT("heartbeat"),
- JOIN("phx_join"),
- LEAVE("phx_leave"),
- REPLY("phx_reply"),
- ERROR("phx_error"),
- CLOSE("phx_close");
-
- companion object {
- fun isLifecycleEvent(event: String): Boolean {
- return when (event) {
- JOIN.value,
- LEAVE.value,
- REPLY.value,
- ERROR.value,
- CLOSE.value -> true
- else -> false
- }
- }
- }
- }
-
-
- var state: PhxChannel.PhxState
- val bindings: ConcurrentHashMap Unit>>>
- var bindingRef: Int
- var timeout: Long
- var joinedOnce: Boolean
- var joinPush: PhxPush
- var pushBuffer: MutableList
- var rejoinTimer: PhxTimer? = null
- var onMessage: (message: PhxMessage) -> PhxMessage = onMessage@{
- return@onMessage it
- }
-
-
- init {
- this.state = PhxChannel.PhxState.CLOSED
- this.bindings = ConcurrentHashMap()
- this.bindingRef = 0
- this.timeout = socket.timeout
- this.joinedOnce = false
- this.pushBuffer = ArrayList()
- this.joinPush = PhxPush(this,
- PhxEvent.JOIN.value, params, timeout)
-
- // Create the Rejoin Timer that will be used in rejoin attempts
- this.rejoinTimer = PhxTimer({
- this.rejoinTimer?.scheduleTimeout()
- if (this.socket.isConnected) {
- rejoin()
- }
- }, socket.reconnectAfterMs)
-
- // Perform once the Channel is joined
- this.joinPush.receive("ok") {
- this.state = PhxState.JOINED
- this.rejoinTimer?.reset()
- this.pushBuffer.forEach { it.send() }
- this.pushBuffer = ArrayList()
- }
-
- // Perform if the Push to join timed out
- this.joinPush.receive("timeout") {
- // Do not handle timeout if joining. Handled differently
- if (this.isJoining) {
- return@receive
- }
- this.socket.logItems("Channel: timeouts $topic, $joinRef after $timeout ms")
-
- val leavePush = PhxPush(this, PhxEvent.LEAVE.value, HashMap(), timeout)
- leavePush.send()
-
- this.state = PhxState.ERRORED
- this.joinPush.reset()
- this.rejoinTimer?.scheduleTimeout()
- }
-
- // Clean up when the channel closes
- this.onClose {
- this.rejoinTimer?.reset()
- this.socket.logItems("Channel: close $topic")
- this.state = PhxState.CLOSED
- this.socket.remove(this)
- }
-
- // Handles an error, attempts to rejoin
- this.onError {
- if (this.isLeaving || !this.isClosed) {
- this.socket.logItems("Channel: error $topic")
- this.state = PhxState.ERRORED
- this.rejoinTimer?.scheduleTimeout()
- }
- }
-
- // Handles when a reply from the server comes back
- this.on(PhxEvent.REPLY) {
- val replyEventName = this.replyEventName(it.ref)
- val replyMessage = PhxMessage(it.ref, it.topic, replyEventName, it.payload, it.joinRef)
- this.trigger(replyMessage)
- }
- }
-
-
- //------------------------------------------------------------------------------
- // Public
- //------------------------------------------------------------------------------
- /**
- * Joins the channel
- *
- * @param joinParams: Overrides the params given when channel was initialized
- * @param timeout: Overrides the default timeout
- * @return Push which receive hooks can be applied to
- */
- fun join(joinParams: Payload? = null, timeout: Long? = null): PhxPush {
- if (joinedOnce) {
- throw IllegalStateException("Tried to join channel multiple times. `join()` can only be called once per channel")
- }
-
- joinParams?.let {
- this.joinPush.updatePayload(joinParams)
- }
-
- this.joinedOnce = true
- this.rejoin(timeout)
- return joinPush
- }
-
- /**
- * Hook into channel close
- *
- * @param callback: Callback to be informed when the channel closes
- * @return the ref counter of the subscription
- */
- fun onClose(callback: (msg: PhxMessage) -> Unit): Int {
- return this.on(PhxEvent.CLOSE, callback)
- }
-
- /**
- * Hook into channel error
- *
- * @param callback: Callback to be informed when the channel errors
- * @return the ref counter of the subscription
- */
- fun onError(callback: (msg: PhxMessage) -> Unit): Int {
- return this.on(PhxEvent.ERROR, callback)
- }
-
- /**
- * Convenience method to take the Channel.Event enum. Same as channel.on(string)
- */
- fun on(event: PhxChannel.PhxEvent, callback: (PhxMessage) -> Unit): Int {
- return this.on(event.value, callback)
- }
-
- /**
- * Subscribes on channel events
- *
- * Subscription returns the ref counter which can be used later to
- * unsubscribe the exact event listener
- *
- * Example:
- * val ref1 = channel.on("event", do_stuff)
- * val ref2 = channel.on("event", do_other_stuff)
- * channel.off("event", ref1)
- *
- * This example will unsubscribe the "do_stuff" callback but not
- * the "do_other_stuff" callback.
- *
- * @param event: Name of the event to subscribe to
- * @param callback: Receives payload of the event
- * @return: The subscriptions ref counter
- */
- fun on(event: String, callback: (PhxMessage) -> Unit): Int {
- val ref = bindingRef
- this.bindingRef = ref + 1
-
- this.bindings.getOrPut(event) { ConcurrentLinkedQueue() }
- .add(ref to callback)
-
- return ref
- }
-
- /**
- * Unsubscribe from channel events. If ref counter is not provided, then
- * all subscriptions for the event will be removed.
- *
- * Example:
- * val ref1 = channel.on("event", do_stuff)
- * val ref2 = channel.on("event", do_other_stuff)
- * channel.off("event", ref1)
- *
- * This example will unsubscribe the "do_stuff" callback but not
- * the "do_other_stuff" callback.
- *
- * @param event: Event to unsubscribe from
- * @param ref: Optional. Ref counter returned when subscribed to event
- */
- fun off(event: String, ref: Int? = null) {
- // Remove any subscriptions that match the given event and ref ID. If no ref
- // ID is given, then remove all subscriptions for an event.
- if (ref != null) {
- this.bindings[event]?.removeIf{ ref == it.first }
- } else {
- this.bindings.remove(event)
- }
- }
-
- /**
- * Push a payload to the Channel
- *
- * @param event: Event to push
- * @param payload: Payload to push
- * @param timeout: Optional timeout. Default will be used
- * @return [PhxPush] that can be hooked into
- */
- fun push(event: String, payload: Payload, timeout: Long = DEFAULT_TIMEOUT): PhxPush {
- if (!joinedOnce) {
- // If the Channel has not been joined, throw an exception
- throw RuntimeException("Tried to push $event to $topic before joining. Use channel.join() before pushing events")
- }
-
- val pushEvent = PhxPush(this, event, payload, timeout)
- if (canPush) {
- pushEvent.send()
- } else {
- pushEvent.startTimeout()
- pushBuffer.add(pushEvent)
- }
-
- return pushEvent
- }
-
- /**
- * Leaves a channel
- *
- * Unsubscribe from server events and instructs Channel to terminate on Server
- *
- * Triggers .onClose() hooks
- *
- * To receive leave acknowledgements, use the receive hook to bind to the server ack
- *
- * Example:
- * channel.leave().receive("ok) { print("left channel") }
- *
- * @param timeout: Optional timeout. Default will be used
- */
- fun leave(timeout: Long = DEFAULT_TIMEOUT): PhxPush {
- this.state = PhxState.LEAVING
-
- val onClose: ((PhxMessage) -> Unit) = {
- this.socket.logItems("Channel: leave $topic")
- this.trigger(it)
- }
-
- val leavePush = PhxPush(this, PhxEvent.LEAVE.value, HashMap(), timeout)
- leavePush
- .receive("ok", onClose)
- .receive("timeout", onClose)
-
- leavePush.send()
- if (!canPush) {
- leavePush.trigger("ok", HashMap())
- }
-
- return leavePush
- }
-
- /**
- * Override message hook. Receives all events for specialized message
- * handling before dispatching to the channel callbacks
- *
- * @param callback: Callback which will receive the inbound message before
- * it is dispatched to other callbacks. Must return a Message object.
- */
- fun onMessage(callback: (message: PhxMessage) -> PhxMessage) {
- this.onMessage = callback
- }
-
-
- //------------------------------------------------------------------------------
- // Internal
- //------------------------------------------------------------------------------
- /** Checks if an event received by the socket belongs to the Channel */
- fun isMember(message: PhxMessage): Boolean {
- if (message.topic != this.topic) { return false }
-
- val isLifecycleEvent = PhxEvent.isLifecycleEvent(message.event)
-
- // If the message is a lifecycle event and it is not a join for this channel, drop the outdated message
- if (message.joinRef != null && isLifecycleEvent && message.joinRef != this.joinRef) {
- this.socket.logItems("Channel: Dropping outdated message. ${message.topic}")
- return false
- }
-
- return true
- }
-
- /** Sends the payload to join the Channel */
- fun sendJoin(timeout: Long) {
- this.state = PhxState.JOINING
- this.joinPush.resend(timeout)
-
- }
-
- /** Rejoins the Channel */
- fun rejoin(timeout: Long? = null) {
- this.sendJoin(timeout ?: this.timeout)
- }
-
- /**
- * Triggers an event to the correct event binding created by `channel.on("event")
- *
- * @param message: Message that was received that will be sent to the correct binding
- */
- fun trigger(message: PhxMessage) {
- val handledMessage = onMessage(message)
- this.bindings[message.event]?.forEach { it.second(handledMessage) }
- }
-
- /**
- * @param ref: The ref of the reply push event
- * @return the name of the event
- */
- fun replyEventName(ref: String): String {
- return "chan_reply_$ref"
- }
-
- /** The ref sent during the join message. */
- val joinRef: String
- get() = joinPush.ref ?: ""
-
- /**
- * @return True if the Channel can push messages, meaning the socket
- * is connected and the channel is joined
- */
- val canPush: Boolean
- get() = this.socket.isConnected && this.isJoined
-
- /** @return: True if the Channel has been closed */
- val isClosed: Boolean
- get() = state == PhxState.CLOSED
-
- /** @return: True if the Channel experienced an error */
- val isErrored: Boolean
- get() = state == PhxState.ERRORED
-
- /** @return: True if the channel has joined */
- val isJoined: Boolean
- get() = state == PhxState.JOINED
-
- /** @return: True if the channel has requested to join */
- val isJoining: Boolean
- get() = state == PhxState.JOINING
-
- /** @return: True if the channel has requested to leave */
- val isLeaving: Boolean
- get() = state == PhxState.LEAVING
-}
diff --git a/src/main/kotlin/org/phoenixframework/PhxMessage.kt b/src/main/kotlin/org/phoenixframework/PhxMessage.kt
deleted file mode 100644
index 8150472..0000000
--- a/src/main/kotlin/org/phoenixframework/PhxMessage.kt
+++ /dev/null
@@ -1,33 +0,0 @@
-package org.phoenixframework
-
-import com.google.gson.annotations.SerializedName
-
-data class PhxMessage(
- /** The unique string ref. Empty if not present */
- @SerializedName("ref")
- val ref: String = "",
-
- /** The message topic */
- @SerializedName("topic")
- val topic: String = "",
-
- /** The message event name, for example "phx_join" or any other custom name */
- @SerializedName("event")
- val event: String = "",
-
- /** The payload of the message */
- @SerializedName("payload")
- val payload: Payload = HashMap(),
-
- /** The ref sent during a join event. Empty if not present. */
- @SerializedName("join_ref")
- val joinRef: String? = null) {
-
-
- /**
- * Convenience var to access the message's payload's status. Equivalent
- * to checking message.payload["status"] yourself
- */
- val status: String?
- get() = payload["status"] as? String
-}
diff --git a/src/main/kotlin/org/phoenixframework/PhxPush.kt b/src/main/kotlin/org/phoenixframework/PhxPush.kt
deleted file mode 100644
index 5666119..0000000
--- a/src/main/kotlin/org/phoenixframework/PhxPush.kt
+++ /dev/null
@@ -1,192 +0,0 @@
-package org.phoenixframework
-
-import java.util.*
-import kotlin.collections.HashMap
-import kotlin.concurrent.schedule
-
-class PhxPush(
- val channel: PhxChannel,
- val event: String,
- var payload: Payload,
- var timeout: Long
-) {
-
- /** The server's response to the Push */
- var receivedMessage: PhxMessage? = null
-
- /** Timer which triggers a timeout event */
- var timeoutTimer: Timer? = null
-
- /** Hooks into a Push. Where .receive("ok", callback(Payload)) are stored */
- var receiveHooks: MutableMap Unit)>> = HashMap()
-
- /** True if the Push has been sent */
- var sent: Boolean = false
-
- /** The reference ID of the Push */
- var ref: String? = null
-
- /** The event that is associated with the reference ID of the Push */
- var refEvent: String? = null
-
-
- //------------------------------------------------------------------------------
- // Public
- //------------------------------------------------------------------------------
- /** Resend a Push */
- fun resend(timeout: Long = DEFAULT_TIMEOUT) {
- this.timeout = timeout
- this.reset()
- this.send()
- }
-
- /**
- * Receive a specific event when sending an Outbound message
- *
- * Example:
- * channel
- * .send("event", myPayload)
- * .receive("error") { }
- */
- fun receive(status: String, callback: (message: PhxMessage) -> Unit): PhxPush {
- // If the message has already be received, pass it to the callback
- receivedMessage?.let {
- if (hasReceivedStatus(status)) {
- callback(it)
- }
- }
-
- // Create a new array of hooks if no previous hook is associated with status
- if (receiveHooks[status] == null) {
- receiveHooks[status] = arrayListOf(callback)
- } else {
- // A previous hook for this status already exists. Just append the new hook
- receiveHooks[status]?.add(callback)
- }
-
- return this
- }
-
-
- /**
- * @param payload: New payload to be sent through with the Push
- */
- fun updatePayload(payload: Payload) {
- this.payload = payload
- }
-
- //------------------------------------------------------------------------------
- // Internal
- //------------------------------------------------------------------------------
- /**
- * Sends the Push through the socket
- */
- fun send() {
- if (hasReceivedStatus("timeout")) {
- return
- }
-
- this.startTimeout()
- this.sent = true
-
- this.channel.socket.push(
- this.channel.topic,
- this.event,
- this.payload,
- this.ref,
- this.channel.joinRef)
- }
-
- /** Resets the Push as it was after initialization */
- fun reset() {
- this.cancelRefEvent()
- this.ref = null
- this.refEvent = null
- this.receivedMessage = null
- this.sent = false
- }
-
- /**
- * Finds the receiveHook which needs to be informed of a status response
- *
- * @param status: Status to find the hook for
- * @param message: Message to send to the matched hook
- */
- fun matchReceive(status: String, message: PhxMessage) {
- receiveHooks[status]?.forEach { it(message) }
- }
-
- /**
- * Reverses the result of channel.on(event, callback) that spawned the Push
- */
- fun cancelRefEvent() {
- this.refEvent?.let {
- this.channel.off(it)
- }
- }
-
- /**
- * Cancels any ongoing Timeout timer
- */
- fun cancelTimeout() {
- this.timeoutTimer?.cancel()
- this.timeoutTimer = null
- }
-
- /**
- * Starts the Timer which will trigger a timeout after a specific delay
- * in milliseconds is reached.
- */
- fun startTimeout() {
- this.timeoutTimer?.cancel()
-
- val ref = this.channel.socket.makeRef()
- this.ref = ref
-
- val refEvent = this.channel.replyEventName(ref)
- this.refEvent = refEvent
-
- // If a response is received before the Timer triggers, cancel timer
- // and match the received event to it's corresponding hook.
- this.channel.on(refEvent) {
- this.cancelRefEvent()
- this.cancelTimeout()
- this.receivedMessage = it
-
- // Check if there is an event status available
- val message = it
- message.status?.let {
- this.matchReceive(it, message)
- }
- }
-
- // Start the timer. If the timer fires, then send a timeout event to the Push
- this.timeoutTimer = Timer()
- this.timeoutTimer?.schedule(timeout) {
- trigger("timeout", HashMap())
- }
- }
-
- /**
- * Checks if a status has already been received by the Push.
- *
- * @param status: Status to check
- * @return True if the Push has received the given status. False otherwise
- */
- fun hasReceivedStatus(status: String): Boolean {
- return receivedMessage?.status == status
- }
-
- /**
- * Triggers an event to be sent through the Channel
- */
- fun trigger(status: String, payload: Payload) {
- val mutPayload = payload.toMutableMap()
- mutPayload["status"] = status
-
- refEvent?.let {
- val message = PhxMessage(it, "", "", mutPayload)
- this.channel.trigger(message)
- }
- }
-}
diff --git a/src/main/kotlin/org/phoenixframework/PhxSocket.kt b/src/main/kotlin/org/phoenixframework/PhxSocket.kt
deleted file mode 100644
index 119f81b..0000000
--- a/src/main/kotlin/org/phoenixframework/PhxSocket.kt
+++ /dev/null
@@ -1,460 +0,0 @@
-package org.phoenixframework
-
-import com.google.gson.FieldNamingPolicy
-import com.google.gson.Gson
-import com.google.gson.GsonBuilder
-import okhttp3.HttpUrl
-import okhttp3.OkHttpClient
-import okhttp3.Request
-import okhttp3.Response
-import okhttp3.WebSocket
-import okhttp3.WebSocketListener
-import java.net.URL
-import java.util.Timer
-import kotlin.collections.ArrayList
-import kotlin.collections.HashMap
-import kotlin.concurrent.schedule
-
-typealias Payload = Map
-
-/** Default timeout set to 10s */
-const val DEFAULT_TIMEOUT: Long = 10000
-
-/** Default heartbeat interval set to 30s */
-const val DEFAULT_HEARTBEAT: Long = 30000
-
-/** The code used when the socket was closed without error */
-const val WS_CLOSE_NORMAL = 1000
-
-/** The code used when the socket was closed after the heartbeat timer timed out */
-const val WS_CLOSE_HEARTBEAT_ERROR = 5000
-
-open class PhxSocket(
- url: String,
- params: Payload? = null,
- private val client: OkHttpClient = OkHttpClient.Builder().build()
-) : WebSocketListener() {
-
- //------------------------------------------------------------------------------
- // Public Attributes
- //------------------------------------------------------------------------------
- /** Timeout to use when opening connections */
- var timeout: Long = DEFAULT_TIMEOUT
-
- /** Interval between sending a heartbeat */
- var heartbeatIntervalMs: Long = DEFAULT_HEARTBEAT
-
- /** Interval between socket reconnect attempts */
- var reconnectAfterMs: ((tries: Int) -> Long) = closure@{
- return@closure if (it >= 3) 100000 else longArrayOf(1000, 2000, 5000)[it]
- }
-
- /** Hook for custom logging into the client */
- var logger: ((msg: String) -> Unit)? = null
-
- /** Disable sending Heartbeats by setting to true */
- var skipHeartbeat: Boolean = false
-
- /**
- * Socket will attempt to reconnect if the Socket was closed. Will not
- * reconnect if the Socket errored (e.g. connection refused.) Default
- * is set to true
- */
- var autoReconnect: Boolean = true
-
-
- //------------------------------------------------------------------------------
- // Private Attributes
- //------------------------------------------------------------------------------
- /// Collection of callbacks for onOpen socket events
- private var onOpenCallbacks: MutableList<() -> Unit> = ArrayList()
-
- /// Collection of callbacks for onClose socket events
- private var onCloseCallbacks: MutableList<() -> Unit> = ArrayList()
-
- /// Collection of callbacks for onError socket events
- private var onErrorCallbacks: MutableList<(Throwable, Response?) -> Unit> = ArrayList()
-
- /// Collection of callbacks for onMessage socket events
- private var onMessageCallbacks: MutableList<(PhxMessage) -> Unit> = ArrayList()
-
- /// Collection on channels created for the Socket
- private var channels: MutableList = ArrayList()
-
- /// Buffers messages that need to be sent once the socket has connected
- private var sendBuffer: MutableList<() -> Unit> = ArrayList()
-
- /// Ref counter for messages
- private var ref: Int = 0
-
- /// Internal endpoint that the Socket is connecting to
- var endpoint: URL
-
- /// Timer that triggers sending new Heartbeat messages
- private var heartbeatTimer: Timer? = null
-
- /// Ref counter for the last heartbeat that was sent
- private var pendingHeartbeatRef: String? = null
-
- /// Timer to use when attempting to reconnect
- private var reconnectTimer: PhxTimer? = null
-
-
- private val gson: Gson = GsonBuilder()
- .setLenient()
- .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
- .create()
-
- private val request: Request
-
- /// WebSocket connection to the server
- private var connection: WebSocket? = null
-
-
- init {
-
- // Silently replace web socket URLs with HTTP URLs.
- var mutableUrl = url
- if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) {
- mutableUrl = "http:" + url.substring(3)
- } else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) {
- mutableUrl = "https:" + url.substring(4)
- }
-
- var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url")
-
- // If there are query params, append them now
- params?.let {
- val httpBuilder = httpUrl.newBuilder()
- it.forEach { (key, value) ->
- httpBuilder.addQueryParameter(key, value.toString())
- }
-
- httpUrl = httpBuilder.build()
- }
-
- reconnectTimer = PhxTimer(
- callback = {
- disconnect().also {
- connect()
- }
- },
- timerCalculation = reconnectAfterMs
- )
-
- // Hold reference to where the Socket is pointing to
- this.endpoint = httpUrl.url()
-
- // Create the request and client that will be used to connect to the WebSocket
- request = Request.Builder().url(httpUrl).build()
- }
-
-
- //------------------------------------------------------------------------------
- // Public
- //------------------------------------------------------------------------------
- /** True if the Socket is currently connected */
- val isConnected: Boolean
- get() = connection != null
-
- /**
- * Disconnects the Socket
- */
- fun disconnect(code: Int = WS_CLOSE_NORMAL) {
- connection?.close(WS_CLOSE_NORMAL, null)
- connection = null
-
- }
-
- /**
- * Connects the Socket. The params passed to the Socket on initialization
- * will be sent through the connection. If the Socket is already connected,
- * then this call will be ignored.
- */
- fun connect() {
- // Do not attempt to reconnect if already connected
- if (isConnected) return
- connection = client.newWebSocket(request, this)
- }
-
- /**
- * Registers a callback for connection open events
- *
- * Example:
- * socket.onOpen {
- * print("Socket Connection Opened")
- * }
- *
- * @param callback: Callback to register
- */
- fun onOpen(callback: () -> Unit) {
- this.onOpenCallbacks.add(callback)
- }
-
-
- /**
- * Registers a callback for connection close events
- *
- * Example:
- * socket.onClose {
- * print("Socket Connection Closed")
- * }
- *
- * @param callback: Callback to register
- */
- fun onClose(callback: () -> Unit) {
- this.onCloseCallbacks.add(callback)
- }
-
- /**
- * Registers a callback for connection error events
- *
- * Example:
- * socket.onError { error, response ->
- * print("Socket Connection Error")
- * }
- *
- * @param callback: Callback to register
- */
- fun onError(callback: (Throwable?, Response?) -> Unit) {
- this.onErrorCallbacks.add(callback)
- }
-
- /**
- * Registers a callback for connection message events
- *
- * Example:
- * socket.onMessage { [unowned self] (message) in
- * print("Socket Connection Message")
- * }
- *
- * @param callback: Callback to register
- */
- fun onMessage(callback: (PhxMessage) -> Unit) {
- this.onMessageCallbacks.add(callback)
- }
-
-
- /**
- * Releases all stored callback hooks (onError, onOpen, onClose, etc.) You should
- * call this method when you are finished when the Socket in order to release
- * any references held by the socket.
- */
- fun removeAllCallbacks() {
- this.onOpenCallbacks.clear()
- this.onCloseCallbacks.clear()
- this.onErrorCallbacks.clear()
- this.onMessageCallbacks.clear()
- }
-
- /**
- * Removes the Channel from the socket. This does not cause the channel to inform
- * the server that it is leaving so you should call channel.leave() first.
- */
- fun remove(channel: PhxChannel) {
- this.channels = channels
- .filter { it.joinRef != channel.joinRef }
- .toMutableList()
- }
-
- /**
- * Initializes a new Channel with the given topic
- *
- * Example:
- * val channel = socket.channel("rooms", params)
- */
- fun channel(topic: String, params: Payload? = null): PhxChannel {
- val channel = PhxChannel(topic, params ?: HashMap(), this)
- this.channels.add(channel)
- return channel
- }
-
- /**
- * Sends data through the Socket
- */
- open fun push(topic: String,
- event: String,
- payload: Payload,
- ref: String? = null,
- joinRef: String? = null) {
-
- val callback: (() -> Unit) = {
- val body: MutableMap = HashMap()
- body["topic"] = topic
- body["event"] = event
- body["payload"] = payload
-
- ref?.let { body["ref"] = it }
- joinRef?.let { body["join_ref"] = it }
-
- val data = gson.toJson(body)
- connection?.let {
- this.logItems("Push: Sending $data")
- it.send(data)
- }
- }
-
- // If the socket is connected, then execute the callback immediately
- if (isConnected) {
- callback()
- } else {
- // If the socket is not connected, add the push to a buffer which
- // will be sent immediately upon connection
- this.sendBuffer.add(callback)
- }
- }
-
-
- /**
- * @return the next message ref, accounting for overflows
- */
- open fun makeRef(): String {
- val newRef = this.ref + 1
- this.ref = if (newRef == Int.MAX_VALUE) 0 else newRef
-
- return newRef.toString()
- }
-
- //------------------------------------------------------------------------------
- // Internal
- //------------------------------------------------------------------------------
- fun logItems(body: String) {
- logger?.let {
- it(body)
- }
- }
-
-
- //------------------------------------------------------------------------------
- // Private
- //------------------------------------------------------------------------------
-
- /** Triggers a message when the socket is opened */
- private fun onConnectionOpened() {
- this.logItems("Transport: Connected to $endpoint")
- this.flushSendBuffer()
- this.reconnectTimer?.reset()
-
- // start sending heartbeats if enabled {
- if (!skipHeartbeat) startHeartbeatTimer()
-
- // Inform all onOpen callbacks that the Socket as opened
- this.onOpenCallbacks.forEach { it() }
- }
-
- /** Triggers a message when the socket is closed */
- private fun onConnectionClosed(code: Int) {
- this.logItems("Transport: close")
- this.triggerChannelError()
-
- // Terminate any ongoing heartbeats
- this.heartbeatTimer?.cancel()
-
- // Attempt to reconnect the socket. If the socket was closed normally,
- // then do not attempt to reconnect
- if (autoReconnect && code != WS_CLOSE_NORMAL) reconnectTimer?.scheduleTimeout()
-
- // Inform all onClose callbacks that the Socket closed
- this.onCloseCallbacks.forEach { it() }
- }
-
- /** Triggers a message when an error comes through the Socket */
- private fun onConnectionError(t: Throwable, response: Response?) {
- this.logItems("Transport: error")
-
- // Inform all onError callbacks that an error occurred
- this.onErrorCallbacks.forEach { it(t, response) }
-
- // Inform all channels that a socket error occurred
- this.triggerChannelError()
-
- // There was an error, violently cancel the connection. This is a safe operation
- // since the underlying WebSocket will no longer return messages to the Connection
- // after a Failure
- connection?.cancel()
- connection = null
- }
-
- /** Triggers a message to the correct Channel when it comes through the Socket */
- private fun onConnectionMessage(rawMessage: String) {
- this.logItems("Receive: $rawMessage")
-
- val message = gson.fromJson(rawMessage, PhxMessage::class.java)
-
- // Dispatch the message to all channels that belong to the topic
- this.channels
- .filter { it.isMember(message) }
- .forEach { it.trigger(message) }
-
- // Inform all onMessage callbacks of the message
- this.onMessageCallbacks.forEach { it(message) }
-
- // Check if this message was a pending heartbeat
- if (message.ref == pendingHeartbeatRef) {
- this.logItems("Received Pending Heartbeat")
- this.pendingHeartbeatRef = null
- }
- }
-
- /** Triggers an error event to all connected Channels */
- private fun triggerChannelError() {
- val errorMessage = PhxMessage(event = PhxChannel.PhxEvent.ERROR.value)
- this.channels.forEach { it.trigger(errorMessage) }
- }
-
- /** Send all messages that were buffered before the socket opened */
- private fun flushSendBuffer() {
- if (isConnected && sendBuffer.count() > 0) {
- this.sendBuffer.forEach { it() }
- this.sendBuffer.clear()
- }
- }
-
-
- //------------------------------------------------------------------------------
- // Timers
- //------------------------------------------------------------------------------
- /** Initializes a 30s */
- fun startHeartbeatTimer() {
- heartbeatTimer?.cancel()
- heartbeatTimer = null;
-
- heartbeatTimer = Timer()
- heartbeatTimer?.schedule(heartbeatIntervalMs, heartbeatIntervalMs) {
- if (!isConnected) return@schedule
-
- pendingHeartbeatRef?.let {
- pendingHeartbeatRef = null
- logItems("Transport: Heartbeat timeout. Attempt to re-establish connection")
- disconnect(WS_CLOSE_HEARTBEAT_ERROR)
- return@schedule
- }
-
- pendingHeartbeatRef = makeRef()
- push("phoenix", PhxChannel.PhxEvent.HEARTBEAT.value, HashMap(), pendingHeartbeatRef)
- }
- }
-
-
- //------------------------------------------------------------------------------
- // WebSocketListener
- //------------------------------------------------------------------------------
- override fun onOpen(webSocket: WebSocket?, response: Response?) {
- this.onConnectionOpened()
-
- }
-
- override fun onMessage(webSocket: WebSocket?, text: String?) {
- text?.let {
- this.onConnectionMessage(it)
- }
- }
-
- override fun onClosed(webSocket: WebSocket?, code: Int, reason: String?) {
- this.onConnectionClosed(code)
- }
-
- override fun onFailure(webSocket: WebSocket?, t: Throwable, response: Response?) {
- this.onConnectionError(t, response)
- }
-}
diff --git a/src/main/kotlin/org/phoenixframework/PhxTimer.kt b/src/main/kotlin/org/phoenixframework/PhxTimer.kt
deleted file mode 100644
index 9d0a597..0000000
--- a/src/main/kotlin/org/phoenixframework/PhxTimer.kt
+++ /dev/null
@@ -1,46 +0,0 @@
-package org.phoenixframework
-
-import java.util.*
-import kotlin.concurrent.schedule
-
-class PhxTimer(
- private val callback: () -> Unit,
- private val timerCalculation: (tries: Int) -> Long
-) {
-
- // The underlying Java timer
- private var timer: Timer? = null
- // How many tries the Timer has attempted
- private var tries: Int = 0
-
-
- /**
- * Resets the Timer, clearing the number of current tries and stops
- * any scheduled timeouts.
- */
- fun reset() {
- this.tries = 0
- this.clearTimer()
- }
-
- /** Cancels any previous timeouts and scheduled a new one */
- fun scheduleTimeout() {
- this.clearTimer()
-
- // Start up a new Timer
- val timeout = timerCalculation(tries)
- this.timer = Timer()
- this.timer?.schedule(timeout) {
- tries += 1
- callback()
- }
- }
-
- //------------------------------------------------------------------------------
- // Private
- //------------------------------------------------------------------------------
- private fun clearTimer() {
- this.timer?.cancel()
- this.timer = null
- }
-}
diff --git a/src/main/kotlin/org/phoenixframework/Presence.kt b/src/main/kotlin/org/phoenixframework/Presence.kt
new file mode 100644
index 0000000..3a1a9b8
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Presence.kt
@@ -0,0 +1,321 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+//------------------------------------------------------------------------------
+// Type Aliases
+//------------------------------------------------------------------------------
+/** Meta details of a Presence. Just a dictionary of properties */
+typealias PresenceMeta = Map
+
+/** A mapping of a String to an array of Metas. e.g. {"metas": [{id: 1}]} */
+typealias PresenceMap = MutableMap>
+
+/** A mapping of a Presence state to a mapping of Metas */
+typealias PresenceState = MutableMap
+
+/**
+ * Diff has keys "joins" and "leaves", pointing to a Presence.State each containing the users
+ * that joined and left.
+ */
+typealias PresenceDiff = MutableMap
+
+/** Closure signature of OnJoin callbacks */
+typealias OnJoin = (key: String, current: PresenceMap?, new: PresenceMap) -> Unit
+
+/** Closure signature for OnLeave callbacks */
+typealias OnLeave = (key: String, current: PresenceMap, left: PresenceMap) -> Unit
+
+/** Closure signature for OnSync callbacks */
+typealias OnSync = () -> Unit
+
+class Presence(channel: Channel, opts: Options = Options.defaults) {
+
+ //------------------------------------------------------------------------------
+ // Enums and Data classes
+ //------------------------------------------------------------------------------
+ /**
+ * Custom options that can be provided when creating Presence
+ */
+ data class Options(val events: Map) {
+ companion object {
+
+ /**
+ * Default set of Options used when creating Presence. Uses the
+ * phoenix events "presence_state" and "presence_diff"
+ */
+ val defaults: Options
+ get() = Options(
+ mapOf(
+ Events.STATE to "presence_state",
+ Events.DIFF to "presence_diff"))
+ }
+ }
+
+ /** Collection of callbacks with default values */
+ data class Caller(
+ var onJoin: OnJoin = { _, _, _ -> },
+ var onLeave: OnLeave = { _, _, _ -> },
+ var onSync: OnSync = {}
+ )
+
+ /** Presence Events of "state" and "diff" */
+ enum class Events {
+ STATE,
+ DIFF
+ }
+
+ //------------------------------------------------------------------------------
+ // Properties
+ //------------------------------------------------------------------------------
+ /** The channel the Presence belongs to */
+ internal val channel: Channel
+
+ /** Caller to callback hooks */
+ internal val caller: Caller
+
+ /** The state of the Presence */
+ var state: PresenceState
+ internal set
+
+ /** Pending `join` and `leave` diffs that need to be synced */
+ var pendingDiffs: MutableList
+ private set
+
+ /** The channel's joinRef, set when state events occur */
+ var joinRef: String?
+ private set
+
+ /** True if the Presence has not yet initially synced */
+ val isPendingSyncState: Boolean
+ get() = this.joinRef == null || (this.joinRef !== this.channel.joinRef)
+
+ //------------------------------------------------------------------------------
+ // Initialization
+ //------------------------------------------------------------------------------
+ init {
+ this.state = mutableMapOf()
+ this.pendingDiffs = mutableListOf()
+ this.channel = channel
+ this.joinRef = null
+ this.caller = Caller()
+
+ val stateEvent = opts.events[Events.STATE]
+ val diffEvent = opts.events[Events.DIFF]
+
+ if (stateEvent != null && diffEvent != null) {
+
+ this.channel.on(stateEvent) { message ->
+ val newState = message.rawPayload.toMutableMap() as PresenceState
+
+ this.joinRef = this.channel.joinRef
+ this.state =
+ Presence.syncState(state, newState, caller.onJoin, caller.onLeave)
+
+
+ this.pendingDiffs.forEach { diff ->
+ this.state = syncDiff(state, diff, caller.onJoin, caller.onLeave)
+ }
+
+ this.pendingDiffs.clear()
+ this.caller.onSync()
+ }
+
+ this.channel.on(diffEvent) { message ->
+ val diff = message.rawPayload.toMutableMap() as PresenceDiff
+ if (isPendingSyncState) {
+ this.pendingDiffs.add(diff)
+ } else {
+ this.state = syncDiff(state, diff, caller.onJoin, caller.onLeave)
+ this.caller.onSync()
+ }
+ }
+ }
+ }
+
+ //------------------------------------------------------------------------------
+ // Callbacks
+ //------------------------------------------------------------------------------
+ fun onJoin(callback: OnJoin) {
+ this.caller.onJoin = callback
+ }
+
+ fun onLeave(callback: OnLeave) {
+ this.caller.onLeave = callback
+ }
+
+ fun onSync(callback: OnSync) {
+ this.caller.onSync = callback
+ }
+
+ //------------------------------------------------------------------------------
+ // Listing
+ //------------------------------------------------------------------------------
+ fun list(): List {
+ return this.listBy { it.value }
+ }
+
+ fun listBy(transform: (Map.Entry) -> T): List {
+ return Presence.listBy(state, transform)
+ }
+
+ fun filterBy(predicate: ((Map.Entry) -> Boolean)?): PresenceState {
+ return Presence.filter(state, predicate)
+ }
+
+ //------------------------------------------------------------------------------
+ // Syncing
+ //------------------------------------------------------------------------------
+ companion object {
+
+ private fun cloneMap(map: PresenceMap): PresenceMap {
+ val clone: PresenceMap = mutableMapOf()
+ map.forEach { entry -> clone[entry.key] = entry.value.toList() }
+ return clone
+ }
+
+ private fun cloneState(state: PresenceState): PresenceState {
+ val clone: PresenceState = mutableMapOf()
+ state.forEach { entry -> clone[entry.key] = cloneMap(entry.value) }
+ return clone
+ }
+
+
+ /**
+ * Used to sync the list of presences on the server with the client's state. An optional
+ * `onJoin` and `onLeave` callback can be provided to react to changes in the client's local
+ * presences across disconnects and reconnects with the server.
+ *
+ */
+ fun syncState(
+ currentState: PresenceState,
+ newState: PresenceState,
+ onJoin: OnJoin = { _, _, _ -> },
+ onLeave: OnLeave = { _, _, _ -> }
+ ): PresenceState {
+ val state = cloneState(currentState)
+ val leaves: PresenceState = mutableMapOf()
+ val joins: PresenceState = mutableMapOf()
+
+ state.forEach { (key, presence) ->
+ if (!newState.containsKey(key)) {
+ leaves[key] = presence
+ }
+ }
+
+ newState.forEach { (key, newPresence) ->
+ state[key]?.let { currentPresence ->
+ val newRefs = newPresence["metas"]!!.map { meta -> meta["phx_ref"] as String }
+ val curRefs = currentPresence["metas"]!!.map { meta -> meta["phx_ref"] as String }
+
+ val joinedMetas = newPresence["metas"]!!.filter { meta ->
+ curRefs.indexOf(meta["phx_ref"]) < 0
+ }
+ val leftMetas = currentPresence["metas"]!!.filter { meta ->
+ newRefs.indexOf(meta["phx_ref"]) < 0
+ }
+
+ if (joinedMetas.isNotEmpty()) {
+ joins[key] = cloneMap(newPresence)
+ joins[key]!!["metas"] = joinedMetas
+ }
+
+ if (leftMetas.isNotEmpty()) {
+ leaves[key] = cloneMap(currentPresence)
+ leaves[key]!!["metas"] = leftMetas
+ }
+ } ?: run {
+ joins[key] = newPresence
+ }
+ }
+
+ val diff: PresenceDiff = mutableMapOf("joins" to joins, "leaves" to leaves)
+ return syncDiff(state, diff, onJoin, onLeave)
+
+ }
+
+ /**
+ * Used to sync a diff of presence join and leave events from the server, as they happen.
+ * Like `syncState`, `syncDiff` accepts optional `onJoin` and `onLeave` callbacks to react
+ * to a user joining or leaving from a device.
+ */
+ fun syncDiff(
+ currentState: PresenceState,
+ diff: PresenceDiff,
+ onJoin: OnJoin = { _, _, _ -> },
+ onLeave: OnLeave = { _, _, _ -> }
+ ): PresenceState {
+ val state = cloneState(currentState)
+
+ // Sync the joined states and inform onJoin of new presence
+ diff["joins"]?.forEach { (key, newPresence) ->
+ val currentPresence = state[key]
+ state[key] = cloneMap(newPresence)
+
+ currentPresence?.let { curPresence ->
+ val joinedRefs = state[key]!!["metas"]!!.map { m -> m["phx_ref"] as String }
+ val curMetas = curPresence["metas"]!!.filter { m -> joinedRefs.indexOf(m["phx_ref"]) < 0 }
+
+ // Data structures are immutable. Need to convert to a mutable copy,
+ // add the metas, and then reassign to the state
+ val mutableMetas = state[key]!!["metas"]!!.toMutableList()
+ mutableMetas.addAll(0, curMetas)
+
+ state[key]!!["metas"] = mutableMetas
+ }
+
+ onJoin.invoke(key, currentPresence, newPresence)
+ }
+
+ // Sync the left diff and inform onLeave of left presence
+ diff["leaves"]?.forEach { (key, leftPresence) ->
+ val curPresence = state[key] ?: return@forEach
+
+ val refsToRemove = leftPresence["metas"]!!.map { it["phx_ref"] as String }
+ curPresence["metas"] =
+ curPresence["metas"]!!.filter { m -> refsToRemove.indexOf(m["phx_ref"]) < 0 }
+
+ onLeave.invoke(key, curPresence, leftPresence)
+ if (curPresence["metas"]?.isEmpty() == true) {
+ state.remove(key)
+ }
+ }
+
+ return state
+ }
+
+ fun filter(
+ presence: PresenceState,
+ predicate: ((Map.Entry) -> Boolean)?
+ ): PresenceState {
+ return presence.filter(predicate ?: { true }).toMutableMap()
+ }
+
+ fun listBy(
+ presence: PresenceState,
+ transform: (Map.Entry) -> T
+ ): List {
+ return presence.map(transform)
+ }
+ }
+}
diff --git a/src/main/kotlin/org/phoenixframework/Push.kt b/src/main/kotlin/org/phoenixframework/Push.kt
new file mode 100644
index 0000000..5234205
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Push.kt
@@ -0,0 +1,188 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import java.util.concurrent.TimeUnit
+
+/**
+ * A Push represents an attempt to send a payload through a Channel for a specific event.
+ */
+class Push(
+ /** The channel the Push is being sent through */
+ val channel: Channel,
+ /** The event the Push is targeting */
+ val event: String,
+ /** The message to be sent */
+ var payload: Payload = mapOf(),
+ /** Duration before the message is considered timed out and failed to send */
+ var timeout: Long = Defaults.TIMEOUT
+) {
+
+ /** The server's response to the Push */
+ var receivedMessage: Message? = null
+
+ /** The task to be triggered if the Push times out */
+ var timeoutTask: DispatchWorkItem? = null
+
+ /** Hooks into a Push. Where .receive("ok", callback(Payload)) are stored */
+ var receiveHooks: MutableMap Unit)>> = HashMap()
+
+ /** True if the Push has been sent */
+ var sent: Boolean = false
+
+ /** The reference ID of the Push */
+ var ref: String? = null
+
+ /** The event that is associated with the reference ID of the Push */
+ var refEvent: String? = null
+
+ //------------------------------------------------------------------------------
+ // Public
+ //------------------------------------------------------------------------------
+ /**
+ * Resets and sends the Push
+ * @param timeout Optional. The push timeout. Default is 10_000ms = 10s
+ */
+ fun resend(timeout: Long = Defaults.TIMEOUT) {
+ this.timeout = timeout
+ this.reset()
+ this.send()
+ }
+
+ /**
+ * Sends the Push. If it has already timed out then the call will be ignored. use
+ * `resend(timeout:)` in this case.
+ */
+ fun send() {
+ if (hasReceived("timeout")) return
+
+ this.startTimeout()
+ this.sent = true
+ this.channel.socket.push(channel.topic, event, payload, ref, channel.joinRef)
+ }
+
+ /**
+ * Receive a specific event when sending an Outbound message
+ *
+ * Example:
+ * channel
+ * .send("event", myPayload)
+ * .receive("error") { }
+ */
+ fun receive(status: String, callback: (Message) -> Unit): Push {
+ // If the message has already be received, pass it to the callback
+ receivedMessage?.let { if (hasReceived(status)) callback(it) }
+
+ // If a previous hook for this status already exists. Just append the new hook. If not, then
+ // create a new array of hooks if no previous hook is associated with status
+ receiveHooks[status] = receiveHooks[status]?.plus(callback) ?: arrayListOf(callback)
+
+ return this
+ }
+
+ //------------------------------------------------------------------------------
+ // Internal
+ //------------------------------------------------------------------------------
+ /** Resets the Push as it was after it was first initialized. */
+ internal fun reset() {
+ this.cancelRefEvent()
+ this.ref = null
+ this.refEvent = null
+ this.receivedMessage = null
+ this.sent = false
+ }
+
+ /**
+ * Triggers an event to be sent through the Push's parent Channel
+ */
+ internal fun trigger(status: String, payload: Payload) {
+ this.refEvent?.let { refEvent ->
+ val mutPayload = payload.toMutableMap()
+ mutPayload["status"] = status
+
+ this.channel.trigger(refEvent, mutPayload)
+ }
+ }
+
+ /**
+ * Schedules a timeout task which will be triggered after a specific timeout is reached
+ */
+ internal fun startTimeout() {
+ // Cancel any existing timeout before starting a new one
+ this.timeoutTask?.let { if (!it.isCancelled) this.cancelTimeout() }
+
+ // Get the ref of the Push
+ val ref = this.channel.socket.makeRef()
+ val refEvent = this.channel.replyEventName(ref)
+
+ this.ref = ref
+ this.refEvent = refEvent
+
+ // Subscribe to a reply from the server when the Push is received
+ this.channel.on(refEvent) { message ->
+ this.cancelRefEvent()
+ this.cancelTimeout()
+ this.receivedMessage = message
+
+ // Check if there is an event receive hook to be informed
+ message.status?.let { status -> matchReceive(status, message) }
+ }
+
+ // Setup and start the Timer
+ this.timeoutTask = channel.socket.dispatchQueue.queue(timeout, TimeUnit.MILLISECONDS) {
+ this.trigger("timeout", hashMapOf())
+ }
+ }
+
+ //------------------------------------------------------------------------------
+ // Private
+ //------------------------------------------------------------------------------
+ /**
+ * Finds the receiveHook which needs to be informed of a status response and passes it the message
+ *
+ * @param status Status which was received. e.g. "ok", "error", etc.
+ * @param message Message to pass to receive hook
+ */
+ private fun matchReceive(status: String, message: Message) {
+ receiveHooks[status]?.forEach { it(message) }
+ }
+
+ /** Removes receive hook from Channel regarding this Push */
+ private fun cancelRefEvent() {
+ this.refEvent?.let { this.channel.off(it) }
+ }
+
+ /** Cancels any ongoing timeout task */
+ internal fun cancelTimeout() {
+ this.timeoutTask?.cancel()
+ this.timeoutTask = null
+ }
+
+ /**
+ * @param status Status to check if it has been received
+ * @return True if the status has already been received by the Push
+ */
+ private fun hasReceived(status: String): Boolean {
+ return receivedMessage?.status == status
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt
new file mode 100644
index 0000000..73f5e23
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Socket.kt
@@ -0,0 +1,608 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import okhttp3.OkHttpClient
+import okhttp3.Response
+import java.net.URL
+import java.util.concurrent.TimeUnit
+
+/** Alias for a JSON mapping */
+typealias Payload = Map
+
+/** Data class that holds callbacks assigned to the socket */
+internal class StateChangeCallbacks {
+
+ var open: List Unit>> = ArrayList()
+ private set
+ var close: List Unit>> = ArrayList()
+ private set
+ var error: List Unit>> = ArrayList()
+ private set
+ var message: List Unit>> = ArrayList()
+ private set
+
+ /** Safely adds an onOpen callback */
+ fun onOpen(
+ ref: String,
+ callback: () -> Unit
+ ) {
+ this.open = this.open + Pair(ref, callback)
+ }
+
+ /** Safely adds an onClose callback */
+ fun onClose(
+ ref: String,
+ callback: () -> Unit
+ ) {
+ this.close = this.close + Pair(ref, callback)
+ }
+
+ /** Safely adds an onError callback */
+ fun onError(
+ ref: String,
+ callback: (Throwable, Response?) -> Unit
+ ) {
+ this.error = this.error + Pair(ref, callback)
+ }
+
+ /** Safely adds an onMessage callback */
+ fun onMessage(
+ ref: String,
+ callback: (Message) -> Unit
+ ) {
+ this.message = this.message + Pair(ref, callback)
+ }
+
+ /** Clears any callbacks with the matching refs */
+ fun release(refs: List) {
+ open = open.filterNot { refs.contains(it.first) }
+ close = close.filterNot { refs.contains(it.first) }
+ error = error.filterNot { refs.contains(it.first) }
+ message = message.filterNot { refs.contains(it.first) }
+ }
+
+ /** Clears all stored callbacks */
+ fun release() {
+ open = emptyList()
+ close = emptyList()
+ error = emptyList()
+ message = emptyList()
+ }
+}
+
+/** RFC 6455: indicates a normal closure */
+const val WS_CLOSE_NORMAL = 1000
+
+/** RFC 6455: indicates that the connection was closed abnormally */
+const val WS_CLOSE_ABNORMAL = 1006
+
+/**
+ * A closure that will return an optional Payload
+ */
+typealias PayloadClosure = () -> Payload?
+
+/** A closure that will encode a Map into a JSON String */
+typealias EncodeClosure = (Any) -> String
+
+/** A closure that will decode a JSON String into a [Message] */
+typealias DecodeClosure = (String) -> Message
+
+
+/**
+ * Connects to a Phoenix Server
+ */
+
+/**
+ * A [Socket] which connects to a Phoenix Server. Takes a closure to allow for changing parameters
+ * to be sent to the server when connecting.
+ *
+ * ## Example
+ * ```
+ * val socket = Socket("https://example.com/socket", { mapOf("token" to mAuthToken) })
+ * ```
+ * @param url Url to connect to such as https://example.com/socket
+ * @param paramsClosure Closure which allows to change parameters sent during connection.
+ * @param vsn JSON Serializer version to use. Defaults to 2.0.0
+ * @param encode Optional. Provide a custom JSON encoding implementation
+ * @param decode Optional. Provide a custom JSON decoding implementation
+ * @param client Default OkHttpClient to connect with. You can provide your own if needed.
+ */
+class Socket(
+ url: String,
+ val paramsClosure: PayloadClosure,
+ val vsn: String = Defaults.VSN,
+ private val encode: EncodeClosure = Defaults.encode,
+ private val decode: DecodeClosure = Defaults.decode,
+ private val client: OkHttpClient = OkHttpClient.Builder().build()
+) {
+
+ //------------------------------------------------------------------------------
+ // Public Attributes
+ //------------------------------------------------------------------------------
+ /**
+ * The string WebSocket endpoint (ie `"ws://example.com/socket"`,
+ * `"wss://example.com"`, etc.) that was passed to the Socket during
+ * initialization. The URL endpoint will be modified by the Socket to
+ * include `"/websocket"` if missing.
+ */
+ val endpoint: String
+
+ /** The fully qualified socket URL */
+ var endpointUrl: URL
+ private set
+
+ /** Timeout to use when opening a connection */
+ var timeout: Long = Defaults.TIMEOUT
+
+ /** Interval between sending a heartbeat, in ms */
+ var heartbeatIntervalMs: Long = Defaults.HEARTBEAT
+
+ /** Interval between socket reconnect attempts, in ms */
+ var reconnectAfterMs: ((Int) -> Long) = Defaults.reconnectSteppedBackOff
+
+ /** Interval between channel rejoin attempts, in ms */
+ var rejoinAfterMs: ((Int) -> Long) = Defaults.rejoinSteppedBackOff
+
+ /** The optional function to receive logs */
+ var logger: ((String) -> Unit)? = null
+
+ /** Disables heartbeats from being sent. Default is false. */
+ var skipHeartbeat: Boolean = false
+
+ //------------------------------------------------------------------------------
+ // Internal Attributes
+ //------------------------------------------------------------------------------
+ /**
+ * All timers associated with a socket will share the same pool. Used for every Channel or
+ * Push that is sent through or created by a Socket instance. Different Socket instances will
+ * create individual thread pools.
+ */
+// internal var timerPool: ScheduledExecutorService = ScheduledThreadPoolExecutor(8)
+ internal var dispatchQueue: DispatchQueue = ScheduledDispatchQueue()
+
+ //------------------------------------------------------------------------------
+ // Private Attributes
+ // these are marked as `internal` so that they can be accessed during tests
+ //------------------------------------------------------------------------------
+ /** Returns the type of transport to use. Potentially expose for custom transports */
+ internal var transport: (URL) -> Transport = { WebSocketTransport(it, client) }
+
+ /** Collection of callbacks for socket state changes */
+ internal val stateChangeCallbacks: StateChangeCallbacks = StateChangeCallbacks()
+
+ /** Collection of unclosed channels created by the Socket */
+ internal var channels: List = ArrayList()
+
+ /**
+ * Buffers messages that need to be sent once the socket has connected. It is an array of Pairs
+ * that contain the ref of the message to send and the callback that will send the message.
+ */
+ internal var sendBuffer: MutableList Unit>> = ArrayList()
+
+ /** Ref counter for messages */
+ internal var ref: Int = 0
+
+ /** Task to be triggered in the future to send a heartbeat message */
+ internal var heartbeatTask: DispatchWorkItem? = null
+
+ /** Ref counter for the last heartbeat that was sent */
+ internal var pendingHeartbeatRef: String? = null
+
+ /** Timer to use when attempting to reconnect */
+ internal var reconnectTimer: TimeoutTimer
+
+ /** True if the Socket closed cleaned. False if not (connection timeout, heartbeat, etc) */
+ internal var closeWasClean = false
+
+ //------------------------------------------------------------------------------
+ // Connection Attributes
+ //------------------------------------------------------------------------------
+ /** The underlying WebSocket connection */
+ internal var connection: Transport? = null
+
+ //------------------------------------------------------------------------------
+ // Initialization
+ //------------------------------------------------------------------------------
+ /**
+ * A [Socket] which connects to a Phoenix Server. Takes a constant parameter to be sent to the
+ * server when connecting. Defaults to null if excluded.
+ *
+ * ## Example
+ * ```
+ * val socket = Socket("https://example.com/socket", mapOf("token" to mAuthToken))
+ * ```
+ *
+ * @param url Url to connect to such as https://example.com/socket
+ * @param params Constant parameters to send when connecting. Defaults to null
+ * @param vsn JSON Serializer version to use. Defaults to 2.0.0
+ * @param encode Optional. Provide a custom JSON encoding implementation
+ * @param decode Optional. Provide a custom JSON decoding implementation
+ * @param client Default OkHttpClient to connect with. You can provide your own if needed.
+ */
+ constructor(
+ url: String,
+ params: Payload? = null,
+ vsn: String = Defaults.VSN,
+ encode: EncodeClosure = Defaults.encode,
+ decode: DecodeClosure = Defaults.decode,
+ client: OkHttpClient = OkHttpClient.Builder().build()
+ ) : this(url, { params }, vsn, encode, decode, client)
+
+ init {
+ var mutableUrl = url
+
+ // Ensure that the URL ends with "/websocket"
+ if (!mutableUrl.contains("/websocket")) {
+ // Do not duplicate '/' in path
+ if (mutableUrl.last() != '/') {
+ mutableUrl += "/"
+ }
+
+ // append "websocket" to the path
+ mutableUrl += "websocket"
+ }
+
+ // Store the endpoint before changing the protocol
+ this.endpoint = mutableUrl
+
+ // Store the URL that will be used to establish a connection. Could potentially be
+ // different at the time connect() is called based on a changing params closure.
+ this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure, this.vsn)
+
+ // Create reconnect timer
+ this.reconnectTimer = TimeoutTimer(
+ dispatchQueue = dispatchQueue,
+ timerCalculation = reconnectAfterMs,
+ callback = {
+ this.logItems("Socket attempting to reconnect")
+ this.teardown { this.connect() }
+ })
+ }
+
+ //------------------------------------------------------------------------------
+ // Public Properties
+ //------------------------------------------------------------------------------
+ /** @return The socket protocol being used. e.g. "wss", "ws" */
+ val protocol: String
+ get() = when (endpointUrl.protocol) {
+ "https" -> "wss"
+ "http" -> "ws"
+ else -> endpointUrl.protocol
+ }
+
+ /** @return True if the connection exists and is open */
+ val isConnected: Boolean
+ get() = this.connectionState == Transport.ReadyState.OPEN
+
+ /** @return The ready state of the connection. */
+ val connectionState: Transport.ReadyState
+ get() = this.connection?.readyState ?: Transport.ReadyState.CLOSED
+
+ //------------------------------------------------------------------------------
+ // Public
+ //------------------------------------------------------------------------------
+ fun connect() {
+ // Do not attempt to connect if already connected
+ if (isConnected) return
+
+ // Reset the clean close flag when attempting to connect
+ this.closeWasClean = false
+
+ // Build the new endpointUrl with the params closure. The payload returned
+ // from the closure could be different such as a changing authToken.
+ this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure, this.vsn)
+
+ // Now create the connection transport and attempt to connect
+ this.connection = this.transport(endpointUrl)
+ this.connection?.onOpen = { onConnectionOpened() }
+ this.connection?.onClose = { code -> onConnectionClosed(code) }
+ this.connection?.onError = { t, r -> onConnectionError(t, r) }
+ this.connection?.onMessage = { m -> onConnectionMessage(m) }
+ this.connection?.connect()
+ }
+
+ fun disconnect(
+ code: Int = WS_CLOSE_NORMAL,
+ reason: String? = null,
+ callback: (() -> Unit)? = null
+ ) {
+ // The socket was closed cleanly by the User
+ this.closeWasClean = true
+
+ // Reset any reconnects and teardown the socket connection
+ this.reconnectTimer.reset()
+ this.teardown(code, reason, callback)
+ }
+
+ fun onOpen(callback: (() -> Unit)): String {
+ return makeRef().apply { stateChangeCallbacks.onOpen(this, callback) }
+ }
+
+ fun onClose(callback: () -> Unit): String {
+ return makeRef().apply { stateChangeCallbacks.onClose(this, callback) }
+ }
+
+ fun onError(callback: (Throwable, Response?) -> Unit): String {
+ return makeRef().apply { stateChangeCallbacks.onError(this, callback) }
+ }
+
+ fun onMessage(callback: (Message) -> Unit): String {
+ return makeRef().apply { stateChangeCallbacks.onMessage(this, callback) }
+ }
+
+ fun removeAllCallbacks() {
+ this.stateChangeCallbacks.release()
+ }
+
+ fun channel(
+ topic: String,
+ params: Payload = mapOf()
+ ): Channel {
+ val channel = Channel(topic, params, this)
+ this.channels = this.channels + channel
+
+ return channel
+ }
+
+ fun remove(channel: Channel) {
+ this.off(channel.stateChangeRefs)
+
+ // To avoid a ConcurrentModificationException, filter out the channels to be
+ // removed instead of calling .remove() on the list, thus returning a new list
+ // that does not contain the channel that was removed.
+ this.channels = channels
+ .filter { it.joinRef != channel.joinRef }
+ }
+
+ /**
+ * Removes [onOpen], [onClose], [onError], and [onMessage] registrations by their [ref] value.
+ *
+ * @param refs List of refs to remove
+ */
+ fun off(refs: List) {
+ this.stateChangeCallbacks.release(refs)
+ }
+
+ //------------------------------------------------------------------------------
+ // Internal
+ //------------------------------------------------------------------------------
+ internal fun push(
+ topic: String,
+ event: String,
+ payload: Payload,
+ ref: String? = null,
+ joinRef: String? = null
+ ) {
+
+ val callback: (() -> Unit) = {
+ val body = listOf(joinRef, ref, topic, event, payload)
+ val data = this.encode(body)
+ connection?.let { transport ->
+ this.logItems("Push: Sending $data")
+ transport.send(data)
+ }
+ }
+
+ if (isConnected) {
+ // If the socket is connected, then execute the callback immediately.
+ callback.invoke()
+ } else {
+ // If the socket is not connected, add the push to a buffer which will
+ // be sent immediately upon connection.
+ sendBuffer.add(Pair(ref, callback))
+ }
+ }
+
+ /** @return the next message ref, accounting for overflows */
+ internal fun makeRef(): String {
+ this.ref = if (ref == Int.MAX_VALUE) 0 else ref + 1
+ return ref.toString()
+ }
+
+ fun logItems(body: String) {
+ logger?.invoke(body)
+ }
+
+ //------------------------------------------------------------------------------
+ // Private
+ //------------------------------------------------------------------------------
+ private fun teardown(
+ code: Int = WS_CLOSE_NORMAL,
+ reason: String? = null,
+ callback: (() -> Unit)? = null
+ ) {
+ // Disconnect the transport
+ this.connection?.onClose = null
+ this.connection?.disconnect(code, reason)
+ this.connection = null
+
+ // Heartbeats are no longer needed
+ this.heartbeatTask?.cancel()
+ this.heartbeatTask = null
+
+ // Since the connections onClose was null'd out, inform all state callbacks
+ // that the Socket has closed
+ this.stateChangeCallbacks.close.forEach { it.second.invoke() }
+ callback?.invoke()
+ }
+
+ /** Triggers an error event to all connected Channels */
+ private fun triggerChannelError() {
+ this.channels.forEach { channel ->
+ // Only trigger a channel error if it is in an "opened" state
+ if (!(channel.isErrored || channel.isLeaving || channel.isClosed)) {
+ channel.trigger(Channel.Event.ERROR.value)
+ }
+ }
+ }
+
+ /** Send all messages that were buffered before the socket opened */
+ internal fun flushSendBuffer() {
+ if (isConnected && sendBuffer.isNotEmpty()) {
+ this.sendBuffer.forEach { it.second.invoke() }
+ this.sendBuffer.clear()
+ }
+ }
+
+ /** Removes an item from the send buffer with the matching ref */
+ internal fun removeFromSendBuffer(ref: String) {
+ this.sendBuffer = this.sendBuffer
+ .filter { it.first != ref }
+ .toMutableList()
+ }
+
+ internal fun leaveOpenTopic(topic: String) {
+ this.channels
+ .firstOrNull { it.topic == topic && (it.isJoined || it.isJoining) }
+ ?.let {
+ logItems("Transport: Leaving duplicate topic: [$topic]")
+ it.leave()
+ }
+ }
+
+ //------------------------------------------------------------------------------
+ // Heartbeat
+ //------------------------------------------------------------------------------
+ internal fun resetHeartbeat() {
+ // Clear anything related to the previous heartbeat
+ this.pendingHeartbeatRef = null
+ this.heartbeatTask?.cancel()
+ this.heartbeatTask = null
+
+ // Do not start up the heartbeat timer if skipHeartbeat is true
+ if (skipHeartbeat) return
+ val delay = heartbeatIntervalMs
+ val period = heartbeatIntervalMs
+
+ heartbeatTask =
+ dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() }
+ }
+
+ internal fun sendHeartbeat() {
+ // Do not send if the connection is closed
+ if (!isConnected) return
+
+ // If there is a pending heartbeat ref, then the last heartbeat was
+ // never acknowledged by the server. Close the connection and attempt
+ // to reconnect.
+ pendingHeartbeatRef?.let {
+ pendingHeartbeatRef = null
+ logItems("Transport: Heartbeat timeout. Attempt to re-establish connection")
+
+ // Close the socket, flagging the closure as abnormal
+ this.abnormalClose("heartbeat timeout")
+ return
+ }
+
+ // The last heartbeat was acknowledged by the server. Send another one
+ this.pendingHeartbeatRef = this.makeRef()
+ this.push(
+ topic = "phoenix",
+ event = Channel.Event.HEARTBEAT.value,
+ payload = mapOf(),
+ ref = pendingHeartbeatRef
+ )
+ }
+
+ private fun abnormalClose(reason: String) {
+ this.closeWasClean = false
+
+ /*
+ We use NORMAL here since the client is the one determining to close the connection. However,
+ we keep a flag `closeWasClean` set to false so that the client knows that it should attempt
+ to reconnect.
+ */
+ this.connection?.disconnect(WS_CLOSE_NORMAL, reason)
+ }
+
+ //------------------------------------------------------------------------------
+ // Connection Transport Hooks
+ //------------------------------------------------------------------------------
+ internal fun onConnectionOpened() {
+ this.logItems("Transport: Connected to $endpoint")
+
+ // Reset the closeWasClean flag now that the socket has been connected
+ this.closeWasClean = false
+
+ // Send any messages that were waiting for a connection
+ this.flushSendBuffer()
+
+ // Reset how the socket tried to reconnect
+ this.reconnectTimer.reset()
+
+ // Restart the heartbeat timer
+ this.resetHeartbeat()
+
+ // Inform all onOpen callbacks that the Socket has opened
+ this.stateChangeCallbacks.open.forEach { it.second.invoke() }
+ }
+
+ internal fun onConnectionClosed(code: Int) {
+ this.logItems("Transport: close")
+ this.triggerChannelError()
+
+ // Prevent the heartbeat from triggering if the socket closed
+ this.heartbeatTask?.cancel()
+ this.heartbeatTask = null
+
+ // Only attempt to reconnect if the socket did not close normally
+ if (!this.closeWasClean) {
+ this.reconnectTimer.scheduleTimeout()
+ }
+
+ // Inform callbacks the socket closed
+ this.stateChangeCallbacks.close.forEach { it.second.invoke() }
+ }
+
+ internal fun onConnectionMessage(rawMessage: String) {
+ this.logItems("Receive: $rawMessage")
+
+ // Parse the message as JSON
+ val message = this.decode(rawMessage)
+
+ // Clear heartbeat ref, preventing a heartbeat timeout disconnect
+ if (message.ref == pendingHeartbeatRef) pendingHeartbeatRef = null
+
+ // Dispatch the message to all channels that belong to the topic
+ this.channels
+ .filter { it.isMember(message) }
+ .forEach { it.trigger(message) }
+
+ // Inform all onMessage callbacks of the message
+ this.stateChangeCallbacks.message.forEach { it.second.invoke(message) }
+ }
+
+ internal fun onConnectionError(
+ t: Throwable,
+ response: Response?
+ ) {
+ this.logItems("Transport: error $t")
+
+ // Send an error to all channels
+ this.triggerChannelError()
+
+ // Inform any state callbacks of the error
+ this.stateChangeCallbacks.error.forEach { it.second.invoke(t, response) }
+ }
+}
diff --git a/src/main/kotlin/org/phoenixframework/TimeoutTimer.kt b/src/main/kotlin/org/phoenixframework/TimeoutTimer.kt
new file mode 100644
index 0000000..1faec4b
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/TimeoutTimer.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import java.util.concurrent.TimeUnit
+
+/**
+ * A Timer class that schedules a callback to be called in the future. Can be configured
+ * to use a custom retry pattern, such as exponential backoff.
+ */
+class TimeoutTimer(
+ private val dispatchQueue: DispatchQueue,
+ private val callback: () -> Unit,
+ private val timerCalculation: (tries: Int) -> Long
+) {
+
+ /** How many tries the Timer has attempted */
+ private var tries: Int = 0
+
+ /** The task that has been scheduled to be executed */
+ private var workItem: DispatchWorkItem? = null
+
+ /**
+ * Resets the Timer, clearing the number of current tries and stops
+ * any scheduled timeouts.
+ */
+ fun reset() {
+ this.tries = 0
+ this.clearTimer()
+ }
+
+ /** Cancels any previous timeouts and scheduled a new one */
+ fun scheduleTimeout() {
+ this.clearTimer()
+
+ // Schedule a task to be performed after the calculated timeout in milliseconds
+ val timeout = timerCalculation(tries + 1)
+ this.workItem = dispatchQueue.queue(timeout, TimeUnit.MILLISECONDS) {
+ this.tries += 1
+ this.callback.invoke()
+ }
+ }
+
+ //------------------------------------------------------------------------------
+ // Private
+ //------------------------------------------------------------------------------
+ private fun clearTimer() {
+ // Cancel the task from completing, allowing it to fi
+ this.workItem?.cancel()
+ this.workItem = null
+ }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/org/phoenixframework/Transport.kt b/src/main/kotlin/org/phoenixframework/Transport.kt
new file mode 100644
index 0000000..25c8818
--- /dev/null
+++ b/src/main/kotlin/org/phoenixframework/Transport.kt
@@ -0,0 +1,168 @@
+/*
+ * Copyright (c) 2019 Daniel Rees
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package org.phoenixframework
+
+import okhttp3.OkHttpClient
+import okhttp3.Request
+import okhttp3.Response
+import okhttp3.WebSocket
+import okhttp3.WebSocketListener
+import java.net.URL
+
+/**
+ * Interface that defines different types of Transport layers. A default {@link WebSocketTransport}
+ * is provided which uses an OkHttp WebSocket to transport data between your Phoenix server.
+ *
+ * Future support may be added to provide your own custom Transport, such as a LongPoll
+ */
+interface Transport {
+
+ /** Available ReadyStates of a {@link Transport}. */
+ enum class ReadyState {
+
+ /** The Transport is connecting to the server */
+ CONNECTING,
+
+ /** The Transport is connected and open */
+ OPEN,
+
+ /** The Transport is closing */
+ CLOSING,
+
+ /** The Transport is closed */
+ CLOSED
+ }
+
+ /** The state of the Transport. See {@link ReadyState} */
+ val readyState: ReadyState
+
+ /** Called when the Transport opens */
+ var onOpen: (() -> Unit)?
+ /** Called when the Transport receives an error */
+ var onError: ((Throwable, Response?) -> Unit)?
+ /** Called each time the Transport receives a message */
+ var onMessage: ((String) -> Unit)?
+ /** Called when the Transport closes */
+ var onClose: ((Int) -> Unit)?
+
+ /** Connect to the server */
+ fun connect()
+
+ /**
+ * Disconnect from the Server
+ *
+ * @param code Status code as defined by Section 7.4 of RFC 6455 .
+ * @param reason Reason for shutting down or {@code null}.
+ */
+ fun disconnect(code: Int, reason: String? = null)
+
+ /**
+ * Sends text to the Server
+ */
+ fun send(data: String)
+}
+
+/**
+ * A WebSocket implementation of a Transport that uses a WebSocket to facilitate sending
+ * and receiving data.
+ *
+ * @param url: URL to connect to
+ * @param okHttpClient: Custom client that can be pre-configured before connecting
+ */
+class WebSocketTransport(
+ private val url: URL,
+ private val okHttpClient: OkHttpClient
+) :
+ WebSocketListener(),
+ Transport {
+
+ internal var connection: WebSocket? = null
+
+ override var readyState: Transport.ReadyState = Transport.ReadyState.CLOSED
+ override var onOpen: (() -> Unit)? = null
+ override var onError: ((Throwable, Response?) -> Unit)? = null
+ override var onMessage: ((String) -> Unit)? = null
+ override var onClose: ((Int) -> Unit)? = null
+
+ override fun connect() {
+ this.readyState = Transport.ReadyState.CONNECTING
+ val request = Request.Builder().url(url).build()
+ connection = okHttpClient.newWebSocket(request, this)
+ }
+
+ override fun disconnect(code: Int, reason: String?) {
+ connection?.close(code, reason)
+ connection = null
+ }
+
+ override fun send(data: String) {
+ connection?.send(data)
+ }
+
+ //------------------------------------------------------------------------------
+ // WebSocket Listener
+ //------------------------------------------------------------------------------
+ override fun onOpen(webSocket: WebSocket, response: Response) {
+ this.readyState = Transport.ReadyState.OPEN
+ this.onOpen?.invoke()
+ }
+
+ override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
+ // Set the state of the Transport as CLOSED since no more data will be received
+ this.readyState = Transport.ReadyState.CLOSED
+
+ // Invoke the onError callback, to inform of the error
+ this.onError?.invoke(t, response)
+
+ /*
+ According to the OkHttp documentation, `onFailure` will be
+
+ "Invoked when a web socket has been closed due to an error reading from or writing to the
+ network. Both outgoing and incoming messages may have been lost. No further calls to this
+ listener will be made."
+
+ This means `onClose` will never be called which will never kick off the socket reconnect
+ attempts.
+
+ The JS WebSocket class calls `onError` and then `onClose` which will then trigger
+ the reconnect logic inside of the PhoenixClient. In order to mimic this behavior and abstract
+ this detail of OkHttp away from the PhoenixClient, the `WebSocketTransport` class should
+ convert `onFailure` calls to an `onError` and `onClose` sequence.
+ */
+ this.onClose?.invoke(WS_CLOSE_ABNORMAL)
+ }
+
+ override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
+ this.readyState = Transport.ReadyState.CLOSING
+ }
+
+ override fun onMessage(webSocket: WebSocket, text: String) {
+ this.onMessage?.invoke(text)
+ }
+
+ override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
+ this.readyState = Transport.ReadyState.CLOSED
+ this.onClose?.invoke(code)
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/ChannelTest.kt b/src/test/kotlin/org/phoenixframework/ChannelTest.kt
new file mode 100644
index 0000000..bd83c49
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/ChannelTest.kt
@@ -0,0 +1,1167 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth.assertThat
+import org.mockito.kotlin.any
+import org.mockito.kotlin.eq
+import org.mockito.kotlin.mock
+import org.mockito.kotlin.never
+import org.mockito.kotlin.spy
+import org.mockito.kotlin.times
+import org.mockito.kotlin.verify
+import org.mockito.kotlin.whenever
+import okhttp3.OkHttpClient
+import org.junit.jupiter.api.AfterEach
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.DisplayName
+import org.junit.jupiter.api.Nested
+import org.junit.jupiter.api.Test
+import org.mockito.Mock
+import org.mockito.Mockito.verifyNoInteractions
+import org.mockito.MockitoAnnotations
+import org.phoenixframework.queue.ManualDispatchQueue
+import org.phoenixframework.utilities.getBindings
+
+class ChannelTest {
+
+
+ @Mock lateinit var okHttpClient: OkHttpClient
+
+ @Mock lateinit var socket: Socket
+ @Mock lateinit var mockCallback: ((Message) -> Unit)
+
+ private val kDefaultRef = "1"
+ private val kDefaultTimeout = 10_000L
+ private val kDefaultPayload: Payload = mapOf("one" to "two")
+ private val kEmptyPayload: Payload = mapOf()
+
+ lateinit var fakeClock: ManualDispatchQueue
+ lateinit var channel: Channel
+
+ @BeforeEach
+ internal fun setUp() {
+ MockitoAnnotations.initMocks(this)
+
+ fakeClock = ManualDispatchQueue()
+
+ whenever(socket.dispatchQueue).thenReturn(fakeClock)
+ whenever(socket.makeRef()).thenReturn(kDefaultRef)
+ whenever(socket.timeout).thenReturn(kDefaultTimeout)
+ whenever(socket.reconnectAfterMs).thenReturn(Defaults.reconnectSteppedBackOff)
+ whenever(socket.rejoinAfterMs).thenReturn(Defaults.rejoinSteppedBackOff)
+
+ channel = Channel("topic", kDefaultPayload, socket)
+ }
+
+ @AfterEach
+ internal fun tearDown() {
+ fakeClock.reset()
+ }
+
+ @Nested
+ @DisplayName("ChannelEvent")
+ inner class ChannelEvent {
+ @Test
+ internal fun `isLifecycleEvent returns true for lifecycle events`() {
+ assertThat(Channel.Event.isLifecycleEvent(Channel.Event.HEARTBEAT.value)).isFalse()
+ assertThat(Channel.Event.isLifecycleEvent(Channel.Event.JOIN.value)).isTrue()
+ assertThat(Channel.Event.isLifecycleEvent(Channel.Event.LEAVE.value)).isTrue()
+ assertThat(Channel.Event.isLifecycleEvent(Channel.Event.REPLY.value)).isTrue()
+ assertThat(Channel.Event.isLifecycleEvent(Channel.Event.ERROR.value)).isTrue()
+ assertThat(Channel.Event.isLifecycleEvent(Channel.Event.CLOSE.value)).isTrue()
+ assertThat(Channel.Event.isLifecycleEvent("random")).isFalse()
+ }
+
+ /* End ChannelEvent */
+ }
+
+ @Nested
+ @DisplayName("constructor")
+ inner class Constructor {
+ @Test
+ internal fun `sets defaults`() {
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+ assertThat(channel.topic).isEqualTo("topic")
+ assertThat(channel.params["one"]).isEqualTo("two")
+ assertThat(channel.socket).isEqualTo(socket)
+ assertThat(channel.timeout).isEqualTo(10_000L)
+ assertThat(channel.joinedOnce).isFalse()
+ assertThat(channel.joinPush).isNotNull()
+ assertThat(channel.pushBuffer).isEmpty()
+ }
+
+ @Test
+ internal fun `sets up joinPush with literal params`() {
+ val joinPush = channel.joinPush
+
+ assertThat(joinPush.channel).isEqualTo(channel)
+ assertThat(joinPush.payload["one"]).isEqualTo("two")
+ assertThat(joinPush.event).isEqualTo("phx_join")
+ assertThat(joinPush.timeout).isEqualTo(10_000L)
+ }
+
+ /* End Constructor */
+ }
+
+ @Nested
+ @DisplayName("onMessage")
+ inner class OnMessage {
+ @Test
+ internal fun `returns message by default`() {
+ val message = channel.onMessage.invoke(Message(ref = "original"))
+ assertThat(message.ref).isEqualTo("original")
+ }
+
+ @Test
+ internal fun `can be overidden`() {
+ channel.onMessage { Message(ref = "changed") }
+
+ val message = channel.onMessage.invoke(Message(ref = "original"))
+ assertThat(message.ref).isEqualTo("changed")
+ }
+
+ /* End OnMessage */
+ }
+
+ @Nested
+ @DisplayName("join params")
+ inner class JoinParams {
+ @Test
+ internal fun `updating join params`() {
+ val params = mapOf("value" to 1)
+ val change = mapOf("value" to 2)
+
+ channel = Channel("topic", params, socket)
+ val joinPush = channel.joinPush
+
+ assertThat(joinPush.channel).isEqualTo(channel)
+ assertThat(joinPush.payload["value"]).isEqualTo(1)
+ assertThat(joinPush.event).isEqualTo("phx_join")
+ assertThat(joinPush.timeout).isEqualTo(10_000L)
+
+ channel.params = change
+ assertThat(joinPush.channel).isEqualTo(channel)
+ assertThat(joinPush.payload["value"]).isEqualTo(2)
+ assertThat(channel.params["value"]).isEqualTo(2)
+ assertThat(joinPush.event).isEqualTo("phx_join")
+ assertThat(joinPush.timeout).isEqualTo(10_000L)
+ }
+
+ /* End JoinParams */
+ }
+
+ @Nested
+ @DisplayName("join")
+ inner class Join {
+
+ @BeforeEach
+ internal fun setUp() {
+ socket = spy(Socket(url = "https://localhost:4000/socket", client = okHttpClient))
+ socket.dispatchQueue = fakeClock
+ channel = Channel("topic", kDefaultPayload, socket)
+ }
+
+ @Test
+ internal fun `sets state to joining`() {
+ channel.join()
+ assertThat(channel.state).isEqualTo(Channel.State.JOINING)
+ }
+
+ @Test
+ internal fun `sets joinedOnce to true`() {
+ assertThat(channel.joinedOnce).isFalse()
+
+ channel.join()
+ assertThat(channel.joinedOnce).isTrue()
+ }
+
+ @Test
+ internal fun `throws if attempting to join multiple times`() {
+ var exceptionThrown = false
+ try {
+ channel.join()
+ channel.join()
+ } catch (e: Exception) {
+ exceptionThrown = true
+ assertThat(e).isInstanceOf(IllegalStateException::class.java)
+ assertThat(e.message).isEqualTo(
+ "Tried to join channel multiple times. `join()` can only be called once per channel")
+ }
+
+ assertThat(exceptionThrown).isTrue()
+ }
+
+ @Test
+ internal fun `triggers socket push with channel params`() {
+ channel.join()
+ verify(socket).push("topic", "phx_join", kDefaultPayload, "3", channel.joinRef)
+ }
+
+ @Test
+ internal fun `can set timeout on joinPush`() {
+ val newTimeout = 20_000L
+ val joinPush = channel.joinPush
+
+ assertThat(joinPush.timeout).isEqualTo(kDefaultTimeout)
+ channel.join(newTimeout)
+ assertThat(joinPush.timeout).isEqualTo(newTimeout)
+ }
+
+ @Test
+ internal fun `leaves existing duplicate topic on new join`() {
+ val socket = spy(Socket("wss://localhost:4000/socket"))
+ val channel = socket.channel("topic")
+
+ channel.join().receive("ok") {
+ val newChannel = socket.channel("topic")
+ assertThat(channel.isJoined).isTrue()
+ newChannel.join()
+
+ assertThat(channel.isJoined).isFalse()
+ }
+
+ channel.joinPush.trigger("ok", kEmptyPayload)
+ }
+
+ @Nested
+ @DisplayName("timeout behavior")
+ inner class TimeoutBehavior {
+
+ private lateinit var joinPush: Push
+
+ private fun receiveSocketOpen() {
+ whenever(socket.isConnected).thenReturn(true)
+ socket.onConnectionOpened()
+ }
+
+ @BeforeEach
+ internal fun setUp() {
+ joinPush = channel.joinPush
+ }
+
+ @Test
+ internal fun `succeeds before timeout`() {
+ val timeout = channel.timeout
+
+ socket.connect()
+ this.receiveSocketOpen()
+
+ channel.join()
+ verify(socket).push(any(), any(), any(), any(), any())
+ assertThat(channel.timeout).isEqualTo(10_000)
+
+ fakeClock.tick(100)
+
+ joinPush.trigger("ok", kEmptyPayload)
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+
+ fakeClock.tick(timeout)
+ verify(socket, times(1)).push(any(), any(), any(), any(), any())
+ }
+
+ @Test
+ internal fun `retries with backoff after timeout`() {
+ val timeout = channel.timeout
+
+ socket.connect()
+ this.receiveSocketOpen()
+
+ channel.join().receive("timeout", mockCallback)
+
+ verify(socket, times(1)).push(any(), eq("phx_join"), any(), any(), any())
+ verify(mockCallback, never()).invoke(any())
+
+ fakeClock.tick(timeout) // leave pushed to server
+ verify(socket, times(1)).push(any(), eq("phx_leave"), any(), any(), any())
+ verify(mockCallback, times(1)).invoke(any())
+
+ fakeClock.tick(timeout + 1000) // rejoin
+ verify(socket, times(2)).push(any(), eq("phx_join"), any(), any(), any())
+ verify(socket, times(2)).push(any(), eq("phx_leave"), any(), any(), any())
+ verify(mockCallback, times(2)).invoke(any())
+
+ fakeClock.tick(10_000)
+ joinPush.trigger("ok", kEmptyPayload)
+ verify(socket, times(3)).push(any(), eq("phx_join"), any(), any(), any())
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+ }
+
+ @Test
+ internal fun `with socket and join delay`() {
+ val joinPush = channel.joinPush
+
+ channel.join()
+ verify(socket, times(1)).push(any(), any(), any(), any(), any())
+
+ // Open the socket after a delay
+ fakeClock.tick(9_000)
+ verify(socket, times(1)).push(any(), any(), any(), any(), any())
+
+ // join request returns between timeouts
+ fakeClock.tick(1_000)
+ socket.connect()
+
+ assertThat(channel.state).isEqualTo(Channel.State.ERRORED)
+ this.receiveSocketOpen()
+ joinPush.trigger("ok", kEmptyPayload)
+
+ fakeClock.tick(1_000)
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+
+ verify(socket, times(3)).push(any(), any(), any(), any(), any())
+ }
+
+ @Test
+ internal fun `with socket delay only`() {
+ val joinPush = channel.joinPush
+
+ channel.join()
+ assertThat(channel.state).isEqualTo(Channel.State.JOINING)
+
+ // connect socket after a delay
+ fakeClock.tick(6_000)
+ socket.connect()
+
+ // open socket after delay
+ fakeClock.tick(5_000)
+ this.receiveSocketOpen()
+ joinPush.trigger("ok", kEmptyPayload)
+
+ joinPush.trigger("ok", kEmptyPayload)
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+ }
+
+
+ @Test
+ internal fun `does not enter rejoin loop if leave is called before joined`() {
+ socket.connect()
+ this.receiveSocketOpen()
+
+ channel.join()
+ channel.leave()
+
+ fakeClock.tick(channel.timeout * 4)
+ verify(socket, times(2)).push(any(), any(), any(), any(), any())
+ }
+
+ /* End TimeoutBehavior */
+ }
+
+ /* End Join */
+ }
+
+ @Nested
+ @DisplayName("joinPush")
+ inner class JoinPush {
+
+ private lateinit var joinPush: Push
+
+ /* setup */
+ @BeforeEach
+ internal fun setUp() {
+ socket = spy(Socket("https://localhost:4000/socket"))
+ socket.dispatchQueue = fakeClock
+
+ whenever(socket.isConnected).thenReturn(true)
+
+ channel = Channel("topic", kDefaultPayload, socket)
+ joinPush = channel.joinPush
+
+ channel.join()
+ }
+
+ /* helper methods */
+ private fun receivesOk() {
+ fakeClock.tick(joinPush.timeout / 2)
+ joinPush.trigger("ok", mapOf("a" to "b"))
+ }
+
+ private fun receivesTimeout() {
+ fakeClock.tick(joinPush.timeout * 2)
+ }
+
+ private fun receivesError() {
+ fakeClock.tick(joinPush.timeout / 2)
+ joinPush.trigger("error", mapOf("a" to "b"))
+ }
+
+ @Nested
+ @DisplayName("receives 'ok'")
+ inner class ReceivesOk {
+ @Test
+ internal fun `sets channel state to joined`() {
+ assertThat(channel.state).isNotEqualTo(Channel.State.JOINED)
+
+ receivesOk()
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+ }
+
+ @Test
+ internal fun `triggers receive(ok) callback after ok response`() {
+ joinPush.receive("ok", mockCallback)
+
+ receivesOk()
+ verify(mockCallback, times(1)).invoke(any())
+ }
+
+ @Test
+ internal fun `triggers receive('ok') callback if ok response already received`() {
+ receivesOk()
+ joinPush.receive("ok", mockCallback)
+
+ verify(mockCallback, times(1)).invoke(any())
+ }
+
+ @Test
+ internal fun `does not trigger other receive callbacks after ok response`() {
+ joinPush
+ .receive("error", mockCallback)
+ .receive("timeout", mockCallback)
+
+ receivesOk()
+ receivesTimeout()
+ verify(mockCallback, times(0)).invoke(any())
+ }
+
+ @Test
+ internal fun `clears timeoutTimer workItem`() {
+ assertThat(joinPush.timeoutTask).isNotNull()
+
+ val mockTimeoutTask = mock()
+ joinPush.timeoutTask = mockTimeoutTask
+
+ receivesOk()
+ verify(mockTimeoutTask).cancel()
+ assertThat(joinPush.timeoutTask).isNull()
+ }
+
+ @Test
+ internal fun `sets receivedMessage`() {
+ assertThat(joinPush.receivedMessage).isNull()
+
+ receivesOk()
+ assertThat(joinPush.receivedMessage?.payload).isEqualTo(mapOf("status" to "ok", "a" to "b"))
+ assertThat(joinPush.receivedMessage?.status).isEqualTo("ok")
+ }
+
+ @Test
+ internal fun `removes channel binding`() {
+ var bindings = channel.getBindings("chan_reply_3")
+ assertThat(bindings).hasSize(1)
+
+ receivesOk()
+ bindings = channel.getBindings("chan_reply_3")
+ assertThat(bindings).isEmpty()
+ }
+
+ @Test
+ internal fun `resets channel rejoinTimer`() {
+ val mockRejoinTimer = mock()
+ channel.rejoinTimer = mockRejoinTimer
+
+ receivesOk()
+ verify(mockRejoinTimer, times(1)).reset()
+ }
+
+ @Test
+ internal fun `sends and empties channel's buffered pushEvents`() {
+ val mockPush = mock()
+ channel.pushBuffer.add(mockPush)
+
+ receivesOk()
+ verify(mockPush).send()
+ assertThat(channel.pushBuffer).isEmpty()
+ }
+
+ /* End ReceivesOk */
+ }
+
+ @Nested
+ @DisplayName("receives 'timeout'")
+ inner class ReceivesTimeout {
+ @Test
+ internal fun `sets channel state to errored`() {
+ var timeoutReceived = false
+ joinPush.receive("timeout") {
+ timeoutReceived = true
+ assertThat(channel.state).isEqualTo(Channel.State.ERRORED)
+ }
+
+ receivesTimeout()
+ assertThat(timeoutReceived).isTrue()
+ }
+
+ @Test
+ internal fun `triggers receive('timeout') callback after ok response`() {
+ val mockCallback = mock<(Message) -> Unit>()
+ joinPush.receive("timeout", mockCallback)
+
+ receivesTimeout()
+ verify(mockCallback).invoke(any())
+ }
+
+ @Test
+ internal fun `does not trigger other receive callbacks after timeout response`() {
+ val mockOk = mock<(Message) -> Unit>()
+ val mockError = mock<(Message) -> Unit>()
+ var timeoutReceived = false
+
+ joinPush
+ .receive("ok", mockOk)
+ .receive("error", mockError)
+ .receive("timeout") {
+ verifyNoInteractions(mockOk)
+ verifyNoInteractions(mockError)
+ timeoutReceived = true
+ }
+
+ receivesTimeout()
+ receivesOk()
+
+ assertThat(timeoutReceived).isTrue()
+ }
+
+ @Test
+ internal fun `schedules rejoinTimer timeout`() {
+ val mockTimer = mock()
+ channel.rejoinTimer = mockTimer
+
+ receivesTimeout()
+ verify(mockTimer).scheduleTimeout()
+ }
+
+ /* End ReceivesTimeout */
+ }
+
+ @Nested
+ @DisplayName("receives 'error'")
+ inner class ReceivesError {
+ @Test
+ internal fun `triggers receive('error') callback after error response`() {
+ assertThat(channel.state).isEqualTo(Channel.State.JOINING)
+ joinPush.receive("error", mockCallback)
+
+ receivesError()
+ joinPush.trigger("error", kEmptyPayload)
+ verify(mockCallback, times(1)).invoke(any())
+ }
+
+ @Test
+ internal fun `triggers receive('error') callback if error response already received`() {
+ receivesError()
+
+ joinPush.receive("error", mockCallback)
+
+ verify(mockCallback).invoke(any())
+ }
+
+ @Test
+ internal fun `does not trigger other receive callbacks after ok response`() {
+ val mockOk = mock<(Message) -> Unit>()
+ val mockError = mock<(Message) -> Unit>()
+ val mockTimeout = mock<(Message) -> Unit>()
+ joinPush
+ .receive("ok", mockOk)
+ .receive("error") {
+ mockError.invoke(it)
+ channel.leave()
+ }
+ .receive("timeout", mockTimeout)
+
+ receivesError()
+ receivesTimeout()
+
+ verify(mockError, times(1)).invoke(any())
+ verifyNoInteractions(mockOk)
+ verifyNoInteractions(mockTimeout)
+ }
+
+ @Test
+ internal fun `clears timeoutTimer workItem`() {
+ val mockTask = mock()
+ assertThat(joinPush.timeoutTask).isNotNull()
+
+ joinPush.timeoutTask = mockTask
+ receivesError()
+
+ verify(mockTask).cancel()
+ assertThat(joinPush.timeoutTask).isNull()
+ }
+
+ @Test
+ internal fun `sets receivedMessage`() {
+ assertThat(joinPush.receivedMessage).isNull()
+
+ receivesError()
+ assertThat(joinPush.receivedMessage).isNotNull()
+ assertThat(joinPush.receivedMessage?.status).isEqualTo("error")
+ assertThat(joinPush.receivedMessage?.payload?.get("a")).isEqualTo("b")
+ }
+
+ @Test
+ internal fun `removes channel binding`() {
+ var bindings = channel.getBindings("chan_reply_3")
+ assertThat(bindings).hasSize(1)
+
+ receivesError()
+ bindings = channel.getBindings("chan_reply_1")
+ assertThat(bindings).isEmpty()
+ }
+
+ @Test
+ internal fun `does not sets channel state to joined`() {
+ receivesError()
+ assertThat(channel.state).isNotEqualTo(Channel.State.JOINED)
+ }
+
+ @Test
+ internal fun `does not trigger channel's buffered pushEvents`() {
+ val mockPush = mock()
+ channel.pushBuffer.add(mockPush)
+
+ receivesError()
+ verifyNoInteractions(mockPush)
+ assertThat(channel.pushBuffer).hasSize(1)
+ }
+
+ /* End ReceivesError */
+ }
+
+ /* End JoinPush */
+ }
+
+ @Nested
+ @DisplayName("onError")
+ inner class OnError {
+
+ private lateinit var joinPush: Push
+
+ /* setup */
+ @BeforeEach
+ internal fun setUp() {
+ socket = spy(Socket("https://localhost:4000/socket"))
+ socket.dispatchQueue = fakeClock
+
+ whenever(socket.isConnected).thenReturn(true)
+
+ channel = Channel("topic", kDefaultPayload, socket)
+ joinPush = channel.joinPush
+
+ channel.join()
+ joinPush.trigger("ok", kEmptyPayload)
+ }
+
+
+ @Test
+ internal fun `sets channel state to errored`() {
+ assertThat(channel.state).isNotEqualTo(Channel.State.ERRORED)
+
+ channel.trigger(Channel.Event.ERROR)
+ assertThat(channel.state).isEqualTo(Channel.State.ERRORED)
+ }
+
+ @Test
+ internal fun `does not trigger redundant errors during backoff`() {
+ // Spy the channel's join push
+ joinPush = spy(channel.joinPush)
+ channel.joinPush = joinPush
+
+ verify(joinPush, times(0)).send()
+
+ channel.trigger(Channel.Event.ERROR)
+
+ fakeClock.tick(1000)
+ verify(joinPush, times(1)).send()
+
+ channel.trigger(Channel.Event.ERROR)
+
+ fakeClock.tick(1000)
+ verify(joinPush, times(1)).send()
+ }
+
+ @Test
+ internal fun `removes the joinPush message from sendBuffer`() {
+ val channel = Channel("topic", kDefaultPayload, socket)
+ val push = mock()
+ whenever(push.ref).thenReturn("10")
+ channel.joinPush = push
+ channel.state = Channel.State.JOINING
+
+ channel.trigger(Channel.Event.ERROR)
+ verify(socket).removeFromSendBuffer("10")
+ verify(push).reset()
+ }
+
+ @Test
+ internal fun `tries to rejoin with backoff`() {
+ val mockTimer = mock()
+ channel.rejoinTimer = mockTimer
+
+ channel.trigger(Channel.Event.ERROR)
+ verify(mockTimer).scheduleTimeout()
+ }
+
+ @Test
+ internal fun `does not rejoin if leaving channel`() {
+ channel.state = Channel.State.LEAVING
+
+ // Spy the joinPush
+ joinPush = spy(channel.joinPush)
+ channel.joinPush = joinPush
+
+ socket.onConnectionError(Throwable(), null)
+
+ fakeClock.tick(1_000)
+ verify(joinPush, never()).send()
+
+ fakeClock.tick(2_000)
+ verify(joinPush, never()).send()
+
+ assertThat(channel.state).isEqualTo(Channel.State.LEAVING)
+ }
+
+ @Test
+ internal fun `does not rejoin if channel is closed`() {
+ channel.state = Channel.State.CLOSED
+
+ // Spy the joinPush
+ joinPush = spy(channel.joinPush)
+ channel.joinPush = joinPush
+
+ socket.onConnectionError(Throwable(), null)
+
+ fakeClock.tick(1_000)
+ verify(joinPush, never()).send()
+
+ fakeClock.tick(2_000)
+ verify(joinPush, never()).send()
+
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+ }
+
+ @Test
+ internal fun `triggers additional callbacks after join`() {
+ channel.onError(mockCallback)
+ joinPush.trigger("ok", kEmptyPayload)
+
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+ verifyNoInteractions(mockCallback)
+
+ channel.trigger(Channel.Event.ERROR)
+ verify(mockCallback, times(1)).invoke(any())
+ }
+
+ /* End OnError */
+ }
+
+ @Nested
+ @DisplayName("onClose")
+ inner class OnClose {
+
+ private lateinit var joinPush: Push
+
+ /* setup */
+ @BeforeEach
+ internal fun setUp() {
+ socket = spy(Socket("https://localhost:4000/socket"))
+ socket.dispatchQueue = fakeClock
+
+ whenever(socket.isConnected).thenReturn(true)
+
+ channel = Channel("topic", kDefaultPayload, socket)
+ joinPush = channel.joinPush
+
+ channel.join()
+ }
+
+
+ @Test
+ internal fun `sets state to closed`() {
+ assertThat(channel.state).isNotEqualTo(Channel.State.CLOSED)
+
+ channel.trigger(Channel.Event.CLOSE)
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+ }
+
+ @Test
+ internal fun `does not rejoin`() {
+ // Spy the channel's join push
+ joinPush = spy(channel.joinPush)
+ channel.joinPush = joinPush
+
+ channel.trigger(Channel.Event.CLOSE)
+
+ fakeClock.tick(1_000)
+ verify(joinPush, never()).send()
+
+ fakeClock.tick(2_000)
+ verify(joinPush, never()).send()
+ }
+
+ @Test
+ internal fun `resets the rejoin timer`() {
+ val mockTimer = mock()
+ channel.rejoinTimer = mockTimer
+
+ channel.trigger(Channel.Event.CLOSE)
+ verify(mockTimer).reset()
+ }
+
+ @Test
+ internal fun `removes channel from socket`() {
+ channel.trigger(Channel.Event.CLOSE)
+ verify(socket).remove(channel)
+ }
+
+ @Test
+ internal fun `triggers additional callbacks`() {
+ channel.onClose(mockCallback)
+ verifyNoInteractions(mockCallback)
+
+ channel.trigger(Channel.Event.CLOSE)
+ verify(mockCallback, times(1)).invoke(any())
+ }
+
+ /* End OnClose */
+ }
+
+ @Nested
+ @DisplayName("canPush")
+ inner class CanPush {
+ @Test
+ internal fun `returns true when socket connected and channel joined`() {
+ channel.state = Channel.State.JOINED
+ whenever(socket.isConnected).thenReturn(true)
+
+ assertThat(channel.canPush).isTrue()
+ }
+
+ @Test
+ internal fun `otherwise returns false`() {
+ channel.state = Channel.State.JOINED
+ whenever(socket.isConnected).thenReturn(false)
+ assertThat(channel.canPush).isFalse()
+
+ channel.state = Channel.State.JOINING
+ whenever(socket.isConnected).thenReturn(true)
+ assertThat(channel.canPush).isFalse()
+
+ channel.state = Channel.State.JOINING
+ whenever(socket.isConnected).thenReturn(false)
+ assertThat(channel.canPush).isFalse()
+ }
+
+ /* End CanPush */
+ }
+
+ @Nested
+ @DisplayName("on(event, callback)")
+ inner class OnEventCallback {
+ @Test
+ internal fun `sets up callback for event`() {
+ channel.trigger(event = "event", ref = kDefaultRef)
+
+ channel.on("event", mockCallback)
+ channel.trigger(event = "event", ref = kDefaultRef)
+ verify(mockCallback, times(1)).invoke(any())
+ }
+
+ @Test
+ internal fun `other event callbacks are ignored`() {
+ val mockIgnoredCallback = mock<(Message) -> Unit>()
+
+ channel.on("ignored_event", mockIgnoredCallback)
+ channel.trigger(event = "event", ref = kDefaultRef)
+
+ channel.on("event", mockCallback)
+ channel.trigger(event = "event", ref = kDefaultRef)
+
+ verify(mockIgnoredCallback, never()).invoke(any())
+ }
+
+ @Test
+ internal fun `generates unique refs for callbacks`() {
+ val ref1 = channel.on("event1") {}
+ val ref2 = channel.on("event2") {}
+
+ assertThat(ref1).isNotEqualTo(ref2)
+ assertThat(ref1 + 1).isEqualTo(ref2)
+ }
+
+ /* End OnEventCallback */
+ }
+
+ @Nested
+ @DisplayName("off")
+ inner class Off {
+ @Test
+ internal fun `removes all callbacks for event`() {
+ val callback1 = mock<(Message) -> Unit>()
+ val callback2 = mock<(Message) -> Unit>()
+ val callback3 = mock<(Message) -> Unit>()
+
+ channel.on("event", callback1)
+ channel.on("event", callback2)
+ channel.on("other", callback3)
+
+ channel.off("event")
+ channel.trigger(event = "event", ref = kDefaultRef)
+ channel.trigger(event = "other", ref = kDefaultRef)
+
+ verifyNoInteractions(callback1)
+ verifyNoInteractions(callback2)
+ verify(callback3, times(1)).invoke(any())
+ }
+
+ @Test
+ internal fun `removes callback by ref`() {
+ val callback1 = mock<(Message) -> Unit>()
+ val callback2 = mock<(Message) -> Unit>()
+
+ val ref1 = channel.on("event", callback1)
+ channel.on("event", callback2)
+
+ channel.off("event", ref1)
+ channel.trigger(event = "event", ref = kDefaultRef)
+
+ verifyNoInteractions(callback1)
+ verify(callback2, times(1)).invoke(any())
+ }
+
+ /* End Off */
+ }
+
+ @Nested
+ @DisplayName("push")
+ inner class PushFunction {
+
+ @BeforeEach
+ internal fun setUp() {
+ whenever(socket.isConnected).thenReturn(true)
+ }
+
+ @Test
+ internal fun `sends push event when successfully joined`() {
+ channel.join().trigger("ok", kEmptyPayload)
+ channel.push("event", mapOf("foo" to "bar"))
+
+ verify(socket).push("topic", "event", mapOf("foo" to "bar"), channel.joinRef, kDefaultRef)
+ }
+
+ @Test
+ internal fun `enqueues push event to be sent once join has succeeded`() {
+ val joinPush = channel.join()
+ channel.push("event", mapOf("foo" to "bar"))
+
+ verify(socket, never()).push(any(), any(), eq(mapOf("foo" to "bar")), any(), any())
+
+ fakeClock.tick(channel.timeout / 2)
+ joinPush.trigger("ok", kEmptyPayload)
+
+ verify(socket).push(any(), any(), eq(mapOf("foo" to "bar")), any(), any())
+ }
+
+ @Test
+ internal fun `does not push if channel join times out`() {
+ val joinPush = channel.join()
+ channel.push("event", mapOf("foo" to "bar"))
+
+ verify(socket, never()).push(any(), any(), eq(mapOf("foo" to "bar")), any(), any())
+
+ fakeClock.tick(channel.timeout * 2)
+ joinPush.trigger("ok", kEmptyPayload)
+
+ verify(socket, never()).push(any(), any(), eq(mapOf("foo" to "bar")), any(), any())
+ }
+
+ @Test
+ internal fun `uses channel timeout by default`() {
+ channel.join().trigger("ok", kEmptyPayload)
+ channel
+ .push("event", mapOf("foo" to "bar"))
+ .receive("timeout", mockCallback)
+
+ fakeClock.tick(channel.timeout / 2)
+ verifyNoInteractions(mockCallback)
+
+ fakeClock.tick(channel.timeout)
+ verify(mockCallback).invoke(any())
+ }
+
+ @Test
+ internal fun `accepts timeout arg`() {
+ channel.join().trigger("ok", kEmptyPayload)
+ channel
+ .push("event", mapOf("foo" to "bar"), channel.timeout * 2)
+ .receive("timeout", mockCallback)
+
+ fakeClock.tick(channel.timeout)
+ verifyNoInteractions(mockCallback)
+
+ fakeClock.tick(channel.timeout * 2)
+ verify(mockCallback).invoke(any())
+ }
+
+ @Test
+ internal fun `does not time out after receiving 'ok'`() {
+ channel.join().trigger("ok", kEmptyPayload)
+ val push = channel
+ .push("event", mapOf("foo" to "bar"), channel.timeout * 2)
+ .receive("timeout", mockCallback)
+
+ fakeClock.tick(channel.timeout / 2)
+ verifyNoInteractions(mockCallback)
+
+ push.trigger("ok", kEmptyPayload)
+
+ fakeClock.tick(channel.timeout)
+ verifyNoInteractions(mockCallback)
+ }
+
+ @Test
+ internal fun `throws if channel has not been joined`() {
+ var exceptionThrown = false
+ try {
+ channel.push("event", kEmptyPayload)
+ } catch (e: Exception) {
+ exceptionThrown = true
+ assertThat(e.message).isEqualTo(
+ "Tried to push event to topic before joining. Use channel.join() before pushing events")
+ }
+
+ assertThat(exceptionThrown).isTrue()
+ }
+
+ /* End PushFunction */
+ }
+
+ @Nested
+ @DisplayName("leave")
+ inner class Leave {
+ @BeforeEach
+ internal fun setUp() {
+ whenever(socket.isConnected).thenReturn(true)
+ channel.join().trigger("ok", kEmptyPayload)
+ }
+
+ @Test
+ internal fun `unsubscribes from server events`() {
+ val joinRef = channel.joinRef
+ channel.leave()
+
+ verify(socket).push("topic", "phx_leave", emptyMap(), joinRef, kDefaultRef)
+ }
+
+ @Test
+ internal fun `closes channel on 'ok' from server`() {
+ channel.leave().trigger("ok", kEmptyPayload)
+ verify(socket).remove(channel)
+ }
+
+ @Test
+ internal fun `sets state to closed on 'ok' event`() {
+ channel.leave().trigger("ok", kEmptyPayload)
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+ }
+
+ @Test
+ internal fun `sets state to leaving initially`() {
+ channel.leave()
+ assertThat(channel.state).isEqualTo(Channel.State.LEAVING)
+ }
+
+ @Test
+ internal fun `closes channel on timeout`() {
+ channel.leave()
+ assertThat(channel.state).isEqualTo(Channel.State.LEAVING)
+
+ fakeClock.tick(channel.timeout)
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+ }
+
+ @Test
+ internal fun `triggers immediately if cannot push`() {
+ whenever(socket.isConnected).thenReturn(false)
+ channel.leave()
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+ }
+
+ /* End Leave */
+ }
+
+ @Nested
+ @DisplayName("state accessors")
+ inner class StateAccessors {
+ @Test
+ fun `isClosed returns true if state is CLOSED`() {
+ channel.state = Channel.State.JOINED
+ assertThat(channel.isClosed).isFalse()
+
+ channel.state = Channel.State.CLOSED
+ assertThat(channel.isClosed).isTrue()
+ }
+
+ @Test
+ fun `isErrored returns true if state is ERRORED`() {
+ channel.state = Channel.State.JOINED
+ assertThat(channel.isErrored).isFalse()
+
+ channel.state = Channel.State.ERRORED
+ assertThat(channel.isErrored).isTrue()
+ }
+
+ @Test
+ fun `isJoined returns true if state is JOINED`() {
+ channel.state = Channel.State.JOINING
+ assertThat(channel.isJoined).isFalse()
+
+ channel.state = Channel.State.JOINED
+ assertThat(channel.isJoined).isTrue()
+ }
+
+ @Test
+ fun `isJoining returns true if state is JOINING`() {
+ channel.state = Channel.State.JOINED
+ assertThat(channel.isJoining).isFalse()
+
+ channel.state = Channel.State.JOINING
+ assertThat(channel.isJoining).isTrue()
+ }
+
+ @Test
+ fun `isLeaving returns true if state is LEAVING`() {
+ channel.state = Channel.State.JOINED
+ assertThat(channel.isLeaving).isFalse()
+
+ channel.state = Channel.State.LEAVING
+ assertThat(channel.isLeaving).isTrue()
+ }
+ /* End StateAccessors */
+ }
+
+ @Nested
+ @DisplayName("isMember")
+ inner class IsMember {
+
+ @Test
+ fun `returns false if topics are different`() {
+ val message = Message(topic = "other-topic")
+ assertThat(channel.isMember(message)).isFalse()
+ }
+
+ @Test
+ fun `drops outdated messages`() {
+ channel.joinPush.ref = "9"
+ val message = Message(topic = "topic", event = Channel.Event.LEAVE.value, joinRef = "7")
+ assertThat(channel.isMember(message)).isFalse()
+ }
+
+ @Test
+ fun `returns true if message belongs to channel`() {
+ val message = Message(topic = "topic", event = "msg:new")
+ assertThat(channel.isMember(message)).isTrue()
+ }
+
+ /* End IsMember */
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/DefaultsTest.kt b/src/test/kotlin/org/phoenixframework/DefaultsTest.kt
new file mode 100644
index 0000000..bf2da9b
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/DefaultsTest.kt
@@ -0,0 +1,142 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth.assertThat
+import org.junit.jupiter.api.Test
+
+internal class DefaultsTest {
+
+ @Test
+ internal fun `default timeout is 10_000`() {
+ assertThat(Defaults.TIMEOUT).isEqualTo(10_000)
+ }
+
+ @Test
+ internal fun `default heartbeat is 30_000`() {
+ assertThat(Defaults.HEARTBEAT).isEqualTo(30_000)
+ }
+
+ @Test
+ internal fun `default reconnectAfterMs returns all values`() {
+ val reconnect = Defaults.reconnectSteppedBackOff
+
+ assertThat(reconnect(1)).isEqualTo(10)
+ assertThat(reconnect(2)).isEqualTo(50)
+ assertThat(reconnect(3)).isEqualTo(100)
+ assertThat(reconnect(4)).isEqualTo(150)
+ assertThat(reconnect(5)).isEqualTo(200)
+ assertThat(reconnect(6)).isEqualTo(250)
+ assertThat(reconnect(7)).isEqualTo(500)
+ assertThat(reconnect(8)).isEqualTo(1_000)
+ assertThat(reconnect(9)).isEqualTo(2_000)
+ assertThat(reconnect(10)).isEqualTo(5_000)
+ assertThat(reconnect(11)).isEqualTo(5_000)
+ }
+
+ @Test
+ internal fun `default rejoinAfterMs returns all values`() {
+ val reconnect = Defaults.rejoinSteppedBackOff
+
+ assertThat(reconnect(1)).isEqualTo(1_000)
+ assertThat(reconnect(2)).isEqualTo(2_000)
+ assertThat(reconnect(3)).isEqualTo(5_000)
+ assertThat(reconnect(4)).isEqualTo(10_000)
+ assertThat(reconnect(5)).isEqualTo(10_000)
+ }
+
+ @Test
+ internal fun `decoder converts json array into message`() {
+ val v2Json = """
+ [null,null,"room:lobby","shout",{"message":"Hi","name":"Tester"}]
+ """.trimIndent()
+
+ val message = Defaults.decode(v2Json)
+ assertThat(message.joinRef).isNull()
+ assertThat(message.ref).isEqualTo("")
+ assertThat(message.topic).isEqualTo("room:lobby")
+ assertThat(message.event).isEqualTo("shout")
+ assertThat(message.payload).isEqualTo(mapOf("message" to "Hi", "name" to "Tester"))
+ }
+
+ @Test
+ internal fun `decoder provides raw json payload`() {
+ val v2Json = """
+ ["1","2","room:lobby","shout",{"message":"Hi","name":"Tester","count":15,"ratio":0.2}]
+ """.trimIndent()
+
+ val message = Defaults.decode(v2Json)
+ assertThat(message.joinRef).isEqualTo("1")
+ assertThat(message.ref).isEqualTo("2")
+ assertThat(message.topic).isEqualTo("room:lobby")
+ assertThat(message.event).isEqualTo("shout")
+ assertThat(message.payloadJson).isEqualTo("{\"message\":\"Hi\",\"name\":\"Tester\",\"count\":15,\"ratio\":0.2}")
+ assertThat(message.payload).isEqualTo(mapOf(
+ "message" to "Hi",
+ "name" to "Tester",
+ "count" to 15.0, // Note that this is a bug and should eventually be removed
+ "ratio" to 0.2
+ ))
+ }
+
+ @Test
+ internal fun `decoder decodes a status`() {
+ val v2Json = """
+ ["1","2","room:lobby","phx_reply",{"response":{"message":"Hi","name":"Tester","count":15,"ratio":0.2},"status":"ok"}]
+ """.trimIndent()
+
+ val message = Defaults.decode(v2Json)
+ assertThat(message.joinRef).isEqualTo("1")
+ assertThat(message.ref).isEqualTo("2")
+ assertThat(message.topic).isEqualTo("room:lobby")
+ assertThat(message.event).isEqualTo("phx_reply")
+ assertThat(message.payloadJson).isEqualTo("{\"message\":\"Hi\",\"name\":\"Tester\",\"count\":15,\"ratio\":0.2}")
+ assertThat(message.payload).isEqualTo(mapOf(
+ "message" to "Hi",
+ "name" to "Tester",
+ "count" to 15.0, // Note that this is a bug and should eventually be removed
+ "ratio" to 0.2
+ ))
+ }
+
+
+
+ @Test
+ internal fun `decoder decodes an error`() {
+ val v2Json = """
+ ["6","8","drivers:self","phx_reply",{"response":{"details":"invalid code specified"},"status":"error"}]
+ """.trimIndent()
+
+ val message = Defaults.decode(v2Json)
+ assertThat(message.payloadJson).isEqualTo("{\"details\":\"invalid code specified\"}")
+ assertThat(message.rawPayload).isEqualTo(mapOf(
+ "response" to mapOf(
+ "details" to "invalid code specified"
+ ),
+ "status" to "error"
+ ))
+ assertThat(message.payload).isEqualTo(mapOf(
+ "details" to "invalid code specified"
+ ))
+
+ }
+
+ @Test
+ internal fun `decoder decodes a non-json payload`() {
+ val v2Json = """
+ ["1","2","room:lobby","phx_reply",{"response":"hello","status":"ok"}]
+ """.trimIndent()
+
+ val message = Defaults.decode(v2Json)
+ assertThat(message.payloadJson).isEqualTo("\"hello\"")
+ assertThat(message.payload).isEqualTo(mapOf(
+ "response" to "hello",
+ "status" to "ok"
+ ))
+ }
+
+ @Test
+ internal fun `encode converts message into json`() {
+ val body = listOf(null, null, "topic", "event", mapOf("one" to "two"))
+ assertThat(Defaults.encode(body))
+ .isEqualTo("[null,null,\"topic\",\"event\",{\"one\":\"two\"}]")
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/MessageTest.kt b/src/test/kotlin/org/phoenixframework/MessageTest.kt
new file mode 100644
index 0000000..2dcd321
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/MessageTest.kt
@@ -0,0 +1,46 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth.assertThat
+import org.junit.jupiter.api.DisplayName
+import org.junit.jupiter.api.Nested
+import org.junit.jupiter.api.Test
+
+class MessageTest {
+
+ @Nested
+ @DisplayName("json parsing")
+ inner class JsonParsing {
+
+ @Test
+ internal fun `jsonParsing parses normal message`() {
+ val json = """
+ [null,"6","my-topic","update",{"user":"James S.","message":"This is a test"}]
+ """.trimIndent()
+
+ val message = Defaults.decode.invoke(json)
+
+ assertThat(message.ref).isEqualTo("6")
+ assertThat(message.topic).isEqualTo("my-topic")
+ assertThat(message.event).isEqualTo("update")
+ assertThat(message.payload).isEqualTo(mapOf("user" to "James S.", "message" to "This is a test"))
+ assertThat(message.joinRef).isNull()
+ assertThat(message.status).isNull()
+ }
+
+ @Test
+ internal fun `jsonParsing parses a reply`() {
+ val json = """
+ [null,"6","my-topic","phx_reply",{"response":{"user":"James S.","message":"This is a test"},"status": "ok"}]
+ """.trimIndent()
+
+ val message = Defaults.decode.invoke(json)
+
+ assertThat(message.ref).isEqualTo("6")
+ assertThat(message.topic).isEqualTo("my-topic")
+ assertThat(message.event).isEqualTo("phx_reply")
+ assertThat(message.payload).isEqualTo(mapOf("user" to "James S.", "message" to "This is a test"))
+ assertThat(message.joinRef).isNull()
+ assertThat(message.status).isEqualTo("ok")
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt b/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt
deleted file mode 100644
index 879fe68..0000000
--- a/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt
+++ /dev/null
@@ -1,152 +0,0 @@
-package org.phoenixframework
-
-import com.google.common.truth.Truth.assertThat
-import org.junit.Before
-import org.junit.Test
-import org.mockito.Mockito
-import org.mockito.MockitoAnnotations
-import org.mockito.Spy
-import java.util.concurrent.CompletableFuture
-import java.util.concurrent.TimeUnit
-
-class PhxChannelTest {
-
- private val defaultRef = "1"
-
- @Spy
- var socket: PhxSocket = PhxSocket("http://localhost:4000/socket/websocket")
- lateinit var channel: PhxChannel
-
- @Before
- fun setUp() {
- MockitoAnnotations.initMocks(this)
- Mockito.doReturn(defaultRef).`when`(socket).makeRef()
-
- socket.timeout = 1234
- channel = PhxChannel("topic", hashMapOf("one" to "two"), socket)
- }
-
-
- //------------------------------------------------------------------------------
- // Constructor
- //------------------------------------------------------------------------------
- @Test
- fun `constructor sets defaults`() {
- assertThat(channel.isClosed).isTrue()
- assertThat(channel.topic).isEqualTo("topic")
- assertThat(channel.params["one"]).isEqualTo("two")
- assertThat(channel.socket).isEqualTo(socket)
- assertThat(channel.timeout).isEqualTo(1234)
- assertThat(channel.joinedOnce).isFalse()
- assertThat(channel.pushBuffer).isEmpty()
- }
-
- @Test
- fun `constructor sets up joinPush with params`() {
- val joinPush = channel.joinPush
-
- assertThat(joinPush.channel).isEqualTo(channel)
- assertThat(joinPush.payload["one"]).isEqualTo("two")
- assertThat(joinPush.event).isEqualTo(PhxChannel.PhxEvent.JOIN.value)
- assertThat(joinPush.timeout).isEqualTo(1234)
- }
-
-
- //------------------------------------------------------------------------------
- // Join
- //------------------------------------------------------------------------------
- @Test
- fun `it sets the state to joining`() {
- channel.join()
- assertThat(channel.isJoining).isTrue()
- }
-
- @Test
- fun `it updates the join parameters`() {
- channel.join(hashMapOf("one" to "three"))
-
- val joinPush = channel.joinPush
- assertThat(joinPush.payload["one"]).isEqualTo("three")
- }
-
- @Test
- fun `it sets joinedOnce to true`() {
- assertThat(channel.joinedOnce).isFalse()
-
- channel.join()
- assertThat(channel.joinedOnce).isTrue()
- }
-
- @Test(expected = IllegalStateException::class)
- fun `it throws if attempting to join multiple times`() {
- channel.join()
- channel.join()
- }
-
-
-
- //------------------------------------------------------------------------------
- // .off()
- //------------------------------------------------------------------------------
- @Test
- fun `it removes all callbacks for events`() {
- Mockito.doReturn(defaultRef).`when`(socket).makeRef()
-
- var aCalled = false
- var bCalled = false
- var cCalled = false
-
- channel.on("event") { aCalled = true }
- channel.on("event") { bCalled = true }
- channel.on("other") { cCalled = true }
-
- channel.off("event")
-
- channel.trigger(PhxMessage(event = "event", ref = defaultRef))
- channel.trigger(PhxMessage(event = "other", ref = defaultRef))
-
- assertThat(aCalled).isFalse()
- assertThat(bCalled).isFalse()
- assertThat(cCalled).isTrue()
- }
-
- @Test
- fun `it removes callbacks by its ref`() {
- var aCalled = false
- var bCalled = false
-
- val aRef = channel.on("event") { aCalled = true }
- channel.on("event") { bCalled = true }
-
-
- channel.off("event", aRef)
-
- channel.trigger(PhxMessage(event = "event", ref = defaultRef))
-
- assertThat(aCalled).isFalse()
- assertThat(bCalled).isTrue()
- }
-
- @Test
- fun `Issue 22`() {
- // This reproduces a concurrent modification exception. The original cause is most likely as follows:
- // 1. Push (And receive) messages very quickly
- // 2. PhxChannel.push, calls PhxPush.send()
- // 3. PhxPush calls startTimeout().
- // 4. PhxPush.startTimeout() calls this.channel.on(refEvent) - This modifies the bindings list
- // 5. any trigger (possibly from a timeout) can be iterating through the binding list that was modified in step 4.
-
- val f1 = CompletableFuture.runAsync {
- for (i in 0..1000) {
- channel.on("event-$i") { /** do nothing **/ }
- }
- }
- val f3 = CompletableFuture.runAsync {
- for (i in 0..1000) {
- channel.trigger(PhxMessage(event = "event-$i", ref = defaultRef))
- }
- }
-
- CompletableFuture.allOf(f1, f3).get(10, TimeUnit.SECONDS)
- }
-}
diff --git a/src/test/kotlin/org/phoenixframework/PhxMessageTest.kt b/src/test/kotlin/org/phoenixframework/PhxMessageTest.kt
deleted file mode 100644
index d8131d0..0000000
--- a/src/test/kotlin/org/phoenixframework/PhxMessageTest.kt
+++ /dev/null
@@ -1,22 +0,0 @@
-package org.phoenixframework
-
-import com.google.common.truth.Truth.assertThat
-import org.junit.Test
-
-class PhxMessageTest {
-
- @Test
- fun getStatus_returnsPayloadStatus() {
-
- val payload = hashMapOf("status" to "ok", "topic" to "chat:1")
-
- val message = PhxMessage("ref1", "chat:1", "event1", payload)
-
- assertThat(message.ref).isEqualTo("ref1")
- assertThat(message.topic).isEqualTo("chat:1")
- assertThat(message.event).isEqualTo("event1")
- assertThat(message.payload["topic"]).isEqualTo("chat:1")
- assertThat(message.status).isEqualTo("ok")
-
- }
-}
diff --git a/src/test/kotlin/org/phoenixframework/PhxSocketTest.kt b/src/test/kotlin/org/phoenixframework/PhxSocketTest.kt
deleted file mode 100644
index 317c8ef..0000000
--- a/src/test/kotlin/org/phoenixframework/PhxSocketTest.kt
+++ /dev/null
@@ -1,39 +0,0 @@
-package org.phoenixframework
-
-import com.google.common.truth.Truth.assertThat
-import org.junit.Test
-
-class PhxSocketTest {
-
- @Test
- fun init_buildsUrlProper() {
- assertThat(PhxSocket("http://localhost:4000/socket/websocket").endpoint.toString())
- .isEqualTo("http://localhost:4000/socket/websocket")
-
- assertThat(PhxSocket("https://localhost:4000/socket/websocket").endpoint.toString())
- .isEqualTo("https://localhost:4000/socket/websocket")
-
- assertThat(PhxSocket("ws://localhost:4000/socket/websocket").endpoint.toString())
- .isEqualTo("http://localhost:4000/socket/websocket")
-
- assertThat(PhxSocket("wss://localhost:4000/socket/websocket").endpoint.toString())
- .isEqualTo("https://localhost:4000/socket/websocket")
-
-
- // test params
- val singleParam = hashMapOf("token" to "abc123")
- assertThat(PhxSocket("ws://localhost:4000/socket/websocket", singleParam).endpoint.toString())
- .isEqualTo("http://localhost:4000/socket/websocket?token=abc123")
-
-
- val multipleParams = hashMapOf("token" to "abc123", "user_id" to 1)
- assertThat(PhxSocket("http://localhost:4000/socket/websocket", multipleParams).endpoint.toString())
- .isEqualTo("http://localhost:4000/socket/websocket?user_id=1&token=abc123")
-
-
- // test params with spaces
- val spacesParams = hashMapOf("token" to "abc 123", "user_id" to 1)
- assertThat(PhxSocket("wss://localhost:4000/socket/websocket", spacesParams).endpoint.toString())
- .isEqualTo("https://localhost:4000/socket/websocket?user_id=1&token=abc%20123")
- }
-}
diff --git a/src/test/kotlin/org/phoenixframework/PresenceTest.kt b/src/test/kotlin/org/phoenixframework/PresenceTest.kt
new file mode 100644
index 0000000..4eaee21
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/PresenceTest.kt
@@ -0,0 +1,477 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth.assertThat
+import org.mockito.kotlin.mock
+import org.mockito.kotlin.whenever
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.DisplayName
+import org.junit.jupiter.api.Nested
+import org.junit.jupiter.api.Test
+import org.mockito.Mock
+import org.mockito.MockitoAnnotations
+import org.phoenixframework.utilities.getBindings
+
+class PresenceTest {
+
+ @Mock lateinit var socket: Socket
+
+ private val fixJoins: PresenceState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(mapOf("id" to 1, "phx_ref" to "1.2"))))
+ private val fixLeaves: PresenceState = mutableMapOf(
+ "u2" to mutableMapOf("metas" to listOf(mapOf("id" to 2, "phx_ref" to "2"))))
+ private val fixState: PresenceState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(mapOf("id" to 1, "phx_ref" to "1"))),
+ "u2" to mutableMapOf("metas" to listOf(mapOf("id" to 2, "phx_ref" to "2"))),
+ "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))
+ )
+
+ private val listByFirst: (Map.Entry) -> PresenceMeta =
+ { it.value["metas"]!!.first() }
+
+ lateinit var channel: Channel
+ lateinit var presence: Presence
+
+ @BeforeEach
+ internal fun setUp() {
+ MockitoAnnotations.initMocks(this)
+
+ whenever(socket.timeout).thenReturn(Defaults.TIMEOUT)
+ whenever(socket.makeRef()).thenReturn("1")
+ whenever(socket.reconnectAfterMs).thenReturn { 1_000 }
+ whenever(socket.rejoinAfterMs).thenReturn(Defaults.rejoinSteppedBackOff)
+ whenever(socket.dispatchQueue).thenReturn(mock())
+
+ channel = Channel("topic", mapOf(), socket)
+ channel.joinPush.ref = "1"
+
+ presence = Presence(channel)
+ }
+
+ @Nested
+ @DisplayName("constructor")
+ inner class Constructor {
+
+ @Test
+ internal fun `sets defaults`() {
+ assertThat(presence.state).isEmpty()
+ assertThat(presence.pendingDiffs).isEmpty()
+ assertThat(presence.channel).isEqualTo(channel)
+ assertThat(presence.joinRef).isNull()
+ }
+
+ @Test
+ internal fun `binds to channel with default arguments`() {
+ assertThat(presence.channel.getBindings("presence_state")).hasSize(1)
+ assertThat(presence.channel.getBindings("presence_diff")).hasSize(1)
+ }
+
+ @Test
+ internal fun `binds to channel with custom options`() {
+ val channel = Channel("topic", mapOf(), socket)
+ val customOptions = Presence.Options(mapOf(
+ Presence.Events.STATE to "custom_state",
+ Presence.Events.DIFF to "custom_diff"))
+
+ val p = Presence(channel, customOptions)
+ assertThat(p.channel.getBindings("presence_state")).isEmpty()
+ assertThat(p.channel.getBindings("presence_diff")).isEmpty()
+ assertThat(p.channel.getBindings("custom_state")).hasSize(1)
+ assertThat(p.channel.getBindings("custom_diff")).hasSize(1)
+ }
+
+ @Test
+ internal fun `syncs state and diffs`() {
+ val user1: PresenceMap = mutableMapOf("metas" to mutableListOf(
+ mapOf("id" to 1, "phx_ref" to "1")))
+ val user2: PresenceMap = mutableMapOf("metas" to mutableListOf(
+ mapOf("id" to 2, "phx_ref" to "2")))
+ val newState: PresenceState = mutableMapOf("u1" to user1, "u2" to user2)
+
+ channel.trigger("presence_state", newState, "1")
+ val s = presence.listBy(listByFirst)
+ assertThat(s).hasSize(2)
+ assertThat(s[0]["id"]).isEqualTo(1)
+ assertThat(s[0]["phx_ref"]).isEqualTo("1")
+
+ assertThat(s[1]["id"]).isEqualTo(2)
+ assertThat(s[1]["phx_ref"]).isEqualTo("2")
+
+ channel.trigger("presence_diff",
+ mapOf("joins" to emptyMap(), "leaves" to mapOf("u1" to user1)))
+ val l = presence.listBy(listByFirst)
+ assertThat(l).hasSize(1)
+ assertThat(l[0]["id"]).isEqualTo(2)
+ assertThat(l[0]["phx_ref"]).isEqualTo("2")
+ }
+
+ @Test
+ internal fun `applies pending diff if state is not yet synced`() {
+ val onJoins = mutableListOf>()
+ val onLeaves = mutableListOf>()
+
+ presence.onJoin { key, current, new -> onJoins.add(Triple(key, current, new)) }
+ presence.onLeave { key, current, left -> onLeaves.add(Triple(key, current, left)) }
+
+ val user1 = mutableMapOf("metas" to mutableListOf(mutableMapOf("id" to 1, "phx_ref" to "1")))
+ val user2 = mutableMapOf("metas" to mutableListOf(mutableMapOf("id" to 2, "phx_ref" to "2")))
+ val user3 = mutableMapOf("metas" to mutableListOf(mutableMapOf("id" to 3, "phx_ref" to "3")))
+
+ val newState = mutableMapOf("u1" to user1, "u2" to user2)
+ val leaves = mapOf("u2" to user2)
+
+ val payload1 = mapOf("joins" to emptyMap(), "leaves" to leaves)
+ channel.trigger("presence_diff", payload1, "")
+
+ // There is no state
+ assertThat(presence.listBy(listByFirst)).isEmpty()
+
+ // pending diffs 1
+ assertThat(presence.pendingDiffs).hasSize(1)
+ assertThat(presence.pendingDiffs[0]["joins"]).isEmpty()
+ assertThat(presence.pendingDiffs[0]["leaves"]).isEqualTo(leaves)
+
+ channel.trigger("presence_state", newState, "")
+ assertThat(onLeaves).hasSize(1)
+ assertThat(onLeaves[0].first).isEqualTo("u2")
+ assertThat(onLeaves[0].second["metas"]).isEmpty()
+ assertThat(onLeaves[0].third["metas"]!![0]["id"]).isEqualTo(2)
+
+ val s = presence.listBy(listByFirst)
+ assertThat(s).hasSize(1)
+ assertThat(s[0]["id"]).isEqualTo(1)
+ assertThat(s[0]["phx_ref"]).isEqualTo("1")
+ assertThat(presence.pendingDiffs).isEmpty()
+
+ assertThat(onJoins).hasSize(2)
+ assertThat(onJoins[0].first).isEqualTo("u1")
+ assertThat(onJoins[0].second).isNull()
+ assertThat(onJoins[0].third["metas"]!![0]["id"]).isEqualTo(1)
+
+ assertThat(onJoins[1].first).isEqualTo("u2")
+ assertThat(onJoins[1].second).isNull()
+ assertThat(onJoins[1].third["metas"]!![0]["id"]).isEqualTo(2)
+
+ // disconnect then reconnect
+ assertThat(presence.isPendingSyncState).isFalse()
+ channel.joinPush.ref = "2"
+ assertThat(presence.isPendingSyncState).isTrue()
+
+
+ channel.trigger("presence_diff",
+ mapOf("joins" to mapOf(), "leaves" to mapOf("u1" to user1)))
+ val d = presence.listBy(listByFirst)
+ assertThat(d).hasSize(1)
+ assertThat(d[0]["id"]).isEqualTo(1)
+ assertThat(d[0]["phx_ref"]).isEqualTo("1")
+
+
+ channel.trigger("presence_state",
+ mapOf("u1" to user1, "u3" to user3))
+ val s2 = presence.listBy(listByFirst)
+ assertThat(s2).hasSize(1)
+ assertThat(s2[0]["id"]).isEqualTo(3)
+ assertThat(s2[0]["phx_ref"]).isEqualTo("3")
+ }
+ /* End Constructor */
+ }
+
+ @Nested
+ @DisplayName("syncState")
+ inner class SyncState {
+
+ @Test
+ internal fun `syncs empty state`() {
+ val newState: PresenceState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(mapOf("id" to 1, "phx_ref" to "1"))))
+ var state: PresenceState = mutableMapOf()
+ val stateBefore = state
+
+ Presence.syncState(state, newState)
+ assertThat(state).isEqualTo(stateBefore)
+
+ state = Presence.syncState(state, newState)
+ assertThat(state).isEqualTo(newState)
+ }
+
+ @Test
+ internal fun `onJoins new presences and onLeaves left presences`() {
+ val newState = fixState
+ var state = mutableMapOf(
+ "u4" to mutableMapOf("metas" to listOf(mapOf("id" to 4, "phx_ref" to "4"))))
+
+ val joined: PresenceDiff = mutableMapOf()
+ val left: PresenceDiff = mutableMapOf()
+
+ val onJoin: OnJoin = { key, current, newPres ->
+ val joinState: PresenceState = mutableMapOf("newPres" to newPres)
+ current?.let { c -> joinState["current"] = c }
+
+ joined[key] = joinState
+ }
+
+ val onLeave: OnLeave = { key, current, leftPres ->
+ left[key] = mutableMapOf("current" to current, "leftPres" to leftPres)
+ }
+
+ val stateBefore = state
+ Presence.syncState(state, newState, onJoin, onLeave)
+ assertThat(state).isEqualTo(stateBefore)
+
+ state = Presence.syncState(state, newState, onJoin, onLeave)
+ assertThat(state).isEqualTo(newState)
+
+ // asset equality in joined
+ val joinedExpectation: PresenceDiff = mutableMapOf(
+ "u1" to mutableMapOf("newPres" to mutableMapOf(
+ "metas" to listOf(mapOf("id" to 1, "phx_ref" to "1")))),
+ "u2" to mutableMapOf("newPres" to mutableMapOf(
+ "metas" to listOf(mapOf("id" to 2, "phx_ref" to "2")))),
+ "u3" to mutableMapOf("newPres" to mutableMapOf(
+ "metas" to listOf(mapOf("id" to 3, "phx_ref" to "3"))))
+ )
+
+ assertThat(joined).isEqualTo(joinedExpectation)
+
+ // assert equality in left
+ val leftExpectation: PresenceDiff = mutableMapOf(
+ "u4" to mutableMapOf(
+ "current" to mutableMapOf(
+ "metas" to mutableListOf()),
+ "leftPres" to mutableMapOf(
+ "metas" to listOf(mapOf("id" to 4, "phx_ref" to "4"))))
+ )
+ assertThat(left).isEqualTo(leftExpectation)
+ }
+
+ @Test
+ internal fun `onJoins only newly added metas`() {
+ var state = mutableMapOf(
+ "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3"))))
+ val newState = mutableMapOf(
+ "u3" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 3, "phx_ref" to "3"),
+ mapOf("id" to 3, "phx_ref" to "3.new")
+ )))
+
+ val joined: PresenceDiff = mutableMapOf()
+ val left: PresenceDiff = mutableMapOf()
+
+ val onJoin: OnJoin = { key, current, newPres ->
+ val joinState: PresenceState = mutableMapOf("newPres" to newPres)
+ current?.let { c -> joinState["current"] = c }
+
+ joined[key] = joinState
+ }
+
+ val onLeave: OnLeave = { key, current, leftPres ->
+ left[key] = mutableMapOf("current" to current, "leftPres" to leftPres)
+ }
+
+ state = Presence.syncState(state, newState, onJoin, onLeave)
+ assertThat(state).isEqualTo(newState)
+
+ // asset equality in joined
+ val joinedExpectation: PresenceDiff = mutableMapOf(
+ "u3" to mutableMapOf(
+ "newPres" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3.new"))),
+ "current" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))
+
+ ))
+ assertThat(joined).isEqualTo(joinedExpectation)
+
+ // assert equality in left
+ assertThat(left).isEmpty()
+ }
+
+ @Test
+ internal fun `onLeaves only newly removed metas`() {
+ val newState = mutableMapOf(
+ "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3"))))
+ var state = mutableMapOf(
+ "u3" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 3, "phx_ref" to "3"),
+ mapOf("id" to 3, "phx_ref" to "3.left")
+ )))
+
+ val joined: PresenceDiff = mutableMapOf()
+ val left: PresenceDiff = mutableMapOf()
+
+ val onJoin: OnJoin = { key, current, newPres ->
+ val joinState: PresenceState = mutableMapOf("newPres" to newPres)
+ current?.let { c -> joinState["current"] = c }
+
+ joined[key] = joinState
+ }
+
+ val onLeave: OnLeave = { key, current, leftPres ->
+ left[key] = mutableMapOf("current" to current, "leftPres" to leftPres)
+ }
+
+ state = Presence.syncState(state, newState, onJoin, onLeave)
+ assertThat(state).isEqualTo(newState)
+
+ // asset equality in joined
+ val leftExpectation: PresenceDiff = mutableMapOf(
+ "u3" to mutableMapOf(
+ "leftPres" to mutableMapOf(
+ "metas" to listOf(mapOf("id" to 3, "phx_ref" to "3.left"))),
+ "current" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))
+
+ ))
+ assertThat(left).isEqualTo(leftExpectation)
+
+ // assert equality in left
+ assertThat(joined).isEmpty()
+ }
+
+ @Test
+ internal fun `syncs both joined and left metas`() {
+ val newState = mutableMapOf(
+ "u3" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 3, "phx_ref" to "3"),
+ mapOf("id" to 3, "phx_ref" to "3.new")
+ )))
+
+ var state = mutableMapOf(
+ "u3" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 3, "phx_ref" to "3"),
+ mapOf("id" to 3, "phx_ref" to "3.left")
+ )))
+
+ val joined: PresenceDiff = mutableMapOf()
+ val left: PresenceDiff = mutableMapOf()
+
+ val onJoin: OnJoin = { key, current, newPres ->
+ val joinState: PresenceState = mutableMapOf("newPres" to newPres)
+ current?.let { c -> joinState["current"] = c }
+
+ joined[key] = joinState
+ }
+
+ val onLeave: OnLeave = { key, current, leftPres ->
+ left[key] = mutableMapOf("current" to current, "leftPres" to leftPres)
+ }
+
+ state = Presence.syncState(state, newState, onJoin, onLeave)
+ assertThat(state).isEqualTo(newState)
+
+ // asset equality in joined
+ val joinedExpectation: PresenceDiff = mutableMapOf(
+ "u3" to mutableMapOf(
+ "newPres" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3.new"))),
+ "current" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 3, "phx_ref" to "3"),
+ mapOf("id" to 3, "phx_ref" to "3.left")))
+ ))
+ assertThat(joined).isEqualTo(joinedExpectation)
+
+ // assert equality in left
+ val leftExpectation: PresenceDiff = mutableMapOf(
+ "u3" to mutableMapOf(
+ "leftPres" to mutableMapOf(
+ "metas" to listOf(mapOf("id" to 3, "phx_ref" to "3.left"))),
+ "current" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 3, "phx_ref" to "3"),
+ mapOf("id" to 3, "phx_ref" to "3.new")))
+ ))
+ assertThat(left).isEqualTo(leftExpectation)
+ }
+
+ /* end SyncState */
+ }
+
+ @Nested
+ @DisplayName("syncDiff")
+ inner class SyncDiff {
+
+ @Test
+ internal fun `syncs empty state`() {
+ val joins: PresenceState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to
+ listOf(mapOf("id" to 1, "phx_ref" to "1"))))
+ var state: PresenceState = mutableMapOf()
+
+ Presence.syncDiff(state, mutableMapOf("joins" to joins, "leaves" to mutableMapOf()))
+ assertThat(state).isEmpty()
+
+ state = Presence.syncDiff(state, mutableMapOf("joins" to joins, "leaves" to mutableMapOf()))
+ assertThat(state).isEqualTo(joins)
+ }
+
+ @Test
+ internal fun `removes presence when meta is empty and adds additional meta`() {
+ var state = fixState
+ val diff: PresenceDiff = mutableMapOf("joins" to fixJoins, "leaves" to fixLeaves)
+ state = Presence.syncDiff(state, diff)
+
+ val expectation: PresenceState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to
+ listOf(
+ mapOf("id" to 1, "phx_ref" to "1"),
+ mapOf("id" to 1, "phx_ref" to "1.2")
+ )
+ ),
+ "u3" to mutableMapOf("metas" to
+ listOf(mapOf("id" to 3, "phx_ref" to "3"))
+ )
+ )
+
+ assertThat(state).isEqualTo(expectation)
+ }
+
+ @Test
+ internal fun `removes meta while leaving key if other metas exist`() {
+ var state = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 1, "phx_ref" to "1"),
+ mapOf("id" to 1, "phx_ref" to "1.2")
+ )))
+
+ val leaves = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 1, "phx_ref" to "1")
+ )))
+ val diff: PresenceDiff = mutableMapOf("joins" to mutableMapOf(), "leaves" to leaves)
+ state = Presence.syncDiff(state, diff)
+
+ val expectedState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(mapOf("id" to 1, "phx_ref" to "1.2"))))
+ assertThat(state).isEqualTo(expectedState)
+ }
+
+ /* End SyncDiff */
+ }
+
+ @Nested
+ @DisplayName("listBy")
+ inner class ListBy {
+
+ @Test
+ internal fun `lists full presence by default`() {
+ presence.state = fixState
+
+ val listExpectation = listOf(
+ mapOf("metas" to listOf(mapOf("id" to 1, "phx_ref" to "1"))),
+ mapOf("metas" to listOf(mapOf("id" to 2, "phx_ref" to "2"))),
+ mapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))
+ )
+
+ assertThat(presence.list()).isEqualTo(listExpectation)
+ }
+
+ @Test
+ internal fun `lists with custom function`() {
+ val state: PresenceState = mutableMapOf(
+ "u1" to mutableMapOf("metas" to listOf(
+ mapOf("id" to 1, "phx_ref" to "1.first"),
+ mapOf("id" to 1, "phx_ref" to "1.second"))
+ )
+ )
+
+ presence.state = state
+ val listBy = presence.listBy { it.value["metas"]!!.first() }
+ assertThat(listBy).isEqualTo(listOf(mapOf("id" to 1, "phx_ref" to "1.first")))
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/SocketTest.kt b/src/test/kotlin/org/phoenixframework/SocketTest.kt
new file mode 100644
index 0000000..5fc512f
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/SocketTest.kt
@@ -0,0 +1,1077 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth.assertThat
+import org.mockito.kotlin.any
+import org.mockito.kotlin.argumentCaptor
+import org.mockito.kotlin.eq
+import org.mockito.kotlin.mock
+import org.mockito.kotlin.never
+import org.mockito.kotlin.spy
+import org.mockito.kotlin.times
+import org.mockito.kotlin.verify
+import org.mockito.kotlin.verifyNoInteractions
+import org.mockito.kotlin.whenever
+import okhttp3.OkHttpClient
+import okhttp3.Response
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.DisplayName
+import org.junit.jupiter.api.Nested
+import org.junit.jupiter.api.Test
+import org.mockito.Mock
+import org.mockito.MockitoAnnotations
+import java.net.URL
+import java.util.concurrent.TimeUnit
+
+class SocketTest {
+
+ @Mock lateinit var okHttpClient: OkHttpClient
+ @Mock lateinit var mockDispatchQueue: DispatchQueue
+
+ lateinit var connection: Transport
+ lateinit var socket: Socket
+
+ @BeforeEach
+ internal fun setUp() {
+ MockitoAnnotations.initMocks(this)
+
+ connection = spy(WebSocketTransport(URL("https://localhost:4000/socket"), okHttpClient))
+
+ socket = Socket("wss://localhost:4000/socket")
+ socket.transport = { connection }
+ socket.dispatchQueue = mockDispatchQueue
+ }
+
+ @Nested
+ @DisplayName("constructor")
+ inner class Constructor {
+ @Test
+ internal fun `sets defaults`() {
+ val socket = Socket("wss://localhost:4000/socket")
+
+ assertThat(socket.paramsClosure.invoke()).isNull()
+ assertThat(socket.channels).isEmpty()
+ assertThat(socket.sendBuffer).isEmpty()
+ assertThat(socket.ref).isEqualTo(0)
+ assertThat(socket.endpoint).isEqualTo("wss://localhost:4000/socket/websocket")
+ assertThat(socket.vsn).isEqualTo(Defaults.VSN)
+ assertThat(socket.stateChangeCallbacks.open).isEmpty()
+ assertThat(socket.stateChangeCallbacks.close).isEmpty()
+ assertThat(socket.stateChangeCallbacks.error).isEmpty()
+ assertThat(socket.stateChangeCallbacks.message).isEmpty()
+ assertThat(socket.timeout).isEqualTo(Defaults.TIMEOUT)
+ assertThat(socket.heartbeatIntervalMs).isEqualTo(Defaults.HEARTBEAT)
+ assertThat(socket.logger).isNull()
+ assertThat(socket.reconnectAfterMs(1)).isEqualTo(10)
+ assertThat(socket.reconnectAfterMs(2)).isEqualTo(50)
+ assertThat(socket.reconnectAfterMs(3)).isEqualTo(100)
+ assertThat(socket.reconnectAfterMs(4)).isEqualTo(150)
+ assertThat(socket.reconnectAfterMs(5)).isEqualTo(200)
+ assertThat(socket.reconnectAfterMs(6)).isEqualTo(250)
+ assertThat(socket.reconnectAfterMs(7)).isEqualTo(500)
+ assertThat(socket.reconnectAfterMs(8)).isEqualTo(1_000)
+ assertThat(socket.reconnectAfterMs(9)).isEqualTo(2_000)
+ assertThat(socket.reconnectAfterMs(10)).isEqualTo(5_000)
+ assertThat(socket.reconnectAfterMs(11)).isEqualTo(5_000)
+ }
+
+ @Test
+ internal fun `overrides some defaults`() {
+ val socket = Socket("wss://localhost:4000/socket/", mapOf("one" to 2))
+ socket.timeout = 40_000
+ socket.heartbeatIntervalMs = 60_000
+ socket.logger = { }
+ socket.reconnectAfterMs = { 10 }
+
+ assertThat(socket.paramsClosure?.invoke()).isEqualTo(mapOf("one" to 2))
+ assertThat(socket.endpoint).isEqualTo("wss://localhost:4000/socket/websocket")
+ assertThat(socket.timeout).isEqualTo(40_000)
+ assertThat(socket.heartbeatIntervalMs).isEqualTo(60_000)
+ assertThat(socket.logger).isNotNull()
+ assertThat(socket.reconnectAfterMs(1)).isEqualTo(10)
+ assertThat(socket.reconnectAfterMs(2)).isEqualTo(10)
+ }
+
+ @Test
+ internal fun `constructs with a valid URL`() {
+ // Test different schemes
+ assertThat(Socket("http://localhost:4000/socket/websocket").endpointUrl.toString())
+ .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0")
+
+ assertThat(Socket("https://localhost:4000/socket/websocket").endpointUrl.toString())
+ .isEqualTo("https://localhost:4000/socket/websocket?vsn=2.0.0")
+
+ assertThat(Socket("ws://localhost:4000/socket/websocket").endpointUrl.toString())
+ .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0")
+
+ assertThat(Socket("wss://localhost:4000/socket/websocket").endpointUrl.toString())
+ .isEqualTo("https://localhost:4000/socket/websocket?vsn=2.0.0")
+
+ // test params
+ val singleParam = hashMapOf("token" to "abc123")
+ assertThat(Socket("ws://localhost:4000/socket/websocket", singleParam).endpointUrl.toString())
+ .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0&token=abc123")
+
+ val multipleParams = hashMapOf("token" to "abc123", "user_id" to 1)
+ assertThat(
+ Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString()
+ )
+ .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0&user_id=1&token=abc123")
+
+ // test params with spaces
+ val spacesParams = hashMapOf("token" to "abc 123", "user_id" to 1)
+ assertThat(
+ Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString()
+ )
+ .isEqualTo("https://localhost:4000/socket/websocket?vsn=2.0.0&user_id=1&token=abc%20123")
+ }
+
+ /* End Constructor */
+ }
+
+ @Nested
+ @DisplayName("protocol")
+ inner class Protocol {
+ @Test
+ internal fun `returns wss when protocol is https`() {
+ val socket = Socket("https://example.com/")
+ assertThat(socket.protocol).isEqualTo("wss")
+ }
+
+ @Test
+ internal fun `returns ws when protocol is http`() {
+ val socket = Socket("http://example.com/")
+ assertThat(socket.protocol).isEqualTo("ws")
+ }
+
+ @Test
+ internal fun `returns value if not https or http`() {
+ val socket = Socket("wss://example.com/")
+ assertThat(socket.protocol).isEqualTo("wss")
+ }
+
+ /* End Protocol */
+ }
+
+ @Nested
+ @DisplayName("isConnected")
+ inner class IsConnected {
+ @Test
+ internal fun `returns false if connection is null`() {
+ assertThat(socket.isConnected).isFalse()
+ }
+
+ @Test
+ internal fun `is false if state is not open`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.CLOSING)
+
+ socket.connection = connection
+ assertThat(socket.isConnected).isFalse()
+ }
+
+ @Test
+ internal fun `is true if state open`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+
+ socket.connection = connection
+ assertThat(socket.isConnected).isTrue()
+ }
+
+ /* End IsConnected */
+ }
+
+ @Nested
+ @DisplayName("connect")
+ inner class Connect {
+ @Test
+ internal fun `establishes websocket connection with endpoint`() {
+ socket.connect()
+ assertThat(socket.connection).isNotNull()
+ }
+
+ @Test
+ internal fun `accounts for changing parameters`() {
+ val transport = mock<(URL) -> Transport>()
+ whenever(transport.invoke(any())).thenReturn(connection)
+
+ var token = "a"
+ val socket = Socket("wss://localhost:4000/socket", { mapOf("token" to token) })
+ socket.transport = transport
+
+ socket.connect()
+ argumentCaptor {
+ verify(transport).invoke(capture())
+ assertThat(firstValue.query).isEqualTo("vsn=2.0.0&token=a")
+
+ token = "b"
+ socket.disconnect()
+ socket.connect()
+ verify(transport, times(2)).invoke(capture())
+ assertThat(lastValue.query).isEqualTo("vsn=2.0.0&token=b")
+ }
+ }
+
+ @Test
+ internal fun `sets callbacks for connection`() {
+ var open = 0
+ socket.onOpen { open += 1 }
+
+ var close = 0
+ socket.onClose { close += 1 }
+
+ var lastError: Throwable? = null
+ var lastResponse: Response? = null
+ socket.onError { throwable, response ->
+ lastError = throwable
+ lastResponse = response
+ }
+
+ var lastMessage: Message? = null
+ socket.onMessage { lastMessage = it }
+
+ socket.connect()
+
+ socket.connection?.onOpen?.invoke()
+ assertThat(open).isEqualTo(1)
+
+ socket.connection?.onClose?.invoke(1000)
+ assertThat(close).isEqualTo(1)
+
+ socket.connection?.onError?.invoke(Throwable(), null)
+ assertThat(lastError).isNotNull()
+ assertThat(lastResponse).isNull()
+
+ val data = listOf(null, null, "topic", "event", mapOf("go" to true))
+
+ val json = Defaults.gson.toJson(data)
+ socket.connection?.onMessage?.invoke(json)
+ assertThat(lastMessage?.payload).isEqualTo(mapOf("go" to true))
+ }
+
+ @Test
+ internal fun `removes callbacks`() {
+ var open = 0
+ socket.onOpen { open += 1 }
+
+ var close = 0
+ socket.onClose { close += 1 }
+
+ var lastError: Throwable? = null
+ var lastResponse: Response? = null
+ socket.onError { throwable, response ->
+ lastError = throwable
+ lastResponse = response
+ }
+
+ var lastMessage: Message? = null
+ socket.onMessage { lastMessage = it }
+
+ socket.removeAllCallbacks()
+ socket.connect()
+
+ socket.connection?.onOpen?.invoke()
+ assertThat(open).isEqualTo(0)
+
+ socket.connection?.onClose?.invoke(1000)
+ assertThat(close).isEqualTo(0)
+
+ socket.connection?.onError?.invoke(Throwable(), null)
+ assertThat(lastError).isNull()
+ assertThat(lastResponse).isNull()
+
+ val data = listOf(null, null, "topic", "event", mapOf("go" to true))
+
+ val json = Defaults.gson.toJson(data)
+ socket.connection?.onMessage?.invoke(json)
+ assertThat(lastMessage?.payload).isNull()
+ }
+
+ @Test
+ internal fun `does not connect if already connected`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+ socket.connect()
+ socket.connect()
+
+ verify(connection, times(1)).connect()
+ }
+
+ /* End Connect */
+ }
+
+ @Nested
+ @DisplayName("disconnect")
+ inner class Disconnect {
+ @Test
+ internal fun `removes existing connection`() {
+ socket.connect()
+ socket.disconnect()
+
+ assertThat(socket.connection).isNull()
+ verify(connection).disconnect(WS_CLOSE_NORMAL)
+ }
+
+ @Test
+ internal fun `flags the socket as closed cleanly`() {
+ assertThat(socket.closeWasClean).isFalse()
+
+ socket.disconnect()
+ assertThat(socket.closeWasClean).isTrue()
+ }
+
+ @Test
+ internal fun `calls callback`() {
+ val mockCallback = mock<() -> Unit> {}
+
+ socket.disconnect(callback = mockCallback)
+ verify(mockCallback).invoke()
+ }
+
+ @Test
+ internal fun `calls connection close callback`() {
+ socket.connect()
+ socket.disconnect(10, "reason")
+ verify(connection).disconnect(10, "reason")
+ }
+
+ @Test
+ internal fun `resets reconnect timer`() {
+ val mockTimer = mock()
+ socket.reconnectTimer = mockTimer
+
+ socket.disconnect()
+ verify(mockTimer).reset()
+ }
+
+ @Test
+ internal fun `cancels and releases heartbeat timer`() {
+ val mockTask = mock()
+ socket.heartbeatTask = mockTask
+
+ socket.disconnect()
+ verify(mockTask).cancel()
+ assertThat(socket.heartbeatTask).isNull()
+ }
+
+ @Test
+ internal fun `does nothing if not connected`() {
+ socket.disconnect()
+ verifyNoInteractions(connection)
+ }
+
+ /* End Disconnect */
+ }
+
+ @Nested
+ @DisplayName("channel")
+ inner class NewChannel {
+ @Test
+ internal fun `returns channel with given topic and params`() {
+ val channel = socket.channel("topic", mapOf("one" to "two"))
+
+ assertThat(channel.socket).isEqualTo(socket)
+ assertThat(channel.topic).isEqualTo("topic")
+ assertThat(channel.params["one"]).isEqualTo("two")
+ }
+
+ @Test
+ internal fun `adds channel to socket's channel list`() {
+ assertThat(socket.channels).isEmpty()
+
+ val channel = socket.channel("topic", mapOf("one" to "two"))
+
+ assertThat(socket.channels).hasSize(1)
+ assertThat(socket.channels.first()).isEqualTo(channel)
+ }
+
+ /* End Channel */
+ }
+
+ @Nested
+ @DisplayName("remove")
+ inner class Remove {
+ @Test
+ internal fun `removes given channel from channels`() {
+ val channel1 = socket.channel("topic-1")
+ val channel2 = socket.channel("topic-2")
+
+ channel1.joinPush.ref = "1"
+ channel2.joinPush.ref = "2"
+
+ socket.remove(channel1)
+ assertThat(socket.channels).doesNotContain(channel1)
+ assertThat(socket.channels).contains(channel2)
+ }
+
+ @Test
+ internal fun `does not throw exception when iterating over channels`() {
+ val channel1 = socket.channel("topic-1")
+ val channel2 = socket.channel("topic-2")
+
+ channel1.joinPush.ref = "1"
+ channel2.joinPush.ref = "2"
+
+ channel1.join().trigger("ok", emptyMap())
+ channel2.join().trigger("ok", emptyMap())
+
+ var chan1Called = false
+ channel1.onError { chan1Called = true }
+
+ var chan2Called = false
+ channel2.onError {
+ chan2Called = true
+ socket.remove(channel2)
+ }
+
+ // This will trigger an iteration over the socket.channels list which will trigger
+ // channel2.onError. That callback will attempt to remove channel2 during iteration
+ // which would throw a ConcurrentModificationException if the socket.remove method
+ // is implemented incorrectly.
+ socket.onConnectionError(IllegalStateException(), null)
+
+ // Assert that both on all error's got called even when a channel was removed
+ assertThat(chan1Called).isTrue()
+ assertThat(chan2Called).isTrue()
+
+ assertThat(socket.channels).doesNotContain(channel2)
+ assertThat(socket.channels).contains(channel1)
+ }
+
+ /* End Remove */
+ }
+
+ @Nested
+ @DisplayName("release")
+ inner class Release {
+ @Test
+ internal fun `Clears any callbacks with the matching refs`() {
+ socket.stateChangeCallbacks.onOpen("1") {}
+ socket.stateChangeCallbacks.onOpen("2") {}
+ socket.stateChangeCallbacks.onClose("1") {}
+ socket.stateChangeCallbacks.onClose("2") {}
+ socket.stateChangeCallbacks.onError("1") { _: Throwable, _: Response? -> }
+ socket.stateChangeCallbacks.onError("2") { _: Throwable, _: Response? -> }
+ socket.stateChangeCallbacks.onMessage("1") { }
+ socket.stateChangeCallbacks.onMessage("2") { }
+
+ socket.stateChangeCallbacks.release(listOf("1"))
+
+ assertThat(socket.stateChangeCallbacks.open).doesNotContain("1")
+ assertThat(socket.stateChangeCallbacks.close).doesNotContain("1")
+ assertThat(socket.stateChangeCallbacks.error).doesNotContain("1")
+ assertThat(socket.stateChangeCallbacks.message).doesNotContain("1")
+ }
+ }
+
+ @Nested
+ @DisplayName("push")
+ inner class Push {
+ @Test
+ internal fun `sends data to connection when connected`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+
+ socket.connect()
+ socket.push("topic", "event", mapOf("one" to "two"), "ref", "join-ref")
+
+ val expected = "[\"join-ref\",\"ref\",\"topic\",\"event\",{\"one\":\"two\"}]"
+ verify(connection).send(expected)
+ }
+
+ @Test
+ internal fun `excludes ref information if not passed`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+
+ socket.connect()
+ socket.push("topic", "event", mapOf("one" to "two"))
+
+ val expected = "[null,null,\"topic\",\"event\",{\"one\":\"two\"}]"
+ verify(connection).send(expected)
+ }
+
+ @Test
+ internal fun `buffers data when not connected`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.CLOSED)
+ socket.connect()
+
+ socket.push("topic", "event1", mapOf("one" to "two"))
+ verify(connection, never()).send(any())
+ assertThat(socket.sendBuffer).hasSize(1)
+
+ socket.push("topic", "event2", mapOf("one" to "two"))
+ verify(connection, never()).send(any())
+ assertThat(socket.sendBuffer).hasSize(2)
+
+ socket.sendBuffer.forEach { it.second.invoke() }
+ verify(connection, times(2)).send(any())
+ }
+
+ /* End Push */
+ }
+
+ @Nested
+ @DisplayName("makeRef")
+ inner class MakeRef {
+ @Test
+ internal fun `returns next message ref`() {
+ assertThat(socket.ref).isEqualTo(0)
+ assertThat(socket.makeRef()).isEqualTo("1")
+ assertThat(socket.ref).isEqualTo(1)
+ assertThat(socket.makeRef()).isEqualTo("2")
+ assertThat(socket.ref).isEqualTo(2)
+ }
+
+ @Test
+ internal fun `resets to 0 if it hits max int`() {
+ socket.ref = Int.MAX_VALUE
+
+ assertThat(socket.makeRef()).isEqualTo("0")
+ assertThat(socket.ref).isEqualTo(0)
+ }
+
+ /* End MakeRef */
+ }
+
+ @Nested
+ @DisplayName("sendHeartbeat")
+ inner class SendHeartbeat {
+
+ @BeforeEach
+ internal fun setUp() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+ socket.connect()
+ }
+
+ @Test
+ internal fun `closes socket when heartbeat is not ack'd within heartbeat window`() {
+ socket.sendHeartbeat()
+ verify(connection, never()).disconnect(any(), any())
+ assertThat(socket.pendingHeartbeatRef).isNotNull()
+
+ socket.sendHeartbeat()
+ verify(connection).disconnect(WS_CLOSE_NORMAL, "heartbeat timeout")
+ assertThat(socket.pendingHeartbeatRef).isNull()
+ }
+
+ @Test
+ internal fun `pushes heartbeat data when connected`() {
+ socket.sendHeartbeat()
+
+ val expected = "[null,\"1\",\"phoenix\",\"heartbeat\",{}]"
+ assertThat(socket.pendingHeartbeatRef).isEqualTo(socket.ref.toString())
+ verify(connection).send(expected)
+ }
+
+ @Test
+ internal fun `does nothing when not connected`() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.CLOSED)
+ socket.sendHeartbeat()
+
+ verify(connection, never()).disconnect(any(), any())
+ verify(connection, never()).send(any())
+ }
+
+ /* End SendHeartbeat */
+ }
+
+ @Nested
+ @DisplayName("flushSendBuffer")
+ inner class FlushSendBuffer {
+ @Test
+ internal fun `invokes callbacks in buffer when connected`() {
+ var oneCalled = 0
+ socket.sendBuffer.add(Pair("0", { oneCalled += 1 }))
+ var twoCalled = 0
+ socket.sendBuffer.add(Pair("1", { twoCalled += 1 }))
+ val threeCalled = 0
+
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+
+ // does nothing if not connected
+ socket.flushSendBuffer()
+ assertThat(oneCalled).isEqualTo(0)
+
+ // connect
+ socket.connect()
+
+ // sends once connected
+ socket.flushSendBuffer()
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(1)
+ assertThat(threeCalled).isEqualTo(0)
+ }
+
+ @Test
+ internal fun `empties send buffer`() {
+ socket.sendBuffer.add(Pair(null, {}))
+
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+ socket.connect()
+
+ assertThat(socket.sendBuffer).isNotEmpty()
+ socket.flushSendBuffer()
+
+ assertThat(socket.sendBuffer).isEmpty()
+ }
+
+ /* End FlushSendBuffer */
+ }
+
+ @Nested
+ @DisplayName("removeFromSendBuffer")
+ inner class RemoveFromSendBuffer {
+ @Test
+ internal fun `removes a callback with matching ref`() {
+ var oneCalled = 0
+ socket.sendBuffer.add(Pair("0", { oneCalled += 1 }))
+ var twoCalled = 0
+ socket.sendBuffer.add(Pair("1", { twoCalled += 1 }))
+
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+
+ // connect
+ socket.connect()
+
+ socket.removeFromSendBuffer("0")
+
+ // sends once connected
+ socket.flushSendBuffer()
+ assertThat(oneCalled).isEqualTo(0)
+ assertThat(twoCalled).isEqualTo(1)
+ }
+ }
+
+ @Nested
+ @DisplayName("resetHeartbeat")
+ inner class ResetHeartbeat {
+ @Test
+ internal fun `clears any pending heartbeat`() {
+ socket.pendingHeartbeatRef = "1"
+ socket.resetHeartbeat()
+
+ assertThat(socket.pendingHeartbeatRef).isNull()
+ }
+
+ @Test
+ fun `does not schedule heartbeat if skipHeartbeat == true`() {
+ socket.skipHeartbeat = true
+ socket.resetHeartbeat()
+
+ verifyNoInteractions(mockDispatchQueue)
+ }
+
+ @Test
+ internal fun `creates a future heartbeat task`() {
+ val mockTask = mock()
+ whenever(mockDispatchQueue.queueAtFixedRate(any(), any(), any(), any())).thenReturn(mockTask)
+
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+ socket.connect()
+ socket.heartbeatIntervalMs = 5_000
+
+ assertThat(socket.heartbeatTask).isNull()
+ socket.resetHeartbeat()
+
+ assertThat(socket.heartbeatTask).isNotNull()
+ argumentCaptor<() -> Unit> {
+ verify(mockDispatchQueue).queueAtFixedRate(
+ eq(5_000L), eq(5_000L),
+ eq(TimeUnit.MILLISECONDS), capture()
+ )
+
+ // fire the task
+ allValues.first().invoke()
+
+ val expected = "[null,\"1\",\"phoenix\",\"heartbeat\",{}]"
+ verify(connection).send(expected)
+ }
+ }
+
+ /* End ResetHeartbeat */
+ }
+
+ @Nested
+ @DisplayName("onConnectionOpened")
+ inner class OnConnectionOpened {
+
+ @BeforeEach
+ internal fun setUp() {
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+ socket.connect()
+ }
+
+ @Test
+ internal fun `flushes the send buffer`() {
+ var oneCalled = 0
+ socket.sendBuffer.add(Pair("1", { oneCalled += 1 }))
+
+ socket.onConnectionOpened()
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(socket.sendBuffer).isEmpty()
+ }
+
+ @Test
+ internal fun `resets reconnect timer`() {
+ val mockTimer = mock()
+ socket.reconnectTimer = mockTimer
+
+ socket.onConnectionOpened()
+ verify(mockTimer).reset()
+ }
+
+ @Test
+ internal fun `resets the heartbeat`() {
+ val mockTask = mock()
+ socket.heartbeatTask = mockTask
+
+ socket.onConnectionOpened()
+ verify(mockTask).cancel()
+ verify(mockDispatchQueue).queueAtFixedRate(any(), any(), any(), any())
+ }
+
+ @Test
+ internal fun `invokes all onOpen callbacks`() {
+ var oneCalled = 0
+ socket.onOpen { oneCalled += 1 }
+ var twoCalled = 0
+ socket.onOpen { twoCalled += 1 }
+ var threeCalled = 0
+ socket.onClose { threeCalled += 1 }
+
+ socket.onConnectionOpened()
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(1)
+ assertThat(threeCalled).isEqualTo(0)
+ }
+
+ /* End OnConnectionOpened */
+ }
+
+ @Nested
+ @DisplayName("onConnectionClosed")
+ inner class OnConnectionClosed {
+
+ private lateinit var mockTimer: TimeoutTimer
+
+ @BeforeEach
+ internal fun setUp() {
+ mockTimer = mock()
+ socket.reconnectTimer = mockTimer
+
+ whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN)
+ socket.connect()
+ }
+
+ @Test
+ internal fun `schedules reconnectTimer timeout if normal close`() {
+ socket.onConnectionClosed(WS_CLOSE_NORMAL)
+ verify(mockTimer).scheduleTimeout()
+ }
+
+ @Test
+ internal fun `does not schedule reconnectTimer timeout if normal close after explicit disconnect`() {
+ socket.disconnect()
+ verify(mockTimer, never()).scheduleTimeout()
+ }
+
+ @Test
+ internal fun `schedules reconnectTimer if not normal close`() {
+ socket.onConnectionClosed(1001)
+ verify(mockTimer).scheduleTimeout()
+ }
+
+ @Test
+ internal fun `schedules reconnectTimer timeout if connection cannot be made after a previous clean disconnect`() {
+ socket.disconnect()
+ socket.connect()
+
+ socket.onConnectionClosed(1001)
+ verify(mockTimer).scheduleTimeout()
+ }
+
+ @Test
+ internal fun `cancels heartbeat task`() {
+ val mockTask = mock()
+ socket.heartbeatTask = mockTask
+
+ socket.onConnectionClosed(1000)
+ verify(mockTask).cancel()
+ assertThat(socket.heartbeatTask).isNull()
+ }
+
+ @Test
+ internal fun `triggers onClose callbacks`() {
+ var oneCalled = 0
+ socket.onClose { oneCalled += 1 }
+ var twoCalled = 0
+ socket.onClose { twoCalled += 1 }
+ var threeCalled = 0
+ socket.onOpen { threeCalled += 1 }
+
+ socket.onConnectionClosed(1000)
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(1)
+ assertThat(threeCalled).isEqualTo(0)
+ }
+
+ @Test
+ internal fun `triggers channel error if joining`() {
+ val channel = socket.channel("topic")
+ val spy = spy(channel)
+
+ // Use the spy instance instead of the Channel instance
+ socket.channels = socket.channels.minus(channel)
+ socket.channels = socket.channels.plus(spy)
+
+ spy.join()
+ assertThat(spy.state).isEqualTo(Channel.State.JOINING)
+
+ socket.onConnectionClosed(1001)
+ verify(spy).trigger("phx_error")
+ }
+
+ @Test
+ internal fun `triggers channel error if joined`() {
+ val channel = socket.channel("topic")
+ val spy = spy(channel)
+
+ // Use the spy instance instead of the Channel instance
+ socket.channels = socket.channels.minus(channel)
+ socket.channels = socket.channels.plus(spy)
+
+ spy.join().trigger("ok", emptyMap())
+
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+
+ socket.onConnectionClosed(1001)
+ verify(spy).trigger("phx_error")
+ }
+
+ @Test
+ internal fun `does not trigger channel error after leave`() {
+ val channel = socket.channel("topic")
+ val spy = spy(channel)
+
+ // Use the spy instance instead of the Channel instance
+ socket.channels = socket.channels.minus(channel)
+ socket.channels = socket.channels.plus(spy)
+
+ spy.join().trigger("ok", emptyMap())
+ spy.leave()
+
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+
+ socket.onConnectionClosed(1001)
+ verify(spy, never()).trigger("phx_error")
+ }
+
+ /* End OnConnectionClosed */
+ }
+
+ @Nested
+ @DisplayName("onConnectionError")
+ inner class OnConnectionError {
+ @Test
+ internal fun `triggers onClose callbacks`() {
+ var oneCalled = 0
+ socket.onError { _, _ -> oneCalled += 1 }
+ var twoCalled = 0
+ socket.onError { _, _ -> twoCalled += 1 }
+ var threeCalled = 0
+ socket.onOpen { threeCalled += 1 }
+
+ socket.onConnectionError(Throwable(), null)
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(1)
+ assertThat(threeCalled).isEqualTo(0)
+ }
+
+ @Test
+ internal fun `triggers channel error if joining`() {
+ val channel = socket.channel("topic")
+ val spy = spy(channel)
+
+ // Use the spy instance instead of the Channel instance
+ socket.channels = socket.channels.minus(channel)
+ socket.channels = socket.channels.plus(spy)
+
+ spy.join()
+ assertThat(spy.state).isEqualTo(Channel.State.JOINING)
+
+ socket.onConnectionError(Throwable(), null)
+ verify(spy).trigger("phx_error")
+ }
+
+ @Test
+ internal fun `triggers channel error if joined`() {
+ val channel = socket.channel("topic")
+ val spy = spy(channel)
+
+ // Use the spy instance instead of the Channel instance
+ socket.channels = socket.channels.minus(channel)
+ socket.channels = socket.channels.plus(spy)
+
+ spy.join().trigger("ok", emptyMap())
+
+ assertThat(channel.state).isEqualTo(Channel.State.JOINED)
+
+ socket.onConnectionError(Throwable(), null)
+ verify(spy).trigger("phx_error")
+ }
+
+ @Test
+ internal fun `does not trigger channel error after leave`() {
+ val channel = socket.channel("topic")
+ val spy = spy(channel)
+
+ // Use the spy instance instead of the Channel instance
+ socket.channels = socket.channels.minus(channel)
+ socket.channels = socket.channels.plus(spy)
+
+ spy.join().trigger("ok", emptyMap())
+ spy.leave()
+
+ assertThat(channel.state).isEqualTo(Channel.State.CLOSED)
+
+ socket.onConnectionError(Throwable(), null)
+ verify(spy, never()).trigger("phx_error")
+ }
+
+ /* End OnConnectionError */
+ }
+
+ @Nested
+ @DisplayName("onConnectionMessage")
+ inner class OnConnectionMessage {
+
+
+ @Test
+ internal fun `parses raw messages and triggers channel event`() {
+ val targetChannel = mock()
+ whenever(targetChannel.isMember(any())).thenReturn(true)
+ val otherChannel = mock()
+ whenever(otherChannel.isMember(any())).thenReturn(false)
+
+ socket.channels = socket.channels.plus(targetChannel)
+ socket.channels = socket.channels.minus(otherChannel)
+
+ val rawMessage = "[null,null,\"topic\",\"event\",{\"one\":\"two\",\"status\":\"ok\"}]"
+ socket.onConnectionMessage(rawMessage)
+
+ verify(targetChannel).trigger(message = any())
+ verify(otherChannel, never()).trigger(message = any())
+ }
+
+ @Test
+ internal fun `invokes onMessage callbacks`() {
+ var message: Message? = null
+ socket.onMessage { message = it }
+
+ val rawMessage = "[null,null,\"topic\",\"event\",{\"one\":\"two\",\"status\":\"ok\"}]"
+ socket.onConnectionMessage(rawMessage)
+
+ assertThat(message?.topic).isEqualTo("topic")
+ assertThat(message?.event).isEqualTo("event")
+ }
+
+ @Test
+ internal fun `clears pending heartbeat`() {
+ socket.pendingHeartbeatRef = "5"
+
+ val rawMessage = "[null,\"5\",\"topic\",\"event\",{\"one\":\"two\",\"status\":\"ok\"}]"
+ socket.onConnectionMessage(rawMessage)
+ assertThat(socket.pendingHeartbeatRef).isNull()
+ }
+
+ /* End OnConnectionMessage */
+ }
+
+ @Nested
+ @DisplayName("ConcurrentModificationException")
+ inner class ConcurrentModificationExceptionTests {
+
+ @Test
+ internal fun `onOpen does not throw`() {
+ var oneCalled = 0
+ var twoCalled = 0
+ socket.onOpen {
+ socket.onOpen { twoCalled += 1 }
+ oneCalled += 1
+ }
+
+ socket.onConnectionOpened()
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(0)
+
+ socket.onConnectionOpened()
+ assertThat(oneCalled).isEqualTo(2)
+ assertThat(twoCalled).isEqualTo(1)
+ }
+
+ @Test
+ internal fun `onClose does not throw`() {
+ var oneCalled = 0
+ var twoCalled = 0
+ socket.onClose {
+ socket.onClose { twoCalled += 1 }
+ oneCalled += 1
+ }
+
+ socket.onConnectionClosed(1000)
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(0)
+
+ socket.onConnectionClosed(1001)
+ assertThat(oneCalled).isEqualTo(2)
+ assertThat(twoCalled).isEqualTo(1)
+ }
+
+ @Test
+ internal fun `onError does not throw`() {
+ var oneCalled = 0
+ var twoCalled = 0
+ socket.onError { _, _ ->
+ socket.onError { _, _ -> twoCalled += 1 }
+ oneCalled += 1
+ }
+
+ socket.onConnectionError(Throwable(), null)
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(0)
+
+ socket.onConnectionError(Throwable(), null)
+ assertThat(oneCalled).isEqualTo(2)
+ assertThat(twoCalled).isEqualTo(1)
+ }
+
+ @Test
+ internal fun `onMessage does not throw`() {
+ var oneCalled = 0
+ var twoCalled = 0
+ socket.onMessage {
+ socket.onMessage { twoCalled += 1 }
+ oneCalled += 1
+ }
+
+ val message = "[null,null,\"room:lobby\",\"shout\",{\"message\":\"Hi\",\"name\":\"Tester\"}]"
+ socket.onConnectionMessage(message)
+ assertThat(oneCalled).isEqualTo(1)
+ assertThat(twoCalled).isEqualTo(0)
+
+ socket.onConnectionMessage(message)
+ assertThat(oneCalled).isEqualTo(2)
+ assertThat(twoCalled).isEqualTo(1)
+ }
+
+ @Test
+ internal fun `does not throw when adding channel`() {
+ var oneCalled = 0
+ socket.onOpen {
+ val channel = socket.channel("foo")
+ oneCalled += 1
+ }
+
+ socket.onConnectionOpened()
+ assertThat(oneCalled).isEqualTo(1)
+ }
+
+ /* End ConcurrentModificationExceptionTests */
+ }
+}
diff --git a/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt b/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt
new file mode 100644
index 0000000..29ac4ea
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt
@@ -0,0 +1,67 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth
+import org.mockito.kotlin.any
+import org.mockito.kotlin.argumentCaptor
+import org.mockito.kotlin.eq
+import org.mockito.kotlin.times
+import org.mockito.kotlin.verify
+import org.mockito.kotlin.whenever
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.Test
+import org.mockito.Mock
+import org.mockito.MockitoAnnotations
+import org.phoenixframework.DispatchQueue
+import org.phoenixframework.DispatchWorkItem
+import org.phoenixframework.TimeoutTimer
+import java.util.concurrent.TimeUnit
+
+class TimeoutTimerTest {
+
+ @Mock lateinit var mockWorkItem: DispatchWorkItem
+ @Mock lateinit var mockDispatchQueue: DispatchQueue
+
+ private var callbackCallCount: Int = 0
+ private lateinit var timeoutTimer: TimeoutTimer
+
+ @BeforeEach
+ internal fun setUp() {
+ MockitoAnnotations.initMocks(this)
+
+ callbackCallCount = 0
+ whenever(mockDispatchQueue.queue(any(), any(), any())).thenReturn(mockWorkItem)
+
+ timeoutTimer = TimeoutTimer(
+ dispatchQueue = mockDispatchQueue,
+ callback = { callbackCallCount += 1 },
+ timerCalculation = { tries ->
+ if(tries > 3 ) 10000 else listOf(1000L, 2000L, 5000L)[tries -1]
+ })
+ }
+
+ @Test
+ internal fun `scheduleTimeout executes with backoff`() {
+ argumentCaptor<() -> Unit> {
+ timeoutTimer.scheduleTimeout()
+ verify(mockDispatchQueue).queue(eq(1000L), eq(TimeUnit.MILLISECONDS), capture())
+ lastValue.invoke()
+ Truth.assertThat(callbackCallCount).isEqualTo(1)
+
+ timeoutTimer.scheduleTimeout()
+ verify(mockDispatchQueue).queue(eq(2000L), eq(TimeUnit.MILLISECONDS), capture())
+ lastValue.invoke()
+ Truth.assertThat(callbackCallCount).isEqualTo(2)
+
+ timeoutTimer.scheduleTimeout()
+ verify(mockDispatchQueue).queue(eq(5000L), eq(TimeUnit.MILLISECONDS), capture())
+ lastValue.invoke()
+ Truth.assertThat(callbackCallCount).isEqualTo(3)
+
+ timeoutTimer.reset()
+ timeoutTimer.scheduleTimeout()
+ verify(mockDispatchQueue, times(2)).queue(eq(1000L), eq(TimeUnit.MILLISECONDS), capture())
+ lastValue.invoke()
+ Truth.assertThat(callbackCallCount).isEqualTo(4)
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt b/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt
new file mode 100644
index 0000000..381b5ed
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt
@@ -0,0 +1,178 @@
+package org.phoenixframework
+
+import com.google.common.truth.Truth.assertThat
+import org.mockito.kotlin.any
+import org.mockito.kotlin.mock
+import org.mockito.kotlin.verify
+import org.mockito.kotlin.whenever
+import okhttp3.OkHttpClient
+import okhttp3.Response
+import okhttp3.WebSocket
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.DisplayName
+import org.junit.jupiter.api.Nested
+import org.junit.jupiter.api.Test
+import org.mockito.Mock
+import org.mockito.MockitoAnnotations
+import java.net.SocketException
+import java.net.URL
+
+class WebSocketTransportTest {
+
+ @Mock lateinit var mockClient: OkHttpClient
+ @Mock lateinit var mockWebSocket: WebSocket
+ @Mock lateinit var mockResponse: Response
+
+ private lateinit var transport: WebSocketTransport
+
+ @BeforeEach
+ internal fun setUp() {
+ MockitoAnnotations.initMocks(this)
+
+ val url = URL("http://localhost:400/socket/websocket")
+ transport = WebSocketTransport(url, mockClient)
+ }
+
+ @Nested
+ @DisplayName("connect")
+ inner class Connect {
+ @Test
+ internal fun `sets ready state to CONNECTING and creates connection`() {
+ whenever(mockClient.newWebSocket(any(), any())).thenReturn(mockWebSocket)
+
+ transport.connect()
+ assertThat(transport.readyState).isEqualTo(Transport.ReadyState.CONNECTING)
+ assertThat(transport.connection).isNotNull()
+ }
+
+ /* End Connect */
+ }
+
+ @Nested
+ @DisplayName("disconnect")
+ inner class Disconnect {
+ @Test
+ internal fun `closes and releases connection`() {
+ transport.connection = mockWebSocket
+
+ transport.disconnect(10, "Test reason")
+ verify(mockWebSocket).close(10, "Test reason")
+ assertThat(transport.connection).isNull()
+ }
+
+ /* End Disconnect */
+ }
+
+ @Nested
+ @DisplayName("send")
+ inner class Send {
+ @Test
+ internal fun `sends text through the connection`() {
+ transport.connection = mockWebSocket
+
+ transport.send("some data")
+ verify(mockWebSocket).send("some data")
+ }
+
+ /* End Send */
+ }
+
+ @Nested
+ @DisplayName("onOpen")
+ inner class OnOpen {
+ @Test
+ internal fun `sets ready state to OPEN and invokes the onOpen callback`() {
+ val mockClosure = mock<() -> Unit>()
+ transport.onOpen = mockClosure
+
+ assertThat(transport.readyState).isEqualTo(Transport.ReadyState.CLOSED)
+
+ transport.onOpen(mockWebSocket, mockResponse)
+ assertThat(transport.readyState).isEqualTo(Transport.ReadyState.OPEN)
+ verify(mockClosure).invoke()
+ }
+
+ /* End OnOpen */
+ }
+
+
+ @Nested
+ @DisplayName("onFailure")
+ inner class OnFailure {
+ @Test
+ internal fun `sets ready state to CLOSED and invokes onError callback`() {
+ val mockClosure = mock<(Throwable, Response?) -> Unit>()
+ transport.onError = mockClosure
+
+ transport.readyState = Transport.ReadyState.CONNECTING
+
+ val throwable = Throwable()
+ transport.onFailure(mockWebSocket, throwable, mockResponse)
+ assertThat(transport.readyState).isEqualTo(Transport.ReadyState.CLOSED)
+ verify(mockClosure).invoke(throwable, mockResponse)
+ }
+
+ @Test
+ internal fun `also triggers onClose`() {
+ val mockOnError = mock<(Throwable, Response?) -> Unit>()
+ val mockOnClose = mock<(Int) -> Unit>()
+ transport.onClose = mockOnClose
+ transport.onError = mockOnError
+
+ val throwable = SocketException()
+ transport.onFailure(mockWebSocket, throwable, mockResponse)
+ verify(mockOnError).invoke(throwable, mockResponse)
+ verify(mockOnClose).invoke(1006)
+ }
+
+ /* End OnFailure */
+ }
+
+ @Nested
+ @DisplayName("onClosing")
+ inner class OnClosing {
+ @Test
+ internal fun `sets ready state to CLOSING`() {
+ transport.readyState = Transport.ReadyState.OPEN
+
+ transport.onClosing(mockWebSocket, 10, "reason")
+ assertThat(transport.readyState).isEqualTo(Transport.ReadyState.CLOSING)
+ }
+
+ /* End OnClosing */
+ }
+
+ @Nested
+ @DisplayName("onMessage")
+ inner class OnMessage {
+ @Test
+ internal fun `invokes onMessage closure`() {
+ val mockClosure = mock<(String) -> Unit>()
+ transport.onMessage = mockClosure
+
+ transport.onMessage(mockWebSocket, "text")
+ verify(mockClosure).invoke("text")
+ }
+
+ /* End OnMessage*/
+ }
+
+
+ @Nested
+ @DisplayName("onClosed")
+ inner class OnClosed {
+ @Test
+ internal fun `sets readyState to CLOSED and invokes closure`() {
+ val mockClosure = mock<(Int) -> Unit>()
+ transport.onClose = mockClosure
+
+ transport.readyState = Transport.ReadyState.CONNECTING
+
+ transport.onClosed(mockWebSocket, 10, "reason")
+ assertThat(transport.readyState).isEqualTo(Transport.ReadyState.CLOSED)
+ verify(mockClosure).invoke(10)
+ }
+
+ /* End OnClosed */
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt
new file mode 100644
index 0000000..c1ff3a6
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueue.kt
@@ -0,0 +1,73 @@
+package org.phoenixframework.queue
+
+import org.phoenixframework.DispatchQueue
+import org.phoenixframework.DispatchWorkItem
+import java.util.concurrent.TimeUnit
+
+class ManualDispatchQueue : DispatchQueue {
+
+ var tickTime: Long = 0
+ private val tickTimeUnit: TimeUnit = TimeUnit.MILLISECONDS
+ var workItems: MutableList = mutableListOf()
+
+ fun reset() {
+ this.tickTime = 0
+ this.workItems = mutableListOf()
+ }
+
+ fun tick(duration: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
+ val durationInMs = tickTimeUnit.convert(duration, unit)
+
+ // calculate what time to advance to
+ val advanceTo = tickTime + durationInMs
+
+ // Filter all work items that are due to be fired and have not been
+ // cancelled. Return early if there are no items to fire
+ var pastDueWorkItems = workItems.filter { it.isPastDue(advanceTo) && !it.isCancelled }.sorted()
+
+ // Keep looping until there are no more work items that are passed the advance to time
+ while (pastDueWorkItems.isNotEmpty()) {
+
+ val firstItem = pastDueWorkItems.first()
+ tickTime = firstItem.deadline
+ firstItem.perform()
+
+ // Remove all work items that are past due or canceled
+ workItems.removeAll { it.isPastDue(tickTime) || it.isCancelled }
+ pastDueWorkItems = workItems.filter { it.isPastDue(advanceTo) && !it.isCancelled }.sorted()
+ }
+
+ // Now that all work has been performed, advance the clock
+ this.tickTime = advanceTo
+
+ }
+
+ override fun queue(delay: Long, unit: TimeUnit, runnable: () -> Unit): DispatchWorkItem {
+ // Converts the given unit and delay to the unit used by this class
+ val delayInMs = tickTimeUnit.convert(delay, unit)
+ val deadline = tickTime + delayInMs
+
+ val workItem = ManualDispatchWorkItem(runnable, deadline)
+ workItems.add(workItem)
+
+ return workItem
+ }
+
+ override fun queueAtFixedRate(
+ delay: Long,
+ period: Long,
+ unit: TimeUnit,
+ runnable: () -> Unit
+ ): DispatchWorkItem {
+
+ val delayInMs = tickTimeUnit.convert(delay, unit)
+ val periodInMs = tickTimeUnit.convert(period, unit)
+ val deadline = tickTime + delayInMs
+
+ val workItem =
+ ManualDispatchWorkItem(runnable, deadline, periodInMs)
+ workItems.add(workItem)
+
+ return workItem
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt
new file mode 100644
index 0000000..02c18cc
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchQueueTest.kt
@@ -0,0 +1,208 @@
+package org.phoenixframework.queue
+
+import com.google.common.truth.Truth.assertThat
+import org.junit.jupiter.api.AfterEach
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.Test
+import java.util.concurrent.TimeUnit
+
+internal class ManualDispatchQueueTest {
+
+
+ private lateinit var queue: ManualDispatchQueue
+
+ @BeforeEach
+ internal fun setUp() {
+ queue = ManualDispatchQueue()
+ }
+
+ @AfterEach
+ internal fun tearDown() {
+ queue.reset()
+ }
+
+ @Test
+ internal fun `reset the queue`() {
+ var task100Called = false
+ var task200Called = false
+ var task300Called = false
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+ }
+ queue.queue(200, TimeUnit.MILLISECONDS) {
+ task200Called = true
+ }
+ queue.queue(300, TimeUnit.MILLISECONDS) {
+ task300Called = true
+ }
+
+ queue.tick(250)
+
+ assertThat(queue.tickTime).isEqualTo(250)
+ assertThat(queue.workItems).hasSize(1)
+
+ queue.reset()
+ assertThat(queue.tickTime).isEqualTo(0)
+ assertThat(queue.workItems).isEmpty()
+ }
+
+ @Test
+ internal fun `triggers work that is passed due`() {
+ var task100Called = false
+ var task200Called = false
+ var task300Called = false
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+ }
+ queue.queue(200, TimeUnit.MILLISECONDS) {
+ task200Called = true
+ }
+ queue.queue(300, TimeUnit.MILLISECONDS) {
+ task300Called = true
+ }
+
+ queue.tick(100)
+ assertThat(task100Called).isTrue()
+
+ queue.tick(100)
+ assertThat(task200Called).isTrue()
+
+ queue.tick(50)
+ assertThat(task300Called).isFalse()
+ }
+
+ @Test
+ internal fun `triggers all work that is passed due`() {
+ var task100Called = false
+ var task200Called = false
+ var task300Called = false
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+ }
+ queue.queue(200, TimeUnit.MILLISECONDS) {
+ task200Called = true
+ }
+ queue.queue(300, TimeUnit.MILLISECONDS) {
+ task300Called = true
+ }
+
+ queue.tick(250)
+ assertThat(task100Called).isTrue()
+ assertThat(task200Called).isTrue()
+ assertThat(task300Called).isFalse()
+ }
+
+ @Test
+ internal fun `triggers work that is scheduled for a time that is after tick`() {
+ var task100Called = false
+ var task200Called = false
+ var task300Called = false
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task200Called = true
+ }
+
+ }
+
+ queue.queue(300, TimeUnit.MILLISECONDS) {
+ task300Called = true
+ }
+
+ queue.tick(250)
+ assertThat(task100Called).isTrue()
+ assertThat(task200Called).isTrue()
+ assertThat(task300Called).isFalse()
+
+ assertThat(queue.tickTime).isEqualTo(250)
+ }
+
+ @Test
+ internal fun `triggers work in order of deadline`() {
+ var task200Called = false
+ var task100Called = false
+
+
+ val task200 = queue.queue(200, TimeUnit.MILLISECONDS) {
+ task200Called = true
+ }
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+ task200.cancel()
+ }
+
+
+ queue.tick(300)
+ assertThat(task100Called).isTrue()
+ assertThat(task200Called).isFalse()
+ }
+
+ @Test
+ internal fun `triggers inserted work in order of deadline`() {
+ var task500Called = false
+ var task200Called = false
+ var task100Called = false
+
+ val task500 = queue.queue(500, TimeUnit.MILLISECONDS) {
+ task500Called = true
+ }
+
+ queue.queue(200, TimeUnit.MILLISECONDS) {
+ task200Called = true
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+ task500.cancel()
+ }
+ }
+
+ queue.tick(600)
+ assertThat(task100Called).isTrue()
+ assertThat(task200Called).isTrue()
+ assertThat(task500Called).isFalse()
+ }
+
+
+
+ @Test
+ internal fun `does not triggers nested work that is scheduled outside of the tick`() {
+ var task100Called = false
+ var task200Called = false
+ var task300Called = false
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task100Called = true
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task200Called = true
+
+ queue.queue(100, TimeUnit.MILLISECONDS) {
+ task300Called = true
+ }
+ }
+ }
+
+ queue.tick(250)
+ assertThat(task100Called).isTrue()
+ assertThat(task200Called).isTrue()
+ assertThat(task300Called).isFalse()
+ }
+
+ @Test
+ internal fun `queueAtFixedRate repeats work`() {
+ var repeatTaskCallCount = 0
+
+ queue.queueAtFixedRate(100, 100, TimeUnit.MILLISECONDS) {
+ repeatTaskCallCount += 1
+ }
+
+ queue.tick(500)
+ assertThat(repeatTaskCallCount).isEqualTo(5)
+ }
+}
\ No newline at end of file
diff --git a/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt
new file mode 100644
index 0000000..42d47eb
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/queue/ManualDispatchWorkItem.kt
@@ -0,0 +1,37 @@
+package org.phoenixframework.queue
+
+import org.phoenixframework.DispatchWorkItem
+
+class ManualDispatchWorkItem(
+ private val runnable: () -> Unit,
+ var deadline: Long,
+ private val period: Long = 0
+) : DispatchWorkItem, Comparable {
+
+ private var performCount = 0
+
+ // Test
+ fun isPastDue(tickTime: Long): Boolean {
+ return this.deadline <= tickTime
+ }
+
+ fun perform() {
+ if (isCancelled) return
+ runnable.invoke()
+ performCount += 1
+
+ // If the task is repeatable, then schedule the next deadline after the given period
+ deadline += period
+ }
+
+ // DispatchWorkItem
+ override var isCancelled: Boolean = false
+
+ override fun cancel() {
+ this.isCancelled = true
+ }
+
+ override fun compareTo(other: ManualDispatchWorkItem): Int {
+ return deadline.compareTo(other.deadline)
+ }
+}
diff --git a/src/test/kotlin/org/phoenixframework/utilities/TestUtilities.kt b/src/test/kotlin/org/phoenixframework/utilities/TestUtilities.kt
new file mode 100644
index 0000000..015e95e
--- /dev/null
+++ b/src/test/kotlin/org/phoenixframework/utilities/TestUtilities.kt
@@ -0,0 +1,8 @@
+package org.phoenixframework.utilities
+
+import org.phoenixframework.Binding
+import org.phoenixframework.Channel
+
+fun Channel.getBindings(event: String): List {
+ return bindings.toList().filter { it.event == event }
+}
\ No newline at end of file
diff --git a/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker
new file mode 100644
index 0000000..1f0955d
--- /dev/null
+++ b/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker
@@ -0,0 +1 @@
+mock-maker-inline