From b9105f1ebb95fabc0f51cb93c7d82e75facdf823 Mon Sep 17 00:00:00 2001 From: Jordao Rosario Date: Thu, 30 Apr 2020 15:42:20 -0300 Subject: [PATCH 1/4] Implementation of some subscriptions over ws utitilies For usage in the subscriptions handlers implementation --- juniper_actix/Cargo.toml | 7 + juniper_actix/examples/actix_subscriptions.rs | 206 +++++++ juniper_actix/src/lib.rs | 20 +- juniper_actix/src/subscriptions.rs | 550 ++++++++++++++++++ juniper_subscriptions/Cargo.toml | 6 +- juniper_subscriptions/src/lib.rs | 18 + juniper_subscriptions/src/ws_util.rs | 285 +++++++++ 7 files changed, 1087 insertions(+), 5 deletions(-) create mode 100644 juniper_actix/examples/actix_subscriptions.rs create mode 100644 juniper_actix/src/subscriptions.rs create mode 100644 juniper_subscriptions/src/ws_util.rs diff --git a/juniper_actix/Cargo.toml b/juniper_actix/Cargo.toml index 1285cc323..ac3355b55 100644 --- a/juniper_actix/Cargo.toml +++ b/juniper_actix/Cargo.toml @@ -8,6 +8,8 @@ documentation = "https://docs.rs/juniper_actix" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +subscriptions = ["juniper_subscriptions"] [dependencies] actix = "0.9.0" @@ -16,6 +18,7 @@ actix-web = { version = "2.0.0", features = ["rustls"] } actix-web-actors = "2.0.0" futures = { version = "0.3.1", features = ["compat"] } juniper = { version = "0.14.2", path = "../juniper", default-features = false } +juniper_subscriptions = { path = "../juniper_subscriptions", optional = true, features = ["ws-subscriptions"]} tokio = { version = "0.2", features = ["time"] } serde = { version = "1.0.75", features = ["derive"] } serde_json = "1.0.24" @@ -30,3 +33,7 @@ tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] } actix-cors = "0.2.0" actix-identity = "0.2.0" bytes = "0.5.4" + +[[example]] +name="actix_subscriptions" +required-features=["subscriptions"] \ No newline at end of file diff --git a/juniper_actix/examples/actix_subscriptions.rs b/juniper_actix/examples/actix_subscriptions.rs new file mode 100644 index 000000000..6c984ff94 --- /dev/null +++ b/juniper_actix/examples/actix_subscriptions.rs @@ -0,0 +1,206 @@ +#![deny(warnings)] + +use actix_cors::Cors; +use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer}; +use futures::Stream; +use juniper::{DefaultScalarValue, FieldError, RootNode}; +use juniper_actix::{ + graphiql_handler as gqli_handler, graphql_handler, playground_handler as play_handler, + subscriptions::{graphql_subscriptions as sub_handler, EmptySubscriptionHandler}, +}; +use juniper_subscriptions::Coordinator; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; +use std::{pin::Pin, time::Duration}; +use tokio::sync::broadcast::{channel, Receiver, Sender}; + +type Schema = RootNode<'static, Query, Mutation, Subscription>; +type MyCoordinator = + Coordinator<'static, Query, Mutation, Subscription, Context, DefaultScalarValue>; + +struct ChatRoom { + pub name: String, + pub channel: (Sender, Receiver), +} + +impl ChatRoom { + pub fn new(name: String) -> Self { + Self { + name, + channel: channel(16), + } + } +} + +struct Context { + pub chat_rooms: Arc>>, +} + +impl Context { + pub fn new(chat_rooms: Arc>>) -> Self { + Self { chat_rooms } + } +} + +impl juniper::Context for Context {} + +struct Query; + +#[juniper::graphql_object(Context = Context)] +impl Query { + pub fn chat_rooms(ctx: &Context) -> Vec { + ctx.chat_rooms + .lock() + .unwrap() + .iter() + .map(|(_, chat_room)| chat_room.name.clone()) + .collect() + } +} + +struct Mutation; + +#[juniper::graphql_object(Context = Context)] +impl Mutation { + pub fn send_message(room_name: String, msg: String, sender: String, ctx: &Context) -> bool { + ctx.chat_rooms + .lock() + .unwrap() + .get(&room_name) + .map(|chat_room| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::new(0, 0)); + chat_room + .channel + .0 + .send(Msg { + sender, + value: msg, + date: format!("{}", now.as_secs()), + }) + .is_ok() + }) + .is_some() + } +} + +#[derive(juniper::GraphQLObject, Clone)] +struct Msg { + pub sender: String, + pub value: String, + pub date: String, +} + +type StringStream = Pin> + Send>>; + +type VecStringStream = Pin, FieldError>> + Send>>; + +struct Subscription; + +#[juniper::graphql_subscription(Context = Context)] +impl Subscription { + async fn hello_world() -> StringStream { + let mut counter = 0; + let stream = tokio::time::interval(Duration::from_secs(1)).map(move |_| { + counter += 1; + if counter % 2 == 0 { + Ok(String::from("World!")) + } else { + Ok(String::from("Hello")) + } + }); + + Box::pin(stream) + } + + async fn chat_room(room_name: String, ctx: &Context) -> VecStringStream { + let mut messages: Vec = Vec::new(); + let channel_rx = { + match ctx.chat_rooms.lock().unwrap().entry(room_name.clone()) { + Entry::Occupied(o) => o.get().channel.0.subscribe(), + Entry::Vacant(v) => v.insert(ChatRoom::new(room_name)).channel.0.subscribe(), + } + }; + let stream = channel_rx.map(move |msg| { + let msg = msg?; + messages.push(msg); + Ok(messages.clone()) + }); + Box::pin(stream) + } +} + +fn schema() -> Schema { + Schema::new(Query {}, Mutation {}, Subscription {}) +} + +async fn graphiql_handler() -> Result { + gqli_handler("/", Some("/subscriptions")).await +} +async fn playground_handler() -> Result { + play_handler("/", Some("/subscriptions")).await +} + +async fn graphql( + req: actix_web::HttpRequest, + payload: actix_web::web::Payload, + schema: web::Data, + chat_rooms: web::Data>>, +) -> Result { + let context = Context::new(chat_rooms.into_inner()); + graphql_handler(&schema, &context, req, payload).await +} + +async fn graphql_subscriptions( + coordinator: web::Data, + stream: web::Payload, + req: HttpRequest, + chat_rooms: web::Data>>, +) -> Result { + let context = Context::new(chat_rooms.into_inner()); + let handler: Option = None; + sub_handler( + coordinator, + context, + stream, + req, + handler, + Some(Duration::from_secs(5)), + ) + .await +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + env_logger::init(); + let chat_rooms: Mutex> = Mutex::new(HashMap::new()); + let chat_rooms = web::Data::new(chat_rooms); + let server = HttpServer::new(move || { + App::new() + .app_data(chat_rooms.clone()) + .data(schema()) + .data(juniper_subscriptions::Coordinator::new(schema())) + .wrap(middleware::Compress::default()) + .wrap(middleware::Logger::default()) + .wrap( + Cors::new() + .allowed_methods(vec!["POST", "GET"]) + .supports_credentials() + .max_age(3600) + .finish(), + ) + .service( + web::resource("/") + .route(web::post().to(graphql)) + .route(web::get().to(graphql)), + ) + .service(web::resource("/playground").route(web::get().to(playground_handler))) + .service(web::resource("/graphiql").route(web::get().to(graphiql_handler))) + .service(web::resource("/subscriptions").to(graphql_subscriptions)) + }); + server.bind("127.0.0.1:8080").unwrap().run().await +} diff --git a/juniper_actix/src/lib.rs b/juniper_actix/src/lib.rs index e85f536d0..2ed11a239 100644 --- a/juniper_actix/src/lib.rs +++ b/juniper_actix/src/lib.rs @@ -55,8 +55,17 @@ use juniper::{ }; use serde::Deserialize; -#[derive(Deserialize, Clone, PartialEq, Debug)] +/// this is the `juniper_actix` subscriptions handler implementation +/// does not fully support the GraphQL over WS[1] specification. +/// +/// *Note: this implementation is in an pre-alpha state.* +/// +/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +#[cfg(feature = "subscriptions")] +pub mod subscriptions; + #[serde(deny_unknown_fields)] +#[derive(Deserialize, Clone, PartialEq, Debug)] struct GetGraphQLRequest { query: String, #[serde(rename = "operationName")] @@ -181,7 +190,8 @@ where Ok(response.content_type("application/json").body(gql_response)) } -/// Create a handler that replies with an HTML page containing GraphiQL. This does not handle routing, so you can mount it on any endpoint +/// Create a handler that replies with an HTML page containing GraphiQL. This does not handle +/// routing, so you can mount it on any endpoint /// /// For example: /// @@ -193,7 +203,8 @@ where /// # use actix_web::{web, App}; /// /// let app = App::new() -/// .route("/", web::get().to(|| graphiql_handler("/graphql", Some("/graphql/subscriptions")))); +/// .route("/", web::get().to(|| +/// graphiql_handler("/graphql", Some("/graphql/subscriptions")))); /// ``` #[allow(dead_code)] pub async fn graphiql_handler( @@ -206,7 +217,8 @@ pub async fn graphiql_handler( .body(html)) } -/// Create a handler that replies with an HTML page containing GraphQL Playground. This does not handle routing, so you cant mount it on any endpoint. +/// Create a handler that replies with an HTML page containing GraphQL Playground. This does not +/// handle routing, so you cant mount it on any endpoint. pub async fn playground_handler( graphql_endpoint_url: &str, subscriptions_endpoint_url: Option<&'static str>, diff --git a/juniper_actix/src/subscriptions.rs b/juniper_actix/src/subscriptions.rs new file mode 100644 index 000000000..b57e2c1b3 --- /dev/null +++ b/juniper_actix/src/subscriptions.rs @@ -0,0 +1,550 @@ +use actix::{ + Actor, ActorContext, ActorFuture, AsyncContext, Handler, Message, Recipient, SpawnHandle, + StreamHandler, WrapFuture, +}; +use actix_web::{error::PayloadError, web, web::Bytes, Error, HttpRequest, HttpResponse}; +use actix_web_actors::{ + ws, + ws::{handshake_with_protocols, WebsocketContext}, +}; +use futures::{Stream, StreamExt}; +use juniper::{http::GraphQLRequest, ScalarValue, SubscriptionCoordinator}; +use juniper_subscriptions::ws_util::GraphQLOverWebSocketMessage; +pub use juniper_subscriptions::ws_util::{ + EmptySubscriptionHandler, GraphQLPayload, SubscriptionState, SubscriptionStateHandler, + WsPayload, +}; +use juniper_subscriptions::Coordinator; +use serde::Serialize; +use std::ops::Deref; +use std::{ + collections::HashMap, + error::Error as StdError, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tokio::time::Duration; + +/// Websocket Subscription Handler +/// +/// # Arguments +/// * `coordinator` - The Subscription Coordinator stored in the App State +/// * `context` - The Context that will be used by the Coordinator +/// * `stream` - The Stream used by the request to create the WebSocket +/// * `req` - The Initial Request sent by the Client +/// * `handler` - The SubscriptionStateHandler implementation that will be used in the Subscription. +/// * `ka_interval` - The Duration that will be used to interleave the keep alive messages sent by the server. The default value is 10 seconds. +pub async fn graphql_subscriptions( + coordinator: web::Data>, + context: Context, + stream: web::Payload, + req: HttpRequest, + handler: Option, + ka_interval: Option, +) -> Result +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + start( + GraphQLWSSession { + coordinator: coordinator.into_inner(), + graphql_context: Arc::new(context), + map_req_id_to_spawn_handle: HashMap::new(), + has_started: Arc::new(AtomicBool::new(false)), + handler, + error_handler: std::marker::PhantomData, + ka_interval: ka_interval.unwrap_or_else(|| Duration::from_secs(10)), + }, + &req, + stream, + ) +} + +fn start( + actor: GraphQLWSSession, + req: &HttpRequest, + stream: T, +) -> Result +where + T: Stream> + 'static, + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + let mut res = handshake_with_protocols(req, &["graphql-ws"])?; + Ok(res.streaming(WebsocketContext::create(actor, stream))) +} + +struct GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + pub map_req_id_to_spawn_handle: HashMap, + pub has_started: Arc, + pub graphql_context: Arc, + pub coordinator: Arc>, + pub handler: Option, + pub ka_interval: Duration, + error_handler: std::marker::PhantomData, +} + +impl Actor + for GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + type Context = ws::WebsocketContext< + GraphQLWSSession, + >; +} + +/// Internal Struct for handling Messages received from the subscriptions +#[derive(Message)] +#[rtype(result = "()")] +struct Msg(pub Option); + +impl Handler + for GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + type Result = (); + fn handle(&mut self, msg: Msg, ctx: &mut Self::Context) { + match msg.0 { + Some(msg) => ctx.text(msg), + None => ctx.close(None), + } + } +} + +#[allow(dead_code)] +impl + GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + fn gql_connection_ack() -> String { + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionAck).unwrap(); + format!(r#"{{"type":{}, "payload": null }}"#, type_value) + } + + fn gql_connection_ka() -> String { + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionKeepAlive).unwrap(); + format!(r#"{{"type":{}, "payload": null }}"#, type_value) + } + + fn gql_connection_error() -> String { + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionError).unwrap(); + format!(r#"{{"type":{}, "payload": null }}"#, type_value) + } + fn gql_error(request_id: &String, err: T) -> String { + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Error).unwrap(); + format!( + r#"{{"type":{},"id":"{}","payload":{}}}"#, + type_value, + request_id, + serde_json::ser::to_string(&err) + .unwrap_or("Error deserializing GraphQLError".to_owned()) + ) + } + + fn gql_data(request_id: &String, response_text: String) -> String { + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Data).unwrap(); + format!( + r#"{{"type":{},"id":"{}","payload":{} }}"#, + type_value, request_id, response_text + ) + } + + fn gql_complete(request_id: &String) -> String { + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Complete).unwrap(); + format!( + r#"{{"type":{},"id":"{}","payload":null}}"#, + type_value, request_id + ) + } + + fn starting_handle( + result: ( + GraphQLRequest, + String, + Arc, + Arc>, + ), + actor: &mut Self, + ctx: &mut ws::WebsocketContext, + ) -> actix::fut::FutureWrap, Self> { + let (req, req_id, gql_context, coord) = result; + let addr = ctx.address(); + Self::handle_subscription(req, gql_context, req_id, coord, addr.recipient()) + .into_actor(actor) + } + + async fn handle_subscription( + req: GraphQLRequest, + graphql_context: Arc, + request_id: String, + coord: Arc>, + addr: Recipient, + ) { + let mut values_stream = { + let subscribe_result = coord.subscribe(&req, &graphql_context).await; + match subscribe_result { + Ok(s) => s, + Err(err) => { + let _ = addr.do_send(Msg(Some(Self::gql_error(&request_id, err)))); + let _ = addr.do_send(Msg(Some(Self::gql_complete(&request_id)))); + let _ = addr.do_send(Msg(None)); + return; + } + } + }; + + while let Some(response) = values_stream.next().await { + let request_id = request_id.clone(); + let response_text = serde_json::to_string(&response) + .unwrap_or("Error deserializing respone".to_owned()); + let _ = addr.do_send(Msg(Some(Self::gql_data(&request_id, response_text)))); + } + let _ = addr.do_send(Msg(Some(Self::gql_complete(&request_id)))); + } +} + +impl + StreamHandler> + for GraphQLWSSession +where + S: ScalarValue + Send + Sync + 'static, + Context: Send + Sync + 'static + std::marker::Unpin, + Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, + E: 'static + std::error::Error + std::marker::Unpin, +{ + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + let msg = match msg { + Err(_) => { + ctx.stop(); + return; + } + Ok(msg) => msg, + }; + let has_started = self.has_started.clone(); + let has_started_value = has_started.load(Ordering::Relaxed); + match msg { + ws::Message::Text(text) => { + let m = text.trim(); + let request: WsPayload = match serde_json::from_str(m) { + Ok(payload) => payload, + Err(_) => { + return; + } + }; + match request.type_name { + GraphQLOverWebSocketMessage::ConnectionInit => { + if let Some(handler) = &self.handler { + let state = SubscriptionState::OnConnection( + request.payload, + Arc::get_mut(&mut self.graphql_context).unwrap(), + ); + let on_connect_result = handler.handle(state); + if let Err(_err) = on_connect_result { + ctx.text(Self::gql_connection_error()); + ctx.stop(); + return; + } + } + ctx.text(Self::gql_connection_ack()); + ctx.text(Self::gql_connection_ka()); + has_started.store(true, Ordering::Relaxed); + ctx.run_interval(self.ka_interval, |actor, ctx| { + let no_request = actor.map_req_id_to_spawn_handle.len() == 0; + if no_request { + ctx.stop(); + } else { + ctx.text(Self::gql_connection_ka()); + } + }); + } + GraphQLOverWebSocketMessage::Start if has_started_value => { + let coordinator = self.coordinator.clone(); + + let payload = request + .graphql_payload::() + .expect("Could not deserialize payload"); + let request_id = request.id.unwrap_or("1".to_owned()); + let graphql_request = GraphQLRequest::<_>::new( + payload.query.expect("Could not deserialize query"), + None, + payload.variables, + ); + if let Some(handler) = &self.handler { + let state = + SubscriptionState::OnOperation(self.graphql_context.deref()); + handler.handle(state).unwrap(); + } + let context = self.graphql_context.clone(); + { + use std::collections::hash_map::Entry; + let req_id = request_id.clone(); + let future = + async move { (graphql_request, req_id, context, coordinator) } + .into_actor(self) + .then(Self::starting_handle); + match self.map_req_id_to_spawn_handle.entry(request_id) { + // Since there is another request being handle + // this just ignores the start of another request with this same + // request_id + Entry::Occupied(_o) => (), + Entry::Vacant(v) => { + v.insert(ctx.spawn(future)); + } + }; + } + } + GraphQLOverWebSocketMessage::Stop if has_started_value => { + let request_id = request.id.unwrap_or("1".to_owned()); + if let Some(handler) = &self.handler { + let context = self.graphql_context.deref(); + let state = SubscriptionState::OnOperationComplete(context); + handler.handle(state).unwrap(); + } + match self.map_req_id_to_spawn_handle.remove(&request_id) { + Some(spawn_handle) => { + ctx.cancel_future(spawn_handle); + ctx.text(Self::gql_complete(&request_id)); + } + None => { + // No request with this id was found in progress. + // since the Subscription Protocol Spec does not specify + // what occurs in this case im just considering the possibility + // of send a error. + } + } + } + GraphQLOverWebSocketMessage::ConnectionTerminate => { + if let Some(handler) = &self.handler { + let context = self.graphql_context.deref(); + let state = SubscriptionState::OnDisconnect(context); + handler.handle(state).unwrap(); + } + ctx.stop(); + } + _ => {} + } + } + ws::Message::Close(_) => { + if let Some(handler) = &self.handler { + let context = self.graphql_context.deref(); + let state = SubscriptionState::OnDisconnect(context); + handler.handle(state).unwrap(); + } + ctx.stop(); + } + _ => { + // Non Text messages are not allowed + ctx.stop(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::{test, App}; + use futures::StreamExt; + + #[actix_rt::test] + async fn expected_communication() { + use actix_web::HttpRequest; + use actix_web_actors::ws::{Frame, Message}; + use futures::{SinkExt, Stream}; + use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode}; + use juniper_subscriptions::Coordinator; + use std::{pin::Pin, time::Duration}; + + pub struct Query; + + #[juniper::graphql_object(Context = Database)] + impl Query { + fn hello_world() -> &str { + "Hello World!" + } + } + type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; + type StringStream = Pin> + Send>>; + type MyCoordinator = Coordinator< + 'static, + Query, + EmptyMutation, + Subscription, + Database, + DefaultScalarValue, + >; + struct Subscription; + + #[derive(Clone)] + pub struct Database; + + impl juniper::Context for Database {} + + impl Database { + fn new() -> Self { + Self {} + } + } + + #[juniper::graphql_subscription(Context = Database)] + impl Subscription { + async fn hello_world() -> StringStream { + let mut counter = 0; + let stream = tokio::time::interval(Duration::from_secs(2)).map(move |_| { + counter += 1; + if counter % 2 == 0 { + Ok(String::from("World!")) + } else { + Ok(String::from("Hello")) + } + }); + Box::pin(stream) + } + } + + let schema: Schema = + RootNode::new(Query, EmptyMutation::::new(), Subscription {}); + + async fn gql_subscriptions( + coordinator: web::Data, + stream: web::Payload, + req: HttpRequest, + ) -> Result { + let context = Database::new(); + graphql_subscriptions( + coordinator, + context, + stream, + req, + Some(EmptySubscriptionHandler::default()), + None, + ) + .await + } + let coord = web::Data::new(juniper_subscriptions::Coordinator::new(schema)); + let mut app = test::start(move || { + App::new() + .app_data(coord.clone()) + .service(web::resource("/subscriptions").to(gql_subscriptions)) + }); + let mut ws = app.ws_at("/subscriptions").await.unwrap(); + let messages_to_be_sent = vec![ + String::from(r#"{"type":"connection_init","payload":{}}"#), + String::from( + r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + ), + String::from( + r#"{"id":"2","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + ), + String::from(r#"{"id":"1","type":"stop"}"#), + String::from(r#"{"type":"connection_terminate"}"#), + ]; + let messages_to_be_received = vec![ + vec![ + Some(bytes::Bytes::from( + r#"{"type":"connection_ack", "payload": null }"#, + )), + Some(bytes::Bytes::from(r#"{"type":"ka", "payload": null }"#)), + ], + vec![Some(bytes::Bytes::from( + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#, + ))], + vec![Some(bytes::Bytes::from( + r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}} }"#, + ))], + vec![Some(bytes::Bytes::from( + r#"{"type":"complete","id":"1","payload":null}"#, + ))], + vec![None], + ]; + + for (index, msg_to_be_sent) in messages_to_be_sent.into_iter().enumerate() { + let expected_msgs = messages_to_be_received.get(index).unwrap(); + ws.send(Message::Text(msg_to_be_sent)).await.unwrap(); + for expected_msg in expected_msgs { + let (item, ws_stream) = ws.into_future().await; + ws = ws_stream; + match expected_msg { + Some(expected_msg) => { + if let Some(Ok(Frame::Text(msg))) = item { + assert_eq!(msg, expected_msg); + } else { + assert!(false); + } + } + None => assert_eq!(item.is_none(), expected_msg.is_none()), + } + } + } + } +} diff --git a/juniper_subscriptions/Cargo.toml b/juniper_subscriptions/Cargo.toml index cff0f1615..4f9633985 100644 --- a/juniper_subscriptions/Cargo.toml +++ b/juniper_subscriptions/Cargo.toml @@ -8,10 +8,14 @@ documentation = "https://docs.rs/juniper_subscriptions" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +ws-subscriptions = ["serde", "serde_json"] [dependencies] -futures = "0.3.1" +futures = "0.3" juniper = { version = "0.14.2", path = "../juniper", default-features = false } +serde = { version = "1.0", optional = true, features = ["derive"] } +serde_json = { version = "1.0", optional = true } [dev-dependencies] serde_json = "1.0" diff --git a/juniper_subscriptions/src/lib.rs b/juniper_subscriptions/src/lib.rs index ff6666e13..7b49b12be 100644 --- a/juniper_subscriptions/src/lib.rs +++ b/juniper_subscriptions/src/lib.rs @@ -19,6 +19,24 @@ use juniper::{ BoxFuture, ExecutionError, GraphQLError, GraphQLSubscriptionType, GraphQLTypeAsync, Object, ScalarValue, SubscriptionConnection, SubscriptionCoordinator, Value, ValuesStream, }; +/// Utilities for the implementation of subscriptions over WebSocket +/// +/// This module provides some utilities for the implementation of Subscriptions over +/// WebSocket, such as the [`GraphQLOverWebSocketMessage`] that contains the messages that +/// could be sent by the server or client and a [`SubscriptionStateHandler`] trait that allows +/// the user of the integration to handle some Subscription Life Cycle Events, its based on: +/// +/// - [Subscriptions Transport over WS][SubscriptionsWsProtocol] +/// - [GraphQL Subscriptions LifeCycle Events][GraphQLSubscriptionsLifeCycle] +/// +/// [SubscriptionsWsProtocol]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +/// [GraphQLSubscriptionsLifeCycle]: https://www.apollographql.com/docs/graphql-subscriptions/lifecycle-events/ +/// [`GraphQLOverWebSocketMessage`]: GraphQLOverWebSocketMessage +/// [`SubscriptionStateHandler`]: SubscriptionStateHandler +#[cfg(feature = "ws-subscriptions")] +pub mod ws_util; +#[cfg(feature = "ws-subscriptions")] +pub use ws_util::*; /// Simple [`SubscriptionCoordinator`] implementation: /// - contains the schema diff --git a/juniper_subscriptions/src/ws_util.rs b/juniper_subscriptions/src/ws_util.rs new file mode 100644 index 000000000..825894f28 --- /dev/null +++ b/juniper_subscriptions/src/ws_util.rs @@ -0,0 +1,285 @@ +use juniper::{InputValue, ScalarValue}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +/// Enum of Subscription Protocol Message Types over WS +/// to know more access [Subscriptions Transport over WS][SubscriptionsTransportWS] +/// +/// [SubscriptionsTransportWS]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum GraphQLOverWebSocketMessage { + /// Client -> Server + /// Client sends this message after plain websocket connection to start the communication + /// with the server + #[serde(rename = "connection_init")] + ConnectionInit, + /// Server -> Client + /// The server may responses with this message to the GQL_CONNECTION_INIT from client, + /// indicates the server accepted the connection. + #[serde(rename = "connection_ack")] + ConnectionAck, + /// Server -> Client + /// The server may responses with this message to the GQL_CONNECTION_INIT from client, + /// indicates the server rejected the connection. + #[serde(rename = "connection_error")] + ConnectionError, + /// Server -> Client + /// Server message that should be sent right after each GQL_CONNECTION_ACK processed + /// and then periodically to keep the client connection alive. + #[serde(rename = "ka")] + ConnectionKeepAlive, + /// Client -> Server + /// Client sends this message to terminate the connection. + #[serde(rename = "connection_terminate")] + ConnectionTerminate, + /// Client -> Server + /// Client sends this message to execute GraphQL operation + #[serde(rename = "start")] + Start, + /// Server -> Client + /// The server sends this message to transfer the GraphQL execution result from the + /// server to the client, this message is a response for GQL_START message. + #[serde(rename = "data")] + Data, + /// Server -> Client + /// Server sends this message upon a failing operation, before the GraphQL execution, + /// usually due to GraphQL validation errors (resolver errors are part of GQL_DATA message, + /// and will be added as errors array) + #[serde(rename = "error")] + Error, + /// Server -> Client + /// Server sends this message to indicate that a GraphQL operation is done, + /// and no more data will arrive for the specific operation. + #[serde(rename = "complete")] + Complete, + /// Client -> Server + /// Client sends this message in order to stop a running GraphQL operation execution + /// (for example: unsubscribe) + #[serde(rename = "stop")] + Stop, +} + +/// Empty SubscriptionLifeCycleHandler over WS +pub enum SubscriptionState<'a, Context> +where + Context: Send + Sync, +{ + /// The Subscription is at the init of the connection with the client after the + /// server receives the GQL_CONNECTION_INIT message. + OnConnection(Option, &'a mut Context), + /// The Subscription is at the start of a operation after the GQL_START message is + /// is received. + OnOperation(&'a Context), + /// The subscription is on the end of a operation before sending the GQL_COMPLETE + /// message to the client. + OnOperationComplete(&'a Context), + /// The Subscription is terminating the connection with the client. + OnDisconnect(&'a Context), +} + +/// Trait based on the SubscriptionServer [LifeCycleEvents][LifeCycleEvents] +/// +/// [LifeCycleEvents]: https://www.apollographql.com/docs/graphql-subscriptions/lifecycle-events/ +pub trait SubscriptionStateHandler +where + Context: Send + Sync, + E: std::error::Error, +{ + /// This function is called when the state of the Subscription changes + /// with the actual state. + fn handle(&self, _state: SubscriptionState) -> Result<(), E>; +} + +/// A Empty Subscription Handler +#[derive(Default)] +pub struct EmptySubscriptionHandler; + +impl SubscriptionStateHandler for EmptySubscriptionHandler +where + Context: Send + Sync, +{ + fn handle(&self, _state: SubscriptionState) -> Result<(), std::io::Error> { + Ok(()) + } +} + +/// Struct defining the message content sent or received by the server +#[derive(Deserialize, Serialize)] +pub struct WsPayload { + /// ID of the Subscription operation + pub id: Option, + /// Type of the Message + #[serde(rename(deserialize = "type"))] + pub type_name: GraphQLOverWebSocketMessage, + /// Payload of the Message + pub payload: Option, +} + +impl WsPayload { + /// Returns the transformation from the payload Value to a GraphQLPayload + pub fn graphql_payload(&self) -> Option> + where + S: ScalarValue + Send + Sync + 'static, + { + serde_json::from_value(self.payload.clone()?).ok() + } + /// Constructor + pub fn new( + id: Option, + type_name: GraphQLOverWebSocketMessage, + payload: Option, + ) -> Self { + Self { + id, + type_name, + payload, + } + } +} + +/// GraphQLPayload content sent by the client to the server +#[derive(Debug, Deserialize)] +#[serde(bound = "InputValue: Deserialize<'de>")] +pub struct GraphQLPayload +where + S: ScalarValue + Send + Sync + 'static, +{ + /// Variables for the Operation + pub variables: Option>, + /// Extensions + pub extensions: Option>, + /// Name of the Operation to be executed + #[serde(rename(deserialize = "operationName"))] + pub operation_name: Option, + /// Query value of the Operation + pub query: Option, +} + +#[cfg(test)] +pub mod tests { + use super::*; + use juniper::DefaultScalarValue; + use std::sync::atomic::{AtomicBool, Ordering}; + + #[derive(Default)] + struct Context { + pub user_id: Option, + pub has_connected: bool, + pub has_operated: AtomicBool, + pub has_completed_operation: AtomicBool, + pub has_disconnected: AtomicBool, + } + + #[derive(Deserialize)] + struct OnConnPayload { + #[serde(rename = "userId")] + pub user_id: Option, + } + + struct SubStateHandler; + + impl SubscriptionStateHandler for SubStateHandler { + fn handle(&self, state: SubscriptionState) -> Result<(), std::io::Error> { + match state { + SubscriptionState::OnConnection(payload, ctx) => { + if let Some(payload) = payload { + let result = serde_json::from_value::(payload); + if let Ok(payload) = result { + ctx.user_id = payload.user_id; + } + } + ctx.has_connected = true; + } + SubscriptionState::OnOperation(ctx) => { + ctx.has_operated.store(true, Ordering::Relaxed); + } + SubscriptionState::OnOperationComplete(ctx) => { + ctx.has_completed_operation.store(true, Ordering::Relaxed); + } + SubscriptionState::OnDisconnect(ctx) => { + ctx.has_disconnected.store(true, Ordering::Relaxed); + } + }; + Ok(()) + } + } + + const SUB_HANDLER: SubStateHandler = SubStateHandler {}; + + fn implementation_example(msg: &str, ctx: &mut Context) -> bool { + let ws_payload: WsPayload = serde_json::from_str(msg).unwrap(); + match ws_payload.type_name { + GraphQLOverWebSocketMessage::ConnectionInit => { + let state = SubscriptionState::OnConnection(ws_payload.payload, ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + GraphQLOverWebSocketMessage::ConnectionTerminate => { + let state = SubscriptionState::OnDisconnect(ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + GraphQLOverWebSocketMessage::Start => { + // Over here you can make usage of the subscriptions coordinator + // to get the connection related to the client request. This is just a + // testing example to show and verify usage of this module. + let _gql_payload: GraphQLPayload = + ws_payload.graphql_payload().unwrap(); + let state = SubscriptionState::OnOperation(ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + GraphQLOverWebSocketMessage::Stop => { + let state = SubscriptionState::OnOperationComplete(ctx); + SUB_HANDLER.handle(state).unwrap(); + true + } + _ => false, + } + } + + #[test] + fn on_connection() { + let mut ctx = Context::default(); + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionInit).unwrap(); + let msg = format!( + r#"{{"type":{}, "payload": {{ "userId": "1" }} }}"#, + type_value + ); + assert!(implementation_example(&msg, &mut ctx)); + assert!(ctx.has_connected); + assert_eq!(ctx.user_id, Some(String::from("1"))); + } + + #[test] + fn on_operation() { + let mut ctx = Context::default(); + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Start).unwrap(); + let msg = format!(r#"{{"type":{}, "payload": {{}}, "id": "1" }}"#, type_value); + assert!(implementation_example(&msg, &mut ctx)); + assert!(ctx.has_operated.load(Ordering::Relaxed)); + } + + #[test] + fn on_operation_completed() { + let mut ctx = Context::default(); + let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Stop).unwrap(); + let msg = format!(r#"{{"type":{}, "payload": null, "id": "1" }}"#, type_value); + assert!(implementation_example(&msg, &mut ctx)); + let has_completed = ctx.has_completed_operation.load(Ordering::Relaxed); + assert!(has_completed); + } + + #[test] + fn on_disconnect() { + let mut ctx = Context::default(); + let type_value = + serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionTerminate).unwrap(); + let msg = format!(r#"{{"type":{}, "payload": null, "id": "1" }}"#, type_value); + assert!(implementation_example(&msg, &mut ctx)); + let has_disconnected = ctx.has_disconnected.load(Ordering::Relaxed); + assert!(has_disconnected); + } +} From b13507c886f64de31691ddf2314ab7fe1b8432ad Mon Sep 17 00:00:00 2001 From: Jordao Rosario Date: Sat, 16 May 2020 11:22:26 -0300 Subject: [PATCH 2/4] Small improvement in juniper_actix subscription test --- juniper_actix/src/subscriptions.rs | 237 ++++++++++++++++------------- 1 file changed, 133 insertions(+), 104 deletions(-) diff --git a/juniper_actix/src/subscriptions.rs b/juniper_actix/src/subscriptions.rs index b57e2c1b3..6747944a8 100644 --- a/juniper_actix/src/subscriptions.rs +++ b/juniper_actix/src/subscriptions.rs @@ -413,124 +413,89 @@ where #[cfg(test)] mod tests { use super::*; + use actix_web::HttpRequest; use actix_web::{test, App}; + use actix_web_actors::ws::{Frame, Message}; use futures::StreamExt; + use futures::{SinkExt, Stream}; + use juniper::{ + tests::model::Database, tests::schema::Query, DefaultScalarValue, EmptyMutation, + FieldError, RootNode, + }; + use juniper_subscriptions::Coordinator; + use std::{pin::Pin, time::Duration}; + type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; + type StringStream = Pin> + Send>>; + type MyCoordinator = Coordinator< + 'static, + Query, + EmptyMutation, + Subscription, + Database, + DefaultScalarValue, + >; + struct Subscription; - #[actix_rt::test] - async fn expected_communication() { - use actix_web::HttpRequest; - use actix_web_actors::ws::{Frame, Message}; - use futures::{SinkExt, Stream}; - use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode}; - use juniper_subscriptions::Coordinator; - use std::{pin::Pin, time::Duration}; - - pub struct Query; - - #[juniper::graphql_object(Context = Database)] - impl Query { - fn hello_world() -> &str { - "Hello World!" - } - } - type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; - type StringStream = Pin> + Send>>; - type MyCoordinator = Coordinator< - 'static, - Query, - EmptyMutation, - Subscription, - Database, - DefaultScalarValue, - >; - struct Subscription; - - #[derive(Clone)] - pub struct Database; - - impl juniper::Context for Database {} - - impl Database { - fn new() -> Self { - Self {} - } + #[juniper::graphql_subscription(Context = Database)] + impl Subscription { + async fn hello_world() -> StringStream { + let mut counter = 0; + let stream = tokio::time::interval(Duration::from_secs(2)).map(move |_| { + counter += 1; + if counter % 2 == 0 { + Ok(String::from("World!")) + } else { + Ok(String::from("Hello")) + } + }); + Box::pin(stream) } + } - #[juniper::graphql_subscription(Context = Database)] - impl Subscription { - async fn hello_world() -> StringStream { - let mut counter = 0; - let stream = tokio::time::interval(Duration::from_secs(2)).map(move |_| { - counter += 1; - if counter % 2 == 0 { - Ok(String::from("World!")) - } else { - Ok(String::from("Hello")) - } - }); - Box::pin(stream) - } - } + async fn gql_subscriptions( + coordinator: web::Data, + stream: web::Payload, + req: HttpRequest, + ) -> Result { + let context = Database::new(); + graphql_subscriptions( + coordinator, + context, + stream, + req, + Some(EmptySubscriptionHandler::default()), + None, + ) + .await + } + fn test_server() -> test::TestServer { let schema: Schema = RootNode::new(Query, EmptyMutation::::new(), Subscription {}); - async fn gql_subscriptions( - coordinator: web::Data, - stream: web::Payload, - req: HttpRequest, - ) -> Result { - let context = Database::new(); - graphql_subscriptions( - coordinator, - context, - stream, - req, - Some(EmptySubscriptionHandler::default()), - None, - ) - .await - } let coord = web::Data::new(juniper_subscriptions::Coordinator::new(schema)); - let mut app = test::start(move || { + test::start(move || { App::new() .app_data(coord.clone()) .service(web::resource("/subscriptions").to(gql_subscriptions)) - }); - let mut ws = app.ws_at("/subscriptions").await.unwrap(); - let messages_to_be_sent = vec![ - String::from(r#"{"type":"connection_init","payload":{}}"#), - String::from( - r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, - ), - String::from( - r#"{"id":"2","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, - ), - String::from(r#"{"id":"1","type":"stop"}"#), - String::from(r#"{"type":"connection_terminate"}"#), - ]; - let messages_to_be_received = vec![ - vec![ - Some(bytes::Bytes::from( - r#"{"type":"connection_ack", "payload": null }"#, - )), - Some(bytes::Bytes::from(r#"{"type":"ka", "payload": null }"#)), - ], - vec![Some(bytes::Bytes::from( - r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#, - ))], - vec![Some(bytes::Bytes::from( - r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}} }"#, - ))], - vec![Some(bytes::Bytes::from( - r#"{"type":"complete","id":"1","payload":null}"#, - ))], - vec![None], - ]; + }) + } + + fn received_msg(msg: &'static str) -> Option { + Some(bytes::Bytes::from(msg)) + } - for (index, msg_to_be_sent) in messages_to_be_sent.into_iter().enumerate() { - let expected_msgs = messages_to_be_received.get(index).unwrap(); - ws.send(Message::Text(msg_to_be_sent)).await.unwrap(); + async fn test_subscription( + msgs_to_send: Vec<&str>, + msgs_to_receive: Vec>>, + ) { + let mut app = test_server(); + let mut ws = app.ws_at("/subscriptions").await.unwrap(); + for (index, msg_to_be_sent) in msgs_to_send.into_iter().enumerate() { + let expected_msgs = msgs_to_receive.get(index).unwrap(); + ws.send(Message::Text(msg_to_be_sent.to_string())) + .await + .unwrap(); for expected_msg in expected_msgs { let (item, ws_stream) = ws.into_future().await; ws = ws_stream; @@ -547,4 +512,68 @@ mod tests { } } } + + #[actix_rt::test] + async fn basic_connection() { + let msgs_to_send = vec![ + r#"{"type":"connection_init","payload":{}}"#, + r#"{"type":"connection_terminate"}"#, + ]; + let msgs_to_receive = vec![ + vec![ + received_msg(r#"{"type":"connection_ack", "payload": null }"#), + received_msg(r#"{"type":"ka", "payload": null }"#), + ], + vec![None], + ]; + test_subscription(msgs_to_send, msgs_to_receive).await; + } + + #[actix_rt::test] + async fn basic_subscription() { + let msgs_to_send = vec![ + r#"{"type":"connection_init","payload":{}}"#, + r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + r#"{"type":"connection_terminate"}"#, + ]; + let msgs_to_receive = vec![ + vec![ + received_msg(r#"{"type":"connection_ack", "payload": null }"#), + received_msg(r#"{"type":"ka", "payload": null }"#), + ], + vec![received_msg( + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#, + )], + vec![None], + ]; + test_subscription(msgs_to_send, msgs_to_receive).await; + } + + #[actix_rt::test] + async fn conn_with_two_subscriptions() { + let msgs_to_send = vec![ + r#"{"type":"connection_init","payload":{}}"#, + r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + r#"{"id":"2","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#, + r#"{"id":"1","type":"stop"}"#, + r#"{"type":"connection_terminate"}"#, + ]; + let msgs_to_receive = vec![ + vec![ + received_msg(r#"{"type":"connection_ack", "payload": null }"#), + received_msg(r#"{"type":"ka", "payload": null }"#), + ], + vec![received_msg( + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#, + )], + vec![received_msg( + r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}} }"#, + )], + vec![received_msg( + r#"{"type":"complete","id":"1","payload":null}"#, + )], + vec![None], + ]; + test_subscription(msgs_to_send, msgs_to_receive).await; + } } From df43e2ff09b60ffc1308eea9dd0bcc028cf5b914 Mon Sep 17 00:00:00 2001 From: Jordao Rosario Date: Wed, 27 May 2020 21:30:04 -0300 Subject: [PATCH 3/4] Changes in the actix subscriptions implementation --- juniper_actix/examples/actix_subscriptions.rs | 77 +++-- juniper_actix/src/subscriptions.rs | 279 ++++++++---------- juniper_subscriptions/src/ws_util.rs | 27 +- 3 files changed, 172 insertions(+), 211 deletions(-) diff --git a/juniper_actix/examples/actix_subscriptions.rs b/juniper_actix/examples/actix_subscriptions.rs index 6c984ff94..08d35004f 100644 --- a/juniper_actix/examples/actix_subscriptions.rs +++ b/juniper_actix/examples/actix_subscriptions.rs @@ -6,7 +6,10 @@ use futures::Stream; use juniper::{DefaultScalarValue, FieldError, RootNode}; use juniper_actix::{ graphiql_handler as gqli_handler, graphql_handler, playground_handler as play_handler, - subscriptions::{graphql_subscriptions as sub_handler, EmptySubscriptionHandler}, + subscriptions::{ + graphql_subscriptions as sub_handler, GraphQLWSSession, SubscriptionState, + SubscriptionStateHandler, + }, }; use juniper_subscriptions::Coordinator; use std::collections::hash_map::Entry; @@ -23,6 +26,7 @@ type MyCoordinator = struct ChatRoom { pub name: String, pub channel: (Sender, Receiver), + pub msgs: Vec, } impl ChatRoom { @@ -30,6 +34,7 @@ impl ChatRoom { Self { name, channel: channel(16), + msgs: Vec::new(), } } } @@ -58,6 +63,14 @@ impl Query { .map(|(_, chat_room)| chat_room.name.clone()) .collect() } + + pub fn msgs_from_room(room_name: String, ctx: &Context) -> Option> { + ctx.chat_rooms + .lock() + .unwrap() + .get(&room_name) + .map(|chat_room| chat_room.msgs.clone()) + } } struct Mutation; @@ -68,20 +81,18 @@ impl Mutation { ctx.chat_rooms .lock() .unwrap() - .get(&room_name) + .get_mut(&room_name) .map(|chat_room| { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or(Duration::new(0, 0)); - chat_room - .channel - .0 - .send(Msg { - sender, - value: msg, - date: format!("{}", now.as_secs()), - }) - .is_ok() + let msg = Msg { + sender, + value: msg, + date: format!("{}", now.as_secs()), + }; + chat_room.msgs.push(msg.clone()); + chat_room.channel.0.send(msg).is_ok() }) .is_some() } @@ -96,7 +107,7 @@ struct Msg { type StringStream = Pin> + Send>>; -type VecStringStream = Pin, FieldError>> + Send>>; +type MsgStream = Pin> + Send>>; struct Subscription; @@ -116,19 +127,14 @@ impl Subscription { Box::pin(stream) } - async fn chat_room(room_name: String, ctx: &Context) -> VecStringStream { - let mut messages: Vec = Vec::new(); + async fn new_messages(room_name: String, ctx: &Context) -> MsgStream { let channel_rx = { match ctx.chat_rooms.lock().unwrap().entry(room_name.clone()) { Entry::Occupied(o) => o.get().channel.0.subscribe(), Entry::Vacant(v) => v.insert(ChatRoom::new(room_name)).channel.0.subscribe(), } }; - let stream = channel_rx.map(move |msg| { - let msg = msg?; - messages.push(msg); - Ok(messages.clone()) - }); + let stream = channel_rx.map(|msg| Ok(msg?)); Box::pin(stream) } } @@ -138,7 +144,7 @@ fn schema() -> Schema { } async fn graphiql_handler() -> Result { - gqli_handler("/", Some("/subscriptions")).await + gqli_handler("/", Some("ws://localhost:8080/subscriptions")).await } async fn playground_handler() -> Result { play_handler("/", Some("/subscriptions")).await @@ -154,6 +160,24 @@ async fn graphql( graphql_handler(&schema, &context, req, payload).await } +#[derive(Default)] +struct HandlerExample {} + +impl SubscriptionStateHandler for HandlerExample +where + Context: Send + Sync, +{ + fn handle(&self, state: SubscriptionState) -> Result<(), Box> { + match state { + SubscriptionState::OnConnection(_, _) => println!("OnConnection"), + SubscriptionState::OnOperation(_) => println!("OnOperation"), + SubscriptionState::OnOperationComplete(_) => println!("OnOperationComplete"), + SubscriptionState::OnDisconnect(_) => println!("OnDisconnect"), + }; + Ok(()) + } +} + async fn graphql_subscriptions( coordinator: web::Data, stream: web::Payload, @@ -161,16 +185,9 @@ async fn graphql_subscriptions( chat_rooms: web::Data>>, ) -> Result { let context = Context::new(chat_rooms.into_inner()); - let handler: Option = None; - sub_handler( - coordinator, - context, - stream, - req, - handler, - Some(Duration::from_secs(5)), - ) - .await + let actor = GraphQLWSSession::new(coordinator.into_inner(), context); + let actor = actor.with_handler(HandlerExample::default()); + sub_handler(actor, stream, req).await } #[actix_rt::main] diff --git a/juniper_actix/src/subscriptions.rs b/juniper_actix/src/subscriptions.rs index 6747944a8..43b3550a6 100644 --- a/juniper_actix/src/subscriptions.rs +++ b/juniper_actix/src/subscriptions.rs @@ -2,17 +2,16 @@ use actix::{ Actor, ActorContext, ActorFuture, AsyncContext, Handler, Message, Recipient, SpawnHandle, StreamHandler, WrapFuture, }; -use actix_web::{error::PayloadError, web, web::Bytes, Error, HttpRequest, HttpResponse}; +use actix_web::{web, Error, HttpRequest, HttpResponse}; use actix_web_actors::{ ws, ws::{handshake_with_protocols, WebsocketContext}, }; -use futures::{Stream, StreamExt}; +use futures::StreamExt; use juniper::{http::GraphQLRequest, ScalarValue, SubscriptionCoordinator}; use juniper_subscriptions::ws_util::GraphQLOverWebSocketMessage; pub use juniper_subscriptions::ws_util::{ - EmptySubscriptionHandler, GraphQLPayload, SubscriptionState, SubscriptionStateHandler, - WsPayload, + GraphQLPayload, SubscriptionState, SubscriptionStateHandler, WsPayload, }; use juniper_subscriptions::Coordinator; use serde::Serialize; @@ -20,12 +19,9 @@ use std::ops::Deref; use std::{ collections::HashMap, error::Error as StdError, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::Arc, + time::{Duration, Instant}, }; -use tokio::time::Duration; /// Websocket Subscription Handler /// @@ -36,13 +32,10 @@ use tokio::time::Duration; /// * `req` - The Initial Request sent by the Client /// * `handler` - The SubscriptionStateHandler implementation that will be used in the Subscription. /// * `ka_interval` - The Duration that will be used to interleave the keep alive messages sent by the server. The default value is 10 seconds. -pub async fn graphql_subscriptions( - coordinator: web::Data>, - context: Context, +pub async fn graphql_subscriptions( + actor: GraphQLWSSession, stream: web::Payload, req: HttpRequest, - handler: Option, - ka_interval: Option, ) -> Result where S: ScalarValue + Send + Sync + 'static, @@ -53,47 +46,13 @@ where Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, { - start( - GraphQLWSSession { - coordinator: coordinator.into_inner(), - graphql_context: Arc::new(context), - map_req_id_to_spawn_handle: HashMap::new(), - has_started: Arc::new(AtomicBool::new(false)), - handler, - error_handler: std::marker::PhantomData, - ka_interval: ka_interval.unwrap_or_else(|| Duration::from_secs(10)), - }, - &req, - stream, - ) -} - -fn start( - actor: GraphQLWSSession, - req: &HttpRequest, - stream: T, -) -> Result -where - T: Stream> + 'static, - S: ScalarValue + Send + Sync + 'static, - Context: Send + Sync + 'static + std::marker::Unpin, - Query: juniper::GraphQLTypeAsync + Send + Sync + 'static, - Query::TypeInfo: Send + Sync, - Mutation: juniper::GraphQLTypeAsync + Send + Sync + 'static, - Mutation::TypeInfo: Send + Sync, - Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, - Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, -{ - let mut res = handshake_with_protocols(req, &["graphql-ws"])?; + let mut res = handshake_with_protocols(&req, &["graphql-ws"])?; Ok(res.streaming(WebsocketContext::create(actor, stream))) } -struct GraphQLWSSession +/// Actor for handling each WS Session +pub struct GraphQLWSSession where S: ScalarValue + Send + Sync + 'static, Context: Send + Sync + 'static + std::marker::Unpin, @@ -103,20 +62,16 @@ where Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, { - pub map_req_id_to_spawn_handle: HashMap, - pub has_started: Arc, - pub graphql_context: Arc, - pub coordinator: Arc>, - pub handler: Option, - pub ka_interval: Duration, - error_handler: std::marker::PhantomData, + map_req_id_to_spawn_handle: HashMap, + graphql_context: Arc, + coordinator: Arc>, + handler: Option + 'static + std::marker::Unpin>>, + hb: Instant, } -impl Actor - for GraphQLWSSession +impl Actor + for GraphQLWSSession where S: ScalarValue + Send + Sync + 'static, Context: Send + Sync + 'static + std::marker::Unpin, @@ -126,12 +81,13 @@ where Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, { - type Context = ws::WebsocketContext< - GraphQLWSSession, - >; + type Context = + ws::WebsocketContext>; + + fn started(&mut self, ctx: &mut Self::Context) { + self.hb(ctx); + } } /// Internal Struct for handling Messages received from the subscriptions @@ -139,8 +95,8 @@ where #[rtype(result = "()")] struct Msg(pub Option); -impl Handler - for GraphQLWSSession +impl Handler + for GraphQLWSSession where S: ScalarValue + Send + Sync + 'static, Context: Send + Sync + 'static + std::marker::Unpin, @@ -150,8 +106,6 @@ where Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, { type Result = (); fn handle(&mut self, msg: Msg, ctx: &mut Self::Context) { @@ -162,9 +116,14 @@ where } } +/// How often heartbeat pings are sent +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); +/// How long before lack of client response causes a timeout +const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); + #[allow(dead_code)] -impl - GraphQLWSSession +impl + GraphQLWSSession where S: ScalarValue + Send + Sync + 'static, Context: Send + Sync + 'static + std::marker::Unpin, @@ -174,54 +133,70 @@ where Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, { - fn gql_connection_ack() -> String { - let type_value = - serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionAck).unwrap(); - format!(r#"{{"type":{}, "payload": null }}"#, type_value) + /// Creates a instance for usage in the graphql_subscription endpoint + pub fn new( + coord: Arc>, + ctx: Context, + ) -> Self { + Self { + coordinator: coord, + graphql_context: Arc::new(ctx), + map_req_id_to_spawn_handle: HashMap::new(), + handler: None, + hb: Instant::now(), + } + } + + /// Inserts a SubscriptionStateHandler in the Session + pub fn with_handler(self, handler: H) -> Self + where + H: SubscriptionStateHandler + 'static + std::marker::Unpin, + { + Self { + handler: Some(Box::new(handler)), + ..self + } } - fn gql_connection_ka() -> String { - let type_value = - serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionKeepAlive).unwrap(); - format!(r#"{{"type":{}, "payload": null }}"#, type_value) + fn gql_connection_ack() -> String { + let value = serde_json::json!({ "type": GraphQLOverWebSocketMessage::ConnectionAck }); + serde_json::to_string(&value).unwrap() } fn gql_connection_error() -> String { - let type_value = - serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionError).unwrap(); - format!(r#"{{"type":{}, "payload": null }}"#, type_value) + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::ConnectionError, + }); + serde_json::to_string(&value).unwrap() } fn gql_error(request_id: &String, err: T) -> String { - let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Error).unwrap(); - format!( - r#"{{"type":{},"id":"{}","payload":{}}}"#, - type_value, - request_id, - serde_json::ser::to_string(&err) - .unwrap_or("Error deserializing GraphQLError".to_owned()) - ) + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::Error, + "id": request_id, + "payload": err + }); + serde_json::to_string(&value).unwrap() } - fn gql_data(request_id: &String, response_text: String) -> String { - let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Data).unwrap(); - format!( - r#"{{"type":{},"id":"{}","payload":{} }}"#, - type_value, request_id, response_text - ) + fn gql_data(request_id: &String, payload: T) -> String { + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::Data, + "id": request_id, + "payload": payload + }); + serde_json::to_string(&value).unwrap() } fn gql_complete(request_id: &String) -> String { - let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Complete).unwrap(); - format!( - r#"{{"type":{},"id":"{}","payload":null}}"#, - type_value, request_id - ) + let value = serde_json::json!({ + "type": GraphQLOverWebSocketMessage::Complete, + "id": request_id, + }); + serde_json::to_string(&value).unwrap() } - fn starting_handle( + fn starting_subscription( result: ( GraphQLRequest, String, @@ -259,17 +234,25 @@ where while let Some(response) = values_stream.next().await { let request_id = request_id.clone(); - let response_text = serde_json::to_string(&response) - .unwrap_or("Error deserializing respone".to_owned()); - let _ = addr.do_send(Msg(Some(Self::gql_data(&request_id, response_text)))); + let _ = addr.do_send(Msg(Some(Self::gql_data(&request_id, response)))); } let _ = addr.do_send(Msg(Some(Self::gql_complete(&request_id)))); } + + fn hb(&self, ctx: &mut ::Context) { + ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { + if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { + ctx.stop(); + return; + } + ctx.ping(b""); + }); + } } -impl +impl StreamHandler> - for GraphQLWSSession + for GraphQLWSSession where S: ScalarValue + Send + Sync + 'static, Context: Send + Sync + 'static + std::marker::Unpin, @@ -279,8 +262,6 @@ where Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType + Send + Sync + 'static, Subscription::TypeInfo: Send + Sync, - SubHandler: SubscriptionStateHandler + 'static + std::marker::Unpin, - E: 'static + std::error::Error + std::marker::Unpin, { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { let msg = match msg { @@ -290,9 +271,14 @@ where } Ok(msg) => msg, }; - let has_started = self.has_started.clone(); - let has_started_value = has_started.load(Ordering::Relaxed); match msg { + ws::Message::Ping(msg) => { + self.hb = Instant::now(); + ctx.pong(&msg); + } + ws::Message::Pong(_) => { + self.hb = Instant::now(); + } ws::Message::Text(text) => { let m = text.trim(); let request: WsPayload = match serde_json::from_str(m) { @@ -309,25 +295,15 @@ where Arc::get_mut(&mut self.graphql_context).unwrap(), ); let on_connect_result = handler.handle(state); - if let Err(_err) = on_connect_result { + if let Err(_) = on_connect_result { ctx.text(Self::gql_connection_error()); ctx.stop(); return; } } ctx.text(Self::gql_connection_ack()); - ctx.text(Self::gql_connection_ka()); - has_started.store(true, Ordering::Relaxed); - ctx.run_interval(self.ka_interval, |actor, ctx| { - let no_request = actor.map_req_id_to_spawn_handle.len() == 0; - if no_request { - ctx.stop(); - } else { - ctx.text(Self::gql_connection_ka()); - } - }); } - GraphQLOverWebSocketMessage::Start if has_started_value => { + GraphQLOverWebSocketMessage::Start => { let coordinator = self.coordinator.clone(); let payload = request @@ -342,7 +318,7 @@ where if let Some(handler) = &self.handler { let state = SubscriptionState::OnOperation(self.graphql_context.deref()); - handler.handle(state).unwrap(); + handler.as_ref().handle(state).unwrap(); } let context = self.graphql_context.clone(); { @@ -351,7 +327,7 @@ where let future = async move { (graphql_request, req_id, context, coordinator) } .into_actor(self) - .then(Self::starting_handle); + .then(Self::starting_subscription); match self.map_req_id_to_spawn_handle.entry(request_id) { // Since there is another request being handle // this just ignores the start of another request with this same @@ -363,7 +339,7 @@ where }; } } - GraphQLOverWebSocketMessage::Stop if has_started_value => { + GraphQLOverWebSocketMessage::Stop => { let request_id = request.id.unwrap_or("1".to_owned()); if let Some(handler) = &self.handler { let context = self.graphql_context.deref(); @@ -394,7 +370,7 @@ where _ => {} } } - ws::Message::Close(_) => { + ws::Message::Binary(_) | ws::Message::Close(_) | ws::Message::Continuation(_) => { if let Some(handler) = &self.handler { let context = self.graphql_context.deref(); let state = SubscriptionState::OnDisconnect(context); @@ -402,10 +378,7 @@ where } ctx.stop(); } - _ => { - // Non Text messages are not allowed - ctx.stop(); - } + ws::Message::Nop => {} } } } @@ -458,15 +431,8 @@ mod tests { req: HttpRequest, ) -> Result { let context = Database::new(); - graphql_subscriptions( - coordinator, - context, - stream, - req, - Some(EmptySubscriptionHandler::default()), - None, - ) - .await + let actor = GraphQLWSSession::new(coordinator.into_inner(), context); + graphql_subscriptions(actor, stream, req).await } fn test_server() -> test::TestServer { @@ -520,10 +486,7 @@ mod tests { r#"{"type":"connection_terminate"}"#, ]; let msgs_to_receive = vec![ - vec![ - received_msg(r#"{"type":"connection_ack", "payload": null }"#), - received_msg(r#"{"type":"ka", "payload": null }"#), - ], + vec![received_msg(r#"{"type":"connection_ack"}"#)], vec![None], ]; test_subscription(msgs_to_send, msgs_to_receive).await; @@ -537,12 +500,9 @@ mod tests { r#"{"type":"connection_terminate"}"#, ]; let msgs_to_receive = vec![ - vec![ - received_msg(r#"{"type":"connection_ack", "payload": null }"#), - received_msg(r#"{"type":"ka", "payload": null }"#), - ], + vec![received_msg(r#"{"type":"connection_ack"}"#)], vec![received_msg( - r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#, + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}}}"#, )], vec![None], ]; @@ -559,19 +519,14 @@ mod tests { r#"{"type":"connection_terminate"}"#, ]; let msgs_to_receive = vec![ - vec![ - received_msg(r#"{"type":"connection_ack", "payload": null }"#), - received_msg(r#"{"type":"ka", "payload": null }"#), - ], - vec![received_msg( - r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#, - )], + vec![received_msg(r#"{"type":"connection_ack"}"#)], vec![received_msg( - r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}} }"#, + r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}}}"#, )], vec![received_msg( - r#"{"type":"complete","id":"1","payload":null}"#, + r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}}}"#, )], + vec![received_msg(r#"{"type":"complete","id":"1"}"#)], vec![None], ]; test_subscription(msgs_to_send, msgs_to_receive).await; diff --git a/juniper_subscriptions/src/ws_util.rs b/juniper_subscriptions/src/ws_util.rs index 825894f28..df94fb03a 100644 --- a/juniper_subscriptions/src/ws_util.rs +++ b/juniper_subscriptions/src/ws_util.rs @@ -81,27 +81,13 @@ where /// Trait based on the SubscriptionServer [LifeCycleEvents][LifeCycleEvents] /// /// [LifeCycleEvents]: https://www.apollographql.com/docs/graphql-subscriptions/lifecycle-events/ -pub trait SubscriptionStateHandler +pub trait SubscriptionStateHandler where Context: Send + Sync, - E: std::error::Error, { /// This function is called when the state of the Subscription changes /// with the actual state. - fn handle(&self, _state: SubscriptionState) -> Result<(), E>; -} - -/// A Empty Subscription Handler -#[derive(Default)] -pub struct EmptySubscriptionHandler; - -impl SubscriptionStateHandler for EmptySubscriptionHandler -where - Context: Send + Sync, -{ - fn handle(&self, _state: SubscriptionState) -> Result<(), std::io::Error> { - Ok(()) - } + fn handle(&self, _state: SubscriptionState) -> Result<(), Box>; } /// Struct defining the message content sent or received by the server @@ -126,8 +112,8 @@ impl WsPayload { } /// Constructor pub fn new( - id: Option, type_name: GraphQLOverWebSocketMessage, + id: Option, payload: Option, ) -> Self { Self { @@ -179,8 +165,11 @@ pub mod tests { struct SubStateHandler; - impl SubscriptionStateHandler for SubStateHandler { - fn handle(&self, state: SubscriptionState) -> Result<(), std::io::Error> { + impl SubscriptionStateHandler for SubStateHandler { + fn handle( + &self, + state: SubscriptionState, + ) -> Result<(), Box> { match state { SubscriptionState::OnConnection(payload, ctx) => { if let Some(payload) = payload { From 97eee2b91f27f37be825f562770f5be1902b977f Mon Sep 17 00:00:00 2001 From: Jordao Rosario Date: Wed, 10 Jun 2020 19:25:45 -0300 Subject: [PATCH 4/4] Support for binary msgs that can be converted into text --- juniper_actix/src/subscriptions.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/juniper_actix/src/subscriptions.rs b/juniper_actix/src/subscriptions.rs index 43b3550a6..e397aea55 100644 --- a/juniper_actix/src/subscriptions.rs +++ b/juniper_actix/src/subscriptions.rs @@ -370,7 +370,12 @@ where _ => {} } } - ws::Message::Binary(_) | ws::Message::Close(_) | ws::Message::Continuation(_) => { + ws::Message::Binary(msg) => { + if let Ok(msg) = std::str::from_utf8(msg.as_ref()) { + StreamHandler::handle(self, Ok(ws::Message::Text(String::from(msg))), ctx); + } + }, + ws::Message::Close(_) | ws::Message::Continuation(_) => { if let Some(handler) = &self.handler { let context = self.graphql_context.deref(); let state = SubscriptionState::OnDisconnect(context);