Skip to content

Commit e25de1c

Browse files
committed
Adapt ktor websocket transport implementation and ktor plugins to the latest change in transport API and simplify some parts
1 parent 8d48118 commit e25de1c

File tree

6 files changed

+81
-60
lines changed

6 files changed

+81
-60
lines changed

ktor-plugins/ktor-client-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/client/RSocketSupport.kt

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,7 +26,6 @@ import io.rsocket.kotlin.*
2626
import io.rsocket.kotlin.core.*
2727
import io.rsocket.kotlin.transport.*
2828
import io.rsocket.kotlin.transport.ktor.websocket.internal.*
29-
import kotlinx.coroutines.*
3029
import kotlin.coroutines.*
3130

3231
private val RSocketSupportConfigKey = AttributeKey<RSocketSupportConfig.Internal>("RSocketSupportConfig")
@@ -66,9 +65,7 @@ private class RSocketSupportTarget(
6665
override val coroutineContext: CoroutineContext get() = client.coroutineContext
6766

6867
@RSocketTransportApi
69-
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
70-
client.webSocket(request) {
71-
handler.handleKtorWebSocketConnection(this)
72-
}
68+
override suspend fun connectClient(): RSocketConnection {
69+
return KtorWebSocketConnection(client.webSocketSession(request))
7370
}
7471
}

ktor-plugins/ktor-server-rsocket/src/commonMain/kotlin/io/rsocket/kotlin/ktor/server/RSocketSupport.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import io.rsocket.kotlin.*
2424
import io.rsocket.kotlin.core.*
2525
import io.rsocket.kotlin.transport.*
2626
import io.rsocket.kotlin.transport.ktor.websocket.internal.*
27+
import kotlinx.coroutines.*
2728

2829
private val RSocketSupportConfigKey = AttributeKey<RSocketSupportConfig.Internal>("RSocketSupportConfig")
2930

@@ -54,8 +55,8 @@ internal fun Route.rSocketHandler(acceptor: ConnectionAcceptor): suspend Default
5455
val config = application.attributes.getOrNull(RSocketSupportConfigKey)
5556
?: error("Plugin RSocketSupport is not installed. Consider using `install(RSocketSupport)` in server config first.")
5657

57-
val handler = config.server.createHandler(acceptor)
5858
return {
59-
handler.handleKtorWebSocketConnection(this)
59+
config.server.acceptConnection(acceptor, KtorWebSocketConnection(this))
60+
awaitCancellation()
6061
}
6162
}

rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -79,15 +79,15 @@ private class KtorWebSocketClientTransportBuilderImpl : KtorWebSocketClientTrans
7979
install(WebSockets, webSocketsConfig)
8080
}
8181
// only dispatcher of a client is used - it looks like it's Dispatchers.IO now
82-
val newContext = context.supervisorContext() + (httpClient.coroutineContext[ContinuationInterceptor] ?: EmptyCoroutineContext)
83-
val newJob = newContext.job
82+
val transportContext = context.supervisorContext() + Dispatchers.Default
83+
val transportJob = transportContext.job
8484
val httpClientJob = httpClient.coroutineContext.job
8585

86-
httpClientJob.invokeOnCompletion { newJob.cancel("HttpClient closed", it) }
87-
newJob.invokeOnCompletion { httpClientJob.cancel("KtorWebSocketClientTransport closed", it) }
86+
httpClientJob.invokeOnCompletion { transportJob.cancel("HttpClient closed", it) }
87+
transportJob.invokeOnCompletion { httpClientJob.cancel("KtorWebSocketClientTransport closed", it) }
8888

