Skip to content

Commit 0deb488

Browse files
committed
Initial implementation of requested changes
1 parent 4dd336e commit 0deb488

File tree

4 files changed

+128
-94
lines changed

4 files changed

+128
-94
lines changed

juniper_actix/examples/actix_subscriptions.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use juniper_actix::{
88
graphiql_handler as gqli_handler, graphql_handler, playground_handler as play_handler,
99
subscriptions::graphql_subscriptions as sub_handler,
1010
};
11-
use juniper_subscriptions::{Coordinator, EmptySubscriptionLifecycleHandler};
11+
use juniper_subscriptions::{Coordinator, EmptySubscriptionHandler};
1212
use std::{pin::Pin, time::Duration};
1313

1414
pub struct Query;
@@ -87,9 +87,7 @@ async fn graphql_subscriptions(
8787
req: HttpRequest,
8888
) -> Result<HttpResponse, Error> {
8989
let context = Database::new();
90-
let handler: Option<EmptySubscriptionLifecycleHandler> =
91-
Some(EmptySubscriptionLifecycleHandler {});
92-
unsafe { sub_handler(coordinator, context, stream, req, handler) }.await
90+
unsafe { sub_handler(coordinator, context, stream, req, Some(EmptySubscriptionHandler::default())) }.await
9391
}
9492

9593
#[actix_rt::main]

