diff --git a/juniper_actix/Cargo.toml b/juniper_actix/Cargo.toml index 1285cc323..ac3355b55 100644 --- a/juniper_actix/Cargo.toml +++ b/juniper_actix/Cargo.toml @@ -8,6 +8,8 @@ documentation = "https://docs.rs/juniper_actix" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +subscriptions = ["juniper_subscriptions"] [dependencies] actix = "0.9.0" @@ -16,6 +18,7 @@ actix-web = { version = "2.0.0", features = ["rustls"] } actix-web-actors = "2.0.0" futures = { version = "0.3.1", features = ["compat"] } juniper = { version = "0.14.2", path = "../juniper", default-features = false } +juniper_subscriptions = { path = "../juniper_subscriptions", optional = true, features = ["ws-subscriptions"]} tokio = { version = "0.2", features = ["time"] } serde = { version = "1.0.75", features = ["derive"] } serde_json = "1.0.24" @@ -30,3 +33,7 @@ tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] } actix-cors = "0.2.0" actix-identity = "0.2.0" bytes = "0.5.4" + +[[example]] +name="actix_subscriptions" +required-features=["subscriptions"] \ No newline at end of file diff --git a/juniper_actix/examples/actix_subscriptions.rs b/juniper_actix/examples/actix_subscriptions.rs new file mode 100644 index 000000000..08d35004f --- /dev/null +++ b/juniper_actix/examples/actix_subscriptions.rs @@ -0,0 +1,223 @@ +#![deny(warnings)] + +use actix_cors::Cors; +use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer}; +use futures::Stream; +use juniper::{DefaultScalarValue, FieldError, RootNode}; +use juniper_actix::{ + graphiql_handler as gqli_handler, graphql_handler, playground_handler as play_handler, + subscriptions::{ + graphql_subscriptions as sub_handler, GraphQLWSSession, SubscriptionState, + SubscriptionStateHandler, + }, +}; +use juniper_subscriptions::Coordinator; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; +use std::{pin::Pin, time::Duration}; +use tokio::sync::broadcast::{channel, Receiver, Sender}; + +type Schema = RootNode<'static, Query, Mutation, Subscription>; +type MyCoordinator = + Coordinator<'static, Query, Mutation, Subscription, Context, DefaultScalarValue>; + +struct ChatRoom { + pub name: String, + pub channel: (Sender, Receiver), + pub msgs: Vec, +} + +impl ChatRoom { + pub fn new(name: String) -> Self { + Self { + name, + channel: channel(16), + msgs: Vec::new(), + } + } +} + +struct Context { + pub chat_rooms: Arc>>, +} + +impl Context { + pub fn new(chat_rooms: Arc>>) -> Self { + Self { chat_rooms } + } +} + +impl juniper::Context for Context {} + +struct Query; + +#[juniper::graphql_object(Context = Context)] +impl Query { + pub fn chat_rooms(ctx: &Context) -> Vec { + ctx.chat_rooms + .lock() + .unwrap() + .iter() + .map(|(_, chat_room)| chat_room.name.clone()) + .collect() + } + + pub fn msgs_from_room(room_name: String, ctx: &Context) -> Option> { + ctx.chat_rooms + .lock() + .unwrap() + .get(&room_name) + .map(|chat_room| chat_room.msgs.clone()) + } +} + +struct Mutation; + +#[juniper::graphql_object(Context = Context)] +impl Mutation { + pub fn send_message(room_name: String, msg: String, sender: String, ctx: &Context) -> bool { + ctx.chat_rooms + .lock() + .unwrap() + .get_mut(&room_name) + .map(|chat_room| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::new(0, 0)); + let msg = Msg { + sender, + value: msg, + date: format!("{}", now.as_secs()), + }; + chat_room.msgs.push(msg.clone()); + chat_room.channel.0.send(msg).is_ok() + }) + .is_some() + } +} + +#[derive(juniper::GraphQLObject, Clone)] +struct Msg { + pub sender: String, + pub value: String, + pub date: String, +} + +type StringStream = Pin> + Send>>; + +type MsgStream = Pin> + Send>>; + +struct Subscription; + +#[juniper::graphql_subscription(Context = Context)] +impl Subscription { + async fn hello_world() -> StringStream { + let mut counter = 0; + let stream = tokio::time::interval(Duration::from_secs(1)).map(move |_| { + counter += 1; + if counter % 2 == 0 { + Ok(String::from("World!")) + } else { + Ok(String::from("Hello")) + } + }); + + Box::pin(stream) + } + + async fn new_messages(room_name: String, ctx: &Context) -> MsgStream { + let channel_rx = { + match ctx.chat_rooms.lock().unwrap().entry(room_name.clone()) { + Entry::Occupied(o) => o.get().channel.0.subscribe(), + Entry::Vacant(v) => v.insert(ChatRoom::new(room_name)).channel.0.subscribe(), + } + }; + let stream = channel_rx.map(|msg| Ok(msg?)); + Box::pin(stream) + } +} + +fn schema() -> Schema { + Schema::new(Query {}, Mutation {}, Subscription {}) +} + +async fn graphiql_handler() -> Result { + gqli_handler("/", Some("ws://localhost:8080/subscriptions")).await +} +async fn playground_handler() -> Result { + play_handler("/", Some("/subscriptions")).await +} + +async fn graphql( + req: actix_web::HttpRequest, + payload: actix_web::web::Payload, + schema: web::Data, + chat_rooms: web::Data>>, +) -> Result { + let context = Context::new(chat_rooms.into_inner()); + graphql_handler(&schema, &context, req, payload).await +} + +#[derive(Default)] +struct HandlerExample {} + +impl SubscriptionStateHandler for HandlerExample +where + Context: Send + Sync, +{ + fn handle(&self, state: SubscriptionState) -> Result<(), Box> { + match state { + SubscriptionState::OnConnection(_, _) => println!("OnConnection"), + SubscriptionState::OnOperation(_) => println!("OnOperation"), + SubscriptionState::OnOperationComplete(_) => println!("OnOperationComplete"), + SubscriptionState::OnDisconnect(_) => println!("OnDisconnect"), + }; + Ok(()) + } +} + +async fn graphql_subscriptions( + coordinator: web::Data, + stream: web::Payload, + req: HttpRequest, + chat_rooms: web::Data>>, +) -> Result { + let context = Context::new(chat_rooms.into_inner()); + let actor = GraphQLWSSession::new(coordinator.into_inner(), context); + let actor = actor.with_handler(HandlerExample::default()); + sub_handler(actor, stream, req).await +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + env_logger::init(); + let chat_rooms: Mutex> = Mutex::new(HashMap::new()); + let chat_rooms = web::Data::new(chat_rooms); + let server = HttpServer::new(move || { + App::new() + .app_data(chat_rooms.clone()) + .data(schema()) + .data(juniper_subscriptions::Coordinator::new(schema())) + .wrap(middleware::Compress::default()) + .wrap(middleware::Logger::default()) + .wrap( + Cors::new() + .allowed_methods(vec!["POST", "GET"]) + .supports_credentials() + .max_age(3600) + .finish(), + ) + .service( + web::resource("/") + .route(web::post().to(graphql)) + .route(web::get().to(graphql)), + ) + .service(web::resource("/playground").route(web::get().to(playground_handler))) + .service(web::resource("/graphiql").route(web::get().to(graphiql_handler))) + .service(web::resource("/subscriptions").to(graphql_subscriptions)) + }); + server.bind("127.0.0.1:8080").unwrap().run().await +} diff --git a/juniper_actix/src/lib.rs b/juniper_actix/src/lib.rs index e85f536d0..2ed11a239 100644 --- a/juniper_actix/src/lib.rs +++ b/juniper_actix/src/lib.rs @@ -55,8 +55,17 @@ use juniper::{ }; use serde::Deserialize; -#[derive(Deserialize, Clone, PartialEq, Debug)] +/// this is the `juniper_actix` subscriptions handler implementation +/// does not fully support the GraphQL over WS[1] specification. +/// +/// *Note: this implementation is in an pre-alpha state.* +/// +/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +#[cfg(feature = "subscriptions")] +pub mod subscriptions; + #[serde(deny_unknown_fields)] +#[derive(Deserialize, Clone, PartialEq, Debug)] struct GetGraphQLRequest { query: String, #[serde(rename = "operationName")] @@ -181,7 +190,8 @@ where Ok(response.content_type("application/json").body(gql_response)) } -/// Create a handler that replies with an HTML page containing GraphiQL. This does not handle routing, so you can mount it on any endpoint +/// Create a handler that replies with an HTML page containing GraphiQL. This does not handle +/// routing, so you can mount it on any endpoint /// /// For example: /// @@ -193,7 +203,8 @@ where /// # use actix_web::{web, App}; /// /// let app = App::new() -/// .route("/", web::get().to(|| graphiql_handler("/graphql", Some("/graphql/subscriptions")))); +/// .route("/", web::get().to(|| +/// graphiql_handler("/graphql", Some("/graphql/subscriptions")))); /// ``` #[allow(dead_code)] pub async fn graphiql_handler( @@ -206,7 +217,8 @@ pub async fn graphiql_handler( .body(html)) } -/// Create a handler that replies with an HTML page containing GraphQL Playground. This does not handle routing, so you cant mount it on any endpoint. +/// Create a handler that replies with an HTML page containing GraphQL Playground. This does not +/// handle routing, so you cant mount it on any endpoint. pub async fn playground_handler( graphql_endpoint_url: &str, subscriptions_endpoint_url: Option<&'static str>, diff --git a/juniper_actix/src/subscriptions.rs b/juniper_actix/src/subscriptions.rs new file mode 100644 index 000000000..e397aea55 --- /dev/null +++ b/juniper_actix/src/subscriptions.rs @@ -0,0 +1,539 @@ +use actix::{ + Actor, ActorContext, ActorFuture, AsyncContext, Handler, Message, Recipient, SpawnHandle, + StreamHandler, WrapFuture, +}; +use actix_web::{web, Error, HttpRequest, HttpResponse}; +use actix_web_actors::{ + ws, + ws::{handshake_with_protocols, WebsocketContext}, +}; +use futures::StreamExt; +use juniper::{http::GraphQLRequest, ScalarValue, SubscriptionCoordinator}; +use juniper_subscriptions::ws_util::GraphQLOverWebSocketMessage; +pub use juniper_subscriptions::ws_util::{ + GraphQLPayload, SubscriptionState, SubscriptionStateHandler, WsPayload, +}; +use juniper_subscriptions::Coordinator; +use serde::Serialize; +use std::ops::Deref; +use std::{ + collections::HashMap, + error::Error as StdError, + sync::Arc, + time::{Duration, Instant}, +}; + +/// Websocket Subscription Handler +/// +/// # Arguments +/// * `coordinator` - The Subscription Coordinator stored in the App State +/// * `context` - The Context that will be used by the Coordinator +/// * `stream` - The Stream used by the request to create the WebSocket +/// * `req` - The Initial Request sent by the Client +/// * `handler` - The SubscriptionStateHandler implementation that will be used in the Subscription. +/// * `ka_interval` - The Duration that will be used to interleave the keep alive messages sent by the server. The default value is 10 seconds. +pub async fn graphql_subscriptions( + actor: GraphQLWSSession, + stream: web::Payload, + req: HttpRequest, +) -> Result +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, +{ + let mut res = handshake_with_protocols(&req, &["graphql-ws"])?; + Ok(res.streaming(WebsocketContext::create(actor, stream))) +} + +/// Actor for handling each WS Session +pub struct GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, +{ + map_req_id_to_spawn_handle: HashMap, + graphql_context: Arc, + coordinator: Arc>, + handler: Option + 'static + std::marker::Unpin>>, + hb: Instant, +} + +impl Actor + for GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, +{ + type Context = + ws::WebsocketContext>; + + fn started(&mut self, ctx: &mut Self::Context) { + self.hb(ctx); + } +} + +/// Internal Struct for handling Messages received from the subscriptions +#[derive(Message)] +#[rtype(result = "()")] +struct Msg(pub Option); + +impl Handler + for GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, +{ + type Result = (); + fn handle(&mut self, msg: Msg, ctx: &mut Self::Context) { + match msg.0 { + Some(msg) => ctx.text(msg), + None => ctx.close(None), + } + } +} + +/// How often heartbeat pings are sent +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); +/// How long before lack of client response causes a timeout +const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); + +#[allow(dead_code)] +impl + GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, +{ + /// Creates a instance for usage in the graphql_subscription endpoint + pub fn new( + coord: Arc>, + ctx: Context, + ) -> Self { + Self { + coordinator: coord, + graphql_context: Arc::new(ctx), + map_req_id_to_spawn_handle: HashMap::new(), + handler: None, + hb: Instant::now(), + } + } + + /// Inserts a SubscriptionStateHandler in the Session + pub fn with_handler(self, handler: H) -> Self + where + H: SubscriptionStateHandler + 'static + std::marker::Unpin, + { + Self { + handler: Some(Box::new(handler)), + ..self + } + } + + fn gql_connection_ack() -> String { + let value = serde_json::json!({ "type": GraphQLOverWebSocketMessage::ConnectionAck }); + serde_json::to_string(&value).unwrap() + } + + fn gql_connection_error() -> String { + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::ConnectionError, + }); + serde_json::to_string(&value).unwrap() + } + fn gql_error(request_id: &String, err: T) -> String { + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::Error, + "id": request_id, + "payload": err + }); + serde_json::to_string(&value).unwrap() + } + + fn gql_data(request_id: &String, payload: T) -> String { + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::Data, + "id": request_id, + "payload": payload + }); + serde_json::to_string(&value).unwrap() + } + + fn gql_complete(request_id: &String) -> String { + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::Complete, + "id": request_id, + }); + serde_json::to_string(&value).unwrap() + } + + fn starting_subscription( + result: ( + GraphQLRequest, + String, + Arc, + Arc>, + ), + actor: &mut Self, + ctx: &mut ws::WebsocketContext, + ) -> actix::fut::FutureWrap, Self> { + let (req, req_id, gql_context, coord) = result; + let addr = ctx.address(); + Self::handle_subscription(req, gql_context, req_id, coord, addr.recipient()) + .into_actor(actor) + } + + async fn handle_subscription( + req: GraphQLRequest, + graphql_context: Arc, + request_id: String, + coord: Arc>, + addr: Recipient, + ) { + let mut values_stream = { + let subscribe_result = coord.subscribe(&req, &graphql_context).await; + match subscribe_result { + Ok(s) => s, + Err(err) => { + let _ = addr.do_send(Msg(Some(Self::gql_error(&request_id, err)))); + let _ = addr.do_send(Msg(Some(Self::gql_complete(&request_id)))); + let _ = addr.do_send(Msg(None)); + return; + } + } + }; + + while let Some(response) = values_stream.next().await { + let request_id = request_id.clone(); + let _ = addr.do_send(Msg(Some(Self::gql_data(&request_id, response)))); + } + let _ = addr.do_send(Msg(Some(Self::gql_complete(&request_id)))); + } + + fn hb(&self, ctx: &mut ::Context) { + ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { + if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { + ctx.stop(); + return; + } + ctx.ping(b""); + }); + } +} + +impl + StreamHandler> + for GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, +{ + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + let msg = match msg { + Err(_) => { + ctx.stop(); + return; + } + Ok(msg) => msg, + }; + match msg { + ws::Message::Ping(msg) => { + self.hb = Instant::now(); + ctx.pong(&msg); + } + ws::Message::Pong(_) => { + self.hb = Instant::now(); + } + ws::Message::Text(text) => { + let m = text.trim(); + let request: WsPayload = match serde_json::from_str(m) { + Ok(payload) => payload, + Err(_) => { + return; + } + }; + match request.type_name { + GraphQLOverWebSocketMessage::ConnectionInit => { + if let Some(handler) = &self.handler { + let state = SubscriptionState::OnConnection( + request.payload, + Arc::get_mut(&mut self.graphql_context).unwrap(), + ); + let on_connect_result = handler.handle(state); + if let Err(_) = on_connect_result { + ctx.text(Self::gql_connection_error()); + ctx.stop(); + return; + } + } + ctx.text(Self::gql_connection_ack()); + } + GraphQLOverWebSocketMessage::Start => { + let coordinator = self.coordinator.clone(); + + let payload = request + .graphql_payload::() + .expect("Could not deserialize payload"); + let request_id = request.id.unwrap_or("1".to_owned()); + let graphql_request = GraphQLRequest::<_>::new( + payload.query.expect("Could not deserialize query"), + None, + payload.variables, + ); + if let Some(handler) = &self.handler { + let state = + SubscriptionState::OnOperation(self.graphql_context.deref()); + handler.as_ref().handle(state).unwrap(); + } + let context = self.graphql_context.clone(); + { + use std::collections::hash_map::Entry; + let req_id = request_id.clone(); + let future = + async move { (graphql_request, req_id, context, coordinator) } + .into_actor(self) + .then(Self::starting_subscription); + match self.map_req_id_to_spawn_handle.entry(request_id) { + // Since there is another request being handle + // this just ignores the start of another request with this same + // request_id + Entry::Occupied(_o) => (), + Entry::Vacant(v) => { + v.insert(ctx.spawn(future)); + } + }; + } + } + GraphQLOverWebSocketMessage::Stop => { + let request_id = request.id.unwrap_or("1".to_owned()); + if let Some(handler) = &self.handler { + let context = self.graphql_context.deref(); + let state = SubscriptionState::OnOperationComplete(context); + handler.handle(state).unwrap(); + } + match self.map_req_id_to_spawn_handle.remove(&request_id) { + Some(spawn_handle) => { + ctx.cancel_future(spawn_handle); + ctx.text(Self::gql_complete(&request_id)); + } + None => { + // No request with this id was found in progress. + // since the Subscription Protocol Spec does not specify + // what occurs in this case im just considering the possibility + // of send a error. + } + } + } + GraphQLOverWebSocketMessage::ConnectionTerminate => { + if let Some(handler) = &self.handler { + let context = self.graphql_context.deref(); + let state = SubscriptionState::OnDisconnect(context); + handler.handle(state).unwrap(); + } + ctx.stop(); + } + _ => {} + } + } + ws::Message::Binary(msg) => { + if let Ok(msg) = std::str::from_utf8(msg.as_ref()) { + StreamHandler::handle(self, Ok(ws::Message::Text(String::from(msg))), ctx); + } + }, + ws::Message::Close(_) | ws::Message::Continuation(_) => { + if let Some(handler) = &self.handler { + let context = self.graphql_context.deref(); + let state = SubscriptionState::OnDisconnect(context); + handler.handle(state).unwrap(); + } + ctx.stop(); + } + ws::Message::Nop => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::HttpRequest; + use actix_web::{test, App}; + use actix_web_actors::ws::{Frame, Message}; + use futures::StreamExt; + use futures::{SinkExt, Stream}; + use juniper::{ + tests::model::Database, tests::schema::Query, DefaultScalarValue, EmptyMutation, + FieldError, RootNode, + }; + use juniper_subscriptions::Coordinator; + use std::{pin::Pin, time::Duration}; + type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; + type StringStream = Pin> + Send>>; + type MyCoordinator = Coordinator< + 'static, + Query, + EmptyMutation, + Subscription, + Database, + DefaultScalarValue, + >; + struct Subscription; + + #[juniper::graphql_subscription(Context = Database)] + impl Subscription { + async fn hello_world() -> StringStream { + let mut counter = 0; + let stream = tokio::time::interval(Duration::from_secs(2)).map(move |_| { + counter += 1; + if counter % 2 == 0 { + Ok(String::from("World!")) + } else { + Ok(String::from("Hello")) + } + }); + Box::pin(stream) + } + } + + async fn gql_subscriptions( + coordinator: web::Data, + stream: web::Payload, + req: HttpRequest, + ) -> Result { + let context = Database::new(); + let actor = GraphQLWSSession::new(coordinator.into_inner(), context); + graphql_subscriptions(actor, stream, req).await + } + + fn test_server() -> test::TestServer { + let schema: Schema = + RootNode::new(Query, EmptyMutation::::new(), Subscription {}); + + let coord = web::Data::new(juniper_subscriptions::Coordinator::new(schema)); + test::start(move || { + App::new() + .app_data(coord.clone()) + .service(web::resource("/subscriptions").to(gql_subscriptions)) + }) + } + + fn received_msg(msg: &'static str) -> Option { + Some(bytes::Bytes::from(msg)) + } + + async fn test_subscription( + msgs_to_send: Vec<&str>, + msgs_to_receive: Vec>>, + ) { + let mut app = test_server(); + let mut ws = app.ws_at("/subscriptions").await.unwrap(); + for (index, msg_to_be_sent) in msgs_to_send.into_iter().enumerate() { + let expected_msgs = msgs_to_receive.get(index).unwrap(); + ws.send(Message::Text(msg_to_be_sent.to_string())) + .await + .unwrap(); + for expected_msg in expected_msgs { + let (item, ws_stream) = ws.into_future().await; + ws = ws_stream; + match expected_msg { + Some(expected_msg) => { + if let Some(Ok(Frame::Text(msg))) = item { + assert_eq!(msg, expected_msg); + } else { + assert!(false); + } + } + None => assert_eq!(item.is_none(), expected_msg.is_none()), + } + } + } + } + + #[actix_rt::test] + async fn basic_connection() { + let msgs_to_send = vec![ + r#"{"type":"connection_init","payload":{}}"#, + r#"{"type":"connection_terminate"}"#, + ]; + let msgs_to_receive = vec![ + vec![received_msg(r#"{"type":"connection_ack"}"#)], + vec![None], + ]; + test_subscription(msgs_to_send, msgs_to_receive).await; + } + + #[actix_rt::test] + async fn basic_subscription() { + let msgs_to_send = vec![ + r#"{"type":"connection_init","payload":{}}"#, + r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + r#"{"type":"connection_terminate"}"#, + ]; + let msgs_to_receive = vec![ + vec![received_msg(r#"{"type":"connection_ack"}"#)], + vec![received_msg( + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}}}"#, + )], + vec![None], + ]; + test_subscription(msgs_to_send, msgs_to_receive).await; + } + + #[actix_rt::test] + async fn conn_with_two_subscriptions() { + let msgs_to_send = vec![ + r#"{"type":"connection_init","payload":{}}"#, + r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + r#"{"id":"2","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + r#"{"id":"1","type":"stop"}"#, + r#"{"type":"connection_terminate"}"#, + ]; + let msgs_to_receive = vec![ + vec![received_msg(r#"{"type":"connection_ack"}"#)], + vec![received_msg( + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}}}"#, + )], + vec![received_msg( + r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}}}"#, + )], + vec![received_msg(r#"{"type":"complete","id":"1"}"#)], + vec![None], + ]; + test_subscription(msgs_to_send, msgs_to_receive).await; + } +} diff --git a/juniper_subscriptions/Cargo.toml b/juniper_subscriptions/Cargo.toml index cff0f1615..4f9633985 100644 --- a/juniper_subscriptions/Cargo.toml +++ b/juniper_subscriptions/Cargo.toml @@ -8,10 +8,14 @@ documentation = "https://docs.rs/juniper_subscriptions" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +ws-subscriptions = ["serde", "serde_json"] [dependencies] -futures = "0.3.1" +futures = "0.3" juniper = { version = "0.14.2", path = "../juniper", default-features = false } +serde = { version = "1.0", optional = true, features = ["derive"] } +serde_json = { version = "1.0", optional = true } [dev-dependencies] serde_json = "1.0" diff --git a/juniper_subscriptions/src/lib.rs b/juniper_subscriptions/src/lib.rs index ff6666e13..7b49b12be 100644 --- a/juniper_subscriptions/src/lib.rs +++ b/juniper_subscriptions/src/lib.rs @@ -19,6 +19,24 @@ use juniper::{ BoxFuture, ExecutionError, GraphQLError, GraphQLSubscriptionType, GraphQLTypeAsync, Object, ScalarValue, SubscriptionConnection, SubscriptionCoordinator, Value, ValuesStream, }; +/// Utilities for the implementation of subscriptions over WebSocket +/// +/// This module provides some utilities for the implementation of Subscriptions over +/// WebSocket, such as the [`GraphQLOverWebSocketMessage`] that contains the messages that +/// could be sent by the server or client and a [`SubscriptionStateHandler`] trait that allows +/// the user of the integration to handle some Subscription Life Cycle Events, its based on: +/// +/// - [Subscriptions Transport over WS][SubscriptionsWsProtocol] +/// - [GraphQL Subscriptions LifeCycle Events][GraphQLSubscriptionsLifeCycle] +/// +/// [SubscriptionsWsProtocol]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +/// [GraphQLSubscriptionsLifeCycle]: https://www.apollographql.com/docs/graphql-subscriptions/lifecycle-events/ +/// [`GraphQLOverWebSocketMessage`]: GraphQLOverWebSocketMessage +/// [`SubscriptionStateHandler`]: SubscriptionStateHandler +#[cfg(feature = "ws-subscriptions")] +pub mod ws_util; +#[cfg(feature = "ws-subscriptions")] +pub use ws_util::*; /// Simple [`SubscriptionCoordinator`] implementation: /// - contains the schema diff --git a/juniper_subscriptions/src/ws_util.rs b/juniper_subscriptions/src/ws_util.rs new file mode 100644 index 000000000..df94fb03a --- /dev/null +++ b/juniper_subscriptions/src/ws_util.rs @@ -0,0 +1,274 @@ +use juniper::{InputValue, ScalarValue}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +/// Enum of Subscription Protocol Message Types over WS +/// to know more access [Subscriptions Transport over WS][SubscriptionsTransportWS] +/// +/// [SubscriptionsTransportWS]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum GraphQLOverWebSocketMessage { + /// Client -> Server + /// Client sends this message after plain websocket connection to start the communication + /// with the server + #[serde(rename = "connection_init")] + ConnectionInit, + /// Server -> Client + /// The server may responses with this message to the GQL_CONNECTION_INIT from client, + /// indicates the server accepted the connection. + #[serde(rename = "connection_ack")] + ConnectionAck, + /// Server -> Client + /// The server may responses with this message to the GQL_CONNECTION_INIT from client, + /// indicates the server rejected the connection. + #[serde(rename = "connection_error")] + ConnectionError, + /// Server -> Client + /// Server message that should be sent right after each GQL_CONNECTION_ACK processed + /// and then periodically to keep the client connection alive. + #[serde(rename = "ka")] + ConnectionKeepAlive, + /// Client -> Server + /// Client sends this message to terminate the connection. + #[serde(rename = "connection_terminate")] + ConnectionTerminate, + /// Client -> Server + /// Client sends this message to execute GraphQL operation + #[serde(rename = "start")] + Start, + /// Server -> Client + /// The server sends this message to transfer the GraphQL execution result from the + /// server to the client, this message is a response for GQL_START message. + #[serde(rename = "data")] + Data, + /// Server -> Client + /// Server sends this message upon a failing operation, before the GraphQL execution, + /// usually due to GraphQL validation errors (resolver errors are part of GQL_DATA message, + /// and will be added as errors array) + #[serde(rename = "error")] + Error, + /// Server -> Client + /// Server sends this message to indicate that a GraphQL operation is done, + /// and no more data will arrive for the specific operation. + #[serde(rename = "complete")] + Complete, + /// Client -> Server + /// Client sends this message in order to stop a running GraphQL operation execution + /// (for example: unsubscribe) + #[serde(rename = "stop")] + Stop, +} + +/// Empty SubscriptionLifeCycleHandler over WS +pub enum SubscriptionState<'a, Context> +where + Context: Send + Sync, +{ + /// The Subscription is at the init of the connection with the client after the + /// server receives the GQL_CONNECTION_INIT message. + OnConnection(Option, &'a mut Context), + /// The Subscription is at the start of a operation after the GQL_START message is + /// is received. + OnOperation(&'a Context), + /// The subscription is on the end of a operation before sending the GQL_COMPLETE + /// message to the client. + OnOperationComplete(&'a Context), + /// The Subscription is terminating the connection with the client. + OnDisconnect(&'a Context), +} + +/// Trait based on the SubscriptionServer [LifeCycleEvents][LifeCycleEvents] +/// +/// [LifeCycleEvents]: https://www.apollographql.com/docs/graphql-subscriptions/lifecycle-events/ +pub trait SubscriptionStateHandler +where + Context: Send + Sync, +{ + /// This function is called when the state of the Subscription changes + /// with the actual state. + fn handle(&self, _state: SubscriptionState) -> Result<(), Box>; +} + +/// Struct defining the message content sent or received by the server +#[derive(Deserialize, Serialize)] +pub struct WsPayload { + /// ID of the Subscription operation + pub id: Option, + /// Type of the Message + #[serde(rename(deserialize = "type"))] + pub type_name: GraphQLOverWebSocketMessage, + /// Payload of the Message + pub payload: Option, +} + +impl WsPayload { + /// Returns the transformation from the payload Value to a GraphQLPayload + pub fn graphql_payload(&self) -> Option> + where + S: ScalarValue + Send + Sync + 'static, + { + serde_json::from_value(self.payload.clone()?).ok() + } + /// Constructor + pub fn new( + type_name: GraphQLOverWebSocketMessage, + id: Option, + payload: Option, + ) -> Self { + Self { + id, + type_name, + payload, + } + } +} + +/// GraphQLPayload content sent by the client to the server +#[derive(Debug, Deserialize)] +#[serde(bound = "InputValue: Deserialize<'de>")] +pub struct GraphQLPayload +where + S: ScalarValue + Send + Sync + 'static, +{ + /// Variables for the Operation + pub variables: Option>, + /// Extensions + pub extensions: Option>, + /// Name of the Operation to be executed + #[serde(rename(deserialize = "operationName"))] + pub operation_name: Option, + /// Query value of the Operation + pub query: Option, +} + +#[cfg(test)] +pub mod tests { + use super::*; + use juniper::DefaultScalarValue; + use std::sync::atomic::{AtomicBool, Ordering}; + + #[derive(Default)] + struct Context { + pub user_id: Option, + pub has_connected: bool, + pub has_operated: AtomicBool, + pub has_completed_operation: AtomicBool, + pub has_disconnected: AtomicBool, + } + + #[derive(Deserialize)] + struct OnConnPayload { + #[serde(rename = "userId")] + pub user_id: Option, + } + + struct SubStateHandler; + + impl SubscriptionStateHandler for SubStateHandler { + fn handle( + &self, + state: SubscriptionState, + ) -> Result<(), Box> { + match state { + SubscriptionState::OnConnection(payload, ctx) => { + if let Some(payload) = payload { + let result = serde_json::from_value::(payload); + if let Ok(payload) = result { + ctx.user_id = payload.user_id; + } + } + ctx.has_connected = true; + } + SubscriptionState::OnOperation(ctx) => { + ctx.has_operated.store(true, Ordering::Relaxed); + } + SubscriptionState::OnOperationComplete(ctx) => { + ctx.has_completed_operation.store(true, Ordering::Relaxed); + } + SubscriptionState::OnDisconnect(ctx) => { + ctx.has_disconnected.store(true, Ordering::Relaxed); + } + }; + Ok(()) + } + } + + const SUB_HANDLER: SubStateHandler = SubStateHandler {}; + + fn implementation_example(msg: &str, ctx: &mut Context) -> bool { + let ws_payload: WsPayload = serde_json::from_str(msg).unwrap(); + match ws_payload.type_name { + GraphQLOverWebSocketMessage::ConnectionInit => { + let state = SubscriptionState::OnConnection(ws_payload.payload, ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + GraphQLOverWebSocketMessage::ConnectionTerminate => { + let state = SubscriptionState::OnDisconnect(ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + GraphQLOverWebSocketMessage::Start => { + // Over here you can make usage of the subscriptions coordinator + // to get the connection related to the client request. This is just a + // testing example to show and verify usage of this module. + let _gql_payload: GraphQLPayload = + ws_payload.graphql_payload().unwrap(); + let state = SubscriptionState::OnOperation(ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + GraphQLOverWebSocketMessage::Stop => { + let state = SubscriptionState::OnOperationComplete(ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + _ => false, + } + } + + #[test] + fn on_connection() { + let mut ctx = Context::default(); + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionInit).unwrap(); + let msg = format!( + r#"{{"type":{}, "payload": {{ "userId": "1" }} }}"#, + type_value + ); + assert!(implementation_example(&msg, &mut ctx)); + assert!(ctx.has_connected); + assert_eq!(ctx.user_id, Some(String::from("1"))); + } + + #[test] + fn on_operation() { + let mut ctx = Context::default(); + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Start).unwrap(); + let msg = format!(r#"{{"type":{}, "payload": {{}}, "id": "1" }}"#, type_value); + assert!(implementation_example(&msg, &mut ctx)); + assert!(ctx.has_operated.load(Ordering::Relaxed)); + } + + #[test] + fn on_operation_completed() { + let mut ctx = Context::default(); + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Stop).unwrap(); + let msg = format!(r#"{{"type":{}, "payload": null, "id": "1" }}"#, type_value); + assert!(implementation_example(&msg, &mut ctx)); + let has_completed = ctx.has_completed_operation.load(Ordering::Relaxed); + assert!(has_completed); + } + + #[test] + fn on_disconnect() { + let mut ctx = Context::default(); + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionTerminate).unwrap(); + let msg = format!(r#"{{"type":{}, "payload": null, "id": "1" }}"#, type_value); + assert!(implementation_example(&msg, &mut ctx)); + let has_disconnected = ctx.has_disconnected.load(Ordering::Relaxed); + assert!(has_disconnected); + } +}