diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt index 4c3f96668..24c36154b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt @@ -18,6 +18,7 @@ package io.rsocket.kotlin.core import io.ktor.utils.io.core.* import io.rsocket.kotlin.* +import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* @@ -97,30 +98,33 @@ private class ReconnectableRSocket( private val state: StateFlow, ) : RSocket { - private val reconnectHandler = state.mapNotNull { it.handleState { null } }.take(1) + private val reconnectHandler = state.mapNotNull { it.current() }.take(1) - //null pointer will never happen - private suspend fun currentRSocket(): RSocket = state.value.handleState { reconnectHandler.first() }!! + private suspend fun currentRSocket(closeable: Closeable): RSocket = closeable.closeOnError { currentRSocket() } - private inline fun ReconnectState.handleState(onReconnect: () -> RSocket?): RSocket? = when (this) { - is ReconnectState.Connected -> when { - rSocket.isActive -> rSocket //connection is ready to handle requests - else -> onReconnect() //reconnection - } + private suspend fun currentRSocket(): RSocket = state.value.current() ?: reconnectHandler.first() + + private fun ReconnectState.current(): RSocket? = when (this) { + is ReconnectState.Connected -> rSocket.takeIf(RSocket::isActive) //connection is ready to handle requests is ReconnectState.Failed -> throw error //connection failed - fail requests - ReconnectState.Connecting -> onReconnect() //reconnection + ReconnectState.Connecting -> null //reconnection } - private suspend inline fun execSuspend(operation: RSocket.() -> T): T = - currentRSocket().operation() + override suspend fun metadataPush(metadata: ByteReadPacket): Unit = + currentRSocket(metadata).metadataPush(metadata) + + override suspend fun fireAndForget(payload: Payload): Unit = + currentRSocket(payload).fireAndForget(payload) - private inline fun execFlow(crossinline operation: RSocket.() -> Flow): Flow = - flow { emitAll(currentRSocket().operation()) } + override suspend fun requestResponse(payload: Payload): Payload = + currentRSocket(payload).requestResponse(payload) - override suspend fun metadataPush(metadata: ByteReadPacket): Unit = execSuspend { metadataPush(metadata) } - override suspend fun fireAndForget(payload: Payload): Unit = execSuspend { fireAndForget(payload) } - override suspend fun requestResponse(payload: Payload): Payload = execSuspend { requestResponse(payload) } - override fun requestStream(payload: Payload): Flow = execFlow { requestStream(payload) } - override fun requestChannel(payloads: Flow): Flow = execFlow { requestChannel(payloads) } + override fun requestStream(payload: Payload): Flow = flow { + emitAll(currentRSocket(payload).requestStream(payload)) + } + + override fun requestChannel(payloads: Flow): Flow = flow { + emitAll(currentRSocket().requestChannel(payloads)) + } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt index 809595ca4..a1b61e07c 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt @@ -17,6 +17,7 @@ package io.rsocket.kotlin.core import app.cash.turbine.* +import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* @@ -54,7 +55,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { val connect: suspend () -> RSocket = { if (first.value) { first.value = false - rrHandler(firstJob) + handler(firstJob) } else { error("Failed to connect") } @@ -89,7 +90,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { first.value = false error("Failed to connect") } else { - rrHandler(handlerJob) + handler(handlerJob) } } val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> @@ -114,7 +115,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { error("Failed to connect") } else { delay(200) //emulate connection establishment - rrHandler(Job()) + handler(Job()) } } val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> @@ -137,13 +138,13 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { when { first.value -> { first.value = false - rrHandler(firstJob) //first connection + handler(firstJob) //first connection } fails.value < 5 -> { delay(100) error("Failed to connect") } - else -> rrHandler(Job()) + else -> handler(Job()) } } val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> @@ -170,13 +171,13 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { when { first.value -> { first.value = false - streamHandler(firstJob) //first connection + handler(firstJob) //first connection } fails.value < 5 -> { delay(100) error("Failed to connect") } - else -> streamHandler(Job()) + else -> handler(Job()) } } val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> @@ -206,8 +207,52 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { assertEquals(5, fails.value) } - private fun rrHandler(job: Job): RSocket = RSocketRequestHandler(job) { requestResponse { it } } - private fun streamHandler(job: Job): RSocket = RSocketRequestHandler(job) { + @Test + fun testNoLeakMetadataPush() = testNoLeaksInteraction { metadataPush(it.data) } + + @Test + fun testNoLeakFireAndForget() = testNoLeaksInteraction { fireAndForget(it) } + + @Test + fun testNoLeakRequestResponse() = testNoLeaksInteraction { requestResponse(it) } + + @Test + fun testNoLeakRequestStream() = testNoLeaksInteraction { requestStream(it).collect() } + + private inline fun testNoLeaksInteraction(crossinline interaction: suspend RSocket.(payload: Payload) -> Unit) = test { + val firstJob = Job() + val connect: suspend () -> RSocket = { + if (first.compareAndSet(true, false)) { + handler(firstJob) + } else { + error("Failed to connect") + } + } + val rSocket = ReconnectableRSocket(logger, connect) { _, attempt -> + delay(100) + attempt < 5 + } + + rSocket.requestResponse(Payload.Empty) //first request to be sure, that connected + firstJob.cancelAndJoin() //cancel + + val p = payload("text") + assertFails { + rSocket.interaction(p) //test release on reconnecting + } + assertTrue(p.data.isEmpty) + + val p2 = payload("text") + assertFails { + rSocket.interaction(p2) //test release on failed + } + assertTrue(p2.data.isEmpty) + } + + private fun handler(job: Job): RSocket = RSocketRequestHandler(job) { + requestResponse { payload -> + payload + } requestStream { flow { repeat(5) {