Skip to content

Commit c1f1f98

Browse files
mostroverkhovrobertroeser
authored andcommitted
Upstream bug fixes (#16)
* fix bug when RSocket requester streams does not get terminal signal on connection close * fix bug: RSocketClient requests after close were not throwing exception
1 parent ffa31ca commit c1f1f98

File tree

2 files changed

+163
-94
lines changed

2 files changed

+163
-94
lines changed

rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt

Lines changed: 90 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import io.reactivex.Completable
2222
import io.reactivex.Flowable
2323
import io.reactivex.Single
2424
import io.reactivex.disposables.Disposable
25-
import io.reactivex.functions.Action
26-
import io.reactivex.functions.Consumer
2725
import io.reactivex.processors.FlowableProcessor
2826
import io.reactivex.processors.PublishProcessor
2927
import io.reactivex.processors.UnicastProcessor
@@ -63,6 +61,8 @@ internal class RSocketClient @JvmOverloads constructor(
6361
private val senders: IntObjectHashMap<LimitableRequestPublisher<*>> = IntObjectHashMap(256, 0.9f)
6462
private val receivers: IntObjectHashMap<Subscriber<Payload>> = IntObjectHashMap(256, 0.9f)
6563
private val missedAckCounter: AtomicInteger = AtomicInteger()
64+
@Volatile
65+
private var errorSignal: Throwable? = null
6666

6767
private val sendProcessor: FlowableProcessor<Frame> = PublishProcessor
6868
.create<Frame>()
@@ -99,70 +99,79 @@ internal class RSocketClient @JvmOverloads constructor(
9999
connection
100100
.receive()
101101
.doOnSubscribe { started.onComplete() }
102-
.subscribe({ handleIncomingFrames(it) },errorConsumer)
102+
.subscribe({ handleIncomingFrames(it) }, errorConsumer)
103103
}
104104

105105
private fun handleSendProcessorError(t: Throwable) {
106-
val (receivers, senders) = synchronized(this) {
107-
Pair(receivers.values, senders.values)
108-
}
109-
for (subscriber in receivers) {
110-
try {
111-
subscriber.onError(t)
112-
} catch (e: Throwable) {
113-
errorConsumer(e)
114-
}
115-
}
116-
117-
for (p in senders) {
118-
p.cancel()
106+
synchronized(this) {
107+
receivers.values.forEach { it.onError(t) }
108+
senders.values.forEach { it.cancel() }
119109
}
120110
}
121111

122112
private fun sendKeepAlive(ackTimeoutMs: Long, missedAcks: Int): Completable {
123113
return Completable.fromRunnable {
124-
val now = System.currentTimeMillis()
125-
if (now - timeLastTickSentMs > ackTimeoutMs) {
126-
val count = missedAckCounter.incrementAndGet()
127-
if (count >= missedAcks) {
128-
val message = String.format(
129-
"Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
130-
count, missedAcks, ackTimeoutMs)
131-
throw ConnectionException(message)
132-
}
133-
}
134-
135-
sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true))
114+
val now = System.currentTimeMillis()
115+
if (now - timeLastTickSentMs > ackTimeoutMs) {
116+
val count = missedAckCounter.incrementAndGet()
117+
if (count >= missedAcks) {
118+
val message = String.format(
119+
"Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
120+
count, missedAcks, ackTimeoutMs)
121+
throw ConnectionException(message)
136122
}
123+
}
124+
125+
sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true))
126+
}
137127
}
138128

139129
override fun fireAndForget(payload: Payload): Completable {
140130
val defer = Completable.fromRunnable {
141-
val streamId = streamIdSupplier.nextStreamId()
142-
val requestFrame = Frame.Request.from(
143-
streamId,
144-
FrameType.FIRE_AND_FORGET,
145-
payload,
146-
1)
147-
sendProcessor.onNext(requestFrame)
148-
}
131+
val streamId = streamIdSupplier.nextStreamId()
132+
val requestFrame = Frame.Request.from(
133+
streamId,
134+
FrameType.FIRE_AND_FORGET,
135+
payload,
136+
1)
137+
sendProcessor.onNext(requestFrame)
138+
}
149139

