@@ -231,7 +231,7 @@ pub mod subscriptions {
231
231
} ;
232
232
use futures:: { Stream , StreamExt } ;
233
233
use juniper:: { http:: GraphQLRequest , InputValue , ScalarValue , SubscriptionCoordinator } ;
234
- use juniper_subscriptions:: { message_types:: * , Coordinator , SubscriptionLifecycleHandler } ;
234
+ use juniper_subscriptions:: { message_types:: * , MessageTypes , Coordinator , SubscriptionStateHandler , SubscriptionState } ;
235
235
use serde:: { Deserialize , Serialize } ;
236
236
use std:: {
237
237
collections:: HashMap ,
@@ -243,8 +243,8 @@ pub mod subscriptions {
243
243
} ;
244
244
use tokio:: time:: Duration ;
245
245
246
- fn start < Query , Mutation , Subscription , Context , S , SubHandler , T > (
247
- actor : GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler > ,
246
+ fn start < Query , Mutation , Subscription , Context , S , SubHandler , T , E > (
247
+ actor : GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E > ,
248
248
req : & HttpRequest ,
249
249
stream : T ,
250
250
) -> Result < HttpResponse , Error >
@@ -259,7 +259,8 @@ pub mod subscriptions {
259
259
Subscription :
260
260
juniper:: GraphQLSubscriptionType < S , Context = Context > + Send + Sync + ' static ,
261
261
Subscription :: TypeInfo : Send + Sync ,
262
- SubHandler : SubscriptionLifecycleHandler < Context > + ' static + std:: marker:: Unpin ,
262
+ SubHandler : SubscriptionStateHandler < Context , E > + ' static + std:: marker:: Unpin ,
263
+ E : ' static + std:: error:: Error + std:: marker:: Unpin
263
264
{
264
265
let mut res = handshake_with_protocols ( req, & [ "graphql-ws" ] ) ?;
265
266
Ok ( res. streaming ( WebsocketContext :: create ( actor, stream) ) )
@@ -273,6 +274,7 @@ pub mod subscriptions {
273
274
Context ,
274
275
S ,
275
276
SubHandler ,
277
+ E
276
278
> (
277
279
coordinator : web:: Data < Coordinator < ' static , Query , Mutation , Subscription , Context , S > > ,
278
280
context : Context ,
@@ -290,7 +292,8 @@ pub mod subscriptions {
290
292
Subscription :
291
293
juniper:: GraphQLSubscriptionType < S , Context = Context > + Send + Sync + ' static ,
292
294
Subscription :: TypeInfo : Send + Sync ,
293
- SubHandler : SubscriptionLifecycleHandler < Context > + ' static + std:: marker:: Unpin ,
295
+ SubHandler : SubscriptionStateHandler < Context , E > + ' static + std:: marker:: Unpin ,
296
+ E : ' static + std:: error:: Error + std:: marker:: Unpin
294
297
{
295
298
start (
296
299
GraphQLWSSession {
@@ -299,13 +302,14 @@ pub mod subscriptions {
299
302
map_req_id_to_spawn_handle : HashMap :: new ( ) ,
300
303
has_started : Arc :: new ( AtomicBool :: new ( false ) ) ,
301
304
handler,
305
+ error_handler : std:: marker:: PhantomData
302
306
} ,
303
307
& req,
304
308
stream,
305
309
)
306
310
}
307
311
308
- struct GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler >
312
+ struct GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
309
313
where
310
314
S : ScalarValue + Send + Sync + ' static ,
311
315
Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
@@ -316,17 +320,19 @@ pub mod subscriptions {
316
320
Subscription :
317
321
juniper:: GraphQLSubscriptionType < S , Context = Context > + Send + Sync + ' static ,
318
322
Subscription :: TypeInfo : Send + Sync ,
319
- SubHandler : SubscriptionLifecycleHandler < Context > + ' static + std:: marker:: Unpin ,
323
+ SubHandler : SubscriptionStateHandler < Context , E > + ' static + std:: marker:: Unpin ,
324
+ E : ' static + std:: error:: Error + std:: marker:: Unpin
320
325
{
321
326
pub map_req_id_to_spawn_handle : HashMap < String , SpawnHandle > ,
322
327
pub has_started : Arc < AtomicBool > ,
323
328
pub graphql_context : Context ,
324
329
pub coordinator : Arc < Coordinator < ' static , Query , Mutation , Subscription , Context , S > > ,
325
330
pub handler : Option < SubHandler > ,
331
+ error_handler : std:: marker:: PhantomData < E >
326
332
}
327
333
328
- impl < Query , Mutation , Subscription , Context , S , SubHandler > Actor
329
- for GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler >
334
+ impl < Query , Mutation , Subscription , Context , S , SubHandler , E > Actor
335
+ for GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
330
336
where
331
337
S : ScalarValue + Send + Sync + ' static ,
332
338
Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
@@ -337,16 +343,17 @@ pub mod subscriptions {
337
343
Subscription :
338
344
juniper:: GraphQLSubscriptionType < S , Context = Context > + Send + Sync + ' static ,
339
345
Subscription :: TypeInfo : Send + Sync ,
340
- SubHandler : SubscriptionLifecycleHandler < Context > + ' static + std:: marker:: Unpin ,
346
+ SubHandler : SubscriptionStateHandler < Context , E > + ' static + std:: marker:: Unpin ,
347
+ E : ' static + std:: error:: Error + std:: marker:: Unpin
341
348
{
342
349
type Context = ws:: WebsocketContext <
343
- GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler > ,
350
+ GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E > ,
344
351
> ;
345
352
}
346
353
347
354
#[ allow( dead_code) ]
348
- impl < Query , Mutation , Subscription , Context , S , SubHandler >
349
- GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler >
355
+ impl < Query , Mutation , Subscription , Context , S , SubHandler , E >
356
+ GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
350
357
where
351
358
S : ScalarValue + Send + Sync + ' static ,
352
359
Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
@@ -357,7 +364,8 @@ pub mod subscriptions {
357
364
Subscription :
358
365
juniper:: GraphQLSubscriptionType < S , Context = Context > + Send + Sync + ' static ,
359
366
Subscription :: TypeInfo : Send + Sync ,
360
- SubHandler : SubscriptionLifecycleHandler < Context > + ' static + std:: marker:: Unpin ,
367
+ SubHandler : SubscriptionStateHandler < Context , E > + ' static + std:: marker:: Unpin ,
368
+ E : ' static + std:: error:: Error + std:: marker:: Unpin
361
369
{
362
370
fn gql_connection_ack ( ) -> String {
363
371
format ! ( r#"{{"type":"{}", "payload": null }}"# , GQL_CONNECTION_ACK )
@@ -454,9 +462,9 @@ pub mod subscriptions {
454
462
}
455
463
}
456
464
457
- impl < Query , Mutation , Subscription , Context , S , SubHandler >
465
+ impl < Query , Mutation , Subscription , Context , S , SubHandler , E >
458
466
StreamHandler < Result < ws:: Message , ws:: ProtocolError > >
459
- for GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler >
467
+ for GraphQLWSSession < Query , Mutation , Subscription , Context , S , SubHandler , E >
460
468
where
461
469
S : ScalarValue + Send + Sync + ' static ,
462
470
Context : Clone + Send + Sync + ' static + std:: marker:: Unpin ,
@@ -467,7 +475,8 @@ pub mod subscriptions {
467
475
Subscription :
468
476
juniper:: GraphQLSubscriptionType < S , Context = Context > + Send + Sync + ' static ,
469
477
Subscription :: TypeInfo : Send + Sync ,
470
- SubHandler : SubscriptionLifecycleHandler < Context > + ' static + std:: marker:: Unpin ,
478
+ SubHandler : SubscriptionStateHandler < Context , E > + ' static + std:: marker:: Unpin ,
479
+ E : ' static + std:: error:: Error + std:: marker:: Unpin
471
480
{
472
481
fn handle ( & mut self , msg : Result < ws:: Message , ws:: ProtocolError > , ctx : & mut Self :: Context ) {
473
482
let msg = match msg {
@@ -482,12 +491,18 @@ pub mod subscriptions {
482
491
match msg {
483
492
ws:: Message :: Text ( text) => {
484
493
let m = text. trim ( ) ;
485
- let request: WsPayload < S > = serde_json:: from_str ( m) . expect ( "Invalid WsPayload" ) ;
486
- match request. type_name . as_str ( ) {
487
- GQL_CONNECTION_INIT => {
494
+ let request: WsPayload < S > = match serde_json:: from_str ( m) {
495
+ Ok ( payload) => payload,
496
+ Err ( _) => { return ; }
497
+ } ;
498
+ match request. type_ {
499
+ MessageTypes :: GqlConnectionInit => {
488
500
if let Some ( handler) = & self . handler {
489
- let on_connect_result =
490
- handler. on_connect ( m, & mut self . graphql_context ) ;
501
+ let state = SubscriptionState :: OnConnection (
502
+ Some ( String :: from ( m) ) ,
503
+ & mut self . graphql_context
504
+ ) ;
505
+ let on_connect_result = handler. handle ( state) ;
491
506
if let Err ( _err) = on_connect_result {
492
507
ctx. text ( Self :: gql_connection_error ( ) ) ;
493
508
ctx. stop ( ) ;
@@ -505,8 +520,8 @@ pub mod subscriptions {
505
520
ctx. text ( Self :: gql_connection_ka ( ) ) ;
506
521
}
507
522
} ) ;
508
- }
509
- GQL_START if has_started_value => {
523
+ } ,
524
+ MessageTypes :: GqlStart if has_started_value => {
510
525
let coordinator = self . coordinator . clone ( ) ;
511
526
let mut context = self . graphql_context . clone ( ) ;
512
527
let payload = request. payload . expect ( "Could not deserialize payload" ) ;
@@ -517,7 +532,8 @@ pub mod subscriptions {
517
532
payload. variables ,
518
533
) ;
519
534
if let Some ( handler) = & self . handler {
520
- handler. on_operation ( & mut context) ;
535
+ let state = SubscriptionState :: OnOperation ( & mut context) ;
536
+ handler. handle ( state) . unwrap ( ) ;
521
537
}
522
538
{
523
539
use std:: collections:: hash_map:: Entry ;
@@ -537,10 +553,13 @@ pub mod subscriptions {
537
553
} ;
538
554
}
539
555
}
540
- GQL_STOP if has_started_value => {
556
+ MessageTypes :: GqlStop if has_started_value => {
541
557
let request_id = request. id . unwrap_or ( "1" . to_owned ( ) ) ;
542
558
if let Some ( handler) = & self . handler {
543
- handler. on_operation_complete ( & self . graphql_context ) ;
559
+ let state = SubscriptionState :: OnOperationComplete (
560
+ & self . graphql_context
561
+ ) ;
562
+ handler. handle ( state) . unwrap ( ) ;
544
563
}
545
564
match self . map_req_id_to_spawn_handle . remove ( & request_id) {
546
565
Some ( spawn_handle) => {
@@ -558,19 +577,21 @@ pub mod subscriptions {
558
577
// ))
559
578
}
560
579
}
561
- }
562
- GQL_CONNECTION_TERMINATE if has_started_value => {
580
+ } ,
581
+ MessageTypes :: GqlConnectionTerminate => {
563
582
if let Some ( handler) = & self . handler {
564
- handler. on_disconnect ( & self . graphql_context ) ;
583
+ let state = SubscriptionState :: OnDisconnect ( & self . graphql_context ) ;
584
+ handler. handle ( state) . unwrap ( ) ;
565
585
}
566
586
ctx. stop ( ) ;
567
- }
587
+ } ,
568
588
_ => { }
569
589
}
570
590
}
571
591
ws:: Message :: Close ( _) => {
572
592
if let Some ( handler) = & self . handler {
573
- handler. on_disconnect ( & self . graphql_context ) ;
593
+ let state = SubscriptionState :: OnDisconnect ( & self . graphql_context ) ;
594
+ handler. handle ( state) . unwrap ( ) ;
574
595
}
575
596
ctx. stop ( ) ;
576
597
}
@@ -610,7 +631,7 @@ pub mod subscriptions {
610
631
{
611
632
id : Option < String > ,
612
633
#[ serde( rename( deserialize = "type" ) ) ]
613
- type_name : String ,
634
+ type_ : MessageTypes ,
614
635
payload : Option < GraphQLPayload < S > > ,
615
636
}
616
637
@@ -867,7 +888,7 @@ mod tests {
867
888
use actix_web_actors:: ws:: { Frame , Message } ;
868
889
use futures:: { SinkExt , Stream } ;
869
890
use juniper:: { DefaultScalarValue , EmptyMutation , FieldError , RootNode } ;
870
- use juniper_subscriptions:: { Coordinator , EmptySubscriptionLifecycleHandler } ;
891
+ use juniper_subscriptions:: { Coordinator , EmptySubscriptionHandler } ;
871
892
use std:: { pin:: Pin , time:: Duration } ;
872
893
873
894
pub struct Query ;
@@ -933,7 +954,7 @@ mod tests {
933
954
context,
934
955
stream,
935
956
req,
936
- EmptySubscriptionLifecycleHandler :: new ( ) ,
957
+ Some ( EmptySubscriptionHandler :: default ( ) ) ,
937
958
)
938
959
}
939
960
. await
0 commit comments