Skip to content

Commit a180a6f

Browse files
committed
migrate ktor tcp transport to new API
1 parent 87696ca commit a180a6f

File tree

13 files changed

+508
-66
lines changed

13 files changed

+508
-66
lines changed

rsocket-internal-io/api/rsocket-internal-io.api

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt {
2525
public final class io/rsocket/kotlin/internal/io/ContextKt {
2626
public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
2727
public static final fun ensureActive (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;)V
28+
public static final fun launchCoroutine (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
29+
public static synthetic fun launchCoroutine$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
2830
public static final fun onCompletion (Lkotlinx/coroutines/Job;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/Job;
2931
public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
3032
}

rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,13 @@ public inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) {
3232
onInactive() // should not throw
3333
ensureActive() // will throw
3434
}
35+
36+
@Suppress("SuspendFunctionOnCoroutineScope")
37+
public suspend inline fun <T> CoroutineScope.launchCoroutine(
38+
context: CoroutineContext = EmptyCoroutineContext,
39+
crossinline block: suspend (CancellableContinuation<T>) -> Unit,
40+
): T = suspendCancellableCoroutine { cont ->
41+
val job = launch(context) { block(cont) }
42+
job.invokeOnCompletion { if (it != null && cont.isActive) cont.resumeWithException(it) }
43+
cont.invokeOnCancellation { job.cancel("launchCoroutine was cancelled", it) }
44+
}

rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,43 @@
1+
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport {
2+
public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory;
3+
public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget;
4+
public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget;
5+
}
6+
7+
public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory {
8+
}
9+
10+
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder {
11+
public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V
12+
public fun inheritDispatcher ()V
13+
public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V
14+
public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V
15+
public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V
16+
}
17+
18+
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance {
19+
public abstract fun getLocalAddress ()Lio/ktor/network/sockets/SocketAddress;
20+
}
21+
22+
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport {
23+
public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory;
24+
public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
25+
public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget;
26+
public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Lio/ktor/network/sockets/SocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
27+
public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
28+
}
29+
30+
public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory {
31+
}
32+
33+
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder {
34+
public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V
35+
public fun inheritDispatcher ()V
36+
public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V
37+
public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V
38+
public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V
39+
}
40+
141
public final class io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransportKt {
242
public static final fun TcpClientTransport (Lio/ktor/network/sockets/InetSocketAddress;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport;
343
public static final fun TcpClientTransport (Ljava/lang/String;ILkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport;
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2015-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.rsocket.kotlin.transport.ktor.tcp
18+
19+
import io.ktor.network.selector.*
20+
import io.ktor.network.sockets.*
21+
import io.rsocket.kotlin.internal.io.*
22+
import io.rsocket.kotlin.transport.*
23+
import kotlinx.coroutines.*
24+
import kotlin.coroutines.*
25+
26+
public sealed interface KtorTcpClientTransport : RSocketTransport {
27+
public fun target(remoteAddress: SocketAddress): RSocketClientTarget
28+
public fun target(host: String, port: Int): RSocketClientTarget
29+
30+
public companion object Factory :
31+
RSocketTransportFactory<KtorTcpClientTransport, KtorTcpClientTransportBuilder>(::KtorTcpClientTransportBuilderImpl)
32+
}
33+
34+
public sealed interface KtorTcpClientTransportBuilder : RSocketTransportBuilder<KtorTcpClientTransport> {
35+
public fun dispatcher(context: CoroutineContext)
36+
public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext)
37+
38+
public fun selectorManagerDispatcher(context: CoroutineContext)
39+
public fun selectorManager(manager: SelectorManager, manage: Boolean)
40+
41+
public fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit)
42+
43+
//TODO: TLS support
44+
}
45+
46+
private class KtorTcpClientTransportBuilderImpl : KtorTcpClientTransportBuilder {
47+
private var dispatcher: CoroutineContext = Dispatchers.IO
48+
private var selector: KtorTcpSelector? = null
49+
private var socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit = {}
50+
51+
override fun dispatcher(context: CoroutineContext) {
52+
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
53+
this.dispatcher = context
54+
}
55+
56+
override fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) {
57+
this.socketOptions = block
58+
}
59+
60+
override fun selectorManagerDispatcher(context: CoroutineContext) {
61+
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
62+
this.selector = KtorTcpSelector.FromContext(context)
63+
}
64+
65+
override fun selectorManager(manager: SelectorManager, manage: Boolean) {
66+
this.selector = KtorTcpSelector.FromInstance(manager, manage)
67+
}
68+
69+
@RSocketTransportApi
70+
override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport {
71+
val transportContext = context.supervisorContext() + dispatcher
72+
return KtorTcpClientTransportImpl(
73+
coroutineContext = transportContext,
74+
socketOptions = socketOptions,
75+
selectorManager = selector.createFor(transportContext)
76+
)
77+
}
78+
}
79+
80+
private class KtorTcpClientTransportImpl(
81+
override val coroutineContext: CoroutineContext,
82+
private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit,
83+
private val selectorManager: SelectorManager,
84+
) : KtorTcpClientTransport {
85+
override fun target(remoteAddress: SocketAddress): RSocketClientTarget = KtorTcpClientTargetImpl(
86+
coroutineContext = coroutineContext.supervisorContext(),
87+
socketOptions = socketOptions,
88+
selectorManager = selectorManager,
89+
remoteAddress = remoteAddress
90+
)
91+
92+
override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port))
93+
}
94+
95+
private class KtorTcpClientTargetImpl(
96+
override val coroutineContext: CoroutineContext,
97+
private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit,
98+
private val selectorManager: SelectorManager,
99+
private val remoteAddress: SocketAddress,
100+
) : RSocketClientTarget {
101+
102+
@RSocketTransportApi
103+
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
104+
val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions)
105+
handler.handleKtorTcpConnection(socket)
106+
}
107+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2015-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.rsocket.kotlin.transport.ktor.tcp
18+
19+
import io.ktor.network.sockets.*
20+
import io.ktor.utils.io.*
21+
import io.ktor.utils.io.core.*
22+
import io.rsocket.kotlin.internal.io.*
23+
import io.rsocket.kotlin.transport.*
24+
import io.rsocket.kotlin.transport.internal.*
25+
import kotlinx.coroutines.*
26+
import kotlinx.coroutines.channels.*
27+
28+
@RSocketTransportApi
29+
internal suspend fun RSocketConnectionHandler.handleKtorTcpConnection(socket: Socket): Unit = coroutineScope {
30+
val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED)
31+
val inbound = channelForCloseable<ByteReadPacket>(Channel.BUFFERED)
32+
33+
val readerJob = launch {
34+
val input = socket.openReadChannel()
35+
try {
36+
while (true) inbound.send(input.readFrame() ?: break)
37+
input.cancel(null)
38+
} catch (cause: Throwable) {
39+
input.cancel(cause)
40+
throw cause
41+
}
42+
}.onCompletion { inbound.cancel() }
43+
44+
val writerJob = launch {
45+
val output = socket.openWriteChannel()
46+
try {
47+
while (true) {
48+
// we write all available frames here, and only after it flush
49+
// in this case, if there are several buffered frames we can send them in one go
50+
// avoiding unnecessary flushes
51+
output.writeFrame(outboundQueue.dequeueFrame() ?: break)
52+
while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break)
53+
output.flush()
54+
}
55+
output.close(null)
56+
} catch (cause: Throwable) {
57+
output.close(cause)
58+
throw cause
59+
}
60+
}.onCompletion { outboundQueue.cancel() }
61+
62+
try {
63+
handleConnection(KtorTcpConnection(outboundQueue, inbound))
64+
} finally {
65+
readerJob.cancel()
66+
outboundQueue.close() // will cause `writerJob` completion
67+
// even if it was cancelled, we still need to close socket and await it closure
68+
withContext(NonCancellable) {
69+
// await completion of read/write and then close socket
70+
readerJob.join()
71+
writerJob.join()
72+
// close socket
73+
socket.close()
74+
socket.socketContext.join()
75+
}
76+
}
77+
}
78+
79+
@RSocketTransportApi
80+
private class KtorTcpConnection(
81+
private val outboundQueue: PrioritizationFrameQueue,
82+
private val inbound: ReceiveChannel<ByteReadPacket>,
83+
) : RSocketSequentialConnection {
84+
override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend
85+
override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) {
86+
return outboundQueue.enqueueFrame(streamId, frame)
87+
}
88+
89+
override suspend fun receiveFrame(): ByteReadPacket? {
90+
return inbound.receiveCatching().getOrNull()
91+
}
92+
}
93+
94+
private suspend fun ByteWriteChannel.writeFrame(frame: ByteReadPacket) {
95+
val packet = buildPacket {
96+
writeInt24(frame.remaining.toInt())
97+
writePacket(frame)
98+
}
99+
try {
100+
writePacket(packet)
101+
} catch (cause: Throwable) {
102+
packet.close()
103+
throw cause
104+
}
105+
}
106+
107+
private suspend fun ByteReadChannel.readFrame(): ByteReadPacket? {
108+
val lengthPacket = readRemaining(3)
109+
if (lengthPacket.remaining == 0L) return null
110+
return readPacket(lengthPacket.readInt24())
111+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2015-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.rsocket.kotlin.transport.ktor.tcp
18+
19+
import io.ktor.network.selector.*
20+
import kotlinx.coroutines.*
21+
import kotlin.coroutines.*
22+
23+
internal sealed class KtorTcpSelector {
24+
class FromContext(val context: CoroutineContext) : KtorTcpSelector()
25+
class FromInstance(val selectorManager: SelectorManager, val manage: Boolean) : KtorTcpSelector()
26+
}
27+
28+
internal fun KtorTcpSelector?.createFor(parentContext: CoroutineContext): SelectorManager {
29+
val selectorManager: SelectorManager
30+
val manage: Boolean
31+
when (this) {
32+
null -> {
33+
selectorManager = SelectorManager(parentContext)
34+
manage = true
35+
}
36+
37+
is KtorTcpSelector.FromContext -> {
38+
selectorManager = SelectorManager(parentContext + context)
39+
manage = true
40+
}
41+
42+
is KtorTcpSelector.FromInstance -> {
43+
selectorManager = this.selectorManager
44+
manage = this.manage
45+
}
46+
}
47+
if (manage) Job(parentContext.job).invokeOnCompletion { selectorManager.close() }
48+
return selectorManager
49+
}

0 commit comments

Comments
 (0)