150-
return completeOnStart.andThen(defer)
140+
return errorSignal
141+
?.let { Completable.error(it) }
142+
?: completeOnStart.andThen(defer)
151143
}
152144

153145
override fun requestResponse(payload: Payload): Single<Payload> =
154-
handleRequestResponse(payload)
146+
errorSignal
147+
?.let { Single.error<Payload>(it) }
148+
?: handleRequestResponse(payload)
155149

156150
override fun requestStream(payload: Payload): Flowable<Payload> =
157-
handleRequestStream(payload).rebatchRequests(streamDemandLimit)
151+
errorSignal
152+
?.let { Flowable.error<Payload>(it) }
153+
?: handleRequestStream(payload).rebatchRequests(streamDemandLimit)
158154

159155
override fun requestChannel(payloads: Publisher<Payload>): Flowable<Payload> =
160-
handleChannel(
156+
errorSignal
157+
?.let { Flowable.error<Payload>(it) }
158+
?: handleChannel(
161159
Flowable.fromPublisher(payloads).rebatchRequests(streamDemandLimit),
162160
FrameType.REQUEST_CHANNEL
163161
).rebatchRequests(streamDemandLimit)
164162

165-
override fun metadataPush(payload: Payload): Completable {
163+
override fun metadataPush(payload: Payload): Completable =
164+
errorSignal
165+
?.let { Completable.error(it) }
166+
?: handleMetadataPush(payload)
167+
168+
override fun availability(): Double = connection.availability()
169+
170+
override fun close(): Completable = connection.close()
171+
172+
override fun onClose(): Completable = connection.onClose()
173+
174+
private fun handleMetadataPush(payload: Payload): Completable {
166175
val requestFrame = Frame.Request.from(
167176
0,
168177
FrameType.METADATA_PUSH,
@@ -172,44 +181,38 @@ internal class RSocketClient @JvmOverloads constructor(
172181
return Completable.complete()
173182
}
174183

175-
override fun availability(): Double = connection.availability()
176-
177-
override fun close(): Completable = connection.close()
178-
179-
override fun onClose(): Completable = connection.onClose()
180-
181184
private fun handleRequestStream(payload: Payload): Flowable<Payload> {
182185
return completeOnStart.andThen(
183186
Flowable.defer {
184-
val streamId = streamIdSupplier.nextStreamId()
185-
val receiver = UnicastProcessor.create<Payload>()
186-
synchronized(this) {
187-
receivers.put(streamId, receiver)
188-
}
187+
val streamId = streamIdSupplier.nextStreamId()
188+
val receiver = UnicastProcessor.create<Payload>()
189+
synchronized(this) {
190+
receivers.put(streamId, receiver)
191+
}
189192

190-
val first = AtomicBoolean(false)
193+
val first = AtomicBoolean(false)
191194

192-
receiver
193-
.doOnRequest{ l ->
194-
if (first.compareAndSet(false, true) && !receiver.isTerminated()) {
195-
val requestFrame = Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l)
196-
sendProcessor.onNext(requestFrame)
197-
} else if (contains(streamId)) {
198-
sendProcessor.onNext(Frame.RequestN.from(streamId, l))
199-
}
200-
}
201-
.doOnError { t ->
202-
if (contains(streamId) && !receiver.isTerminated()) {
203-
sendProcessor.onNext(Frame.Error.from(streamId, t))
204-
}
205-
}
206-
.doOnCancel {
207-
if (contains(streamId) && !receiver.isTerminated()) {
208-
sendProcessor.onNext(Frame.Cancel.from(streamId))
209-
}
210-
}
211-
.doFinally { removeReceiver(streamId) }
212-
})
195+
receiver
196+
.doOnRequest { l ->
197+
if (first.compareAndSet(false, true) && !receiver.isTerminated()) {
198+
val requestFrame = Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l)
199+
sendProcessor.onNext(requestFrame)
200+
} else if (contains(streamId)) {
201+
sendProcessor.onNext(Frame.RequestN.from(streamId, l))
202+
}
203+
}
204+
.doOnError { t ->
205+
if (contains(streamId) && !receiver.isTerminated()) {
206+
sendProcessor.onNext(Frame.Error.from(streamId, t))
207+
}
208+
}
209+
.doOnCancel {
210+
if (contains(streamId) && !receiver.isTerminated()) {
211+
sendProcessor.onNext(Frame.Cancel.from(streamId))
212+
}
213+
}
214+
.doFinally { removeReceiver(streamId) }
215+
})
213216
}
214217