juniper_actix/src/lib.rs

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ pub mod subscriptions {
231231
};
232232
use futures::{Stream, StreamExt};
233233
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator};
234-
use juniper_subscriptions::{message_types::*, Coordinator, SubscriptionLifecycleHandler};
234+
use juniper_subscriptions::{message_types::*, MessageTypes, Coordinator, SubscriptionStateHandler, SubscriptionState};
235235
use serde::{Deserialize, Serialize};
236236
use std::{
237237
collections::HashMap,
@@ -243,8 +243,8 @@ pub mod subscriptions {
243243
};
244244
use tokio::time::Duration;
245245

246-
fn start<Query, Mutation, Subscription, Context, S, SubHandler, T>(
247-
actor: GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler>,
246+
fn start<Query, Mutation, Subscription, Context, S, SubHandler, T, E>(
247+
actor: GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>,
248248
req: &HttpRequest,
249249
stream: T,
250250
) -> Result<HttpResponse, Error>
@@ -259,7 +259,8 @@ pub mod subscriptions {
259259
Subscription:
260260
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
261261
Subscription::TypeInfo: Send + Sync,
262-
SubHandler: SubscriptionLifecycleHandler<Context> + 'static + std::marker::Unpin,
262+
SubHandler: SubscriptionStateHandler<Context, E> + 'static + std::marker::Unpin,
263+
E: 'static + std::error::Error + std::marker::Unpin
263264
{
264265
let mut res = handshake_with_protocols(req, &["graphql-ws"])?;
265266
Ok(res.streaming(WebsocketContext::create(actor, stream)))
@@ -273,6 +274,7 @@ pub mod subscriptions {
273274
Context,
274275
S,
275276
SubHandler,
277+
E
276278
>(
277279
coordinator: web::Data<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
278280
context: Context,
@@ -290,7 +292,8 @@ pub mod subscriptions {
290292
Subscription:
291293
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
292294
Subscription::TypeInfo: Send + Sync,
293-
SubHandler: SubscriptionLifecycleHandler<Context> + 'static + std::marker::Unpin,
295+
SubHandler: SubscriptionStateHandler<Context, E> + 'static + std::marker::Unpin,
296+
E: 'static + std::error::Error + std::marker::Unpin
294297
{
295298
start(
296299
GraphQLWSSession {
@@ -299,13 +302,14 @@ pub mod subscriptions {
299302
map_req_id_to_spawn_handle: HashMap::new(),
300303
has_started: Arc::new(AtomicBool::new(false)),
301304
handler,
305+
error_handler: std::marker::PhantomData
302306
},
303307
&req,
304308
stream,
305309
)
306310
}
307311

308-
struct GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler>
312+
struct GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
309313
where
310314
S: ScalarValue + Send + Sync + 'static,
311315
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
@@ -316,17 +320,19 @@ pub mod subscriptions {
316320
Subscription:
317321
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
318322
Subscription::TypeInfo: Send + Sync,
319-
SubHandler: SubscriptionLifecycleHandler<Context> + 'static + std::marker::Unpin,
323+
SubHandler: SubscriptionStateHandler<Context, E> + 'static + std::marker::Unpin,
324+
E: 'static + std::error::Error + std::marker::Unpin
320325
{
321326
pub map_req_id_to_spawn_handle: HashMap<String, SpawnHandle>,
322327
pub has_started: Arc<AtomicBool>,
323328
pub graphql_context: Context,
324329
pub coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
325330
pub handler: Option<SubHandler>,
331+
error_handler: std::marker::PhantomData<E>
326332
}
327333

328-
impl<Query, Mutation, Subscription, Context, S, SubHandler> Actor
329-
for GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler>
334+
impl<Query, Mutation, Subscription, Context, S, SubHandler, E> Actor
335+
for GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
330336
where
331337
S: ScalarValue + Send + Sync + 'static,
332338
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
@@ -337,16 +343,17 @@ pub mod subscriptions {
337343
Subscription:
338344
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
339345
Subscription::TypeInfo: Send + Sync,
340-
SubHandler: SubscriptionLifecycleHandler<Context> + 'static + std::marker::Unpin,
346+
SubHandler: SubscriptionStateHandler<Context, E> + 'static + std::marker::Unpin,
347+
E: 'static + std::error::Error + std::marker::Unpin
341348
{
342349
type Context = ws::WebsocketContext<
343-
GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler>,
350+
GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>,
344351
>;
345352
}
346353

347354
#[allow(dead_code)]
348-
impl<Query, Mutation, Subscription, Context, S, SubHandler>
349-
GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler>
355+
impl<Query, Mutation, Subscription, Context, S, SubHandler, E>
356+
GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
350357
where
351358
S: ScalarValue + Send + Sync + 'static,
352359
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
@@ -357,7 +364,8 @@ pub mod subscriptions {
357364
Subscription:
358365
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
359366
Subscription::TypeInfo: Send + Sync,
360-
SubHandler: SubscriptionLifecycleHandler<Context> + 'static + std::marker::Unpin,
367+
SubHandler: SubscriptionStateHandler<Context, E> + 'static + std::marker::Unpin,
368+
E: 'static + std::error::Error + std::marker::Unpin
361369
{
362370
fn gql_connection_ack() -> String {
363371
format!(r#"{{"type":"{}", "payload": null }}"#, GQL_CONNECTION_ACK)
@@ -454,9 +462,9 @@ pub mod subscriptions {
454462
}
455463
}
456464

457-
impl<Query, Mutation, Subscription, Context, S, SubHandler>
465+
impl<Query, Mutation, Subscription, Context, S, SubHandler, E>
458466
StreamHandler<Result<ws::Message, ws::ProtocolError>>
459-
for GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler>
467+
for GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
460468
where
461469
S: ScalarValue + Send + Sync + 'static,
462470
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
@@ -467,7 +475,8 @@ pub mod subscriptions {
467475
Subscription:
468476
juniper::GraphQLSubscriptionType<S, Context = Context> + Send + Sync + 'static,
469477
Subscription::TypeInfo: Send + Sync,
470-
SubHandler: SubscriptionLifecycleHandler<Context> + 'static + std::marker::Unpin,
478+
SubHandler: SubscriptionStateHandler<Context, E> + 'static + std::marker::Unpin,
479+
E: 'static + std::error::Error + std::marker::Unpin
471480
{
472481
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
473482
let msg = match msg {
@@ -482,12 +491,18 @@ pub mod subscriptions {
482491
match msg {
483492
ws::Message::Text(text) => {
484493
let m = text.trim();
485-
let request: WsPayload<S> = serde_json::from_str(m).expect("Invalid WsPayload");
486-
match request.type_name.as_str() {
487-
GQL_CONNECTION_INIT => {
494+
let request: WsPayload<S> = match serde_json::from_str(m) {
495+
Ok(payload) => payload,
496+
Err(_) => { return; }
497+
};
498+
match request.type_ {
499+
MessageTypes::GqlConnectionInit => {
488500
if let Some(handler) = &self.handler {
489-
let on_connect_result =
490-
handler.on_connect(m, &mut self.graphql_context);
501+
let state = SubscriptionState::OnConnection(
502+
Some(String::from(m)),
503+
&mut self.graphql_context
504+
);
505+
let on_connect_result = handler.handle(state);
491506
if let Err(_err) = on_connect_result {
492507
ctx.text(Self::gql_connection_error());
493508
ctx.stop();
@@ -505,8 +520,8 @@ pub mod subscriptions {
505520
ctx.text(Self::gql_connection_ka());
506521
}
507522
});
508-
}
509-
GQL_START if has_started_value => {
523+
},
524+
MessageTypes::GqlStart if has_started_value => {
510525
let coordinator = self.coordinator.clone();
511526
let mut context = self.graphql_context.clone();
512527
let payload = request.payload.expect("Could not deserialize payload");
@@ -517,7 +532,8 @@ pub mod subscriptions {
517532
payload.variables,
518533
);
519534
if let Some(handler) = &self.handler {
520-
handler.on_operation(&mut context);
535+
let state = SubscriptionState::OnOperation(&mut context);
536+
handler.handle(state).unwrap();
521537
}
522538
{
523539
use std::collections::hash_map::Entry;
@@ -537,10 +553,13 @@ pub mod subscriptions {
537553
};
538554
}
539555
}
540-
GQL_STOP if has_started_value => {
556+
MessageTypes::GqlStop if has_started_value => {
541557
let request_id = request.id.unwrap_or("1".to_owned());
542558
if let Some(handler) = &self.handler {
543-
handler.on_operation_complete(&self.graphql_context);
559+
let state = SubscriptionState::OnOperationComplete(
560+
&self.graphql_context
561+
);
562+
handler.handle(state).unwrap();
544563
}
545564
match self.map_req_id_to_spawn_handle.remove(&request_id) {
546565
Some(spawn_handle) => {
@@ -558,19 +577,21 @@ pub mod subscriptions {
558577
// ))
559578
}
560579
}
561-
}
562-
GQL_CONNECTION_TERMINATE if has_started_value => {
580+
},
581+
MessageTypes::GqlConnectionTerminate => {
563582
if let Some(handler) = &self.handler {
564-
handler.on_disconnect(&self.graphql_context);
583+
let state = SubscriptionState::OnDisconnect(&self.graphql_context);
584+
handler.handle(state).unwrap();
565585
}
566586
ctx.stop();
567-
}
587+
},
568588
_ => {}
569589
}
570590
}
571591
ws::Message::Close(_) => {
572592
if let Some(handler) = &self.handler {
573-
handler.on_disconnect(&self.graphql_context);
593+
let state = SubscriptionState::OnDisconnect(&self.graphql_context);
594+
handler.handle(state).unwrap();
574595
}
575596
ctx.stop();
576597
}
@@ -610,7 +631,7 @@ pub mod subscriptions {
610631
{
611632
id: Option<String>,
612633
#[serde(rename(deserialize = "type"))]
613-
type_name: String,
634+
type_: MessageTypes,
614635
payload: Option<GraphQLPayload<S>>,
615636
}
616637

@@ -867,7 +888,7 @@ mod tests {
867888
use actix_web_actors::ws::{Frame, Message};
868889
use futures::{SinkExt, Stream};
869890
use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode};
870-
use juniper_subscriptions::{Coordinator, EmptySubscriptionLifecycleHandler};
891+
use juniper_subscriptions::{Coordinator, EmptySubscriptionHandler};
871892
use std::{pin::Pin, time::Duration};
872893

873894
pub struct Query;
@@ -933,7 +954,7 @@ mod tests {
933954
context,
934955
stream,
935956
req,
936-
EmptySubscriptionLifecycleHandler::new(),
957+
Some(EmptySubscriptionHandler::default()),
937958
)
938959
}
939960
.await

juniper_subscriptions/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ edition = "2018"
1212
[dependencies]
1313
futures = "0.3.1"
1414
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
15+
serde = { version = "1.0.8" }
16+
serde_derive = { version = "1.0.2" }
17+
serde_json = { version="1.0.2", optional = true }
1518

1619
[dev-dependencies]
17-
serde_json = "1.0"
1820
tokio = { version = "0.2", features = ["rt-core", "macros"] }

0 commit comments

Comments
 (0)