From 31e1c45826f9e31d49ad5853b3359dd25b2c454c Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Sat, 25 Jul 2020 06:02:49 -0400 Subject: [PATCH 01/10] update pre-existing juniper_warp::subscriptions --- juniper_warp/src/lib.rs | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 291ba011a..dbb821870 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -403,11 +403,20 @@ pub mod subscriptions { use anyhow::anyhow; use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _}; - use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; + use juniper::{ + http::GraphQLRequest, ExecutionError, InputValue, ScalarValue, + SubscriptionCoordinator as _, Value, + }; use juniper_subscriptions::Coordinator; use serde::{Deserialize, Serialize}; use warp::ws::Message; + #[derive(Serialize)] + struct DataPayload<'a, S: ScalarValue> { + data: &'a Value, + errors: &'a Vec>, + } + /// Listen to incoming messages and do one of the following: /// - execute subscription and return values from stream /// - stop stream and close ws connection @@ -531,15 +540,19 @@ pub mod subscriptions { }; values_stream - .take_while(move |response| { + .take_while(move |(data, errors)| { let request_id = request_id.clone(); let should_stop = state.should_stop.load(Ordering::Relaxed) || got_close_signal.load(Ordering::Relaxed); if !should_stop { - let mut response_text = serde_json::to_string( - &response, - ) - .unwrap_or("Error deserializing response".to_owned()); + let mut response_text = + serde_json::to_string(&DataPayload { + data, + errors, + }) + .unwrap_or( + "Error deserializing response".to_owned(), + ); response_text = format!( r#"{{"type":"data","id":"{}","payload":{} }}"#, From 3dd6b6b769610f9c58ea954aba5e4dd83325a396 Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Sat, 25 Jul 2020 18:06:38 -0400 Subject: [PATCH 02/10] initial draft --- Cargo.toml | 1 + juniper_graphql_ws/Cargo.toml | 19 + juniper_graphql_ws/src/client_message.rs | 115 ++++ juniper_graphql_ws/src/lib.rs | 785 +++++++++++++++++++++++ juniper_graphql_ws/src/server_message.rs | 155 +++++ 5 files changed, 1075 insertions(+) create mode 100644 juniper_graphql_ws/Cargo.toml create mode 100644 juniper_graphql_ws/src/client_message.rs create mode 100644 juniper_graphql_ws/src/lib.rs create mode 100644 juniper_graphql_ws/src/server_message.rs diff --git a/Cargo.toml b/Cargo.toml index 79429a10e..d37670be7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "juniper_rocket", "juniper_rocket_async", "juniper_subscriptions", + "juniper_graphql_ws", "juniper_warp", "juniper_actix", ] diff --git a/juniper_graphql_ws/Cargo.toml b/juniper_graphql_ws/Cargo.toml new file mode 100644 index 000000000..6f15ca303 --- /dev/null +++ b/juniper_graphql_ws/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "juniper_graphql_ws" +version = "0.1.0" +authors = ["Christopher Brown "] +license = "BSD-2-Clause" +description = "Graphql-ws protocol implementation for Juniper" +documentation = "https://docs.rs/juniper_graphql_ws" +repository = "https://github.com/graphql-rust/juniper" +keywords = ["graphql-ws"] +edition = "2018" + +[dependencies] +juniper = { version = "0.14.2", path = "../juniper", default-features = false } +juniper_subscriptions = { path = "../juniper_subscriptions" } +serde = { version = "1.0.8", features = ["derive"] } +tokio = { version = "0.2", features = ["macros", "rt-core", "time"] } + +[dev-dependencies] +serde_json = { version = "1.0.2" } diff --git a/juniper_graphql_ws/src/client_message.rs b/juniper_graphql_ws/src/client_message.rs new file mode 100644 index 000000000..bee6d4697 --- /dev/null +++ b/juniper_graphql_ws/src/client_message.rs @@ -0,0 +1,115 @@ +use juniper::{ScalarValue, Variables}; + +#[derive(Debug, Deserialize, PartialEq)] +#[serde(bound(deserialize = "S: ScalarValue"))] +#[serde(rename_all = "camelCase")] +pub struct StartPayload { + pub query: String, + #[serde(default)] + pub variables: Variables, + pub operation_name: Option, +} + +/// ClientMessage defines the message types that clients can send. +#[derive(Debug, Deserialize, PartialEq)] +#[serde(bound(deserialize = "S: ScalarValue"))] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum ClientMessage { + /// ConnectionInit is sent by the client upon connecting. + ConnectionInit { + #[serde(default)] + payload: Variables, + }, + /// Start messages are used to execute a GraphQL operation. + Start { + id: String, + payload: StartPayload, + }, + /// Stop messages are used to unsubscribe from a subscription. + Stop { id: String }, + /// ConnectionTerminate is used to terminate the connection. + ConnectionTerminate, +} + +#[cfg(test)] +mod test { + use super::*; + use juniper::{DefaultScalarValue, InputValue}; + + #[test] + fn test_deserialization() { + type ClientMessage = super::ClientMessage; + + assert_eq!( + ClientMessage::ConnectionInit { + payload: [("foo".to_string(), InputValue::scalar("bar"))] + .iter() + .cloned() + .collect(), + }, + serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##) + .unwrap(), + ); + + assert_eq!( + ClientMessage::ConnectionInit { + payload: Variables::default(), + }, + serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(), + ); + + assert_eq!( + ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "query MyQuery { __typename }".to_string(), + variables: [("foo".to_string(), InputValue::scalar("bar"))] + .iter() + .cloned() + .collect(), + operation_name: Some("MyQuery".to_string()), + }, + }, + serde_json::from_str( + r##"{"type": "start", "id": "foo", "payload": { + "query": "query MyQuery { __typename }", + "variables": { + "foo": "bar" + }, + "operationName": "MyQuery" + }}"## + ) + .unwrap(), + ); + + assert_eq!( + ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "query MyQuery { __typename }".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }, + serde_json::from_str( + r##"{"type": "start", "id": "foo", "payload": { + "query": "query MyQuery { __typename }" + }}"## + ) + .unwrap(), + ); + + assert_eq!( + ClientMessage::Stop { + id: "foo".to_string() + }, + serde_json::from_str(r##"{"type": "stop", "id": "foo"}"##).unwrap(), + ); + + assert_eq!( + ClientMessage::ConnectionTerminate, + serde_json::from_str(r##"{"type": "connection_terminate"}"##).unwrap(), + ); + } +} diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs new file mode 100644 index 000000000..85495d699 --- /dev/null +++ b/juniper_graphql_ws/src/lib.rs @@ -0,0 +1,785 @@ +#[macro_use] +extern crate serde; + +mod client_message; +pub use client_message::*; + +mod server_message; +pub use server_message::*; + +use juniper::{ + futures::{ + channel::oneshot, + future::{self, BoxFuture, Either, Future, FutureExt, TryFutureExt}, + stream::{self, BoxStream, SelectAll, StreamExt}, + task::{Context, Poll}, + Stream, + }, + GraphQLError, GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, RuleError, ScalarValue, + Variables, +}; +use std::{ + collections::HashMap, convert::{Infallible, TryInto}, error::Error, marker::PhantomPinned, + pin::Pin, sync::Arc, time::Duration, +}; + +struct ExecutionParams { + start_payload: StartPayload, + config: Arc>, + schema: S, +} + +/// Schema defines the requirements for schemas that can be used for operations. Typically this is +/// just an Arc. +pub trait Schema: Unpin + Clone + Send + Sync + 'static { + type Context: Unpin + Send + Sync; + type ScalarValue: ScalarValue + Send + Sync; + type QueryTypeInfo: Send + Sync; + type Query: GraphQLTypeAsync + + Send; + type MutationTypeInfo: Send + Sync; + type Mutation: GraphQLTypeAsync< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::MutationTypeInfo, + > + Send; + type SubscriptionTypeInfo: Send + Sync; + type Subscription: GraphQLSubscriptionType< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::SubscriptionTypeInfo, + > + Send; + + fn root_node( + &self, + ) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>; +} + +impl Schema + for Arc> +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + self + } +} + +/// ConnectionConfig is used to configure the connection once the client sends the ConnectionInit +/// message. +pub struct ConnectionConfig { + context: CtxT, + max_in_flight_operations: usize, + keep_alive_interval: Duration, +} + +impl ConnectionConfig { + /// Constructs the configuration required for a connection to be accepted. + pub fn new(context: CtxT) -> Self { + Self { + context, + max_in_flight_operations: 0, + keep_alive_interval: Duration::from_secs(30), + } + } + + /// Specifies the maximum number of in-flight operations that a connection can have. If this + /// number is exceeded, attempting to start more will result in an error. By default, there is + /// no limit to in-flight operations. + pub fn with_max_in_flight_operations(mut self, max: usize) -> Self { + self.max_in_flight_operations = max; + self + } + + /// Specifies the interval at which to send keep-alives. Specifying a zero duration will + /// disable keep-alives. By default, keep-alives are sent every + /// 30 seconds. + pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self { + self.keep_alive_interval = interval; + self + } +} + +impl Init for ConnectionConfig { + type Error = Infallible; + type Future = future::Ready>; + + fn init(self, _params: Variables) -> Self::Future { + future::ready(Ok(self)) + } +} + +enum Reaction { + ServerMessage(ServerMessage), + Activate { + config: ConnectionConfig, + schema: S, + }, + EndStream, +} + +impl Reaction { + /// Converts the reaction into a one-item stream. + fn to_stream(self) -> BoxStream<'static, Self> { + stream::once(future::ready(self)).boxed() + } +} + +/// Init defines the requirements for types that can provide connection configurations when +/// ConnectionInit messages are received. It is automatically implemented for closures that meet +/// the requirements. +pub trait Init: Unpin + 'static { + type Error: Error; + type Future: Future, Self::Error>> + Send + 'static; + + fn init(self, params: Variables) -> Self::Future; +} + +impl Init for F +where + S: Schema, + F: FnOnce(Variables) -> Fut + Unpin + 'static, + Fut: Future, E>> + Send + 'static, + E: Error, +{ + type Error = E; + type Future = Fut; + + fn init(self, params: Variables) -> Fut { + self(params) + } +} + +enum ConnectionState> { + /// PreInit is the state before a ConnectionInit message has been accepted. + PreInit { init: I, schema: S }, + /// Initializing is the state after a ConnectionInit message has been received, but before the + /// init future has resolved. + Initializing, + /// Active is the state after a ConnectionInit message has been accepted. + Active { + config: Arc>, + stoppers: HashMap>, + schema: S, + }, +} + +impl> ConnectionState { + // Each message we receive results in a stream of zero or more reactions. For example, a + // ConnectionTerminate message results in a one-item stream with the EndStream reaction. + fn handle_message( + &mut self, + msg: ClientMessage, + ) -> BoxStream<'static, Reaction> { + if let ClientMessage::ConnectionTerminate = msg { + return Reaction::EndStream.to_stream(); + } + + match self { + Self::PreInit { .. } => match msg { + ClientMessage::ConnectionInit { payload } => { + match std::mem::replace(self, Self::Initializing) { + Self::PreInit { init, schema } => init + .init(payload) + .map(|r| match r { + Ok(config) => { + let keep_alive_interval = config.keep_alive_interval; + + let mut s = stream::iter(vec![ + Reaction::Activate { config, schema }, + Reaction::ServerMessage(ServerMessage::ConnectionAck), + ]) + .boxed(); + + if keep_alive_interval > Duration::from_secs(0) { + s = s + .chain( + Reaction::ServerMessage( + ServerMessage::ConnectionKeepAlive, + ) + .to_stream(), + ) + .boxed(); + s = s + .chain(stream::unfold((), move |_| async move { + tokio::time::delay_for(keep_alive_interval).await; + Some(( + Reaction::ServerMessage( + ServerMessage::ConnectionKeepAlive, + ), + (), + )) + })) + .boxed(); + } + + s + } + Err(e) => stream::iter(vec![ + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }), + Reaction::EndStream, + ]) + .boxed(), + }) + .into_stream() + .flatten() + .boxed(), + _ => unreachable!(), + } + } + _ => stream::empty().boxed(), + }, + Self::Initializing => stream::empty().boxed(), + Self::Active { + config, + stoppers, + schema, + } => { + match msg { + ClientMessage::Start { id, payload } => { + if stoppers.contains_key(&id) { + // We already have an operation with this id, so we can't start a new + // one. + return stream::empty().boxed(); + } + + // Go ahead and prune canceled stoppers before adding a new one. + stoppers.retain(|_, tx| !tx.is_canceled()); + + if config.max_in_flight_operations > 0 + && stoppers.len() >= config.max_in_flight_operations + { + // Too many in-flight operations. Just send back a validation error. + return stream::iter(vec![ + Reaction::ServerMessage(ServerMessage::Error { + id: id.clone(), + payload: GraphQLError::ValidationError(vec![RuleError::new( + "Too many in-flight operations.", + &[], + )]) + .into(), + }), + Reaction::ServerMessage(ServerMessage::Complete { id }), + ]) + .boxed(); + } + + // Create a channel that we can use to cancel the operation. + let (tx, rx) = oneshot::channel::<()>(); + stoppers.insert(id.clone(), tx); + + // Create the operation stream. This stream will emit Data and Error + // messages, but will not emit Complete – that part is up to us. + let s = Self::start( + id.clone(), + ExecutionParams { + start_payload: payload, + config: config.clone(), + schema: schema.clone(), + }, + ) + .into_stream() + .flatten(); + + // Combine this with our oneshot channel so that the stream ends if the + // oneshot is ever fired. + let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move { + let next = match future::select(rx, s.next()).await { + Either::Left(_) => None, + Either::Right((r, rx)) => r.map(|r| (r, rx)), + }; + next.map(|(r, rx)| (r, (rx, s))) + }); + + // Once the stream ends, send the Complete message. + let s = s.chain( + Reaction::ServerMessage(ServerMessage::Complete { id }).to_stream(), + ); + + s.boxed() + } + ClientMessage::Stop { id } => { + stoppers.remove(&id); + stream::empty().boxed() + } + _ => stream::empty().boxed(), + } + } + } + } + + async fn start(id: String, params: ExecutionParams) -> BoxStream<'static, Reaction> { + // TODO: This could be made more efficient if juniper exposed functionality to allow us to + // parse and validate the query, determine whether it's a subscription, and then execute + // it. For now, the query gets parsed and validated twice. + + let params = Arc::new(params); + + // Try to execute this as a query or mutation. + match juniper::execute( + ¶ms.start_payload.query, + params + .start_payload + .operation_name + .as_ref() + .map(|s| s.as_str()), + params.schema.root_node(), + ¶ms.start_payload.variables, + ¶ms.config.context, + ) + .await + { + Ok((data, errors)) => { + return Reaction::ServerMessage(ServerMessage::Data { + id: id.clone(), + payload: DataPayload { data, errors }, + }) + .to_stream(); + } + Err(GraphQLError::IsSubscription) => {} + Err(e) => { + return Reaction::ServerMessage(ServerMessage::Error { + id: id.clone(), + payload: unsafe { ErrorPayload::new_unchecked(Box::new(params.clone()), e) }, + }) + .to_stream() + } + } + + // Try to execute as a subscription. + SubscriptionStart::new(id, params.clone()).boxed() + } +} + +struct InterruptableStream { + stream: S, + rx: oneshot::Receiver<()>, +} + +impl Stream for InterruptableStream { + type Item = S::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Ready(_) => return Poll::Ready(None), + Poll::Pending => {} + } + Pin::new(&mut self.stream).poll_next(cx) + } +} + +/// SubscriptionStartState is the state for a subscription operation. +enum SubscriptionStartState { + /// Init is the start before being polled for the first time. + Init { id: String }, + /// ResolvingIntoStream is the state after being polled for the first time. In this state, + /// we're parsing, validating, and getting the actual event stream. + ResolvingIntoStream { + id: String, + future: BoxFuture< + 'static, + Result< + juniper_subscriptions::Connection<'static, S::ScalarValue>, + GraphQLError<'static>, + >, + >, + }, + /// Streaming is the state after we've successfully obtained the event stream for the + /// subscription. In this state, we're just forwarding events back to the client. + Streaming { + id: String, + stream: juniper_subscriptions::Connection<'static, S::ScalarValue>, + }, + /// Terminated is the state once we're all done. + Terminated, +} + +/// SubscriptionStart is the stream for a subscription operation. +struct SubscriptionStart { + params: Arc>, + state: SubscriptionStartState, + _marker: PhantomPinned, +} + +impl SubscriptionStart { + fn new(id: String, params: Arc>) -> Pin> { + Box::pin(Self { + params, + state: SubscriptionStartState::Init { id }, + _marker: PhantomPinned, + }) + } +} + +impl Stream for SubscriptionStart { + type Item = Reaction; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let (params, state) = unsafe { + // XXX: The execution parameters are referenced by state and must not be modified. + // Modifying state is fine though. + let inner = self.get_unchecked_mut(); + (&inner.params, &mut inner.state) + }; + + loop { + match state { + SubscriptionStartState::Init { id } => { + // XXX: resolve_into_stream returns a Future that references the execution + // parameters, and the returned stream also references them. We can guarantee + // that everything has the same lifetime in this self-referential struct. + let params = Arc::as_ptr(params); + *state = SubscriptionStartState::ResolvingIntoStream { + id: id.clone(), + future: unsafe { + juniper::resolve_into_stream( + &(*params).start_payload.query, + (*params) + .start_payload + .operation_name + .as_ref() + .map(|s| s.as_str()), + (*params).schema.root_node(), + &(*params).start_payload.variables, + &(*params).config.context, + ) + } + .map_ok(|(stream, errors)| { + juniper_subscriptions::Connection::from_stream(stream, errors) + }) + .boxed(), + }; + } + SubscriptionStartState::ResolvingIntoStream { + ref id, + ref mut future, + } => match future.as_mut().poll(cx) { + Poll::Ready(r) => match r { + Ok(stream) => { + *state = SubscriptionStartState::Streaming { + id: id.clone(), + stream, + } + } + Err(e) => { + return Poll::Ready(Some(Reaction::ServerMessage( + ServerMessage::Error { + id: id.clone(), + payload: unsafe { + ErrorPayload::new_unchecked(Box::new(params.clone()), e) + }, + }, + ))); + } + }, + Poll::Pending => return Poll::Pending, + }, + SubscriptionStartState::Streaming { + ref id, + ref mut stream, + } => match Pin::new(stream).poll_next(cx) { + Poll::Ready(Some((data, errors))) => { + return Poll::Ready(Some(Reaction::ServerMessage(ServerMessage::Data { + id: id.clone(), + payload: DataPayload { data, errors }, + }))); + } + Poll::Ready(None) => { + *state = SubscriptionStartState::Terminated; + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + }, + SubscriptionStartState::Terminated => return Poll::Ready(None), + } + } + } +} + +pub fn serve(stream: St, schema: S, init: I) -> Serve +where + St: Stream + Unpin, + StT: TryInto, Error = StE>, + StE: Error, + S: Schema, + I: Init, +{ + Serve { + stream, + reactions: SelectAll::new(), + state: ConnectionState::PreInit { init, schema }, + } +} + +/// Stream for the serve function. +pub struct Serve> { + stream: St, + reactions: SelectAll>>, + state: ConnectionState, +} + +impl Stream for Serve +where + St: Stream + Unpin, + StT: TryInto, Error = StE>, + StE: Error, + S: Schema, + I: Init, +{ + type Item = ServerMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // Poll the connection for new incoming messages. + loop { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(msg)) => { + // We have a new message. Try to parse it and add the reaction stream. + let reactions = match msg.try_into() { + Ok(msg) => self.state.handle_message(msg), + Err(e) => { + // If we weren't able to parse the message, just send back an error and + // carry on. + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }) + .to_stream() + } + }; + self.reactions.push(reactions); + } + Poll::Ready(None) => { + // The connection stream has ended, so we should end too. + return Poll::Ready(None); + } + Poll::Pending => break, + } + } + + // Poll the reactions for new outgoing messages. + loop { + if !self.reactions.is_empty() { + match Pin::new(&mut self.reactions).poll_next(cx) { + Poll::Ready(Some(reaction)) => match reaction { + Reaction::ServerMessage(msg) => return Poll::Ready(Some(msg)), + Reaction::Activate { config, schema } => { + self.state = ConnectionState::Active { + config: Arc::new(config), + stoppers: HashMap::new(), + schema, + } + } + Reaction::EndStream => return Poll::Ready(None), + }, + Poll::Ready(None) => { + // In rare cases, the reaction stream may terminate. For example, this will + // happen if the first message we receive does not require any reaction. Just + // recreate it in that case. + self.reactions = SelectAll::new(); + return Poll::Pending; + } + Poll::Pending => return Poll::Pending, + } + } else { + return Poll::Pending; + } + } + } +} + +#[cfg(test)] +mod test { + use std::convert::Infallible; + use super::*; + use juniper::{ + futures::channel::mpsc, DefaultScalarValue, EmptyMutation, FieldResult, InputValue, + RootNode, Value, + }; + + struct Context(i32); + + struct Query; + + #[juniper::graphql_object(Context=Context)] + impl Query { + /// context just resolves to the current context. + async fn context(context: &Context) -> i32 { + context.0 + } + } + + struct Subscription; + + #[juniper::graphql_subscription(Context=Context)] + impl Subscription { + /// context emits the current context once, then never emits anything else. + async fn context(context: &Context) -> BoxStream<'static, FieldResult> { + stream::once(future::ready(Ok(context.0))) + .chain( + tokio::time::delay_for(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream(), + ) + .boxed() + } + } + + type ClientMessage = super::ClientMessage; + type ServerMessage = super::ServerMessage; + + fn new_test_schema() -> Arc, Subscription>> { + Arc::new(RootNode::new(Query, EmptyMutation::new(), Subscription)) + } + + #[tokio::test] + async fn test_query() { + let (tx, rx) = mpsc::unbounded::(); + let mut rx = serve( + rx, + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + tx.unbounded_send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, rx.next().await.unwrap()); + + tx.unbounded_send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "{context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .unwrap(); + + assert_eq!( + ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::Object( + [("context", Value::Scalar(DefaultScalarValue::Int(1)))] + .iter() + .cloned() + .collect() + ), + errors: vec![], + }, + }, + rx.next().await.unwrap() + ); + + assert_eq!( + ServerMessage::Complete { + id: "foo".to_string(), + }, + rx.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_subscription() { + let (tx, rx) = mpsc::unbounded::(); + let mut rx = serve( + rx, + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + tx.unbounded_send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, rx.next().await.unwrap()); + + tx.unbounded_send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "subscription Foo {context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .unwrap(); + + assert_eq!( + ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::Object( + [("context", Value::scalar(1))] + .iter() + .cloned() + .collect() + ), + errors: vec![], + }, + }, + rx.next().await.unwrap() + ); + + tx.unbounded_send(ClientMessage::Stop { + id: "foo".to_string(), + }) + .unwrap(); + + assert_eq!( + ServerMessage::Complete { + id: "foo".to_string(), + }, + rx.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_init_params_ok() { + let (tx, rx) = mpsc::unbounded::(); + let mut rx = serve( + rx, + new_test_schema(), + |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar"))); + Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible> + }, + ); + + tx.unbounded_send(ClientMessage::ConnectionInit { + payload: [( + "foo".to_string(), + InputValue::scalar("bar".to_string()) + )] + .iter() + .cloned() + .collect(), + }) + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, rx.next().await.unwrap()); + } +} diff --git a/juniper_graphql_ws/src/server_message.rs b/juniper_graphql_ws/src/server_message.rs new file mode 100644 index 000000000..91172a8db --- /dev/null +++ b/juniper_graphql_ws/src/server_message.rs @@ -0,0 +1,155 @@ +use juniper::{ExecutionError, GraphQLError, ScalarValue, Value}; +use serde::{Serialize, Serializer}; +use std::{any::Any, fmt, marker::PhantomPinned}; + +#[derive(Debug, Serialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ConnectionErrorPayload { + pub message: String, +} + +#[derive(Debug, Serialize, PartialEq)] +#[serde(bound(serialize = "S: ScalarValue"))] +#[serde(rename_all = "camelCase")] +pub struct DataPayload { + pub data: Value, + pub errors: Vec>, +} + +// XXX: Think carefully before deriving traits. This is self-referential (error references +// _execution_params). +pub struct ErrorPayload { + _execution_params: Option>, + error: GraphQLError<'static>, + _marker: PhantomPinned, +} + +impl ErrorPayload { + /// For this to be okay, the caller must guarantee that the error can only reference data from + /// execution_params and that execution_params has not been modified or moved. + pub(crate) unsafe fn new_unchecked<'a>( + execution_params: Box, + error: GraphQLError<'a>, + ) -> Self { + Self { + _execution_params: Some(execution_params), + error: std::mem::transmute(error), + _marker: PhantomPinned, + } + } +} + +impl fmt::Debug for ErrorPayload { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} + +impl PartialEq for ErrorPayload { + fn eq(&self, other: &Self) -> bool { + self.error.eq(&other.error) + } +} + +impl Serialize for ErrorPayload { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.error.serialize(serializer) + } +} + +impl From> for ErrorPayload { + fn from(error: GraphQLError<'static>) -> Self { + Self { + _execution_params: None, + error, + _marker: PhantomPinned, + } + } +} + +/// ServerMessage defines the message types that servers can send. +#[derive(Debug, Serialize, PartialEq)] +#[serde(bound(serialize = "S: ScalarValue"))] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum ServerMessage { + /// ConnectionError is used when the server rejects a connection based on the client's ConnectionInit + /// message or when the server encounters a protocol error such as not being able to parse a + /// client's message. + ConnectionError { payload: ConnectionErrorPayload }, + /// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a + /// connection. + ConnectionAck, + /// Data contains the result of a query, mutation, or subscription event. + Data { id: String, payload: DataPayload }, + /// Error contains an error that occurs before execution, such as validation errors. + Error { id: String, payload: ErrorPayload }, + /// Complete indicates that no more data will be sent for the given operation. + Complete { id: String }, + /// ConnectionKeepAlive is sent periodically after accepting a connection. + #[serde(rename = "ka")] + ConnectionKeepAlive, +} + +#[cfg(test)] +mod test { + use super::*; + use juniper::DefaultScalarValue; + + #[test] + fn test_serialization() { + type ServerMessage = super::ServerMessage; + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: "foo".to_string(), + }, + }) + .unwrap(), + r##"{"type":"connection_error","payload":{"message":"foo"}}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), + r##"{"type":"connection_ack"}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::null(), + errors: vec![], + }, + }) + .unwrap(), + r##"{"type":"data","id":"foo","payload":{"data":null,"errors":[]}}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Error { + id: "foo".to_string(), + payload: GraphQLError::UnknownOperationName.into(), + }) + .unwrap(), + r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Complete { + id: "foo".to_string(), + }) + .unwrap(), + r##"{"type":"complete","id":"foo"}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionKeepAlive).unwrap(), + r##"{"type":"ka"}"##, + ); + } +} From ac7ea1f76b6b770adfe79f57bfe1abd2b51ae0f0 Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Sun, 26 Jul 2020 00:36:59 -0400 Subject: [PATCH 03/10] finish up, update example --- examples/warp_subscriptions/Cargo.toml | 6 +- examples/warp_subscriptions/src/main.rs | 35 +- juniper_graphql_ws/src/client_message.rs | 18 +- juniper_graphql_ws/src/lib.rs | 477 +++++++++++++++-------- juniper_graphql_ws/src/schema.rs | 131 +++++++ juniper_graphql_ws/src/server_message.rs | 50 ++- juniper_warp/Cargo.toml | 4 +- juniper_warp/src/lib.rs | 298 ++++---------- 8 files changed, 613 insertions(+), 406 deletions(-) create mode 100644 juniper_graphql_ws/src/schema.rs diff --git a/examples/warp_subscriptions/Cargo.toml b/examples/warp_subscriptions/Cargo.toml index 7fc8fb4a1..5c69129eb 100644 --- a/examples/warp_subscriptions/Cargo.toml +++ b/examples/warp_subscriptions/Cargo.toml @@ -13,6 +13,6 @@ serde_json = "1.0" tokio = { version = "0.2", features = ["rt-core", "macros"] } warp = "0.2.1" -juniper = { git = "https://github.com/graphql-rust/juniper" } -juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" } -juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] } +juniper = { path = "../../juniper" } +juniper_graphql_ws = { path = "../../juniper_graphql_ws" } +juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] } diff --git a/examples/warp_subscriptions/src/main.rs b/examples/warp_subscriptions/src/main.rs index f0f9f7373..77194e5e4 100644 --- a/examples/warp_subscriptions/src/main.rs +++ b/examples/warp_subscriptions/src/main.rs @@ -2,10 +2,10 @@ use std::{env, pin::Pin, sync::Arc, time::Duration}; -use futures::{Future, FutureExt as _, Stream}; +use futures::{FutureExt as _, Stream}; use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode}; -use juniper_subscriptions::Coordinator; -use juniper_warp::{playground_filter, subscriptions::graphql_subscriptions}; +use juniper_graphql_ws::ConnectionConfig; +use juniper_warp::{playground_filter, subscriptions::serve_graphql_ws}; use warp::{http::Response, Filter}; #[derive(Clone)] @@ -134,6 +134,15 @@ fn schema() -> Schema { Schema::new(Query, EmptyMutation::new(), Subscription) } +async fn on_upgrade(ws: warp::ws::WebSocket, root_node: Arc) { + serve_graphql_ws(ws, root_node, ConnectionConfig::new(Context {})) + .map(|r| { + if let Err(e) = r { + println!("Websocket error: {}", e); + } + }).await; +} + #[tokio::main] async fn main() { env::set_var("RUST_LOG", "warp_subscriptions"); @@ -151,28 +160,16 @@ async fn main() { let qm_state = warp::any().map(move || Context {}); let qm_graphql_filter = juniper_warp::make_graphql_filter(qm_schema, qm_state.boxed()); - let sub_state = warp::any().map(move || Context {}); - let coordinator = Arc::new(juniper_subscriptions::Coordinator::new(schema())); + let root_node = Arc::new(schema()); log::info!("Listening on 127.0.0.1:8080"); let routes = (warp::path("subscriptions") .and(warp::ws()) - .and(sub_state.clone()) - .and(warp::any().map(move || Arc::clone(&coordinator))) .map( - |ws: warp::ws::Ws, - ctx: Context, - coordinator: Arc>| { - ws.on_upgrade(|websocket| -> Pin + Send>> { - graphql_subscriptions(websocket, coordinator, ctx) - .map(|r| { - if let Err(e) = r { - println!("Websocket error: {}", e); - } - }) - .boxed() - }) + move |ws: warp::ws::Ws| { + let root_node = root_node.clone(); + ws.on_upgrade(move |websocket| on_upgrade(websocket, root_node.clone())) }, )) .map(|reply| { diff --git a/juniper_graphql_ws/src/client_message.rs b/juniper_graphql_ws/src/client_message.rs index bee6d4697..1e20caef1 100644 --- a/juniper_graphql_ws/src/client_message.rs +++ b/juniper_graphql_ws/src/client_message.rs @@ -1,12 +1,19 @@ use juniper::{ScalarValue, Variables}; +/// The payload for a client's "start" message. This triggers execution of a query, mutation, or +/// subscription. #[derive(Debug, Deserialize, PartialEq)] #[serde(bound(deserialize = "S: ScalarValue"))] #[serde(rename_all = "camelCase")] pub struct StartPayload { + /// The document body. pub query: String, + + /// The optional variables. #[serde(default)] pub variables: Variables, + + /// The optional operation name (required if the document contains multiple operations). pub operation_name: Option, } @@ -18,16 +25,25 @@ pub struct StartPayload { pub enum ClientMessage { /// ConnectionInit is sent by the client upon connecting. ConnectionInit { + /// Optional parameters of any type sent from the client. These are often used for + /// authentication. #[serde(default)] payload: Variables, }, /// Start messages are used to execute a GraphQL operation. Start { + /// The id of the operation. This can be anything, but must be unique. If there are other + /// in-flight operations with the same id, the message will be ignored or cause an error. id: String, + + /// The query, variables, and operation name. payload: StartPayload, }, /// Stop messages are used to unsubscribe from a subscription. - Stop { id: String }, + Stop { + /// The id of the operation to stop. + id: String, + }, /// ConnectionTerminate is used to terminate the connection. ConnectionTerminate, } diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs index 85495d699..ba7f5cb36 100644 --- a/juniper_graphql_ws/src/lib.rs +++ b/juniper_graphql_ws/src/lib.rs @@ -1,3 +1,14 @@ +/*! + +# juniper_graphql_ws + +This crate contains an implementation of the graphql-ws protocol, as used by Apollo. + +*/ + +#![deny(missing_docs)] +#![deny(warnings)] + #[macro_use] extern crate serde; @@ -7,20 +18,28 @@ pub use client_message::*; mod server_message; pub use server_message::*; +mod schema; +pub use schema::*; + use juniper::{ futures::{ channel::oneshot, future::{self, BoxFuture, Either, Future, FutureExt, TryFutureExt}, stream::{self, BoxStream, SelectAll, StreamExt}, - task::{Context, Poll}, - Stream, + task::{Context, Poll, Waker}, + Sink, Stream, }, - GraphQLError, GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, RuleError, ScalarValue, + GraphQLError, RuleError, ScalarValue, Variables, }; use std::{ - collections::HashMap, convert::{Infallible, TryInto}, error::Error, marker::PhantomPinned, - pin::Pin, sync::Arc, time::Duration, + collections::HashMap, + convert::{Infallible, TryInto}, + error::Error, + marker::PhantomPinned, + pin::Pin, + sync::Arc, + time::Duration, }; struct ExecutionParams { @@ -29,58 +48,6 @@ struct ExecutionParams { schema: S, } -/// Schema defines the requirements for schemas that can be used for operations. Typically this is -/// just an Arc. -pub trait Schema: Unpin + Clone + Send + Sync + 'static { - type Context: Unpin + Send + Sync; - type ScalarValue: ScalarValue + Send + Sync; - type QueryTypeInfo: Send + Sync; - type Query: GraphQLTypeAsync - + Send; - type MutationTypeInfo: Send + Sync; - type Mutation: GraphQLTypeAsync< - Self::ScalarValue, - Context = Self::Context, - TypeInfo = Self::MutationTypeInfo, - > + Send; - type SubscriptionTypeInfo: Send + Sync; - type Subscription: GraphQLSubscriptionType< - Self::ScalarValue, - Context = Self::Context, - TypeInfo = Self::SubscriptionTypeInfo, - > + Send; - - fn root_node( - &self, - ) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>; -} - -impl Schema - for Arc> -where - QueryT: GraphQLTypeAsync + Send + 'static, - QueryT::TypeInfo: Send + Sync, - MutationT: GraphQLTypeAsync + Send + 'static, - MutationT::TypeInfo: Send + Sync, - SubscriptionT: GraphQLSubscriptionType + Send + 'static, - SubscriptionT::TypeInfo: Send + Sync, - CtxT: Unpin + Send + Sync, - S: ScalarValue + Send + Sync + 'static, -{ - type Context = CtxT; - type ScalarValue = S; - type QueryTypeInfo = QueryT::TypeInfo; - type Query = QueryT; - type MutationTypeInfo = MutationT::TypeInfo; - type Mutation = MutationT; - type SubscriptionTypeInfo = SubscriptionT::TypeInfo; - type Subscription = SubscriptionT; - - fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { - self - } -} - /// ConnectionConfig is used to configure the connection once the client sends the ConnectionInit /// message. pub struct ConnectionConfig { @@ -116,11 +83,11 @@ impl ConnectionConfig { } } -impl Init for ConnectionConfig { +impl Init for ConnectionConfig { type Error = Infallible; type Future = future::Ready>; - fn init(self, _params: Variables) -> Self::Future { + fn init(self, _params: Variables) -> Self::Future { future::ready(Ok(self)) } } @@ -142,31 +109,36 @@ impl Reaction { } /// Init defines the requirements for types that can provide connection configurations when -/// ConnectionInit messages are received. It is automatically implemented for closures that meet -/// the requirements. -pub trait Init: Unpin + 'static { +/// ConnectionInit messages are received. Implementations are provided for `ConnectionConfig` and +/// closures that meet the requirements. +pub trait Init: Unpin + 'static { + /// The error that is returned on failure. The formatted error will be used as the contents of + /// the "message" field sent back to the client. type Error: Error; - type Future: Future, Self::Error>> + Send + 'static; - fn init(self, params: Variables) -> Self::Future; + /// The future configuration type. + type Future: Future, Self::Error>> + Send + 'static; + + /// Returns a future for the configuration to use. + fn init(self, params: Variables) -> Self::Future; } -impl Init for F +impl Init for F where - S: Schema, - F: FnOnce(Variables) -> Fut + Unpin + 'static, - Fut: Future, E>> + Send + 'static, + S: ScalarValue, + F: FnOnce(Variables) -> Fut + Unpin + 'static, + Fut: Future, E>> + Send + 'static, E: Error, { type Error = E; type Future = Fut; - fn init(self, params: Variables) -> Fut { + fn init(self, params: Variables) -> Fut { self(params) } } -enum ConnectionState> { +enum ConnectionState> { /// PreInit is the state before a ConnectionInit message has been accepted. PreInit { init: I, schema: S }, /// Initializing is the state after a ConnectionInit message has been received, but before the @@ -180,7 +152,7 @@ enum ConnectionState> { }, } -impl> ConnectionState { +impl> ConnectionState { // Each message we receive results in a stream of zero or more reactions. For example, a // ConnectionTerminate message results in a one-item stream with the EndStream reaction. fn handle_message( @@ -516,65 +488,97 @@ impl Stream for SubscriptionStart { } } -pub fn serve(stream: St, schema: S, init: I) -> Serve +/// Implements the graphql-ws protocol. This is a sink for `TryInto` and a stream of +/// `ServerMessage`. +pub struct Connection> { + reactions: SelectAll>>, + state: ConnectionState, + is_closed: bool, + stream_waker: Option, +} + +impl Connection where - St: Stream + Unpin, - StT: TryInto, Error = StE>, - StE: Error, S: Schema, - I: Init, + I: Init, { - Serve { - stream, - reactions: SelectAll::new(), - state: ConnectionState::PreInit { init, schema }, + /// Creates a new connection, which is a sink for `TryInto` and a stream of `ServerMessage`. + /// + /// The `schema` argument should typically be an `Arc>`. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `ConnectionConfig` if the context and configuration are already + /// known, or it can be a closure that gets executed asynchronously when the client sends the + /// ConnectionInit message. Using a closure allows you to perform authentication based on the + /// parameters provided by the client. + pub fn new(schema: S, init: I) -> Self { + Self { + reactions: SelectAll::new(), + state: ConnectionState::PreInit { init, schema }, + is_closed: false, + stream_waker: None, + } } } -/// Stream for the serve function. -pub struct Serve> { - stream: St, - reactions: SelectAll>>, - state: ConnectionState, +impl Sink for Connection +where + T: TryInto>, + T::Error: Error, + S: Schema, + I: Init, +{ + type Error = Infallible; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + // We're always ready for new messages. + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let reactions = match item.try_into() { + Ok(msg) => self.state.handle_message(msg), + Err(e) => { + // If we weren't able to parse the message, send back an error. + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }) + .to_stream() + } + }; + self.reactions.push(reactions); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + // Flushing an item doesn't really have any meaning here. Just return okay. + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + // Close the stream too. No need to wait for it though. + self.is_closed = true; + if let Some(waker) = self.stream_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(())) + } } -impl Stream for Serve +impl Stream for Connection where - St: Stream + Unpin, - StT: TryInto, Error = StE>, - StE: Error, S: Schema, - I: Init, + I: Init, { type Item = ServerMessage; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - // Poll the connection for new incoming messages. - loop { - match Pin::new(&mut self.stream).poll_next(cx) { - Poll::Ready(Some(msg)) => { - // We have a new message. Try to parse it and add the reaction stream. - let reactions = match msg.try_into() { - Ok(msg) => self.state.handle_message(msg), - Err(e) => { - // If we weren't able to parse the message, just send back an error and - // carry on. - Reaction::ServerMessage(ServerMessage::ConnectionError { - payload: ConnectionErrorPayload { - message: e.to_string(), - }, - }) - .to_stream() - } - }; - self.reactions.push(reactions); - } - Poll::Ready(None) => { - // The connection stream has ended, so we should end too. - return Poll::Ready(None); - } - Poll::Pending => break, - } + self.stream_waker = Some(cx.waker().clone()); + + if self.is_closed { + return Poll::Ready(None); } // Poll the reactions for new outgoing messages. @@ -610,12 +614,13 @@ where #[cfg(test)] mod test { - use std::convert::Infallible; use super::*; use juniper::{ - futures::channel::mpsc, DefaultScalarValue, EmptyMutation, FieldResult, InputValue, - RootNode, Value, + futures::sink::SinkExt, + parser::{ParseError, Spanning, Token}, + DefaultScalarValue, EmptyMutation, FieldResult, InputValue, RootNode, Value, }; + use std::{convert::Infallible, io}; struct Context(i32); @@ -633,6 +638,14 @@ mod test { #[juniper::graphql_subscription(Context=Context)] impl Subscription { + /// never never emits anything. + async fn never(context: &Context) -> BoxStream<'static, FieldResult> { + tokio::time::delay_for(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream() + .boxed() + } + /// context emits the current context once, then never emits anything else. async fn context(context: &Context) -> BoxStream<'static, FieldResult> { stream::once(future::ready(Ok(context.0))) @@ -654,21 +667,20 @@ mod test { #[tokio::test] async fn test_query() { - let (tx, rx) = mpsc::unbounded::(); - let mut rx = serve( - rx, + let mut conn = Connection::new( new_test_schema(), ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), ); - tx.unbounded_send(ClientMessage::ConnectionInit { + conn.send(ClientMessage::ConnectionInit { payload: Variables::default(), }) + .await .unwrap(); - assert_eq!(ServerMessage::ConnectionAck, rx.next().await.unwrap()); + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); - tx.unbounded_send(ClientMessage::Start { + conn.send(ClientMessage::Start { id: "foo".to_string(), payload: StartPayload { query: "{context}".to_string(), @@ -676,6 +688,7 @@ mod test { operation_name: None, }, }) + .await .unwrap(); assert_eq!( @@ -691,34 +704,33 @@ mod test { errors: vec![], }, }, - rx.next().await.unwrap() + conn.next().await.unwrap() ); assert_eq!( ServerMessage::Complete { id: "foo".to_string(), }, - rx.next().await.unwrap() + conn.next().await.unwrap() ); } #[tokio::test] - async fn test_subscription() { - let (tx, rx) = mpsc::unbounded::(); - let mut rx = serve( - rx, + async fn test_subscriptions() { + let mut conn = Connection::new( new_test_schema(), ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), ); - tx.unbounded_send(ClientMessage::ConnectionInit { + conn.send(ClientMessage::ConnectionInit { payload: Variables::default(), }) + .await .unwrap(); - assert_eq!(ServerMessage::ConnectionAck, rx.next().await.unwrap()); + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); - tx.unbounded_send(ClientMessage::Start { + conn.send(ClientMessage::Start { id: "foo".to_string(), payload: StartPayload { query: "subscription Foo {context}".to_string(), @@ -726,60 +738,209 @@ mod test { operation_name: None, }, }) + .await .unwrap(); assert_eq!( ServerMessage::Data { id: "foo".to_string(), payload: DataPayload { - data: Value::Object( - [("context", Value::scalar(1))] - .iter() - .cloned() - .collect() - ), + data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()), errors: vec![], }, }, - rx.next().await.unwrap() + conn.next().await.unwrap() ); - tx.unbounded_send(ClientMessage::Stop { + conn.send(ClientMessage::Start { + id: "bar".to_string(), + payload: StartPayload { + query: "subscription Bar {context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::Data { + id: "bar".to_string(), + payload: DataPayload { + data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()), + errors: vec![], + }, + }, + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Stop { id: "foo".to_string(), }) + .await .unwrap(); assert_eq!( ServerMessage::Complete { id: "foo".to_string(), }, - rx.next().await.unwrap() + conn.next().await.unwrap() ); } #[tokio::test] async fn test_init_params_ok() { - let (tx, rx) = mpsc::unbounded::(); - let mut rx = serve( - rx, + let mut conn = Connection::new(new_test_schema(), |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar"))); + Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible> + }); + + conn.send(ClientMessage::ConnectionInit { + payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))] + .iter() + .cloned() + .collect(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + } + + #[tokio::test] + async fn test_init_params_error() { + let mut conn = Connection::new(new_test_schema(), |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar"))); + Err(io::Error::new(io::ErrorKind::Other, "init error")) + }); + + conn.send(ClientMessage::ConnectionInit { + payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))] + .iter() + .cloned() + .collect(), + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: "init error".to_string(), + }, + }, + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_max_in_flight_operations() { + let mut conn = Connection::new( new_test_schema(), - |params: Variables| async move { - assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar"))); - Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible> + ConnectionConfig::new(Context(1)) + .with_keep_alive_interval(Duration::from_secs(0)) + .with_max_in_flight_operations(1), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "subscription Foo {never}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + conn.send(ClientMessage::Start { + id: "bar".to_string(), + payload: StartPayload { + query: "subscription Bar {never}".to_string(), + variables: Variables::default(), + operation_name: None, }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + ServerMessage::Error { id, .. } => { + assert_eq!(id, "bar"); + } + msg @ _ => panic!("expected error, got: {:?}", msg), + } + } + + #[tokio::test] + async fn test_parse_error() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), ); - tx.unbounded_send(ClientMessage::ConnectionInit { - payload: [( - "foo".to_string(), - InputValue::scalar("bar".to_string()) - )] - .iter() - .cloned() - .collect(), + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), }) + .await .unwrap(); - assert_eq!(ServerMessage::ConnectionAck, rx.next().await.unwrap()); + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "asd".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + ServerMessage::Error { id, payload } => { + assert_eq!(id, "foo"); + match payload.graphql_error() { + GraphQLError::ParseError(Spanning { + item: ParseError::UnexpectedToken(Token::Name("asd")), + .. + }) => {} + p @ _ => panic!("expected graphql parse error, got: {:?}", p), + } + } + msg @ _ => panic!("expected error, got: {:?}", msg), + } + } + + #[tokio::test] + async fn test_keep_alives() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + for _ in 0..10 { + assert_eq!( + ServerMessage::ConnectionKeepAlive, + conn.next().await.unwrap() + ); + } } } diff --git a/juniper_graphql_ws/src/schema.rs b/juniper_graphql_ws/src/schema.rs new file mode 100644 index 000000000..68d282f0b --- /dev/null +++ b/juniper_graphql_ws/src/schema.rs @@ -0,0 +1,131 @@ +use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; +use std::sync::Arc; + +/// Schema defines the requirements for schemas that can be used for operations. Typically this is +/// just an `Arc>` and you should not have to implement it yourself. +pub trait Schema: Unpin + Clone + Send + Sync + 'static { + /// The context type. + type Context: Unpin + Send + Sync; + + /// The scalar value type. + type ScalarValue: ScalarValue + Send + Sync; + + /// The query type info. + type QueryTypeInfo: Send + Sync; + + /// The query type. + type Query: GraphQLTypeAsync + + Send; + + /// The mutation type info. + type MutationTypeInfo: Send + Sync; + + /// The mutation type. + type Mutation: GraphQLTypeAsync< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::MutationTypeInfo, + > + Send; + + /// The subscription type info. + type SubscriptionTypeInfo: Send + Sync; + + /// The subscription type. + type Subscription: GraphQLSubscriptionType< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::SubscriptionTypeInfo, + > + Send; + + /// Returns the root node for the schema. + fn root_node( + &self, + ) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>; +} + +/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552 +/// +/// It can be used in generators where using Arc directly would result in an error. +// TODO: Remove this once that issue is resolved. +#[doc(hidden)] +pub struct ArcSchema( + pub Arc>, +) +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static; + +impl Clone + for ArcSchema +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Schema + for ArcSchema +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + &self.0 + } +} + +impl Schema + for Arc> +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + self + } +} diff --git a/juniper_graphql_ws/src/server_message.rs b/juniper_graphql_ws/src/server_message.rs index 91172a8db..3c3531649 100644 --- a/juniper_graphql_ws/src/server_message.rs +++ b/juniper_graphql_ws/src/server_message.rs @@ -2,20 +2,31 @@ use juniper::{ExecutionError, GraphQLError, ScalarValue, Value}; use serde::{Serialize, Serializer}; use std::{any::Any, fmt, marker::PhantomPinned}; +/// The payload for errors that are not associated with a GraphQL operation. #[derive(Debug, Serialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct ConnectionErrorPayload { + /// The error message. pub message: String, } +/// Sent after execution of an operation. For queries and mutations, this is sent to the client +/// once. For subscriptions, this is sent for every event in the event stream. #[derive(Debug, Serialize, PartialEq)] #[serde(bound(serialize = "S: ScalarValue"))] #[serde(rename_all = "camelCase")] pub struct DataPayload { + /// The result data. pub data: Value, + + /// The errors that have occurred during execution. Note that parse and validation errors are + /// not included here. They are sent via Error messages. pub errors: Vec>, } +/// A payload for errors that can happen before execution. Errors that happen during execution are +/// instead sent to the client via `DataPayload`. `ErrorPayload` is a wrapper for an owned +/// `GraphQLError`. // XXX: Think carefully before deriving traits. This is self-referential (error references // _execution_params). pub struct ErrorPayload { @@ -37,6 +48,11 @@ impl ErrorPayload { _marker: PhantomPinned, } } + + /// Returns the contained GraphQLError. + pub fn graphql_error<'a>(&'a self) -> &GraphQLError<'a> { + &self.error + } } impl fmt::Debug for ErrorPayload { @@ -76,19 +92,39 @@ impl From> for ErrorPayload { #[serde(rename_all = "snake_case")] #[serde(tag = "type")] pub enum ServerMessage { - /// ConnectionError is used when the server rejects a connection based on the client's ConnectionInit - /// message or when the server encounters a protocol error such as not being able to parse a - /// client's message. - ConnectionError { payload: ConnectionErrorPayload }, + /// ConnectionError is used for errors that are not associated with a GraphQL operation. For + /// example, this will be used when: + /// + /// * The server is unable to parse a client's message. + /// * The client's initialization parameters are rejected. + ConnectionError { + /// The error that occurred. + payload: ConnectionErrorPayload, + }, /// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a /// connection. ConnectionAck, /// Data contains the result of a query, mutation, or subscription event. - Data { id: String, payload: DataPayload }, + Data { + /// The id of the operation that the data is for. + id: String, + + /// The data and errors that occurred during execution. + payload: DataPayload, + }, /// Error contains an error that occurs before execution, such as validation errors. - Error { id: String, payload: ErrorPayload }, + Error { + /// The id of the operation that triggered this error. + id: String, + + /// The error(s). + payload: ErrorPayload, + }, /// Complete indicates that no more data will be sent for the given operation. - Complete { id: String }, + Complete { + /// The id of the operation that has completed. + id: String, + }, /// ConnectionKeepAlive is sent periodically after accepting a connection. #[serde(rename = "ka")] ConnectionKeepAlive, diff --git a/juniper_warp/Cargo.toml b/juniper_warp/Cargo.toml index cf14ae322..f2fcb5b50 100644 --- a/juniper_warp/Cargo.toml +++ b/juniper_warp/Cargo.toml @@ -9,7 +9,7 @@ repository = "https://github.com/graphql-rust/juniper" edition = "2018" [features] -subscriptions = ["juniper_subscriptions"] +subscriptions = ["juniper_graphql_ws"] [dependencies] bytes = "0.5" @@ -17,7 +17,7 @@ anyhow = "1.0" thiserror = "1.0" futures = "0.3.1" juniper = { version = "0.14.2", path = "../juniper", default-features = false } -juniper_subscriptions = { path = "../juniper_subscriptions", optional = true } +juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true } serde = { version = "1.0.75", features = ["derive"] } serde_json = "1.0.24" tokio = { version = "0.2", features = ["blocking", "rt-core"] } diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index dbb821870..7c5c516e2 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -393,237 +393,103 @@ fn playground_response( /// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md #[cfg(feature = "subscriptions")] pub mod subscriptions { - use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, + use juniper::{ + futures::{ + future::{self, Either}, + sink::SinkExt, + stream::StreamExt, }, + GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue, }; + use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init}; + use std::{convert::Infallible, fmt, sync::Arc}; - use anyhow::anyhow; - use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _}; - use juniper::{ - http::GraphQLRequest, ExecutionError, InputValue, ScalarValue, - SubscriptionCoordinator as _, Value, - }; - use juniper_subscriptions::Coordinator; - use serde::{Deserialize, Serialize}; - use warp::ws::Message; - - #[derive(Serialize)] - struct DataPayload<'a, S: ScalarValue> { - data: &'a Value, - errors: &'a Vec>, + struct Message(warp::ws::Message); + + impl std::convert::TryFrom for ClientMessage { + type Error = serde_json::Error; + + fn try_from(msg: Message) -> serde_json::Result { + serde_json::from_slice(msg.0.as_bytes()) + } } - /// Listen to incoming messages and do one of the following: - /// - execute subscription and return values from stream - /// - stop stream and close ws connection - #[allow(dead_code)] - pub fn graphql_subscriptions( - websocket: warp::ws::WebSocket, - coordinator: Arc>, - context: CtxT, - ) -> impl Future> + Send - where - Query: juniper::GraphQLTypeAsync + Send + 'static, - Query::TypeInfo: Send + Sync, - Mutation: juniper::GraphQLTypeAsync + Send + 'static, - Mutation::TypeInfo: Send + Sync, - Subscription: juniper::GraphQLSubscriptionType + Send + 'static, - Subscription::TypeInfo: Send + Sync, - CtxT: Send + Sync + 'static, - S: ScalarValue + Send + Sync + 'static, - { - let (sink_tx, sink_rx) = websocket.split(); - let (ws_tx, ws_rx) = mpsc::unbounded(); - tokio::task::spawn( - ws_rx - .take_while(|v: &Option<_>| futures::future::ready(v.is_some())) - .map(|x| x.unwrap()) - .forward(sink_tx), - ); + /// Errors that can happen while serving a connection. + #[derive(Debug)] + pub enum Error { + /// Errors that can happen in Warp while serving a connection. + Warp(warp::Error), - let context = Arc::new(context); - let got_close_signal = Arc::new(AtomicBool::new(false)); - let got_close_signal2 = got_close_signal.clone(); + /// Errors that can happen while serializing outgoing messages. Note that errors that occur + /// while deserializing internal messages are handled internally by the protocol. + Serde(serde_json::Error), + } - struct SubscriptionState { - should_stop: AtomicBool, + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Warp(e) => write!(f, "warp error: {}", e), + Self::Serde(e) => write!(f, "serde error: {}", e), + } } - let subscription_states = HashMap::>::new(); + } - sink_rx - .map_err(move |e| { - got_close_signal2.store(true, Ordering::Relaxed); - anyhow!("Websocket error: {}", e) - }) - .try_fold(subscription_states, move |mut subscription_states, msg| { - let coordinator = coordinator.clone(); - let context = context.clone(); - let got_close_signal = got_close_signal.clone(); - let ws_tx = ws_tx.clone(); - - async move { - if msg.is_close() { - return Ok(subscription_states); - } - - let msg = msg - .to_str() - .map_err(|_| anyhow!("Non-text messages are not accepted"))?; - let request: WsPayload = serde_json::from_str(msg) - .map_err(|e| anyhow!("Invalid WsPayload: {}", e))?; - - match request.type_name.as_str() { - "connection_init" => {} - "start" => { - if got_close_signal.load(Ordering::Relaxed) { - return Ok(subscription_states); - } - - let request_id = request.id.clone().unwrap_or("1".to_owned()); - - if let Some(existing) = subscription_states.get(&request_id) { - existing.should_stop.store(true, Ordering::Relaxed); - } - let state = Arc::new(SubscriptionState { - should_stop: AtomicBool::new(false), - }); - subscription_states.insert(request_id.clone(), state.clone()); - - let ws_tx = ws_tx.clone(); - - if let Some(ref payload) = request.payload { - if payload.query.is_none() { - return Err(anyhow!("Query not found")); - } - } else { - return Err(anyhow!("Payload not found")); - } - - tokio::task::spawn(async move { - let payload = request.payload.unwrap(); - - let graphql_request = GraphQLRequest::::new( - payload.query.unwrap(), - None, - payload.variables, - ); - - let values_stream = match coordinator - .subscribe(&graphql_request, &context) - .await - { - Ok(s) => s, - Err(err) => { - let _ = - ws_tx.unbounded_send(Some(Ok(Message::text(format!( - r#"{{"type":"error","id":"{}","payload":{}}}"#, - request_id, - serde_json::ser::to_string(&err).unwrap_or( - "Error deserializing GraphQLError".to_owned() - ) - ))))); - - let close_message = format!( - r#"{{"type":"complete","id":"{}","payload":null}}"#, - request_id - ); - let _ = ws_tx - .unbounded_send(Some(Ok(Message::text(close_message)))); - // close channel - let _ = ws_tx.unbounded_send(None); - return; - } - }; - - values_stream - .take_while(move |(data, errors)| { - let request_id = request_id.clone(); - let should_stop = state.should_stop.load(Ordering::Relaxed) - || got_close_signal.load(Ordering::Relaxed); - if !should_stop { - let mut response_text = - serde_json::to_string(&DataPayload { - data, - errors, - }) - .unwrap_or( - "Error deserializing response".to_owned(), - ); - - response_text = format!( - r#"{{"type":"data","id":"{}","payload":{} }}"#, - request_id, response_text - ); - - let _ = ws_tx.unbounded_send(Some(Ok(Message::text( - response_text, - )))); - } - - async move { !should_stop } - }) - .for_each(|_| async {}) - .await; - }); - } - "stop" => { - let request_id = request.id.unwrap_or("1".to_owned()); - if let Some(existing) = subscription_states.get(&request_id) { - existing.should_stop.store(true, Ordering::Relaxed); - subscription_states.remove(&request_id); - } - - let close_message = format!( - r#"{{"type":"complete","id":"{}","payload":null}}"#, - request_id - ); - let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message)))); - - // close channel - let _ = ws_tx.unbounded_send(None); - } - _ => {} - } - - Ok(subscription_states) - } - }) - .map_ok(|_| ()) + impl std::error::Error for Error {} + + impl From for Error { + fn from(err: warp::Error) -> Self { + Self::Warp(err) + } } - #[derive(Deserialize)] - #[serde(bound = "GraphQLPayload: Deserialize<'de>")] - struct WsPayload - where - S: ScalarValue + Send + Sync, - { - id: Option, - #[serde(rename(deserialize = "type"))] - type_name: String, - payload: Option>, + impl From for Error { + fn from(_err: Infallible) -> Self { + unreachable!() + } } - #[derive(Debug, Deserialize)] - #[serde(bound = "InputValue: Deserialize<'de>")] - struct GraphQLPayload + /// Serves the graphql-ws protocol over a WebSocket connection. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and + /// configuration are already known, or it can be a closure that gets executed asynchronously + /// when the client sends the ConnectionInit message. Using a closure allows you to perform + /// authentication based on the parameters provided by the client. + pub async fn serve_graphql_ws( + websocket: warp::ws::WebSocket, + root_node: Arc>, + init: I, + ) -> Result<(), Error> where - S: ScalarValue + Send + Sync, + Query: GraphQLTypeAsync + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init, { - variables: Option>, - extensions: Option>, - #[serde(rename(deserialize = "operationName"))] - operaton_name: Option, - query: Option, - } - - #[derive(Serialize)] - struct Output { - data: String, - variables: String, + let (ws_tx, ws_rx) = websocket.split(); + let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split(); + + let ws_rx = ws_rx.map(|r| r.map(|msg| Message(msg))); + let s_rx = s_rx.map(|msg| { + serde_json::to_string(&msg) + .map(|t| warp::ws::Message::text(t)) + .map_err(|e| Error::Serde(e)) + }); + + match future::select( + ws_rx.forward(s_tx.sink_err_into()), + s_rx.forward(ws_tx.sink_err_into()), + ) + .await + { + Either::Left((r, _)) => r.map_err(|e| e.into()), + Either::Right((r, _)) => r, + } } } From eae43007376fff6aa3a3d7af9b13fc6d657abc3d Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Sun, 26 Jul 2020 01:56:02 -0400 Subject: [PATCH 04/10] polish + timing test --- examples/warp_subscriptions/src/main.rs | 27 +- juniper_graphql_ws/src/lib.rs | 386 ++++++++++++++---------- juniper_warp/src/lib.rs | 2 +- 3 files changed, 239 insertions(+), 176 deletions(-) diff --git a/examples/warp_subscriptions/src/main.rs b/examples/warp_subscriptions/src/main.rs index 77194e5e4..0d4f31a67 100644 --- a/examples/warp_subscriptions/src/main.rs +++ b/examples/warp_subscriptions/src/main.rs @@ -134,15 +134,6 @@ fn schema() -> Schema { Schema::new(Query, EmptyMutation::new(), Subscription) } -async fn on_upgrade(ws: warp::ws::WebSocket, root_node: Arc) { - serve_graphql_ws(ws, root_node, ConnectionConfig::new(Context {})) - .map(|r| { - if let Err(e) = r { - println!("Websocket error: {}", e); - } - }).await; -} - #[tokio::main] async fn main() { env::set_var("RUST_LOG", "warp_subscriptions"); @@ -166,12 +157,18 @@ async fn main() { let routes = (warp::path("subscriptions") .and(warp::ws()) - .map( - move |ws: warp::ws::Ws| { - let root_node = root_node.clone(); - ws.on_upgrade(move |websocket| on_upgrade(websocket, root_node.clone())) - }, - )) + .map(move |ws: warp::ws::Ws| { + let root_node = root_node.clone(); + ws.on_upgrade(move |websocket| async move { + serve_graphql_ws(websocket, root_node, ConnectionConfig::new(Context {})) + .map(|r| { + if let Err(e) = r { + println!("Websocket error: {}", e); + } + }) + .await + }) + })) .map(|reply| { // TODO#584: remove this workaround warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws") diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs index ba7f5cb36..1d73067f8 100644 --- a/juniper_graphql_ws/src/lib.rs +++ b/juniper_graphql_ws/src/lib.rs @@ -29,8 +29,7 @@ use juniper::{ task::{Context, Poll, Waker}, Sink, Stream, }, - GraphQLError, RuleError, ScalarValue, - Variables, + GraphQLError, RuleError, ScalarValue, Variables, }; use std::{ collections::HashMap, @@ -94,10 +93,6 @@ impl Init for ConnectionC enum Reaction { ServerMessage(ServerMessage), - Activate { - config: ConnectionConfig, - schema: S, - }, EndStream, } @@ -141,162 +136,164 @@ where enum ConnectionState> { /// PreInit is the state before a ConnectionInit message has been accepted. PreInit { init: I, schema: S }, - /// Initializing is the state after a ConnectionInit message has been received, but before the - /// init future has resolved. - Initializing, /// Active is the state after a ConnectionInit message has been accepted. Active { config: Arc>, stoppers: HashMap>, schema: S, }, + /// Terminated is the state after a ConnectionInit message has been rejected. + Terminated, } impl> ConnectionState { // Each message we receive results in a stream of zero or more reactions. For example, a // ConnectionTerminate message results in a one-item stream with the EndStream reaction. - fn handle_message( - &mut self, + async fn handle_message( + self, msg: ClientMessage, - ) -> BoxStream<'static, Reaction> { + ) -> (Self, BoxStream<'static, Reaction>) { if let ClientMessage::ConnectionTerminate = msg { - return Reaction::EndStream.to_stream(); + return (self, Reaction::EndStream.to_stream()); } match self { - Self::PreInit { .. } => match msg { - ClientMessage::ConnectionInit { payload } => { - match std::mem::replace(self, Self::Initializing) { - Self::PreInit { init, schema } => init - .init(payload) - .map(|r| match r { - Ok(config) => { - let keep_alive_interval = config.keep_alive_interval; - - let mut s = stream::iter(vec![ - Reaction::Activate { config, schema }, - Reaction::ServerMessage(ServerMessage::ConnectionAck), - ]) - .boxed(); - - if keep_alive_interval > Duration::from_secs(0) { - s = s - .chain( - Reaction::ServerMessage( - ServerMessage::ConnectionKeepAlive, - ) - .to_stream(), - ) - .boxed(); - s = s - .chain(stream::unfold((), move |_| async move { - tokio::time::delay_for(keep_alive_interval).await; - Some(( - Reaction::ServerMessage( - ServerMessage::ConnectionKeepAlive, - ), - (), - )) - })) - .boxed(); - } - - s - } - Err(e) => stream::iter(vec![ - Reaction::ServerMessage(ServerMessage::ConnectionError { - payload: ConnectionErrorPayload { - message: e.to_string(), - }, - }), - Reaction::EndStream, - ]) - .boxed(), - }) - .into_stream() - .flatten() - .boxed(), - _ => unreachable!(), + Self::PreInit { init, schema } => match msg { + ClientMessage::ConnectionInit { payload } => match init.init(payload).await { + Ok(config) => { + let keep_alive_interval = config.keep_alive_interval; + + let mut s = stream::iter(vec![Reaction::ServerMessage( + ServerMessage::ConnectionAck, + )]) + .boxed(); + + if keep_alive_interval > Duration::from_secs(0) { + s = s + .chain( + Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive) + .to_stream(), + ) + .boxed(); + s = s + .chain(stream::unfold((), move |_| async move { + tokio::time::delay_for(keep_alive_interval).await; + Some(( + Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive), + (), + )) + })) + .boxed(); + } + + ( + Self::Active { + config: Arc::new(config), + stoppers: HashMap::new(), + schema, + }, + s, + ) } - } - _ => stream::empty().boxed(), + Err(e) => ( + Self::Terminated, + stream::iter(vec![ + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }), + Reaction::EndStream, + ]) + .boxed(), + ), + }, + _ => (Self::PreInit { init, schema }, stream::empty().boxed()), }, - Self::Initializing => stream::empty().boxed(), Self::Active { config, - stoppers, + mut stoppers, schema, } => { - match msg { + let reactions = match msg { ClientMessage::Start { id, payload } => { if stoppers.contains_key(&id) { // We already have an operation with this id, so we can't start a new // one. - return stream::empty().boxed(); - } - - // Go ahead and prune canceled stoppers before adding a new one. - stoppers.retain(|_, tx| !tx.is_canceled()); - - if config.max_in_flight_operations > 0 - && stoppers.len() >= config.max_in_flight_operations - { - // Too many in-flight operations. Just send back a validation error. - return stream::iter(vec![ - Reaction::ServerMessage(ServerMessage::Error { - id: id.clone(), - payload: GraphQLError::ValidationError(vec![RuleError::new( - "Too many in-flight operations.", - &[], - )]) - .into(), - }), - Reaction::ServerMessage(ServerMessage::Complete { id }), - ]) - .boxed(); + stream::empty().boxed() + } else { + // Go ahead and prune canceled stoppers before adding a new one. + stoppers.retain(|_, tx| !tx.is_canceled()); + + if config.max_in_flight_operations > 0 + && stoppers.len() >= config.max_in_flight_operations + { + // Too many in-flight operations. Just send back a validation error. + stream::iter(vec![ + Reaction::ServerMessage(ServerMessage::Error { + id: id.clone(), + payload: GraphQLError::ValidationError(vec![ + RuleError::new("Too many in-flight operations.", &[]), + ]) + .into(), + }), + Reaction::ServerMessage(ServerMessage::Complete { id }), + ]) + .boxed() + } else { + // Create a channel that we can use to cancel the operation. + let (tx, rx) = oneshot::channel::<()>(); + stoppers.insert(id.clone(), tx); + + // Create the operation stream. This stream will emit Data and Error + // messages, but will not emit Complete – that part is up to us. + let s = Self::start( + id.clone(), + ExecutionParams { + start_payload: payload, + config: config.clone(), + schema: schema.clone(), + }, + ) + .into_stream() + .flatten(); + + // Combine this with our oneshot channel so that the stream ends if the + // oneshot is ever fired. + let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move { + let next = match future::select(rx, s.next()).await { + Either::Left(_) => None, + Either::Right((r, rx)) => r.map(|r| (r, rx)), + }; + next.map(|(r, rx)| (r, (rx, s))) + }); + + // Once the stream ends, send the Complete message. + let s = s.chain( + Reaction::ServerMessage(ServerMessage::Complete { id }) + .to_stream(), + ); + + s.boxed() + } } - - // Create a channel that we can use to cancel the operation. - let (tx, rx) = oneshot::channel::<()>(); - stoppers.insert(id.clone(), tx); - - // Create the operation stream. This stream will emit Data and Error - // messages, but will not emit Complete – that part is up to us. - let s = Self::start( - id.clone(), - ExecutionParams { - start_payload: payload, - config: config.clone(), - schema: schema.clone(), - }, - ) - .into_stream() - .flatten(); - - // Combine this with our oneshot channel so that the stream ends if the - // oneshot is ever fired. - let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move { - let next = match future::select(rx, s.next()).await { - Either::Left(_) => None, - Either::Right((r, rx)) => r.map(|r| (r, rx)), - }; - next.map(|(r, rx)| (r, (rx, s))) - }); - - // Once the stream ends, send the Complete message. - let s = s.chain( - Reaction::ServerMessage(ServerMessage::Complete { id }).to_stream(), - ); - - s.boxed() } ClientMessage::Stop { id } => { stoppers.remove(&id); stream::empty().boxed() } _ => stream::empty().boxed(), - } + }; + ( + Self::Active { + config, + stoppers, + schema, + }, + reactions, + ) } + Self::Terminated => (self, stream::empty().boxed()), } } @@ -488,13 +485,22 @@ impl Stream for SubscriptionStart { } } +enum ConnectionSinkState> { + Ready { + state: ConnectionState, + }, + HandlingMessage { + result: BoxFuture<'static, (ConnectionState, BoxStream<'static, Reaction>)>, + }, + Closed, +} + /// Implements the graphql-ws protocol. This is a sink for `TryInto` and a stream of /// `ServerMessage`. pub struct Connection> { reactions: SelectAll>>, - state: ConnectionState, - is_closed: bool, stream_waker: Option, + sink_state: ConnectionSinkState, } impl Connection @@ -514,9 +520,10 @@ where pub fn new(schema: S, init: I) -> Self { Self { reactions: SelectAll::new(), - state: ConnectionState::PreInit { init, schema }, - is_closed: false, stream_waker: None, + sink_state: ConnectionSinkState::Ready { + state: ConnectionState::PreInit { init, schema }, + }, } } } @@ -526,41 +533,63 @@ where T: TryInto>, T::Error: Error, S: Schema, - I: Init, + I: Init + Send, { type Error = Infallible; - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { - // We're always ready for new messages. - Poll::Ready(Ok(())) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut self.sink_state { + ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())), + ConnectionSinkState::HandlingMessage { ref mut result } => { + match Pin::new(result).poll(cx) { + Poll::Ready((state, reactions)) => { + self.reactions.push(reactions); + self.sink_state = ConnectionSinkState::Ready { state }; + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } + ConnectionSinkState::Closed => panic!("poll_ready called after close"), + } } - fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - let reactions = match item.try_into() { - Ok(msg) => self.state.handle_message(msg), - Err(e) => { - // If we weren't able to parse the message, send back an error. - Reaction::ServerMessage(ServerMessage::ConnectionError { - payload: ConnectionErrorPayload { - message: e.to_string(), + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let s = self.get_mut(); + let state = &mut s.sink_state; + *state = match std::mem::replace(state, ConnectionSinkState::Closed) { + ConnectionSinkState::Ready { state } => { + match item.try_into() { + Ok(msg) => ConnectionSinkState::HandlingMessage { + result: state.handle_message(msg).boxed(), }, - }) - .to_stream() + Err(e) => { + // If we weren't able to parse the message, send back an error. + s.reactions.push( + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }) + .to_stream(), + ); + ConnectionSinkState::Ready { state } + } + } } + _ => panic!("start_send called when not ready"), }; - self.reactions.push(reactions); Ok(()) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { - // Flushing an item doesn't really have any meaning here. Just return okay. - Poll::Ready(Ok(())) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + >::poll_ready(self, cx) } fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { - // Close the stream too. No need to wait for it though. - self.is_closed = true; + self.sink_state = ConnectionSinkState::Closed; if let Some(waker) = self.stream_waker.take() { + // Wake up the stream so it can close too. waker.wake(); } Poll::Ready(Ok(())) @@ -577,7 +606,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.stream_waker = Some(cx.waker().clone()); - if self.is_closed { + if let ConnectionSinkState::Closed = self.sink_state { return Poll::Ready(None); } @@ -587,13 +616,6 @@ where match Pin::new(&mut self.reactions).poll_next(cx) { Poll::Ready(Some(reaction)) => match reaction { Reaction::ServerMessage(msg) => return Poll::Ready(Some(msg)), - Reaction::Activate { config, schema } => { - self.state = ConnectionState::Active { - config: Arc::new(config), - stoppers: HashMap::new(), - schema, - } - } Reaction::EndStream => return Poll::Ready(None), }, Poll::Ready(None) => { @@ -943,4 +965,48 @@ mod test { ); } } + + #[tokio::test] + async fn test_slow_init() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + // If we send the start message before the init is handled, we should still get results. + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "{context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + assert_eq!( + ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::Object( + [("context", Value::Scalar(DefaultScalarValue::Int(1)))] + .iter() + .cloned() + .collect() + ), + errors: vec![], + }, + }, + conn.next().await.unwrap() + ); + } } diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 7c5c516e2..06898304f 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -469,7 +469,7 @@ pub mod subscriptions { Subscription::TypeInfo: Send + Sync, CtxT: Unpin + Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static, - I: Init, + I: Init + Send, { let (ws_tx, ws_rx) = websocket.split(); let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split(); From ad9ebe44188c22e7ab8a6b402574dd8fb7029618 Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Sun, 26 Jul 2020 03:04:26 -0400 Subject: [PATCH 05/10] fix pre-existing bug --- juniper_graphql_ws/src/lib.rs | 58 ++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs index 1d73067f8..ac45e2d14 100644 --- a/juniper_graphql_ws/src/lib.rs +++ b/juniper_graphql_ws/src/lib.rs @@ -640,7 +640,7 @@ mod test { use juniper::{ futures::sink::SinkExt, parser::{ParseError, Spanning, Token}, - DefaultScalarValue, EmptyMutation, FieldResult, InputValue, RootNode, Value, + DefaultScalarValue, EmptyMutation, FieldError, FieldResult, InputValue, RootNode, Value, }; use std::{convert::Infallible, io}; @@ -678,6 +678,20 @@ mod test { ) .boxed() } + + /// error emits an error once, then never emits anything else. + async fn error(context: &Context) -> BoxStream<'static, FieldResult> { + stream::once(future::ready(Err(FieldError::new( + "field error", + Value::null(), + )))) + .chain( + tokio::time::delay_for(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream(), + ) + .boxed() + } } type ClientMessage = super::ClientMessage; @@ -1009,4 +1023,46 @@ mod test { conn.next().await.unwrap() ); } + + #[tokio::test] + async fn test_subscription_field_error() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "subscription Foo {error}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + ServerMessage::Data { + id, + payload: DataPayload { data, errors }, + } => { + assert_eq!(id, "foo"); + assert_eq!( + data, + Value::Object([("error", Value::null())].iter().cloned().collect()) + ); + assert_eq!(errors.len(), 1); + } + msg @ _ => panic!("expected data, got: {:?}", msg), + } + } } From cae5a5a776639cf2fdf56fe02ffac08c0ed5eb77 Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Wed, 29 Jul 2020 01:30:30 -0400 Subject: [PATCH 06/10] rebase updates --- juniper_graphql_ws/src/lib.rs | 7 +++++-- juniper_subscriptions/src/lib.rs | 16 +++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs index ac45e2d14..4ea87202b 100644 --- a/juniper_graphql_ws/src/lib.rs +++ b/juniper_graphql_ws/src/lib.rs @@ -467,10 +467,13 @@ impl Stream for SubscriptionStart { ref id, ref mut stream, } => match Pin::new(stream).poll_next(cx) { - Poll::Ready(Some((data, errors))) => { + Poll::Ready(Some(output)) => { return Poll::Ready(Some(Reaction::ServerMessage(ServerMessage::Data { id: id.clone(), - payload: DataPayload { data, errors }, + payload: DataPayload { + data: output.data, + errors: output.errors, + }, }))); } Poll::Ready(None) => { diff --git a/juniper_subscriptions/src/lib.rs b/juniper_subscriptions/src/lib.rs index 0e78279ca..3418c0551 100644 --- a/juniper_subscriptions/src/lib.rs +++ b/juniper_subscriptions/src/lib.rs @@ -222,19 +222,25 @@ where } if filled_count == obj_len { + let mut errors = vec![]; filled_count = 0; let new_vec = (0..obj_len).map(|_| None).collect::>(); let ready_vec = std::mem::replace(&mut ready_vec, new_vec); let ready_vec_iterator = ready_vec.into_iter().map(|el| { let (name, val) = el.unwrap(); - if let Ok(value) = val { - (name, value) - } else { - (name, Value::Null) + match val { + Ok(value) => (name, value), + Err(e) => { + errors.push(e); + (name, Value::Null) + } } }); let obj = Object::from_iter(ready_vec_iterator); - Poll::Ready(Some(ExecutionOutput::from_data(Value::Object(obj)))) + Poll::Ready(Some(ExecutionOutput { + data: Value::Object(obj), + errors, + })) } else { Poll::Pending } From c2c631fc563ca49fbacc4f3fba90881423c532db Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Wed, 29 Jul 2020 02:16:53 -0400 Subject: [PATCH 07/10] address comments --- juniper_graphql_ws/Cargo.toml | 2 +- juniper_graphql_ws/src/lib.rs | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/juniper_graphql_ws/Cargo.toml b/juniper_graphql_ws/Cargo.toml index 6f15ca303..8bf19ee76 100644 --- a/juniper_graphql_ws/Cargo.toml +++ b/juniper_graphql_ws/Cargo.toml @@ -6,7 +6,7 @@ license = "BSD-2-Clause" description = "Graphql-ws protocol implementation for Juniper" documentation = "https://docs.rs/juniper_graphql_ws" repository = "https://github.com/graphql-rust/juniper" -keywords = ["graphql-ws"] +keywords = ["graphql-ws", "juniper", "graphql", "apollo"] edition = "2018" [dependencies] diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs index 4ea87202b..7cf73eecb 100644 --- a/juniper_graphql_ws/src/lib.rs +++ b/juniper_graphql_ws/src/lib.rs @@ -2,7 +2,7 @@ # juniper_graphql_ws -This crate contains an implementation of the graphql-ws protocol, as used by Apollo. +This crate contains an implementation of the [graphql-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/263844b5c1a850c1e29814564eb62cb587e5eaaf/PROTOCOL.md), as used by Apollo. */ @@ -329,9 +329,10 @@ impl> ConnectionState { Err(e) => { return Reaction::ServerMessage(ServerMessage::Error { id: id.clone(), + // e only references data owned by params. The new ErrorPayload will continue to keep that data alive. payload: unsafe { ErrorPayload::new_unchecked(Box::new(params.clone()), e) }, }) - .to_stream() + .to_stream(); } } @@ -454,6 +455,7 @@ impl Stream for SubscriptionStart { return Poll::Ready(Some(Reaction::ServerMessage( ServerMessage::Error { id: id.clone(), + // e only references data owned by params. The new ErrorPayload will continue to keep that data alive. payload: unsafe { ErrorPayload::new_unchecked(Box::new(params.clone()), e) }, From d4e39c525a2e682cd09716382ea33c84490638b6 Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Wed, 29 Jul 2020 02:23:54 -0400 Subject: [PATCH 08/10] add release.toml --- juniper_graphql_ws/release.toml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 juniper_graphql_ws/release.toml diff --git a/juniper_graphql_ws/release.toml b/juniper_graphql_ws/release.toml new file mode 100644 index 000000000..98e705946 --- /dev/null +++ b/juniper_graphql_ws/release.toml @@ -0,0 +1,8 @@ +no-dev-version = true +pre-release-commit-message = "Release {{crate_name}} {{version}}" +pro-release-commit-message = "Bump {{crate_name}} version to {{next_version}}" +tag-message = "Release {{crate_name}} {{version}}" +upload-doc = false +pre-release-replacements = [ + {file="src/lib.rs", search="docs.rs/juniper_graphql_ws/[a-z0-9\\.-]+", replace="docs.rs/juniper_graphql_ws/{{version}}"}, +] From 6bd46906a92d2ca77a39e2ab0f1f6e01c06ff32f Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Wed, 29 Jul 2020 02:35:19 -0400 Subject: [PATCH 09/10] makefile and initial changelog --- juniper_graphql_ws/CHANGELOG.md | 3 +++ juniper_graphql_ws/Makefile.toml | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 juniper_graphql_ws/CHANGELOG.md create mode 100644 juniper_graphql_ws/Makefile.toml diff --git a/juniper_graphql_ws/CHANGELOG.md b/juniper_graphql_ws/CHANGELOG.md new file mode 100644 index 000000000..052324725 --- /dev/null +++ b/juniper_graphql_ws/CHANGELOG.md @@ -0,0 +1,3 @@ +# master + +- Initial Release diff --git a/juniper_graphql_ws/Makefile.toml b/juniper_graphql_ws/Makefile.toml new file mode 100644 index 000000000..ba858470d --- /dev/null +++ b/juniper_graphql_ws/Makefile.toml @@ -0,0 +1,20 @@ +[env] +CARGO_MAKE_CARGO_ALL_FEATURES = "" + +[tasks.build-verbose] +condition = { rust_version = { min = "1.29.0" } } + +[tasks.build-verbose.windows] +condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } } + +[tasks.test-verbose] +condition = { rust_version = { min = "1.29.0" } } + +[tasks.test-verbose.windows] +condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } } + +[tasks.ci-coverage-flow] +condition = { rust_version = { min = "1.29.0" } } + +[tasks.ci-coverage-flow.windows] +disabled = true From fd6c442cea4aa920e1bbb9b5c7dcff8cc03b0354 Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Wed, 29 Jul 2020 02:54:17 -0400 Subject: [PATCH 10/10] add new Cargo.toml to juniper/release.toml --- juniper/release.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/juniper/release.toml b/juniper/release.toml index 723911494..ab15f4e50 100644 --- a/juniper/release.toml +++ b/juniper/release.toml @@ -30,6 +30,8 @@ pre-release-replacements = [ {file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""}, # Subscriptions {file="../juniper_subscriptions/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""}, + # GraphQL-WS + {file="../juniper_graphql_ws/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""}, # Actix-Web {file="../juniper_actix/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""}, {file="../juniper_actix/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},