215218
private fun handleRequestResponse(payload: Payload): Single<Payload> {
@@ -228,8 +231,8 @@ internal class RSocketClient @JvmOverloads constructor(
228231
sendProcessor.onNext(requestFrame)
229232

230233
receiver
231-
.doOnError{ t -> sendProcessor.onNext(Frame.Error.from(streamId, t)) }
232-
.doOnCancel{ sendProcessor.onNext(Frame.Cancel.from(streamId)) }
234+
.doOnError { t -> sendProcessor.onNext(Frame.Error.from(streamId, t)) }
235+
.doOnCancel { sendProcessor.onNext(Frame.Cancel.from(streamId)) }
233236
.doFinally { removeReceiver(streamId) }
234237
.firstOrError()
235238
}))
@@ -302,11 +305,11 @@ internal class RSocketClient @JvmOverloads constructor(
302305
requestFrames
303306
.doOnNext { sendProcessor.onNext(it) }
304307
.subscribe(
305-
{},
306-
{ t ->
307-
errorConsumer(t)
308-
receiver.onError(CancellationException("Disposed"))
309-
})
308+
{},
309+
{ t ->
310+
errorConsumer(t)
311+
receiver.onError(CancellationException("Disposed"))
312+
})
310313
} else {
311314
sendOneFrame(Frame.RequestN.from(streamId, l))
312315
}
@@ -330,23 +333,15 @@ internal class RSocketClient @JvmOverloads constructor(
330333
}
331334

332335
private fun cleanup() {
336+
errorSignal = CLOSED_CHANNEL_EXCEPTION
333337

334-
var subscribers: Collection<Subscriber<Payload>>
335-
var publishers: Collection<LimitableRequestPublisher<*>>
336-
val (subs, pubs) = synchronized(this) {
337-
338-
subscribers = receivers.values
339-
publishers = senders.values
338+
synchronized(this) {
339+
receivers.values.forEach { cleanUpSubscriber(it) }
340+
senders.values.forEach { cleanUpLimitableRequestPublisher(it) }
340341

341-
senders.clear()
342342
receivers.clear()
343-
344-
Pair(subscribers,publishers)
343+
senders.clear()
345344
}
346-
347-
subs.forEach { cleanUpSubscriber(it) }
348-
pubs.forEach { cleanUpLimitableRequestPublisher(it) }
349-
350345
keepAliveSendSub?.dispose()
351346
}
352347

@@ -485,5 +480,6 @@ internal class RSocketClient @JvmOverloads constructor(
485480
private val CLOSED_CHANNEL_EXCEPTION = noStacktrace(ClosedChannelException())
486481
private val DEFAULT_STREAM_WINDOW = 128
487482
}
483+
488484
private fun <T> UnicastProcessor<T>.isTerminated(): Boolean = hasComplete() || hasThrowable()
489485
}

rsocket-core/src/test/java/io/rsocket/android/RSocketClientTest.kt

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
package io.rsocket.android
1919

