From b5b37c6c922789c2af7c1e3d6d16452d71642f74 Mon Sep 17 00:00:00 2001 From: Yuriy Kulikov Date: Tue, 29 Mar 2022 16:16:31 +0200 Subject: [PATCH 1/3] changed LimitingFlowCollector.requests to AtomicLong This avoids Int overflow when client is misbehaving and is sending multiple RequestN frames with n=Int.MAX_VALUE This closes #213 --- .../io/rsocket/kotlin/internal/Limiter.kt | 55 +++- .../io/rsocket/kotlin/core/RSocketTest.kt | 47 +++ .../internal/RSocketResponderRequestNTest.kt | 270 ++++++++++++++++++ 3 files changed, 362 insertions(+), 10 deletions(-) create mode 100644 rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt index dbf01562d..e5d40801f 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt @@ -32,27 +32,62 @@ internal suspend inline fun Flow.collectLimiting(limiter: Limiter, cros } } -//TODO revisit 2 atomics and sync object +/** + * Maintains the amount of requests which the client is ready to consume and + * prevents sending further updates by suspending the sending coroutine + * if this amount reaches 0. + * + * ### Operation + * + * Each [useRequest] call decrements the maintained requests amount. + * Calling coroutine is suspended when this amount reaches 0. + * The coroutine is resumed when [updateRequests] is called. + * + * ### Unbounded mode + * + * Limiter enters an unbounded mode when: + * * [Limiter] is created passing `Int.MAX_VALUE` as `initial` + * * client sends a `RequestN` frame with `Int.MAX_VALUE` + * * Internal Long counter overflows + * + * In unbounded mode Limiter will assume that the client + * is able to process requests without limitations, all further + * [updateRequests] will be NOP and [useRequest] will never suspend. + */ internal class Limiter(initial: Int) : SynchronizedObject() { - private val requests = atomic(initial) - private val awaiter = atomic?>(null) + private val requests: AtomicLong = atomic(initial.toLong()) + private val unbounded: AtomicBoolean = atomic(initial == Int.MAX_VALUE) + private var awaiter: CancellableContinuation? = null fun updateRequests(n: Int) { - if (n <= 0) return + if (n <= 0 || unbounded.value) return synchronized(this) { - requests += n - awaiter.getAndSet(null)?.takeIf(CancellableContinuation::isActive)?.resume(Unit) + val updatedRequests = requests.value + n.toLong() + if (updatedRequests < 0) { + unbounded.value = true + requests.value = Long.MAX_VALUE + } else { + requests.value = updatedRequests + } + + if (awaiter?.isActive == true) { + awaiter?.resume(Unit) + awaiter = null + } } } suspend fun useRequest() { - if (requests.getAndDecrement() > 0) { + if (unbounded.value || requests.decrementAndGet() >= 0) { currentCoroutineContext().ensureActive() } else { - suspendCancellableCoroutine { + suspendCancellableCoroutine { continuation -> synchronized(this) { - awaiter.value = it - if (requests.value >= 0 && it.isActive) it.resume(Unit) + if (requests.value >= 0 && continuation.isActive) { + continuation.resume(Unit) + } else { + this.awaiter = continuation + } } } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index 52c9f3bbe..0d0c7aa3c 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -23,6 +23,7 @@ import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* import io.rsocket.kotlin.transport.local.* +import kotlinx.atomicfu.atomic import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* @@ -192,6 +193,52 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { assertTrue(channel.receiveCatching().isClosed) } + @Test + fun testStreamInitialUnbounded() = test { + val requester = start(RSocketRequestHandler { + requestStream { + (0..9).asFlow().map { + payload(it.toString()) + } + } + }) + requester.requestStream(payload("HELLO")) + .flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)) + .test { + repeat(10) { + awaitItem().close() + } + awaitComplete() + } + } + + @Test + fun testStreamRequestNUnbounded() = test { + class UnboundedAfterNStrategy(private val initial: Int) : RequestStrategy { + override fun provide(): RequestStrategy.Element = Element() + inner class Element : RequestStrategy.Element { + private val requested = atomic(initial) + override suspend fun firstRequest(): Int = initial + override suspend fun nextRequest(): Int { + val requestUnbounded = requested.getAndDecrement() == 0 + return if (requestUnbounded) Int.MAX_VALUE else 0 + } + } + } + + start(RSocketRequestHandler { + requestStream { + (0..9).asFlow().map { payload(it.toString()) } + } + }) + .requestStream(payload("HELLO")) + .flowOn(UnboundedAfterNStrategy(initial = 5)) + .test { + repeat(10) { awaitItem().close() } + awaitComplete() + } + } + @Test fun testChannel() = test { val awaiter = Job() diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt new file mode 100644 index 000000000..7d080fc9a --- /dev/null +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt @@ -0,0 +1,270 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import app.cash.turbine.FlowTurbine +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.frame.io.Version +import io.rsocket.kotlin.keepalive.DefaultKeepAlive +import io.rsocket.kotlin.payload.DefaultPayloadMimeType +import io.rsocket.kotlin.payload.buildPayload +import io.rsocket.kotlin.payload.data +import io.rsocket.kotlin.test.TestExceptionHandler +import io.rsocket.kotlin.test.TestServer +import io.rsocket.kotlin.test.TestWithLeakCheck +import io.rsocket.kotlin.test.payload +import io.rsocket.kotlin.transport.ServerTransport +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { + private val testJob: Job = Job() + + private suspend fun start(handler: RSocket) { + val serverTransport = ServerTransport { accept -> + GlobalScope.async { accept(connection) } + } + + val scope = CoroutineScope(Dispatchers.Unconfined + testJob + TestExceptionHandler) + @Suppress("DeferredResultUnused") + TestServer().bindIn(scope, serverTransport) { + config.setupPayload.close() + handler + } + } + + override suspend fun after() { + super.after() + testJob.cancelAndJoin() + } + + private val setupFrame + get() = SetupFrame( + version = Version.Current, + honorLease = false, + keepAlive = DefaultKeepAlive, + resumeToken = null, + payloadMimeType = DefaultPayloadMimeType, + payload = payload("setup"), + ) + + @Test + fun testStreamInitialEnoughToConsume() = test(timeout = 10.seconds) { + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + (0..9).asFlow().map { buildPayload { data("$it") } } + } + } + ) + + connection.test { + connection.sendToReceiver(setupFrame) + + connection.sendToReceiver(RequestStreamFrame(initialRequestN = 16, streamId = 1, payload = payload("request"))) + + awaitAndReleasePayloadFrames(amount = 10) + awaitCompleteFrame() + expectNoEventsIn(200) + } + } + + @Test + fun testStreamSuspendWhenNoRequestsLeft() = test(timeout = 10.seconds) { + var lastSent = -1 + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + (0..9).asFlow() + .onEach { lastSent = it } + .map { buildPayload { data("$it") } } + } + } + ) + + connection.test { + connection.sendToReceiver(setupFrame) + + connection.sendToReceiver(RequestStreamFrame(initialRequestN = 3, streamId = 1, payload = payload("request"))) + + awaitAndReleasePayloadFrames(amount = 3) + expectNoEventsIn(200) + assertEquals(3, lastSent) + } + } + + @Test + fun testStreamRequestNFrameResumesOperation() = test(timeout = 10.seconds) { + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + (0..15).asFlow().map { buildPayload { data("$it") } } + } + } + ) + connection.test { + connection.sendToReceiver(setupFrame) + + connection.sendToReceiver(RequestStreamFrame(initialRequestN = 3, streamId = 1, payload = payload("request"))) + awaitAndReleasePayloadFrames(amount = 3) + expectNoEventsIn(200) + + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = 5)) + awaitAndReleasePayloadFrames(amount = 5) + expectNoEventsIn(200) + + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = 5)) + awaitAndReleasePayloadFrames(amount = 5) + expectNoEventsIn(200) + } + } + + @Test + fun testStreamInitialUnbounded() = test(timeout = 10.seconds) { + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + (0..19).asFlow().map { buildPayload { data("$it") } } + } + } + ) + connection.test { + connection.sendToReceiver(setupFrame) + + connection.sendToReceiver(RequestStreamFrame(initialRequestN = Int.MAX_VALUE, streamId = 1, payload = payload("request"))) + + awaitAndReleasePayloadFrames(amount = 20) + awaitCompleteFrame() + expectNoEventsIn(200) + } + } + + @Test + fun testStreamRequestNUnbounded() = test(timeout = 10.seconds) { + val total = 20 + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + (0 until total).asFlow().map { buildPayload { data("$it") } } + } + } + ) + connection.test { + connection.sendToReceiver(setupFrame) + + val firstRequest = 3 + connection.sendToReceiver(RequestStreamFrame(initialRequestN = firstRequest, streamId = 1, payload = payload("request"))) + awaitAndReleasePayloadFrames(amount = firstRequest) + expectNoEventsIn(200) + + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE)) + awaitAndReleasePayloadFrames(amount = total - firstRequest) + awaitCompleteFrame() + expectNoEventsIn(200) + } + } + + @Test + fun testStreamRequestNUnboundedWithOverflow() = test(timeout = 10.seconds) { + val latch = Channel(1) + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + latch.receive() + // make sure limiter has got the RequestNFrame before emitting the values + delay(200) + (0..19).asFlow().map { buildPayload { data("$it") } } + } + } + ) + connection.test { + connection.sendToReceiver(setupFrame) + + connection.sendToReceiver(RequestStreamFrame(initialRequestN = Int.MAX_VALUE, streamId = 1, payload = payload("request"))) + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE)) + latch.send(Unit) + + awaitAndReleasePayloadFrames(amount = 20) + awaitCompleteFrame() + expectNoEventsIn(200) + } + } + + + @Test + fun testStreamRequestNUnboundedSummingUpToOverflow() = test(timeout = 10.seconds) { + val latch = Channel(1) + start( + RSocketRequestHandler { + requestStream { payload -> + payload.close() + latch.receive() + // make sure limiter has got the RequestNFrame before emitting the values + delay(200) + (0..19).asFlow().map { buildPayload { data("$it") } } + } + } + ) + + connection.test { + connection.sendToReceiver(setupFrame) + + connection.sendToReceiver(RequestStreamFrame(initialRequestN = 5, streamId = 1, payload = payload("request"))) + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3)) + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3)) + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3)) + connection.sendToReceiver(RequestNFrame(streamId = 1, requestN = Int.MAX_VALUE / 3)) + latch.send(Unit) + + awaitAndReleasePayloadFrames(amount = 20) + awaitCompleteFrame() + expectNoEventsIn(200) + } + } + + private suspend fun FlowTurbine.awaitAndReleasePayloadFrames(amount: Int) { + repeat(amount) { + awaitFrame { frame -> + assertTrue(frame is RequestFrame) + assertEquals(FrameType.Payload, frame.type) + frame.payload.close() + } + } + } + + private suspend fun FlowTurbine.awaitCompleteFrame() { + awaitFrame { frame -> + assertTrue(frame is RequestFrame) + assertEquals(FrameType.Payload, frame.type) + assertTrue(frame.complete, "Frame should be complete") + } + } +} From 242deed181f12cefeefc84bd01c7b9d077833670 Mon Sep 17 00:00:00 2001 From: Yuriy Kulikov Date: Tue, 29 Mar 2022 19:03:15 +0200 Subject: [PATCH 2/3] Removed Limiter.unbounded Related to #213 --- .../io/rsocket/kotlin/internal/Limiter.kt | 16 ++------- .../internal/RSocketResponderRequestNTest.kt | 33 ++++--------------- 2 files changed, 8 insertions(+), 41 deletions(-) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt index e5d40801f..52b68e24c 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt @@ -43,28 +43,16 @@ internal suspend inline fun Flow.collectLimiting(limiter: Limiter, cros * Calling coroutine is suspended when this amount reaches 0. * The coroutine is resumed when [updateRequests] is called. * - * ### Unbounded mode - * - * Limiter enters an unbounded mode when: - * * [Limiter] is created passing `Int.MAX_VALUE` as `initial` - * * client sends a `RequestN` frame with `Int.MAX_VALUE` - * * Internal Long counter overflows - * - * In unbounded mode Limiter will assume that the client - * is able to process requests without limitations, all further - * [updateRequests] will be NOP and [useRequest] will never suspend. */ internal class Limiter(initial: Int) : SynchronizedObject() { private val requests: AtomicLong = atomic(initial.toLong()) - private val unbounded: AtomicBoolean = atomic(initial == Int.MAX_VALUE) private var awaiter: CancellableContinuation? = null fun updateRequests(n: Int) { - if (n <= 0 || unbounded.value) return + if (n <= 0) return synchronized(this) { val updatedRequests = requests.value + n.toLong() if (updatedRequests < 0) { - unbounded.value = true requests.value = Long.MAX_VALUE } else { requests.value = updatedRequests @@ -78,7 +66,7 @@ internal class Limiter(initial: Int) : SynchronizedObject() { } suspend fun useRequest() { - if (unbounded.value || requests.decrementAndGet() >= 0) { + if (requests.decrementAndGet() >= 0) { currentCoroutineContext().ensureActive() } else { suspendCancellableCoroutine { continuation -> diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt index 7d080fc9a..0d073f4b3 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt @@ -71,7 +71,7 @@ class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { ) @Test - fun testStreamInitialEnoughToConsume() = test(timeout = 10.seconds) { + fun testStreamInitialEnoughToConsume() = test { start( RSocketRequestHandler { requestStream { payload -> @@ -93,7 +93,7 @@ class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { } @Test - fun testStreamSuspendWhenNoRequestsLeft() = test(timeout = 10.seconds) { + fun testStreamSuspendWhenNoRequestsLeft() = test { var lastSent = -1 start( RSocketRequestHandler { @@ -118,7 +118,7 @@ class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { } @Test - fun testStreamRequestNFrameResumesOperation() = test(timeout = 10.seconds) { + fun testStreamRequestNFrameResumesOperation() = test { start( RSocketRequestHandler { requestStream { payload -> @@ -145,28 +145,7 @@ class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { } @Test - fun testStreamInitialUnbounded() = test(timeout = 10.seconds) { - start( - RSocketRequestHandler { - requestStream { payload -> - payload.close() - (0..19).asFlow().map { buildPayload { data("$it") } } - } - } - ) - connection.test { - connection.sendToReceiver(setupFrame) - - connection.sendToReceiver(RequestStreamFrame(initialRequestN = Int.MAX_VALUE, streamId = 1, payload = payload("request"))) - - awaitAndReleasePayloadFrames(amount = 20) - awaitCompleteFrame() - expectNoEventsIn(200) - } - } - - @Test - fun testStreamRequestNUnbounded() = test(timeout = 10.seconds) { + fun testStreamRequestNEnoughToComplete() = test { val total = 20 start( RSocketRequestHandler { @@ -192,7 +171,7 @@ class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { } @Test - fun testStreamRequestNUnboundedWithOverflow() = test(timeout = 10.seconds) { + fun testStreamRequestNAttemptedIntOverflow() = test { val latch = Channel(1) start( RSocketRequestHandler { @@ -220,7 +199,7 @@ class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { @Test - fun testStreamRequestNUnboundedSummingUpToOverflow() = test(timeout = 10.seconds) { + fun testStreamRequestNSummingUpToOverflow() = test { val latch = Channel(1) start( RSocketRequestHandler { From 64e61be6f0bda169ace299cb94ecfe33f1a2b2f5 Mon Sep 17 00:00:00 2001 From: Yuriy Kulikov Date: Wed, 30 Mar 2022 10:24:08 +0200 Subject: [PATCH 3/3] Removed UnboundedAfterNStrategy --- .../io/rsocket/kotlin/core/RSocketTest.kt | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index 0d0c7aa3c..fa948300a 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -194,7 +194,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } @Test - fun testStreamInitialUnbounded() = test { + fun testStreamInitialMaxValue() = test { val requester = start(RSocketRequestHandler { requestStream { (0..9).asFlow().map { @@ -213,26 +213,14 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } @Test - fun testStreamRequestNUnbounded() = test { - class UnboundedAfterNStrategy(private val initial: Int) : RequestStrategy { - override fun provide(): RequestStrategy.Element = Element() - inner class Element : RequestStrategy.Element { - private val requested = atomic(initial) - override suspend fun firstRequest(): Int = initial - override suspend fun nextRequest(): Int { - val requestUnbounded = requested.getAndDecrement() == 0 - return if (requestUnbounded) Int.MAX_VALUE else 0 - } - } - } - + fun testStreamRequestN() = test { start(RSocketRequestHandler { requestStream { (0..9).asFlow().map { payload(it.toString()) } } }) .requestStream(payload("HELLO")) - .flowOn(UnboundedAfterNStrategy(initial = 5)) + .flowOn(PrefetchStrategy(5, 3)) .test { repeat(10) { awaitItem().close() } awaitComplete()