8989
return KtorWebSocketClientTransportImpl(
90-
coroutineContext = newContext,
90+
coroutineContext = transportContext,
9191
httpClient = httpClient,
9292
)
9393
}
@@ -98,7 +98,7 @@ private class KtorWebSocketClientTransportImpl(
9898
private val httpClient: HttpClient,
9999
) : KtorWebSocketClientTransport {
100100
override fun target(request: HttpRequestBuilder.() -> Unit): RSocketClientTarget = KtorWebSocketClientTargetImpl(
101-
coroutineContext = coroutineContext,
101+
coroutineContext = coroutineContext.supervisorContext(),
102102
httpClient = httpClient,
103103
request = request
104104
)
@@ -136,12 +136,17 @@ private class KtorWebSocketClientTargetImpl(
136136
private val httpClient: HttpClient,
137137
private val request: HttpRequestBuilder.() -> Unit,
138138
) : RSocketClientTarget {
139-
140139
@RSocketTransportApi
141-
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
142-
httpClient.webSocket(request) {
143-
handler.handleKtorWebSocketConnection(this)
140+
override suspend fun connectClient(): RSocketConnection {
141+
currentCoroutineContext().ensureActive()
142+
coroutineContext.ensureActive()
143+
144+
val session = httpClient.webSocketSession(request)
145+
val handle = coroutineContext.job.invokeOnCompletion {
146+
session.cancel("Transport was cancelled", it)
144147
}
148+
session.coroutineContext.job.invokeOnCompletion { handle.dispose() }
149+
return KtorWebSocketConnection(session)
145150
}
146151
}
147152

rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
public final class io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnectionKt {
2-
public static final fun handleKtorWebSocketConnection (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;Lio/ktor/websocket/WebSocketSession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
1+
public final class io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection : io/rsocket/kotlin/transport/RSocketSequentialConnection {
2+
public fun <init> (Lio/ktor/websocket/WebSocketSession;)V
3+
public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext;
4+
public fun receiveFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
5+
public fun sendFrame (ILkotlinx/io/Buffer;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
36
}
47

58
public final class io/rsocket/kotlin/transport/ktor/websocket/internal/WebSocketConnection : io/rsocket/kotlin/Connection, kotlinx/coroutines/CoroutineScope {
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2015-2024 the original author or authors.
2+
* Copyright 2015-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,43 +21,54 @@ import io.rsocket.kotlin.internal.io.*
2121
import io.rsocket.kotlin.transport.*
2222
import io.rsocket.kotlin.transport.internal.*
2323
import kotlinx.coroutines.*
24-
import kotlinx.coroutines.channels.*
2524
import kotlinx.io.*
25+
import kotlin.coroutines.*
2626

2727
@RSocketTransportApi
28-
public suspend fun RSocketConnectionHandler.handleKtorWebSocketConnection(webSocketSession: WebSocketSession): Unit = coroutineScope {
29-
val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED)
28+
public class KtorWebSocketConnection(
29+
private val session: WebSocketSession,
30+
) : RSocketSequentialConnection {
31+
private val outboundQueue = PrioritizationFrameQueue()
32+
override val coroutineContext: CoroutineContext get() = session.coroutineContext
3033

31-
val senderJob = launch {
32-
while (true) webSocketSession.send(outboundQueue.dequeueFrame()?.readByteArray() ?: break)
33-
}.onCompletion { outboundQueue.cancel() }
34+
init {
35+
@OptIn(DelicateCoroutinesApi::class)
36+
launch(start = CoroutineStart.ATOMIC) {
37+
val outboundJob = launch {
38+
nonCancellable {
39+
try {
40+
while (true) {
41+
session.send(outboundQueue.dequeueFrame()?.readByteArray() ?: break)
42+
}
43+
} catch (cause: Throwable) {
44+
session.outgoing.close(cause)
45+
throw cause
46+
} finally {
47+
outboundQueue.cancel()
48+
}
49+
}
50+
}
3451

35-
try {
36-
handleConnection(KtorWebSocketConnection(outboundQueue, webSocketSession.incoming))
37-
} finally {
38-
webSocketSession.incoming.cancel()
39-
outboundQueue.close()
40-
withContext(NonCancellable) {
41-
senderJob.join() // await all frames sent
42-
webSocketSession.close()
43-
webSocketSession.coroutineContext.job.join()
52+
try {
53+
awaitCancellation()
54+
} finally {
55+
nonCancellable {
56+
session.incoming.cancel()
57+
outboundQueue.close()
58+
outboundJob.join()
59+
// await socket completion
60+
session.close()
61+
}
62+
}
4463
}
4564
}
46-
}
47-
48-
@RSocketTransportApi
49-
private class KtorWebSocketConnection(
50-
private val outboundQueue: PrioritizationFrameQueue,
51-
private val inbound: ReceiveChannel<Frame>,
52-
) : RSocketSequentialConnection {
53-
override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend
5465

5566
override suspend fun sendFrame(streamId: Int, frame: Buffer) {
5667
return outboundQueue.enqueueFrame(streamId, frame)
5768
}
5869

5970
override suspend fun receiveFrame(): Buffer? {
60-
val frame = inbound.receiveCatching().getOrNull() ?: return null
71+
val frame = session.incoming.receiveCatching().getOrNull() ?: return null
6172
return Buffer().apply { write(frame.data) }
6273
}
6374
}

rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ private class KtorWebSocketServerTransportBuilderImpl : KtorWebSocketServerTrans
9292

9393
@RSocketTransportApi
9494
override fun buildTransport(context: CoroutineContext): KtorWebSocketServerTransport = KtorWebSocketServerTransportImpl(
95-
// we always add IO - as it's the best choice here, server will use it's own dispatcher anyway
96-
coroutineContext = context.supervisorContext() + Dispatchers.IoCompatible,
95+
coroutineContext = context.supervisorContext() + Dispatchers.Default,
9796
factory = requireNotNull(httpServerFactory) { "httpEngine is required" },
9897
webSocketsConfig = webSocketsConfig,
9998
)
@@ -151,12 +150,12 @@ private class KtorWebSocketServerTargetImpl(
151150
) : RSocketServerTarget<KtorWebSocketServerInstance> {
152151

153152
@RSocketTransportApi
154-
override suspend fun startServer(handler: RSocketConnectionHandler): KtorWebSocketServerInstance {
153+
override suspend fun startServer(onConnection: (RSocketConnection) -> Unit): KtorWebSocketServerInstance {
155154
currentCoroutineContext().ensureActive()
156155
coroutineContext.ensureActive()
157156

158157
val serverContext = coroutineContext.childContext()
159-
val embeddedServer = createServer(handler, serverContext)
158+
val embeddedServer = createServer(serverContext, onConnection)
160159
val resolvedConnectors = startServer(embeddedServer, serverContext)
161160

162161
return KtorWebSocketServerInstanceImpl(
@@ -170,8 +169,8 @@ private class KtorWebSocketServerTargetImpl(
170169
// parentCoroutineContext is the context of server instance
171170
@RSocketTransportApi
172171
private fun createServer(
173-
handler: RSocketConnectionHandler,
174172
serverContext: CoroutineContext,
173+
onConnection: (RSocketConnection) -> Unit,
175174
): EmbeddedServer<*, *> {
176175
val config = serverConfig {
177176
val target = this@KtorWebSocketServerTargetImpl
@@ -180,7 +179,8 @@ private class KtorWebSocketServerTargetImpl(
180179
install(WebSockets, webSocketsConfig)
181180
routing {
182181
webSocket(target.path, target.protocol) {
183-
handler.handleKtorWebSocketConnection(this)
182+
onConnection(KtorWebSocketConnection(this))
183+
awaitCancellation()
184184
}
185185
}
186186
}
@@ -191,20 +191,24 @@ private class KtorWebSocketServerTargetImpl(
191191
private suspend fun startServer(
192192
embeddedServer: EmbeddedServer<*, *>,
193193
serverContext: CoroutineContext,
194-
): List<EngineConnectorConfig> = launchCoroutine(serverContext + Dispatchers.IoCompatible) { cont ->
195-
embeddedServer.startSuspend()
196-
launch(serverContext + Dispatchers.IoCompatible) {
194+
): List<EngineConnectorConfig> {
195+
@OptIn(DelicateCoroutinesApi::class)
196+
val serverJob = launch(serverContext, start = CoroutineStart.ATOMIC) {
197197
try {
198+
currentCoroutineContext().ensureActive() // because of atomic start
199+
embeddedServer.startSuspend()
198200
awaitCancellation()
199201
} finally {
200-
withContext(NonCancellable) {
202+
nonCancellable {
201203
embeddedServer.stopSuspend()
202204
}
203205
}
204206
}
205-
cont.resume(embeddedServer.engine.resolvedConnectors()) { cause, _, _ ->
206-
// will cause stopping of the server
207-
serverContext.job.cancel("Cancelled", cause)
207+
return try {
208+
embeddedServer.engine.resolvedConnectors()
209+
} catch (cause: Throwable) {
210+
serverJob.cancel("Starting server cancelled", cause)
211+
throw cause
208212
}
209213
}
210214
}

0 commit comments

Comments
 (0)