2020
import io.reactivex.Completable
21+
import io.reactivex.Flowable
22+
import io.reactivex.Single
23+
import io.reactivex.internal.observers.BlockingMultiObserver
2124
import io.reactivex.processors.PublishProcessor
2225
import io.reactivex.subscribers.TestSubscriber
2326
import io.rsocket.android.exceptions.ApplicationException
@@ -33,6 +36,7 @@ import org.junit.Test
3336
import org.junit.rules.ExternalResource
3437
import org.junit.runner.Description
3538
import org.junit.runners.model.Statement
39+
import java.nio.channels.ClosedChannelException
3640
import java.util.concurrent.TimeUnit
3741

3842
class RSocketClientTest {
@@ -171,6 +175,75 @@ class RSocketClientTest {
171175
assertThat("Stream ID reused.", streamId2, not(equalTo(streamId1)))
172176
}
173177

178+
@Test(timeout = 3_000)
179+
fun requestErrorOnConnectionClose() {
180+
Completable.timer(100, TimeUnit.MILLISECONDS)
181+
.andThen { rule.conn.close() }.subscribe()
182+
val requestStream = rule.client.requestStream(PayloadImpl("test"))
183+
val subs = TestSubscriber.create<Payload>()
184+
requestStream.blockingSubscribe(subs)
185+
subs.assertNoValues()
186+
subs.assertError { it is ClosedChannelException }
187+
}
188+
189+
@Test(timeout = 5_000)
190+
fun streamErrorAfterConnectionClose() {
191+
assertFlowableError { it.requestStream(PayloadImpl("test")) }
192+
}
193+
194+
@Test(timeout = 5_000)
195+
fun reqStreamErrorAfterConnectionClose() {
196+
assertFlowableError { it.requestStream(PayloadImpl("test")) }
197+
}
198+
199+
@Test(timeout = 5_000)
200+
fun reqChannelErrorAfterConnectionClose() {
201+
assertFlowableError { it.requestChannel(Flowable.just(PayloadImpl("test"))) }
202+
}
203+
204+
@Test(timeout = 5_000)
205+
fun reqResponseErrorAfterConnectionClose() {
206+
assertSingleError { it.requestResponse(PayloadImpl("test")) }
207+
}
208+
209+
@Test(timeout = 5_000)
210+
fun fnfErrorAfterConnectionClose() {
211+
assertCompletableError { it.fireAndForget(PayloadImpl("test")) }
212+
}
213+
214+
@Test(timeout = 5_000)
215+
fun metadataPushAfterConnectionClose() {
216+
assertCompletableError { it.metadataPush(PayloadImpl("test")) }
217+
}
218+
219+
private fun assertFlowableError(f: (RSocket) -> Flowable<Payload>) {
220+
rule.conn.close().subscribe()
221+
val subs = TestSubscriber.create<Payload>()
222+
f(rule.client).delaySubscription(100, TimeUnit.MILLISECONDS).blockingSubscribe(subs)
223+
subs.assertNoValues()
224+
subs.assertError { it is ClosedChannelException }
225+
}
226+
227+
private fun assertCompletableError(f: (RSocket) -> Completable) {
228+
rule.conn.close().subscribe()
229+
val requestStream = Completable
230+
.timer(100, TimeUnit.MILLISECONDS)
231+
.andThen(f(rule.client))
232+
val err = requestStream.blockingGet()
233+
assertThat("error is not ClosedChannelException",
234+
err is ClosedChannelException)
235+
}
236+
237+
private fun assertSingleError(f: (RSocket) -> Single<Payload>) {
238+
rule.conn.close().subscribe()
239+
val response = f(rule.client).delaySubscription(100, TimeUnit.MILLISECONDS)
240+
val subs = BlockingMultiObserver<Payload>()
241+
response.subscribe(subs)
242+
val err = subs.blockingGetError()
243+
assertThat("error is not ClosedChannelException", err is ClosedChannelException)
244+
}
245+
246+
174247
class ClientSocketRule : ExternalResource() {
175248
lateinit var sender: PublishProcessor<Frame>
176249
lateinit var receiver: PublishProcessor<Frame>

0 commit comments

Comments
 (0)