@@ -22,8 +22,6 @@ import io.reactivex.Completable
2222import io.reactivex.Flowable
2323import io.reactivex.Single
2424import io.reactivex.disposables.Disposable
25- import io.reactivex.functions.Action
26- import io.reactivex.functions.Consumer
2725import io.reactivex.processors.FlowableProcessor
2826import io.reactivex.processors.PublishProcessor
2927import io.reactivex.processors.UnicastProcessor
@@ -63,6 +61,8 @@ internal class RSocketClient @JvmOverloads constructor(
6361 private val senders: IntObjectHashMap <LimitableRequestPublisher <* >> = IntObjectHashMap (256 , 0.9f )
6462 private val receivers: IntObjectHashMap <Subscriber <Payload >> = IntObjectHashMap (256 , 0.9f )
6563 private val missedAckCounter: AtomicInteger = AtomicInteger ()
64+ @Volatile
65+ private var errorSignal: Throwable ? = null
6666
6767 private val sendProcessor: FlowableProcessor <Frame > = PublishProcessor
6868 .create<Frame >()
@@ -99,70 +99,79 @@ internal class RSocketClient @JvmOverloads constructor(
9999 connection
100100 .receive()
101101 .doOnSubscribe { started.onComplete() }
102- .subscribe({ handleIncomingFrames(it) },errorConsumer)
102+ .subscribe({ handleIncomingFrames(it) }, errorConsumer)
103103 }
104104
105105 private fun handleSendProcessorError (t : Throwable ) {
106- val (receivers, senders) = synchronized(this ) {
107- Pair (receivers.values, senders.values)
108- }
109- for (subscriber in receivers) {
110- try {
111- subscriber.onError(t)
112- } catch (e: Throwable ) {
113- errorConsumer(e)
114- }
115- }
116-
117- for (p in senders) {
118- p.cancel()
106+ synchronized(this ) {
107+ receivers.values.forEach { it.onError(t) }
108+ senders.values.forEach { it.cancel() }
119109 }
120110 }
121111
122112 private fun sendKeepAlive (ackTimeoutMs : Long , missedAcks : Int ): Completable {
123113 return Completable .fromRunnable {
124- val now = System .currentTimeMillis()
125- if (now - timeLastTickSentMs > ackTimeoutMs) {
126- val count = missedAckCounter.incrementAndGet()
127- if (count >= missedAcks) {
128- val message = String .format(
129- " Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms" ,
130- count, missedAcks, ackTimeoutMs)
131- throw ConnectionException (message)
132- }
133- }
134-
135- sendProcessor.onNext(Frame .Keepalive .from(Unpooled .EMPTY_BUFFER , true ))
114+ val now = System .currentTimeMillis()
115+ if (now - timeLastTickSentMs > ackTimeoutMs) {
116+ val count = missedAckCounter.incrementAndGet()
117+ if (count >= missedAcks) {
118+ val message = String .format(
119+ " Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms" ,
120+ count, missedAcks, ackTimeoutMs)
121+ throw ConnectionException (message)
136122 }
123+ }
124+
125+ sendProcessor.onNext(Frame .Keepalive .from(Unpooled .EMPTY_BUFFER , true ))
126+ }
137127 }
138128
139129 override fun fireAndForget (payload : Payload ): Completable {
140130 val defer = Completable .fromRunnable {
141- val streamId = streamIdSupplier.nextStreamId()
142- val requestFrame = Frame .Request .from(
143- streamId,
144- FrameType .FIRE_AND_FORGET ,
145- payload,
146- 1 )
147- sendProcessor.onNext(requestFrame)
148- }
131+ val streamId = streamIdSupplier.nextStreamId()
132+ val requestFrame = Frame .Request .from(
133+ streamId,
134+ FrameType .FIRE_AND_FORGET ,
135+ payload,
136+ 1 )
137+ sendProcessor.onNext(requestFrame)
138+ }
149139
150- return completeOnStart.andThen(defer)
140+ return errorSignal
141+ ?.let { Completable .error(it) }
142+ ? : completeOnStart.andThen(defer)
151143 }
152144
153145 override fun requestResponse (payload : Payload ): Single <Payload > =
154- handleRequestResponse(payload)
146+ errorSignal
147+ ?.let { Single .error<Payload >(it) }
148+ ? : handleRequestResponse(payload)
155149
156150 override fun requestStream (payload : Payload ): Flowable <Payload > =
157- handleRequestStream(payload).rebatchRequests(streamDemandLimit)
151+ errorSignal
152+ ?.let { Flowable .error<Payload >(it) }
153+ ? : handleRequestStream(payload).rebatchRequests(streamDemandLimit)
158154
159155 override fun requestChannel (payloads : Publisher <Payload >): Flowable <Payload > =
160- handleChannel(
156+ errorSignal
157+ ?.let { Flowable .error<Payload >(it) }
158+ ? : handleChannel(
161159 Flowable .fromPublisher(payloads).rebatchRequests(streamDemandLimit),
162160 FrameType .REQUEST_CHANNEL
163161 ).rebatchRequests(streamDemandLimit)
164162
165- override fun metadataPush (payload : Payload ): Completable {
163+ override fun metadataPush (payload : Payload ): Completable =
164+ errorSignal
165+ ?.let { Completable .error(it) }
166+ ? : handleMetadataPush(payload)
167+
168+ override fun availability (): Double = connection.availability()
169+
170+ override fun close (): Completable = connection.close()
171+
172+ override fun onClose (): Completable = connection.onClose()
173+
174+ private fun handleMetadataPush (payload : Payload ): Completable {
166175 val requestFrame = Frame .Request .from(
167176 0 ,
168177 FrameType .METADATA_PUSH ,
@@ -172,44 +181,38 @@ internal class RSocketClient @JvmOverloads constructor(
172181 return Completable .complete()
173182 }
174183
175- override fun availability (): Double = connection.availability()
176-
177- override fun close (): Completable = connection.close()
178-
179- override fun onClose (): Completable = connection.onClose()
180-
181184 private fun handleRequestStream (payload : Payload ): Flowable <Payload > {
182185 return completeOnStart.andThen(
183186 Flowable .defer {
184- val streamId = streamIdSupplier.nextStreamId()
185- val receiver = UnicastProcessor .create<Payload >()
186- synchronized(this ) {
187- receivers.put(streamId, receiver)
188- }
187+ val streamId = streamIdSupplier.nextStreamId()
188+ val receiver = UnicastProcessor .create<Payload >()
189+ synchronized(this ) {
190+ receivers.put(streamId, receiver)
191+ }
189192
190- val first = AtomicBoolean (false )
193+ val first = AtomicBoolean (false )
191194
192- receiver
193- .doOnRequest{ l ->
194- if (first.compareAndSet(false , true ) && ! receiver.isTerminated()) {
195- val requestFrame = Frame .Request .from(streamId, FrameType .REQUEST_STREAM , payload, l)
196- sendProcessor.onNext(requestFrame)
197- } else if (contains(streamId)) {
198- sendProcessor.onNext(Frame .RequestN .from(streamId, l))
199- }
200- }
201- .doOnError { t ->
202- if (contains(streamId) && ! receiver.isTerminated()) {
203- sendProcessor.onNext(Frame .Error .from(streamId, t))
204- }
205- }
206- .doOnCancel {
207- if (contains(streamId) && ! receiver.isTerminated()) {
208- sendProcessor.onNext(Frame .Cancel .from(streamId))
209- }
210- }
211- .doFinally { removeReceiver(streamId) }
212- })
195+ receiver
196+ .doOnRequest { l ->
197+ if (first.compareAndSet(false , true ) && ! receiver.isTerminated()) {
198+ val requestFrame = Frame .Request .from(streamId, FrameType .REQUEST_STREAM , payload, l)
199+ sendProcessor.onNext(requestFrame)
200+ } else if (contains(streamId)) {
201+ sendProcessor.onNext(Frame .RequestN .from(streamId, l))
202+ }
203+ }
204+ .doOnError { t ->
205+ if (contains(streamId) && ! receiver.isTerminated()) {
206+ sendProcessor.onNext(Frame .Error .from(streamId, t))
207+ }
208+ }
209+ .doOnCancel {
210+ if (contains(streamId) && ! receiver.isTerminated()) {
211+ sendProcessor.onNext(Frame .Cancel .from(streamId))
212+ }
213+ }
214+ .doFinally { removeReceiver(streamId) }
215+ })
213216 }
214217
215218 private fun handleRequestResponse (payload : Payload ): Single <Payload > {
@@ -228,8 +231,8 @@ internal class RSocketClient @JvmOverloads constructor(
228231 sendProcessor.onNext(requestFrame)
229232
230233 receiver
231- .doOnError{ t -> sendProcessor.onNext(Frame .Error .from(streamId, t)) }
232- .doOnCancel{ sendProcessor.onNext(Frame .Cancel .from(streamId)) }
234+ .doOnError { t -> sendProcessor.onNext(Frame .Error .from(streamId, t)) }
235+ .doOnCancel { sendProcessor.onNext(Frame .Cancel .from(streamId)) }
233236 .doFinally { removeReceiver(streamId) }
234237 .firstOrError()
235238 }))
@@ -302,11 +305,11 @@ internal class RSocketClient @JvmOverloads constructor(
302305 requestFrames
303306 .doOnNext { sendProcessor.onNext(it) }
304307 .subscribe(
305- {},
306- { t ->
307- errorConsumer(t)
308- receiver.onError(CancellationException (" Disposed" ))
309- })
308+ {},
309+ { t ->
310+ errorConsumer(t)
311+ receiver.onError(CancellationException (" Disposed" ))
312+ })
310313 } else {
311314 sendOneFrame(Frame .RequestN .from(streamId, l))
312315 }
@@ -330,23 +333,15 @@ internal class RSocketClient @JvmOverloads constructor(
330333 }
331334
332335 private fun cleanup () {
336+ errorSignal = CLOSED_CHANNEL_EXCEPTION
333337
334- var subscribers: Collection <Subscriber <Payload >>
335- var publishers: Collection <LimitableRequestPublisher <* >>
336- val (subs, pubs) = synchronized(this ) {
337-
338- subscribers = receivers.values
339- publishers = senders.values
338+ synchronized(this ) {
339+ receivers.values.forEach { cleanUpSubscriber(it) }
340+ senders.values.forEach { cleanUpLimitableRequestPublisher(it) }
340341
341- senders.clear()
342342 receivers.clear()
343-
344- Pair (subscribers,publishers)
343+ senders.clear()
345344 }
346-
347- subs.forEach { cleanUpSubscriber(it) }
348- pubs.forEach { cleanUpLimitableRequestPublisher(it) }
349-
350345 keepAliveSendSub?.dispose()
351346 }
352347
@@ -485,5 +480,6 @@ internal class RSocketClient @JvmOverloads constructor(
485480 private val CLOSED_CHANNEL_EXCEPTION = noStacktrace(ClosedChannelException ())
486481 private val DEFAULT_STREAM_WINDOW = 128
487482 }
483+
488484 private fun <T > UnicastProcessor<T>.isTerminated (): Boolean = hasComplete() || hasThrowable()
489485}
0 commit comments