diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt index 64e92c827..9ae322d85 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt @@ -41,6 +41,5 @@ internal suspend fun Connection.receiveFrame(): Frame = receive().readFrame(pool @OptIn(TransportApi::class) internal suspend fun Connection.sendFrame(frame: Frame) { - val packet = frame.toPacket(pool) - packet.closeOnError { send(packet) } + frame.toPacket(pool).closeOnError { send(it) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt index fbb0a3f72..37989d720 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt @@ -22,6 +22,7 @@ import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* @OptIn(TransportApi::class, RSocketLoggingApi::class) public class RSocketConnector internal constructor( @@ -43,18 +44,34 @@ public class RSocketConnector internal constructor( private suspend fun connectOnce(transport: ClientTransport): RSocket { val connection = transport.connect().wrapConnection() - val connectionConfig = connectionConfigProvider() - - return connection.connect(isServer = false, interceptors, connectionConfig, acceptor) { - val setupFrame = SetupFrame( - version = Version.Current, - honorLease = false, - keepAlive = connectionConfig.keepAlive, - resumeToken = null, - payloadMimeType = connectionConfig.payloadMimeType, - payload = connectionConfig.setupPayload + val connectionConfig = try { + connectionConfigProvider() + } catch (cause: Throwable) { + connection.job.cancel("Connection config provider failed", cause) + throw cause + } + val setupFrame = SetupFrame( + version = Version.Current, + honorLease = false, + keepAlive = connectionConfig.keepAlive, + resumeToken = null, + payloadMimeType = connectionConfig.payloadMimeType, + payload = connectionConfig.setupPayload.copy() //copy needed, as it can be used in acceptor + ) + try { + val requester = connection.connect( + isServer = false, + interceptors = interceptors, + connectionConfig = connectionConfig, + acceptor = acceptor ) connection.sendFrame(setupFrame) + return requester + } catch (cause: Throwable) { + connectionConfig.setupPayload.release() + setupFrame.release() + connection.job.cancel("Connection establishment failed", cause) + throw cause } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt index 7c027b160..47c157659 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt @@ -103,7 +103,10 @@ public class RSocketConnectorBuilder internal constructor() { ) private companion object { - private val defaultAcceptor: ConnectionAcceptor = ConnectionAcceptor { EmptyRSocket() } + private val defaultAcceptor: ConnectionAcceptor = ConnectionAcceptor { + config.setupPayload.release() + EmptyRSocket() + } private class EmptyRSocket : RSocket { override val job: Job = Job() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt index 3c94e03a7..d12ba1e2d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt @@ -33,44 +33,40 @@ public class RSocketServer internal constructor( public fun bind( transport: ServerTransport, acceptor: ConnectionAcceptor, - ): T = transport.start { - val connection = it.wrapConnection() - val setupFrame = connection.validateSetup() - connection.start(setupFrame, acceptor) - connection.job.join() - } - - private suspend fun Connection.start(setupFrame: SetupFrame, acceptor: ConnectionAcceptor) { - val connectionConfig = ConnectionConfig( - keepAlive = setupFrame.keepAlive, - payloadMimeType = setupFrame.payloadMimeType, - setupPayload = setupFrame.payload - ) - try { - connect(isServer = true, interceptors, connectionConfig, acceptor) - } catch (e: Throwable) { - failSetup(RSocketError.Setup.Rejected(e.message ?: "Rejected by server acceptor")) - } - } + ): T = transport.start { it.wrapConnection().bind(acceptor).join() } - private suspend fun Connection.validateSetup(): SetupFrame { - val setupFrame = receiveFrame() - return when { + private suspend fun Connection.bind(acceptor: ConnectionAcceptor): Job = receiveFrame().closeOnError { setupFrame -> + when { setupFrame !is SetupFrame -> failSetup(RSocketError.Setup.Invalid("Invalid setup frame: ${setupFrame.type}")) setupFrame.version != Version.Current -> failSetup(RSocketError.Setup.Unsupported("Unsupported version: ${setupFrame.version}")) setupFrame.honorLease -> failSetup(RSocketError.Setup.Unsupported("Lease is not supported")) setupFrame.resumeToken != null -> failSetup(RSocketError.Setup.Unsupported("Resume is not supported")) - else -> setupFrame + else -> try { + connect( + isServer = true, + interceptors = interceptors, + connectionConfig = ConnectionConfig( + keepAlive = setupFrame.keepAlive, + payloadMimeType = setupFrame.payloadMimeType, + setupPayload = setupFrame.payload + ), + acceptor = acceptor + ) + job + } catch (e: Throwable) { + failSetup(RSocketError.Setup.Rejected(e.message ?: "Rejected by server acceptor")) + } } } - private fun Connection.wrapConnection(): Connection = - interceptors.wrapConnection(this) - .logging(loggerFactory.logger("io.rsocket.kotlin.frame")) - private suspend fun Connection.failSetup(error: RSocketError.Setup): Nothing { sendFrame(ErrorFrame(0, error)) job.cancel("Connection establishment failed", error) throw error } + + private fun Connection.wrapConnection(): Connection = + interceptors.wrapConnection(this) + .logging(loggerFactory.logger("io.rsocket.kotlin.frame")) + } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt index 8d6b76352..cfc79dcac 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt @@ -17,31 +17,29 @@ package io.rsocket.kotlin.internal import io.ktor.utils.io.core.* -import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlin.native.concurrent.* -internal inline fun Closeable.closeOnError(block: () -> T): T { +internal inline fun T.closeOnError(block: (T) -> R): R { try { - return block() + return block(this) } catch (e: Throwable) { close() throw e } } -internal fun ReceiveChannel<*>.cancelConsumed(cause: Throwable?) { - cancel(cause?.let { it as? CancellationException ?: CancellationException("Channel was consumed, consumer had failed", it) }) -} - @SharedImmutable private val onUndeliveredCloseable: (Closeable) -> Unit = Closeable::close @Suppress("FunctionName") internal fun SafeChannel(capacity: Int): Channel = Channel(capacity, onUndeliveredElement = onUndeliveredCloseable) -internal fun SendChannel.safeOffer(element: E) { - trySend(element) - .onFailure { element.close() } - .getOrThrow() //TODO +internal fun SendChannel.safeTrySend(element: E) { + trySend(element).onFailure { element.close() } +} + +internal fun Channel.fullClose(cause: Throwable?) { + close(cause) // close channel to provide right cause + cancel() // force call of onUndeliveredElement to release buffered elements } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt index a7fa760eb..b5fc409f2 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt @@ -16,22 +16,80 @@ package io.rsocket.kotlin.internal +import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* +import io.rsocket.kotlin.frame.* +import kotlinx.coroutines.* @OptIn(TransportApi::class) internal suspend inline fun Connection.connect( isServer: Boolean, interceptors: Interceptors, connectionConfig: ConnectionConfig, - acceptor: ConnectionAcceptor, - beforeStart: () -> Unit = {}, + acceptor: ConnectionAcceptor ): RSocket { - val state = RSocketState(this, connectionConfig.keepAlive) - val requester = RSocketRequester(state, StreamId(isServer)).let(interceptors::wrapRequester) - val connectionContext = ConnectionAcceptorContext(connectionConfig, requester) - val requestHandler = with(interceptors.wrapAcceptor(acceptor)) { connectionContext.accept() }.let(interceptors::wrapResponder) - beforeStart() - state.start(requestHandler) + val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive) + val prioritizer = Prioritizer() + val streamsStorage = StreamsStorage(isServer) + val requestJob = SupervisorJob(job) + + requestJob.invokeOnCompletion { + prioritizer.close(it) + streamsStorage.cleanup(it) + connectionConfig.setupPayload.release() + } + + val requestScope = CoroutineScope(requestJob + Dispatchers.Unconfined + CoroutineExceptionHandler { _, _ -> }) + val connectionScope = CoroutineScope(job + Dispatchers.Unconfined + CoroutineExceptionHandler { _, _ -> }) + + val requester = interceptors.wrapRequester(RSocketRequester(job, prioritizer, streamsStorage, requestScope)) + val requestHandler = interceptors.wrapResponder( + with(interceptors.wrapAcceptor(acceptor)) { + ConnectionAcceptorContext(connectionConfig, requester).accept() + } + ) + + // link completing of connection and requestHandler + job.invokeOnCompletion { requestHandler.job.cancel("Connection closed", it) } + requestHandler.job.invokeOnCompletion { if (it != null) job.cancel("Request handler failed", it) } + + // start keepalive ticks + connectionScope.launch { + while (isActive) { + keepAliveHandler.tick() + prioritizer.send(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) + } + } + + // start sending frames to connection + connectionScope.launch { + while (isActive) { + sendFrame(prioritizer.receive()) + } + } + + // start frame handling + connectionScope.launch { + val rSocketResponder = RSocketResponder(prioritizer, requestHandler, requestScope) + while (isActive) { + receiveFrame().closeOnError { frame -> + when (frame.streamId) { + 0 -> when (frame) { + is MetadataPushFrame -> rSocketResponder.handleMetadataPush(frame.metadata) + is ErrorFrame -> job.cancel("Error frame received on 0 stream", frame.throwable) + is KeepAliveFrame -> { + keepAliveHandler.mark() + if (frame.respond) prioritizer.send(KeepAliveFrame(false, 0, frame.data)) else Unit + } + is LeaseFrame -> frame.release().also { error("lease isn't implemented") } + else -> frame.release() + } + else -> streamsStorage.handleFrame(frame, rSocketResponder) + } + } + } + } + return requester } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt index 072e3b235..cb428f94c 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt @@ -16,39 +16,23 @@ package io.rsocket.kotlin.internal -import io.ktor.utils.io.core.* import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.keepalive.* import kotlinx.atomicfu.* import kotlinx.coroutines.* -internal class KeepAliveHandler( - private val keepAlive: KeepAlive, - private val offerFrame: (frame: Frame) -> Unit, -) { +internal class KeepAliveHandler(private val keepAlive: KeepAlive) { + private val lastMark = atomic(currentMillis()) // mark initial timestamp for keepalive - private val lastMark = atomic(currentMillis()) - - fun receive(frame: KeepAliveFrame) { + fun mark() { lastMark.value = currentMillis() - if (frame.respond) { - offerFrame(KeepAliveFrame(false, 0, frame.data)) - } } - fun startIn(scope: CoroutineScope) { - scope.launch { - while (isActive) { - delay(keepAlive.intervalMillis.toLong()) - if (currentMillis() - lastMark.value >= keepAlive.maxLifetimeMillis) { - //for K/N - scope.cancel("Keep alive failed", RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms")) - break - } - offerFrame(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) - } - } + // return boolean because of native + suspend fun tick() { + delay(keepAlive.intervalMillis.toLong()) + if (currentMillis() - lastMark.value < keepAlive.maxLifetimeMillis) return + throw RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms") } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt similarity index 75% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt index 60ee453b4..0eeaad3dc 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt @@ -14,21 +14,25 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal.flow +package io.rsocket.kotlin.internal -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlin.coroutines.* -internal class LimitingFlowCollector( - private val state: RSocketState, - private val streamId: Int, - initial: Int, -) : FlowCollector { +internal suspend inline fun Flow.collectLimiting(limiter: Limiter, crossinline action: suspend (value: Payload) -> Unit) { + collect { payload -> + payload.closeOnError { + limiter.useRequest() + action(it) + } + } +} + +//TODO revisit 2 atomics +internal class Limiter(initial: Int) { private val requests = atomic(initial) private val awaiter = atomic?>(null) @@ -38,12 +42,7 @@ internal class LimitingFlowCollector( awaiter.getAndSet(null)?.takeIf(CancellableContinuation::isActive)?.resume(Unit) } - override suspend fun emit(value: Payload): Unit = value.closeOnError { - useRequest() - state.send(NextPayloadFrame(streamId, value)) - } - - private suspend fun useRequest() { + suspend fun useRequest() { if (requests.getAndDecrement() > 0) { currentCoroutineContext().ensureActive() } else { diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt index e283e7ce2..750ad6381 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt @@ -37,7 +37,7 @@ private class LoggingConnection( } override suspend fun send(packet: ByteReadPacket) { - logger.debug { "Send: ${packet.dumpFrameToString()}" } + logger.debug { "Send: ${packet.dumpFrameToString()}" } delegate.send(packet) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt index 35ebfcd1e..54ee83496 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt @@ -20,30 +20,32 @@ import io.rsocket.kotlin.frame.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* +import kotlin.native.concurrent.* + +@SharedImmutable +private val selectFrame: suspend (Frame) -> Frame = { it } internal class Prioritizer { private val priorityChannel = SafeChannel(Channel.UNLIMITED) private val commonChannel = SafeChannel(Channel.UNLIMITED) - fun send(frame: Frame) { - commonChannel.safeOffer(frame) - } - - fun sendPrioritized(frame: Frame) { - priorityChannel.safeOffer(frame) + suspend fun send(frame: Frame) { + if (frame.type != FrameType.Cancel && frame.type != FrameType.Error) currentCoroutineContext().ensureActive() + val channel = if (frame.streamId == 0) priorityChannel else commonChannel + channel.send(frame) } suspend fun receive(): Frame { priorityChannel.tryReceive().onSuccess { return it } commonChannel.tryReceive().onSuccess { return it } return select { - priorityChannel.onReceive { it } - commonChannel.onReceive { it } + priorityChannel.onReceive(selectFrame) + commonChannel.onReceive(selectFrame) } } - fun cancel(error: CancellationException) { - priorityChannel.cancel(error) - commonChannel.cancel(error) + fun close(error: Throwable?) { + priorityChannel.fullClose(error) + commonChannel.fullClose(error) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt index 929238925..cabcc9d5d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt @@ -19,48 +19,122 @@ package io.rsocket.kotlin.internal import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.flow.* +import io.rsocket.kotlin.internal.handler.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* +@OptIn(ExperimentalStreamsApi::class) internal class RSocketRequester( - private val state: RSocketState, - private val streamId: StreamId, + connectionJob: Job, + private val prioritizer: Prioritizer, + private val streamsStorage: StreamsStorage, + private val requestScope: CoroutineScope ) : RSocket { - override val job: Job get() = state.job + override val job: Job = connectionJob - override suspend fun metadataPush(metadata: ByteReadPacket): Unit = metadata.closeOnError { - job.ensureActive() - state.sendPrioritized(MetadataPushFrame(metadata)) + override suspend fun metadataPush(metadata: ByteReadPacket) { + ensureActiveOrRelease(metadata) + metadata.closeOnError { + prioritizer.send(MetadataPushFrame(metadata)) + } } - override suspend fun fireAndForget(payload: Payload): Unit = payload.closeOnError { - val streamId = createStream() - state.send(RequestFireAndForgetFrame(streamId, payload)) + override suspend fun fireAndForget(payload: Payload) { + ensureActiveOrRelease(payload) + + val streamId = streamsStorage.nextId() + try { + prioritizer.send(RequestFireAndForgetFrame(streamId, payload)) + } catch (cause: Throwable) { + payload.release() + if (job.isActive) prioritizer.send(CancelFrame(streamId)) //if cancelled during fragmentation + throw cause + } } - override suspend fun requestResponse(payload: Payload): Payload = with(state) { - payload.closeOnError { - val streamId = createStream() - val receiver = createReceiverFor(streamId) - send(RequestResponseFrame(streamId, payload)) - consumeReceiverFor(streamId) { - receiver.receive().payload //TODO fragmentation - } + override suspend fun requestResponse(payload: Payload): Payload { + ensureActiveOrRelease(payload) + + val streamId = streamsStorage.nextId() + + val deferred = CompletableDeferred() + val handler = RequesterRequestResponseFrameHandler(streamId, streamsStorage, deferred) + streamsStorage.save(streamId, handler) + + return handler.receiveOrCancel(streamId, payload) { + prioritizer.send(RequestResponseFrame(streamId, payload)) + deferred.await() } } - override fun requestStream(payload: Payload): Flow = RequestStreamRequesterFlow(payload, this, state) + override fun requestStream(payload: Payload): Flow = requestFlow { strategy, initialRequest -> + ensureActiveOrRelease(payload) - override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = - RequestChannelRequesterFlow(initPayload, payloads, this, state) + val streamId = streamsStorage.nextId() - fun createStream(): Int { - job.ensureActive() - return nextStreamId() + val channel = SafeChannel(Channel.UNLIMITED) + val handler = RequesterRequestStreamFrameHandler(streamId, streamsStorage, channel) + streamsStorage.save(streamId, handler) + + handler.receiveOrCancel(streamId, payload) { + prioritizer.send(RequestStreamFrame(streamId, initialRequest, payload)) + emitAllWithRequestN(channel, strategy) { prioritizer.send(RequestNFrame(streamId, it)) } + } } - private fun nextStreamId(): Int = streamId.next(state.receivers) + override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = requestFlow { strategy, initialRequest -> + ensureActiveOrRelease(initPayload) + + val streamId = streamsStorage.nextId() + + val channel = SafeChannel(Channel.UNLIMITED) + val limiter = Limiter(0) + val sender = Job(requestScope.coroutineContext.job) + val handler = RequesterRequestChannelFrameHandler(streamId, streamsStorage, limiter, sender, channel) + streamsStorage.save(streamId, handler) + handler.receiveOrCancel(streamId, initPayload) { + prioritizer.send(RequestChannelFrame(streamId, initialRequest, initPayload)) + //TODO lazy? + requestScope.launch(sender) { + handler.sendOrFail(streamId) { + payloads.collectLimiting(limiter) { prioritizer.send(NextPayloadFrame(streamId, it)) } + prioritizer.send(CompletePayloadFrame(streamId)) + } + } + emitAllWithRequestN(channel, strategy) { prioritizer.send(RequestNFrame(streamId, it)) } + } + } + + private suspend inline fun SendFrameHandler.sendOrFail(id: Int, block: () -> Unit) { + try { + block() + onSendComplete() + } catch (cause: Throwable) { + val isFailed = onSendFailed(cause) + if (job.isActive && isFailed) prioritizer.send(ErrorFrame(id, cause)) + throw cause + } + } + + private suspend inline fun ReceiveFrameHandler.receiveOrCancel(id: Int, payload: Payload, block: () -> T): T { + try { + val result = block() + onReceiveComplete() + return result + } catch (cause: Throwable) { + payload.release() + val isCancelled = onReceiveCancelled(cause) + if (job.isActive && isCancelled) prioritizer.send(CancelFrame(id)) + throw cause + } + } + + private fun ensureActiveOrRelease(closeable: Closeable) { + if (job.isActive) return + closeable.close() + job.ensureActive() + } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt index 6e92b40c4..f5e8a878c 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt @@ -16,82 +16,88 @@ package io.rsocket.kotlin.internal +import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.flow.* +import io.rsocket.kotlin.internal.handler.* +import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* +@OptIn(ExperimentalStreamsApi::class) internal class RSocketResponder( - private val state: RSocketState, + private val prioritizer: Prioritizer, private val requestHandler: RSocket, + private val requestScope: CoroutineScope, ) { - fun handleMetadataPush(frame: MetadataPushFrame) { - state.launch { - requestHandler.metadataPush(frame.metadata) - }.invokeOnCompletion { - frame.release() + private fun Job.closeOnCompletion(closeable: Closeable): Job { + invokeOnCompletion { + closeable.close() } + return this } - fun handleFireAndForget(frame: RequestFrame) { - state.launch { - requestHandler.fireAndForget(frame.payload) - }.invokeOnCompletion { - frame.release() - } - } + fun handleMetadataPush(metadata: ByteReadPacket): Job = requestScope.launch { + requestHandler.metadataPush(metadata) + }.closeOnCompletion(metadata) - fun handlerRequestResponse(frame: RequestFrame): Unit = with(state) { - val streamId = frame.streamId - launchCancelable(streamId) { - val response = requestOrCancel(streamId) { - requestHandler.requestResponse(frame.payload) - } ?: return@launchCancelable - if (isActive) send(NextCompletePayloadFrame(streamId, response)) - }.invokeOnCompletion { - frame.release() + fun handleFireAndForget(payload: Payload, handler: ResponderFireAndForgetFrameHandler): Job = requestScope.launch { + try { + requestHandler.fireAndForget(payload) + } finally { + handler.onSendComplete() } - } + }.closeOnCompletion(payload) - fun handleRequestStream(initFrame: RequestFrame): Unit = with(state) { - val streamId = initFrame.streamId - launchCancelable(streamId) { - val response = requestOrCancel(streamId) { - requestHandler.requestStream(initFrame.payload) - } ?: return@launchCancelable - response.collectLimiting(streamId, initFrame.initialRequest) - }.invokeOnCompletion { - initFrame.release() + fun handleRequestResponse(payload: Payload, id: Int, handler: ResponderRequestResponseFrameHandler): Job = requestScope.launch { + handler.sendOrFail(id, payload) { + val response = requestHandler.requestResponse(payload) + prioritizer.send(NextCompletePayloadFrame(id, response)) } - } + }.closeOnCompletion(payload) - fun handleRequestChannel(initFrame: RequestFrame): Unit = with(state) { - val streamId = initFrame.streamId - val receiver = createReceiverFor(streamId) + fun handleRequestStream(payload: Payload, id: Int, handler: ResponderRequestStreamFrameHandler): Job = requestScope.launch { + handler.sendOrFail(id, payload) { + requestHandler.requestStream(payload).collectLimiting(handler.limiter) { prioritizer.send(NextPayloadFrame(id, it)) } + prioritizer.send(CompletePayloadFrame(id)) + } + }.closeOnCompletion(payload) - val request = RequestChannelResponderFlow(streamId, receiver, state) + fun handleRequestChannel(payload: Payload, id: Int, handler: ResponderRequestChannelFrameHandler): Job = requestScope.launch { + val payloads = requestFlow { strategy, initialRequest -> + handler.receiveOrCancel(id) { + prioritizer.send(RequestNFrame(id, initialRequest)) + emitAllWithRequestN(handler.channel, strategy) { prioritizer.send(RequestNFrame(id, it)) } + } + } + handler.sendOrFail(id, payload) { + requestHandler.requestChannel(payload, payloads).collectLimiting(handler.limiter) { prioritizer.send(NextPayloadFrame(id, it)) } + prioritizer.send(CompletePayloadFrame(id)) + } + }.closeOnCompletion(payload) - launchCancelable(streamId) { - val response = requestOrCancel(streamId) { - requestHandler.requestChannel(initFrame.payload, request) - } ?: return@launchCancelable - response.collectLimiting(streamId, initFrame.initialRequest) - }.invokeOnCompletion { - initFrame.release() - if (it != null) receiver.cancelConsumed(it) //TODO check it + private suspend inline fun SendFrameHandler.sendOrFail(id: Int, payload: Payload, block: () -> Unit) { + try { + block() + onSendComplete() + } catch (cause: Throwable) { + val isFailed = onSendFailed(cause) + if (currentCoroutineContext().isActive && isFailed) prioritizer.send(ErrorFrame(id, cause)) + throw cause + } finally { + payload.release() } } - private inline fun CoroutineScope.requestOrCancel(streamId: Int, block: () -> T): T? = + private suspend inline fun ReceiveFrameHandler.receiveOrCancel(id: Int, block: () -> Unit) { try { block() - } catch (e: Throwable) { - if (isActive) { - state.send(ErrorFrame(streamId, e)) - cancel("Request handling failed", e) //KLUDGE: can be related to IR, because using `throw` fails on JS IR and Native - } - null + onReceiveComplete() + } catch (cause: Throwable) { + val isCancelled = onReceiveCancelled(cause) + if (requestScope.isActive && isCancelled) prioritizer.send(CancelFrame(id)) + throw cause } + } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt deleted file mode 100644 index bf10a243d..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.flow.* -import io.rsocket.kotlin.keepalive.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.flow.* - -@OptIn( - TransportApi::class, - ExperimentalStreamsApi::class -) -internal class RSocketState( - private val connection: Connection, - keepAlive: KeepAlive, -) { - val job get() = connection.job - private val requestJob = SupervisorJob(job) - private val requestScope = CoroutineScope(requestJob) - private val scope = CoroutineScope(job) - - private val prioritizer = Prioritizer() - val receivers: IntMap> = IntMap() - private val senders: IntMap = IntMap() - private val limits: IntMap = IntMap() - - private val keepAliveHandler = KeepAliveHandler(keepAlive, this::sendPrioritized) - - fun send(frame: Frame) { - prioritizer.send(frame) - } - - fun sendPrioritized(frame: Frame) { - prioritizer.sendPrioritized(frame) - } - - fun createReceiverFor(streamId: Int): ReceiveChannel { - val receiver = SafeChannel(Channel.UNLIMITED) - receivers[streamId] = receiver - return receiver - } - - inline fun consumeReceiverFor(streamId: Int, block: () -> R): R { - var cause: Throwable? = null - try { - return block() - } catch (e: Throwable) { - cause = e - throw e - } finally { - if (job.isActive && streamId in receivers) { - if (cause != null) send(CancelFrame(streamId)) - receivers.remove(streamId)?.cancelConsumed(cause) - } - } - } - - suspend fun collectStream( - streamId: Int, - receiver: ReceiveChannel, - strategy: RequestStrategy.Element, - collector: FlowCollector, - ): Unit = consumeReceiverFor(streamId) { - //TODO fragmentation - for (frame in receiver) { - if (frame.complete) return //TODO check next flag - collector.emitOrClose(frame.payload) - val next = strategy.nextRequest() - if (next > 0) send(RequestNFrame(streamId, next)) - } - } - - suspend inline fun Flow.collectLimiting( - streamId: Int, - initialRequest: Int, - crossinline onStart: () -> Unit = {}, - ): Unit = coroutineScope { - val limitingCollector = LimitingFlowCollector(this@RSocketState, streamId, initialRequest) - limits[streamId] = limitingCollector - try { - onStart() - limitingCollector.emitAll(this@collectLimiting) - send(CompletePayloadFrame(streamId)) - } catch (e: Throwable) { - limits.remove(streamId) - //if isn't active, then, that stream was cancelled, and so no need for error frame - if (isActive) send(ErrorFrame(streamId, e)) - cancel("Collect failed", e) //KLUDGE: can be related to IR, because using `throw` fails on JS IR and Native - } - } - - fun launch(block: suspend CoroutineScope.() -> Unit): Job = requestScope.launch(block = block) - - fun launchCancelable(streamId: Int, block: suspend CoroutineScope.() -> Unit): Job { - val job = launch(block) - job.invokeOnCompletion { if (job.isActive) senders.remove(streamId) } - senders[streamId] = job - return job - } - - private fun handleFrame(responder: RSocketResponder, frame: Frame) { - when (val streamId = frame.streamId) { - 0 -> when (frame) { - is ErrorFrame -> job.cancel("Error frame received on 0 stream", frame.throwable) - is KeepAliveFrame -> keepAliveHandler.receive(frame) - is LeaseFrame -> { - frame.release() - error("lease isn't implemented") - } - - is MetadataPushFrame -> responder.handleMetadataPush(frame) - else -> { - //TODO log - frame.release() - } - } - else -> when (frame) { - is RequestNFrame -> limits[streamId]?.updateRequests(frame.requestN) - is CancelFrame -> senders.remove(streamId)?.cancel() - is ErrorFrame -> receivers.remove(streamId)?.close(frame.throwable) - is RequestFrame -> when (frame.type) { - FrameType.Payload -> receivers[streamId]?.safeOffer(frame) ?: frame.release() - FrameType.RequestFnF -> responder.handleFireAndForget(frame) - FrameType.RequestResponse -> responder.handlerRequestResponse(frame) - FrameType.RequestStream -> responder.handleRequestStream(frame) - FrameType.RequestChannel -> responder.handleRequestChannel(frame) - else -> error("never happens") - } - else -> { - //TODO log - frame.release() - } - } - } - } - - fun start(requestHandler: RSocket) { - val responder = RSocketResponder(this, requestHandler) - keepAliveHandler.startIn(scope) - requestHandler.job.invokeOnCompletion { - // if request handler is completed successfully, via Job.complete() - // we don't need to cancel connection - if (it != null) job.cancel("Request handler failed", it) - } - - requestJob.invokeOnCompletion { error -> - val cancelError = CancellationException("Connection closed", error) - requestHandler.job.cancel(cancelError) - receivers.values().forEach { - it.cancel(cancelError) - } - senders.values().forEach { it.cancel(cancelError) } - receivers.clear() - limits.clear() - senders.clear() - prioritizer.cancel(cancelError) - } - scope.launch { - while (job.isActive) { - val frame = prioritizer.receive() - connection.sendFrame(frame) - } - } - scope.launch { - while (job.isActive) { - val frame = connection.receiveFrame() - frame.closeOnError { handleFrame(responder, frame) } - } - } - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RequestFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RequestFlow.kt new file mode 100644 index 000000000..7cc3463e1 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RequestFlow.kt @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.payload.* +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.flow.* + +@ExperimentalStreamsApi +internal inline fun requestFlow( + crossinline block: suspend FlowCollector.(strategy: RequestStrategy.Element, initialRequest: Int) -> Unit +): Flow = object : RequestFlow() { + override suspend fun collect(collector: FlowCollector, strategy: RequestStrategy.Element, initialRequest: Int) { + collector.block(strategy, initialRequest) + } +} + +@ExperimentalStreamsApi +internal suspend inline fun FlowCollector.emitAllWithRequestN( + channel: ReceiveChannel, + strategy: RequestStrategy.Element, + crossinline onRequest: suspend (n: Int) -> Unit, +) { + val collector = object : RequestFlowCollector(this, strategy) { + override suspend fun onRequest(n: Int) { + @OptIn(ExperimentalCoroutinesApi::class) + if (!channel.isClosedForReceive) onRequest(n) + } + } + collector.emitAll(channel) +} + +@ExperimentalStreamsApi +internal abstract class RequestFlow : Flow { + private val consumed = atomic(false) + + @InternalCoroutinesApi + override suspend fun collect(collector: FlowCollector) { + check(!consumed.getAndSet(true)) { "RequestFlow can be collected just once" } + + val strategy = currentCoroutineContext().requestStrategy() + val initial = strategy.firstRequest() + collect(collector, strategy, initial) + } + + abstract suspend fun collect(collector: FlowCollector, strategy: RequestStrategy.Element, initialRequest: Int) +} + +@ExperimentalStreamsApi +internal abstract class RequestFlowCollector( + private val collector: FlowCollector, + private val strategy: RequestStrategy.Element, +) : FlowCollector { + override suspend fun emit(value: Payload): Unit = value.closeOnError { + collector.emit(value) + val next = strategy.nextRequest() + if (next > 0) onRequest(next) + } + + abstract suspend fun onRequest(n: Int) +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt new file mode 100644 index 000000000..206bf0b9e --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt @@ -0,0 +1,73 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.handler.* + +internal class StreamsStorage(private val isServer: Boolean) { + private val streamId: StreamId = StreamId(isServer) + private val handlers: IntMap = IntMap() + + fun nextId(): Int = streamId.next(handlers) + + fun save(id: Int, handler: FrameHandler) { + handlers[id] = handler + } + + fun remove(id: Int): FrameHandler? { + return handlers.remove(id) + } + + fun contains(id: Int): Boolean { + return id in handlers + } + + fun cleanup(error: Throwable?) { + val values = handlers.values() + handlers.clear() + values.forEach { + it.cleanup(error) + } + } + + fun handleFrame(frame: Frame, responder: RSocketResponder) { + val id = frame.streamId + when (frame) { + is RequestNFrame -> handlers[id]?.handleRequestN(frame.requestN) + is CancelFrame -> handlers[id]?.handleCancel() + is ErrorFrame -> handlers[id]?.handleError(frame.throwable) + is RequestFrame -> when { + frame.type == FrameType.Payload -> handlers[id]?.handleRequest(frame) ?: frame.release() // release on unknown stream id + isServer.xor(id % 2 != 0) -> frame.release() // request frame on wrong stream id + else -> { + val initialRequest = frame.initialRequest + val handler = when (frame.type) { + FrameType.RequestFnF -> ResponderFireAndForgetFrameHandler(id, this, responder) + FrameType.RequestResponse -> ResponderRequestResponseFrameHandler(id, this, responder) + FrameType.RequestStream -> ResponderRequestStreamFrameHandler(id, this, responder, initialRequest) + FrameType.RequestChannel -> ResponderRequestChannelFrameHandler(id, this, responder, initialRequest) + else -> error("Wrong request frame type") // should never happen + } + handlers[id] = handler + handler.handleRequest(frame) + } + } + else -> frame.release() + } + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt deleted file mode 100644 index 059be42cc..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.flow - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.atomicfu.* -import kotlinx.coroutines.* -import kotlinx.coroutines.flow.* - -@OptIn(ExperimentalStreamsApi::class) -internal class RequestChannelRequesterFlow( - private val initPayload: Payload, - private val payloads: Flow, - private val requester: RSocketRequester, - private val state: RSocketState, -) : Flow { - private val consumed = atomic(false) - - @InternalCoroutinesApi - override suspend fun collect(collector: FlowCollector): Unit = with(state) { - check(!consumed.getAndSet(true)) { "RSocket.requestChannel can be collected just once" } - - val strategy = currentCoroutineContext().requestStrategy() - val initialRequest = strategy.firstRequest() - initPayload.closeOnError { - val streamId = requester.createStream() - val receiver = createReceiverFor(streamId) - val request = launchCancelable(streamId) { - payloads.collectLimiting(streamId, 0) { - send(RequestChannelFrame(streamId, initialRequest, initPayload)) - } - } - - request.invokeOnCompletion { - if (it != null && it !is CancellationException) receiver.cancelConsumed(it) - } - try { - collectStream(streamId, receiver, strategy, collector) - } catch (e: Throwable) { - if (e is CancellationException) request.cancel(e) - else request.cancel("Receiver failed", e) - throw e - } - } - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt deleted file mode 100644 index 71c54a9e2..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.flow - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.atomicfu.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.flow.* - -@OptIn(ExperimentalStreamsApi::class) -internal class RequestChannelResponderFlow( - private val streamId: Int, - private val receiver: ReceiveChannel, - private val state: RSocketState, -) : Flow { - private val consumed = atomic(false) - - @InternalCoroutinesApi - override suspend fun collect(collector: FlowCollector): Unit = with(state) { - check(!consumed.getAndSet(true)) { "RSocket.requestChannel `payloads` can be collected just once" } - - val strategy = currentCoroutineContext().requestStrategy() - val initialRequest = strategy.firstRequest() - send(RequestNFrame(streamId, initialRequest)) - collectStream(streamId, receiver, strategy, collector) - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt deleted file mode 100644 index 278df0cfe..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.flow - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.atomicfu.* -import kotlinx.coroutines.* -import kotlinx.coroutines.flow.* - -@OptIn(ExperimentalStreamsApi::class) -internal class RequestStreamRequesterFlow( - private val payload: Payload, - private val requester: RSocketRequester, - private val state: RSocketState, -) : Flow { - private val consumed = atomic(false) - - @InternalCoroutinesApi - override suspend fun collect(collector: FlowCollector): Unit = with(state) { - check(!consumed.getAndSet(true)) { "RSocket.requestStream can be collected just once" } - - val strategy = currentCoroutineContext().requestStrategy() - val initialRequest = strategy.firstRequest() - payload.closeOnError { - val streamId = requester.createStream() - val receiver = createReceiverFor(streamId) - send(RequestStreamFrame(streamId, initialRequest, payload)) - collectStream(streamId, receiver, strategy, collector) - } - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt new file mode 100644 index 000000000..1dcb1402e --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt @@ -0,0 +1,94 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal abstract class FrameHandler { + + fun handleRequest(frame: RequestFrame) { + if (frame.next || frame.type.isRequestType) handleNextFragment(frame) + if (frame.complete) handleComplete() + } + + private fun handleNextFragment(frame: RequestFrame) { + //TODO fragmentation will be here + handleNext(frame.payload) + } + + protected abstract fun handleNext(payload: Payload) + protected abstract fun handleComplete() + abstract fun handleError(cause: Throwable) + abstract fun handleCancel() + abstract fun handleRequestN(n: Int) + + abstract fun cleanup(cause: Throwable?) +} + +internal interface ReceiveFrameHandler { + fun onReceiveComplete() + fun onReceiveCancelled(cause: Throwable): Boolean // if true, then request is cancelled +} + +internal interface SendFrameHandler { + fun onSendComplete() + fun onSendFailed(cause: Throwable): Boolean // if true, then request is failed +} + +internal abstract class BaseRequesterFrameHandler : FrameHandler(), ReceiveFrameHandler { + override fun handleCancel() { + //should be called only for RC + } + + override fun handleRequestN(n: Int) { + //should be called only for RC + } +} + +internal abstract class BaseResponderFrameHandler : FrameHandler(), SendFrameHandler { + protected abstract var job: Job? + + protected abstract fun start(payload: Payload): Job + + final override fun handleNext(payload: Payload) { + if (job == null) job = start(payload) + else handleNextPayload(payload) + } + + protected open fun handleNextPayload(payload: Payload) { + //should be called only for RC + } + + override fun handleComplete() { + //should be called only for RC + } + + override fun handleError(cause: Throwable) { + //should be called only for RC + } +} + +internal expect abstract class ResponderFrameHandler() : BaseResponderFrameHandler { + override var job: Job? + //TODO fragmentation will be here +} + +internal expect abstract class RequesterFrameHandler() : BaseRequesterFrameHandler { + //TODO fragmentation will be here +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestChannelFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestChannelFrameHandler.kt new file mode 100644 index 000000000..154bf5936 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestChannelFrameHandler.kt @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +internal class RequesterRequestChannelFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val limiter: Limiter, + private val sender: Job, + private val channel: Channel, +) : RequesterFrameHandler(), SendFrameHandler { + + override fun handleNext(payload: Payload) { + channel.safeTrySend(payload) + } + + override fun handleComplete() { + channel.close() + } + + override fun handleError(cause: Throwable) { + streamsStorage.remove(id) + channel.fullClose(cause) + sender.cancel("Request failed", cause) + } + + override fun handleCancel() { + sender.cancel("Request cancelled") + } + + override fun handleRequestN(n: Int) { + limiter.updateRequests(n) + } + + override fun cleanup(cause: Throwable?) { + channel.fullClose(cause) + sender.cancel("Connection closed", cause) + } + + override fun onReceiveComplete() { + if (!sender.isActive) streamsStorage.remove(id) + } + + override fun onReceiveCancelled(cause: Throwable): Boolean { + val isCancelled = streamsStorage.remove(id) != null + if (isCancelled) sender.cancel("Request cancelled", cause) + return isCancelled + } + + @OptIn(ExperimentalCoroutinesApi::class) + override fun onSendComplete() { + if (channel.isClosedForSend) streamsStorage.remove(id) + } + + override fun onSendFailed(cause: Throwable): Boolean { + if (sender.isCancelled) return false + + val isFailed = streamsStorage.remove(id) != null + if (isFailed) channel.fullClose(cause) + return isFailed + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandler.kt new file mode 100644 index 000000000..b7838c1b1 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandler.kt @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class RequesterRequestResponseFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val deferred: CompletableDeferred +) : RequesterFrameHandler() { + override fun handleNext(payload: Payload) { + deferred.complete(payload) + } + + override fun handleComplete() { + //ignore + } + + override fun handleError(cause: Throwable) { + streamsStorage.remove(id) + deferred.completeExceptionally(cause) + } + + override fun cleanup(cause: Throwable?) { + deferred.cancel("Connection closed", cause) + } + + override fun onReceiveComplete() { + streamsStorage.remove(id) + } + + override fun onReceiveCancelled(cause: Throwable): Boolean = streamsStorage.remove(id) != null +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestStreamFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestStreamFrameHandler.kt new file mode 100644 index 000000000..eba58a903 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestStreamFrameHandler.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.channels.* + +internal class RequesterRequestStreamFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val channel: Channel +) : RequesterFrameHandler() { + + override fun handleNext(payload: Payload) { + channel.safeTrySend(payload) + } + + override fun handleComplete() { + channel.close() + } + + override fun handleError(cause: Throwable) { + streamsStorage.remove(id) + channel.fullClose(cause) + } + + override fun cleanup(cause: Throwable?) { + channel.fullClose(cause) + } + + override fun onReceiveComplete() { + streamsStorage.remove(id) + } + + override fun onReceiveCancelled(cause: Throwable): Boolean = streamsStorage.remove(id) != null +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderFireAndForgetFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderFireAndForgetFrameHandler.kt new file mode 100644 index 000000000..ae049e9ba --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderFireAndForgetFrameHandler.kt @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderFireAndForgetFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val responder: RSocketResponder, +) : ResponderFrameHandler() { + + override fun start(payload: Payload): Job = responder.handleFireAndForget(payload, this) + + override fun handleCancel() { + streamsStorage.remove(id) + job?.cancel("Request cancelled") + } + + override fun handleRequestN(n: Int) { + //ignore + } + + override fun cleanup(cause: Throwable?) { + //ignore + } + + override fun onSendComplete() { + streamsStorage.remove(id) + } + + override fun onSendFailed(cause: Throwable): Boolean = false +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestChannelFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestChannelFrameHandler.kt new file mode 100644 index 000000000..0be3d7f15 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestChannelFrameHandler.kt @@ -0,0 +1,86 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +internal class ResponderRequestChannelFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val responder: RSocketResponder, + initialRequest: Int +) : ResponderFrameHandler(), ReceiveFrameHandler { + val limiter = Limiter(initialRequest) + val channel = SafeChannel(Channel.UNLIMITED) + + @OptIn(ExperimentalStreamsApi::class) + override fun start(payload: Payload): Job = responder.handleRequestChannel(payload, id, this) + + override fun handleNextPayload(payload: Payload) { + channel.safeTrySend(payload) + } + + override fun handleComplete() { + channel.close() + } + + override fun handleError(cause: Throwable) { + streamsStorage.remove(id) + channel.fullClose(cause) + } + + override fun handleCancel() { + streamsStorage.remove(id) + val cancelError = CancellationException("Request cancelled") + channel.fullClose(cancelError) + job?.cancel(cancelError) + } + + override fun handleRequestN(n: Int) { + limiter.updateRequests(n) + } + + override fun cleanup(cause: Throwable?) { + channel.fullClose(cause) + } + + override fun onSendComplete() { + @OptIn(ExperimentalCoroutinesApi::class) + if (channel.isClosedForSend) streamsStorage.remove(id) + } + + override fun onSendFailed(cause: Throwable): Boolean { + val isFailed = streamsStorage.remove(id) != null + if (isFailed) channel.fullClose(cause) + return isFailed + } + + override fun onReceiveComplete() { + val job = this.job!! //always not null here + if (!job.isActive) streamsStorage.remove(id) + } + + override fun onReceiveCancelled(cause: Throwable): Boolean { + val job = this.job!! //always not null here + if (!streamsStorage.contains(id) && job.isActive) job.cancel("Request handling failed [Error frame]", cause) + return !job.isCancelled + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestResponseFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestResponseFrameHandler.kt new file mode 100644 index 000000000..77f2411b8 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestResponseFrameHandler.kt @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderRequestResponseFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val responder: RSocketResponder +) : ResponderFrameHandler() { + + override fun start(payload: Payload): Job = responder.handleRequestResponse(payload, id, this) + + override fun handleCancel() { + streamsStorage.remove(id) + job?.cancel("Request cancelled") + } + + override fun handleRequestN(n: Int) { + //ignore + } + + override fun cleanup(cause: Throwable?) { + //ignore + } + + override fun onSendComplete() { + streamsStorage.remove(id) + } + + override fun onSendFailed(cause: Throwable): Boolean = streamsStorage.remove(id) != null +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestStreamFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestStreamFrameHandler.kt new file mode 100644 index 000000000..41b3ccfc6 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestStreamFrameHandler.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderRequestStreamFrameHandler( + private val id: Int, + private val streamsStorage: StreamsStorage, + private val responder: RSocketResponder, + initialRequest: Int, +) : ResponderFrameHandler() { + val limiter = Limiter(initialRequest) + + override fun start(payload: Payload): Job = responder.handleRequestStream(payload, id, this) + + override fun handleCancel() { + streamsStorage.remove(id) + job?.cancel("Request cancelled") + } + + override fun handleRequestN(n: Int) { + limiter.updateRequests(n) + } + + override fun cleanup(cause: Throwable?) { + //ignore + } + + override fun onSendComplete() { + streamsStorage.remove(id) + } + + override fun onSendFailed(cause: Throwable): Boolean = streamsStorage.remove(id) != null +} diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt similarity index 62% rename from rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt rename to rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt index 75d560ec0..b46f7e5e3 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt @@ -16,6 +16,7 @@ package io.rsocket.kotlin +import io.ktor.utils.io.core.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.frame.io.* @@ -26,7 +27,7 @@ import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlin.test.* -class SetupRejectionTest : SuspendTest, TestWithLeakCheck { +class ConnectionEstablishmentTest : SuspendTest, TestWithLeakCheck { @Test fun responderRejectSetup() = test { val errorMessage = "error" @@ -43,7 +44,16 @@ class SetupRejectionTest : SuspendTest, TestWithLeakCheck { error(errorMessage) } - connection.sendToReceiver(SetupFrame(Version.Current, false, DefaultKeepAlive, null, DefaultPayloadMimeType, Payload.Empty)) + connection.sendToReceiver( + SetupFrame( + version = Version.Current, + honorLease = false, + keepAlive = DefaultKeepAlive, + resumeToken = null, + payloadMimeType = DefaultPayloadMimeType, + payload = payload("setup") //should be released + ) + ) assertFailsWith(RSocketError.Setup.Rejected::class, errorMessage) { deferred.await() } @@ -62,23 +72,24 @@ class SetupRejectionTest : SuspendTest, TestWithLeakCheck { assertEquals(errorMessage, error.message) } -// @Test -// fun requesterStreamsTerminatedOnZeroErrorFrame() = test { -// val errorMessage = "error" -// val connection = TestConnection() -// val requester = RSocketRequester(connection, StreamId.client(), KeepAlive(), RequestStrategy.Default, {}) -// val deferred = GlobalScope.async { requester.requestResponse(Payload.Empty) } -// delay(100) -// connection.sendToReceiver(ErrorFrame(0, RSocketError.ConnectionError(errorMessage))) -// assertFailsWith(errorMessage) { deferred.await() } -// } -// -// @Test -// fun requesterNewStreamsTerminatedAfterZeroErrorFrame() = test { -// val errorMessage = "error" -// val connection = TestConnection() -// val requester = RSocketRequester(connection, StreamId.client(), KeepAlive(), RequestStrategy.Default, {}) -// connection.sendToReceiver(ErrorFrame(0, RSocketError.ConnectionError(errorMessage))) -// assertFailsWith(errorMessage) { requester.requestResponse(Payload.Empty) } -// } + @Test + fun requesterReleaseSetupPayloadOnFailedAcceptor() = test { + val connection = TestConnection() + val p = payload("setup") + assertFailsWith(IllegalStateException::class, "failed") { + RSocketConnector { + connectionConfig { + setupPayload { p } + } + acceptor { + assertTrue(config.setupPayload.data.isNotEmpty) + assertTrue(p.data.isNotEmpty) + error("failed") + } + }.connect { connection } + } + println(p.data) + assertTrue(p.data.isEmpty) + } + } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index 28691951a..0abe3f450 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -42,7 +42,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { private suspend fun start(handler: RSocket? = null): RSocket { val localServer = LocalServer(testJob) RSocketServer { - loggerFactory = NoopLogger + loggerFactory = LoggerFactory { PrintLogger.withLevel(LoggingLevel.DEBUG).logger("SERVER |$it") } }.bind(localServer) { handler ?: RSocketRequestHandler { requestResponse { it } @@ -59,7 +59,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } return RSocketConnector { - loggerFactory = NoopLogger + loggerFactory = LoggerFactory { PrintLogger.withLevel(LoggingLevel.DEBUG).logger("CLIENT |$it") } connectionConfig { keepAlive = KeepAlive(Duration.seconds(1000), Duration.seconds(1000)) } @@ -219,7 +219,10 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } }) val request = flow { error("test") } - requester.requestChannel(Payload.Empty, request).collect() + //TODO + kotlin.runCatching { + requester.requestChannel(Payload.Empty, request).collect() + }.also(::println) val e = error.await() assertTrue(e is RSocketError.ApplicationError) assertEquals("test", e.message) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index 34dccb712..cdfbc6c8e 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -17,6 +17,7 @@ package io.rsocket.kotlin.internal import io.rsocket.kotlin.* +import io.rsocket.kotlin.core.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.payload.* @@ -28,14 +29,21 @@ import kotlin.test.* import kotlin.time.* class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { - private lateinit var requester: RSocketRequester + private lateinit var requester: RSocket override suspend fun before() { super.before() - val state = RSocketState(connection, KeepAlive(Duration.seconds(1000), Duration.seconds(1000))) - requester = RSocketRequester(state, StreamId.client()) - state.start(RSocketRequestHandler { }) + requester = connection.connect( + isServer = false, + interceptors = InterceptorsBuilder().build(), + connectionConfig = ConnectionConfig( + keepAlive = KeepAlive(Duration.seconds(1000), Duration.seconds(1000)), + payloadMimeType = DefaultPayloadMimeType, + setupPayload = Payload.Empty + ), + acceptor = { RSocketRequestHandler { } } + ) } @Test @@ -127,7 +135,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { @Test fun testStreamRequestByFixed() = test { connection.test { - val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(2, 0)).take(4) + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(3, 0)) expectNoEventsIn(200) flow.launchIn(connection) @@ -135,7 +143,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { expectFrame { frame -> assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) - assertEquals(2, frame.initialRequest) + assertEquals(3, frame.initialRequest) } expectNoEventsIn(200) @@ -144,9 +152,12 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + expectFrame { frame -> assertTrue(frame is RequestNFrame) - assertEquals(2, frame.requestN) + assertEquals(3, frame.requestN) } expectNoEventsIn(200) @@ -162,7 +173,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { @Test fun testStreamRequestBy() = test { connection.test { - val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(5, 2)).take(6) + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(5, 2)) expectNoEventsIn(200) flow.launchIn(connection) @@ -200,6 +211,47 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { } } + @Test + fun testStreamRequestCancel() = test { + connection.test { + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(1, 0)).take(3) + + expectNoEventsIn(200) + flow.launchIn(connection) + + expectFrame { frame -> + assertTrue(frame is RequestFrame) + assertEquals(FrameType.RequestStream, frame.type) + assertEquals(1, frame.initialRequest) + } + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectFrame { frame -> + assertTrue(frame is RequestNFrame) + assertEquals(1, frame.requestN) + } + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectFrame { frame -> + assertTrue(frame is RequestNFrame) + assertEquals(1, frame.requestN) + } + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectFrame { frame -> + assertTrue(frame is CancelFrame) + } + + expectNoEventsIn(200) + } + } + @Test fun testHandleSetupException() = test { val errorMessage = "error" diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index 6fabeb30b..d021bf32b 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -18,8 +18,10 @@ package io.rsocket.kotlin.keepalive import io.ktor.utils.io.core.* import io.rsocket.kotlin.* +import io.rsocket.kotlin.core.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* import kotlin.test.* @@ -27,12 +29,14 @@ import kotlin.time.* class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { - private fun requester(keepAlive: KeepAlive = KeepAlive(Duration.milliseconds(100), Duration.seconds(1))): RSocket = run { - val state = RSocketState(connection, keepAlive) - val requester = RSocketRequester(state, StreamId.client()) - state.start(RSocketRequestHandler { }) - requester - } + private suspend fun requester( + keepAlive: KeepAlive = KeepAlive(Duration.milliseconds(100), Duration.seconds(1)) + ): RSocket = connection.connect( + isServer = false, + interceptors = InterceptorsBuilder().build(), + connectionConfig = ConnectionConfig(keepAlive, DefaultPayloadMimeType, Payload.Empty), + acceptor = { RSocketRequestHandler { } } + ) @Test fun requesterSendKeepAlive() = test { diff --git a/rsocket-core/src/jsMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt b/rsocket-core/src/jsMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt new file mode 100644 index 000000000..917419f0c --- /dev/null +++ b/rsocket-core/src/jsMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import kotlinx.coroutines.* + +internal actual abstract class ResponderFrameHandler : BaseResponderFrameHandler() { + actual override var job: Job? = null +} + +internal actual abstract class RequesterFrameHandler : BaseRequesterFrameHandler() { +} diff --git a/rsocket-core/src/jvmMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt b/rsocket-core/src/jvmMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt new file mode 100644 index 000000000..917419f0c --- /dev/null +++ b/rsocket-core/src/jvmMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import kotlinx.coroutines.* + +internal actual abstract class ResponderFrameHandler : BaseResponderFrameHandler() { + actual override var job: Job? = null +} + +internal actual abstract class RequesterFrameHandler : BaseRequesterFrameHandler() { +} diff --git a/rsocket-core/src/nativeMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt b/rsocket-core/src/nativeMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt new file mode 100644 index 000000000..51db7aebc --- /dev/null +++ b/rsocket-core/src/nativeMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.handler + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* + +internal actual abstract class ResponderFrameHandler : BaseResponderFrameHandler() { + actual override var job: Job? by atomic(null) +} + +internal actual abstract class RequesterFrameHandler : BaseRequesterFrameHandler() { +} diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index 9a7606ca3..bc3cae5af 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -106,7 +106,7 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { val request = flow { repeat(200_000) { emit(payload(it)) } } - val list = client.requestChannel(payload(0), request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { it.release() }.toList() + val list = client.requestChannel(payload(0), request).flowOn(PrefetchStrategy(10000, 0)).onEach { it.release() }.toList() assertEquals(200_000, list.size) }