@@ -15,6 +15,7 @@ pub use juniper_subscriptions::ws_util::{
15
15
} ;
16
16
use juniper_subscriptions:: Coordinator ;
17
17
use serde:: Serialize ;
18
+ use std:: ops:: Deref ;
18
19
use std:: {
19
20
collections:: HashMap ,
20
21
error:: Error as StdError ,
@@ -33,7 +34,7 @@ fn start<Query, Mutation, Subscription, Context, S, SubHandler, T, E>(
33
34
where
34
35
T : Stream < Item = Result < Bytes , PayloadError > > + ' static ,
35
36
S : ScalarValue + Send + Sync + ' static ,
36
- Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
37
+ Context : Send + Sync + ' static + std:: marker:: Unpin ,
37
38
Query : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
38
39
Query :: TypeInfo : Send + Sync ,
39
40
Mutation : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
@@ -65,7 +66,7 @@ pub async unsafe fn graphql_subscriptions<
65
66
) -> Result < HttpResponse , Error >
66
67
where
67
68
S : ScalarValue + Send + Sync + ' static ,
68
- Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
69
+ Context : Send + Sync + ' static + std:: marker:: Unpin ,
69
70
Query : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
70
71
Query :: TypeInfo : Send + Sync ,
71
72
Mutation : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
78
79
start (
79
80
GraphQLWSSession {
80
81
coordinator : coordinator. into_inner ( ) ,
81
- graphql_context : context,
82
+ graphql_context : Arc :: new ( context) ,
82
83
map_req_id_to_spawn_handle : HashMap :: new ( ) ,
83
84
has_started : Arc :: new ( AtomicBool :: new ( false ) ) ,
84
85
handler,
92
93
struct GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
93
94
where
94
95
S : ScalarValue + Send + Sync + ' static ,
95
- Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
96
+ Context : Send + Sync + ' static + std:: marker:: Unpin ,
96
97
Query : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
97
98
Query :: TypeInfo : Send + Sync ,
98
99
Mutation : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
@@ -104,7 +105,7 @@ where
104
105
{
105
106
pub map_req_id_to_spawn_handle : HashMap < String , SpawnHandle > ,
106
107
pub has_started : Arc < AtomicBool > ,
107
- pub graphql_context : Context ,
108
+ pub graphql_context : Arc < Context > ,
108
109
pub coordinator : Arc < Coordinator < ' static , Query , Mutation , Subscription , Context , S > > ,
109
110
pub handler : Option < SubHandler > ,
110
111
error_handler : std:: marker:: PhantomData < E > ,
@@ -114,7 +115,7 @@ impl<Query, Mutation, Subscription, Context, S, SubHandler, E> Actor
114
115
for GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
115
116
where
116
117
S : ScalarValue + Send + Sync + ' static ,
117
- Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
118
+ Context : Send + Sync + ' static + std:: marker:: Unpin ,
118
119
Query : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
119
120
Query :: TypeInfo : Send + Sync ,
120
121
Mutation : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
@@ -134,7 +135,7 @@ impl<Query, Mutation, Subscription, Context, S, SubHandler, E>
134
135
GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
135
136
where
136
137
S : ScalarValue + Send + Sync + ' static ,
137
- Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
138
+ Context : Send + Sync + ' static + std:: marker:: Unpin ,
138
139
Query : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
139
140
Query :: TypeInfo : Send + Sync ,
140
141
Mutation : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
@@ -192,7 +193,7 @@ where
192
193
result : (
193
194
GraphQLRequest < S > ,
194
195
String ,
195
- Context ,
196
+ Arc < Context > ,
196
197
Arc < Coordinator < ' static , Query , Mutation , Subscription , Context , S > > ,
197
198
) ,
198
199
actor : & mut Self ,
@@ -205,7 +206,7 @@ where
205
206
206
207
async fn handle_subscription (
207
208
req : GraphQLRequest < S > ,
208
- graphql_context : Context ,
209
+ graphql_context : Arc < Context > ,
209
210
request_id : String ,
210
211
coord : Arc < Coordinator < ' static , Query , Mutation , Subscription , Context , S > > ,
211
212
ctx : * mut ws:: WebsocketContext < Self > ,
@@ -255,7 +256,7 @@ impl<Query, Mutation, Subscription, Context, S, SubHandler, E>
255
256
for GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
256
257
where
257
258
S : ScalarValue + Send + Sync + ' static ,
258
- Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
259
+ Context : Send + Sync + ' static + std:: marker:: Unpin ,
259
260
Query : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
260
261
Query :: TypeInfo : Send + Sync ,
261
262
Mutation : juniper:: GraphQLTypeAsync < S , Context = Context > + Send + Sync + ' static ,
@@ -289,7 +290,7 @@ where
289
290
if let Some ( handler) = & self . handler {
290
291
let state = SubscriptionState :: OnConnection (
291
292
request. payload ,
292
- & mut self . graphql_context ,
293
+ Arc :: get_mut ( & mut self . graphql_context ) . unwrap ( ) ,
293
294
) ;
294
295
let on_connect_result = handler. handle ( state) ;
295
296
if let Err ( _err) = on_connect_result {
@@ -312,7 +313,7 @@ where
312
313
}
313
314
GraphQLOverWebSocketMessage :: Start if has_started_value => {
314
315
let coordinator = self . coordinator . clone ( ) ;
315
- let mut context = self . graphql_context . clone ( ) ;
316
+
316
317
let payload = request
317
318
. graphql_payload :: < S > ( )
318
319
. expect ( "Could not deserialize payload" ) ;
@@ -323,9 +324,12 @@ where
323
324
payload. variables ,
324
325
) ;
325
326
if let Some ( handler) = & self . handler {
326
- let state = SubscriptionState :: OnOperation ( & mut context) ;
327
+ let state = SubscriptionState :: OnOperation (
328
+ self . graphql_context . deref ( ) ,
329
+ ) ;
327
330
handler. handle ( state) . unwrap ( ) ;
328
331
}
332
+ let context = self . graphql_context . clone ( ) ;
329
333
{
330
334
use std:: collections:: hash_map:: Entry ;
331
335
let req_id = request_id. clone ( ) ;
@@ -347,8 +351,8 @@ where
347
351
GraphQLOverWebSocketMessage :: Stop if has_started_value => {
348
352
let request_id = request. id . unwrap_or ( "1" . to_owned ( ) ) ;
349
353
if let Some ( handler) = & self . handler {
350
- let state =
351
- SubscriptionState :: OnOperationComplete ( & self . graphql_context ) ;
354
+ let context = self . graphql_context . deref ( ) ;
355
+ let state = SubscriptionState :: OnOperationComplete ( context ) ;
352
356
handler. handle ( state) . unwrap ( ) ;
353
357
}
354
358
match self . map_req_id_to_spawn_handle . remove ( & request_id) {
@@ -366,7 +370,8 @@ where
366
370
}
367
371
GraphQLOverWebSocketMessage :: ConnectionTerminate => {
368
372
if let Some ( handler) = & self . handler {
369
- let state = SubscriptionState :: OnDisconnect ( & self . graphql_context ) ;
373
+ let context = self . graphql_context . deref ( ) ;
374
+ let state = SubscriptionState :: OnDisconnect ( context) ;
370
375
handler. handle ( state) . unwrap ( ) ;
371
376
}
372
377
ctx. stop ( ) ;
@@ -376,7 +381,8 @@ where
376
381
}
377
382
ws:: Message :: Close ( _) => {
378
383
if let Some ( handler) = & self . handler {
379
- let state = SubscriptionState :: OnDisconnect ( & self . graphql_context ) ;
384
+ let context = self . graphql_context . deref ( ) ;
385
+ let state = SubscriptionState :: OnDisconnect ( context) ;
380
386
handler. handle ( state) . unwrap ( ) ;
381
387
}
382
388
ctx. stop ( ) ;
@@ -483,6 +489,9 @@ mod tests {
483
489
String :: from(
484
490
r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"# ,
485
491
) ,
492
+ String :: from(
493
+ r#"{"id":"2","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"# ,
494
+ ) ,
486
495
String :: from( r#"{"id":"1","type":"stop"}"# ) ,
487
496
String :: from( r#"{"type":"connection_terminate"}"# ) ,
488
497
] ;
@@ -496,6 +505,9 @@ mod tests {
496
505
vec![ Some ( bytes:: Bytes :: from(
497
506
r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"# ,
498
507
) ) ] ,
508
+ vec![ Some ( bytes:: Bytes :: from(
509
+ r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}} }"# ,
510
+ ) ) ] ,
499
511
vec![ Some ( bytes:: Bytes :: from(
500
512
r#"{"type":"complete","id":"1","payload":null}"# ,
501
513
) ) ] ,
0 commit comments