Skip to content

Commit dfcad54

Browse files
committed
Removed Clone trait in Context definition in actix_subscriptions
1 parent a6ebe43 commit dfcad54

File tree

3 files changed

+50
-63
lines changed

3 files changed

+50
-63
lines changed

juniper_actix/examples/actix_subscriptions.rs

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,17 @@
33
use actix_cors::Cors;
44
use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer};
55
use futures::Stream;
6-
use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode};
6+
use juniper::{
7+
tests::{model::Database, schema::Query},
8+
DefaultScalarValue, EmptyMutation, FieldError, RootNode,
9+
};
710
use juniper_actix::{
811
graphiql_handler as gqli_handler, graphql_handler, playground_handler as play_handler,
912
subscriptions::{graphql_subscriptions as sub_handler, EmptySubscriptionHandler},
1013
};
1114
use juniper_subscriptions::Coordinator;
1215
use std::{pin::Pin, time::Duration};
1316

14-
pub struct Query;
15-
16-
#[juniper::graphql_object(Context = Database)]
17-
impl Query {
18-
fn hello_world() -> &str {
19-
"Hello World!"
20-
}
21-
}
2217
type Schema = RootNode<'static, Query, EmptyMutation<Database>, Subscription>;
2318
type MyCoordinator = Coordinator<
2419
'static,
@@ -33,23 +28,13 @@ type StringStream = Pin<Box<dyn Stream<Item = Result<String, FieldError>> + Send
3328

3429
struct Subscription;
3530

36-
#[derive(Clone)]
37-
pub struct Database;
38-
39-
impl juniper::Context for Database {}
40-
41-
impl Database {
42-
fn new() -> Self {
43-
Self {}
44-
}
45-
}
46-
4731
#[juniper::graphql_subscription(Context = Database)]
4832
impl Subscription {
4933
async fn hello_world() -> StringStream {
5034
let mut counter = 0;
51-
let stream = tokio::time::interval(Duration::from_secs(5)).map(move |_| {
35+
let stream = tokio::time::interval(Duration::from_secs(1)).map(move |_| {
5236
counter += 1;
37+
5338
if counter % 2 == 0 {
5439
Ok(String::from("World!"))
5540
} else {
@@ -87,16 +72,8 @@ async fn graphql_subscriptions(
8772
req: HttpRequest,
8873
) -> Result<HttpResponse, Error> {
8974
let context = Database::new();
90-
unsafe {
91-
sub_handler(
92-
coordinator,
93-
context,
94-
stream,
95-
req,
96-
Some(EmptySubscriptionHandler::default()),
97-
)
98-
}
99-
.await
75+
let handler: Option<EmptySubscriptionHandler> = None;
76+
unsafe { sub_handler(coordinator, context, stream, req, handler) }.await
10077
}
10178

10279
#[actix_rt::main]

juniper_actix/src/subscriptions.rs

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub use juniper_subscriptions::ws_util::{
1515
};
1616
use juniper_subscriptions::Coordinator;
1717
use serde::Serialize;
18+
use std::ops::Deref;
1819
use std::{
1920
collections::HashMap,
2021
error::Error as StdError,
@@ -33,7 +34,7 @@ fn start<Query, Mutation, Subscription, Context, S, SubHandler, T, E>(
3334
where
3435
T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
3536
S: ScalarValue + Send + Sync + 'static,
36-
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
37+
Context: Send + Sync + 'static + std::marker::Unpin,
3738
Query: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
3839
Query::TypeInfo: Send + Sync,
3940
Mutation: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
@@ -65,7 +66,7 @@ pub async unsafe fn graphql_subscriptions<
6566
) -> Result<HttpResponse, Error>
6667
where
6768
S: ScalarValue + Send + Sync + 'static,
68-
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
69+
Context: Send + Sync + 'static + std::marker::Unpin,
6970
Query: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
7071
Query::TypeInfo: Send + Sync,
7172
Mutation: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
@@ -78,7 +79,7 @@ where
7879
start(
7980
GraphQLWSSession {
8081
coordinator: coordinator.into_inner(),
81-
graphql_context: context,
82+
graphql_context: Arc::new(context),
8283
map_req_id_to_spawn_handle: HashMap::new(),
8384
has_started: Arc::new(AtomicBool::new(false)),
8485
handler,
@@ -92,7 +93,7 @@ where
9293
struct GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
9394
where
9495
S: ScalarValue + Send + Sync + 'static,
95-
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
96+
Context: Send + Sync + 'static + std::marker::Unpin,
9697
Query: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
9798
Query::TypeInfo: Send + Sync,
9899
Mutation: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
@@ -104,7 +105,7 @@ where
104105
{
105106
pub map_req_id_to_spawn_handle: HashMap<String, SpawnHandle>,
106107
pub has_started: Arc<AtomicBool>,
107-
pub graphql_context: Context,
108+
pub graphql_context: Arc<Context>,
108109
pub coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
109110
pub handler: Option<SubHandler>,
110111
error_handler: std::marker::PhantomData<E>,
@@ -114,7 +115,7 @@ impl<Query, Mutation, Subscription, Context, S, SubHandler, E> Actor
114115
for GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
115116
where
116117
S: ScalarValue + Send + Sync + 'static,
117-
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
118+
Context: Send + Sync + 'static + std::marker::Unpin,
118119
Query: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
119120
Query::TypeInfo: Send + Sync,
120121
Mutation: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
@@ -134,7 +135,7 @@ impl<Query, Mutation, Subscription, Context, S, SubHandler, E>
134135
GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
135136
where
136137
S: ScalarValue + Send + Sync + 'static,
137-
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
138+
Context: Send + Sync + 'static + std::marker::Unpin,
138139
Query: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
139140
Query::TypeInfo: Send + Sync,
140141
Mutation: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
@@ -192,7 +193,7 @@ where
192193
result: (
193194
GraphQLRequest<S>,
194195
String,
195-
Context,
196+
Arc<Context>,
196197
Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
197198
),
198199
actor: &mut Self,
@@ -205,7 +206,7 @@ where
205206

206207
async fn handle_subscription(
207208
req: GraphQLRequest<S>,
208-
graphql_context: Context,
209+
graphql_context: Arc<Context>,
209210
request_id: String,
210211
coord: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
211212
ctx: *mut ws::WebsocketContext<Self>,
@@ -255,7 +256,7 @@ impl<Query, Mutation, Subscription, Context, S, SubHandler, E>
255256
for GraphQLWSSession<Query, Mutation, Subscription, Context, S, SubHandler, E>
256257
where
257258
S: ScalarValue + Send + Sync + 'static,
258-
Context: Clone + Send + Sync + 'static + std::marker::Unpin,
259+
Context: Send + Sync + 'static + std::marker::Unpin,
259260
Query: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
260261
Query::TypeInfo: Send + Sync,
261262
Mutation: juniper::GraphQLTypeAsync<S, Context = Context> + Send + Sync + 'static,
@@ -289,7 +290,7 @@ where
289290
if let Some(handler) = &self.handler {
290291
let state = SubscriptionState::OnConnection(
291292
request.payload,
292-
&mut self.graphql_context,
293+
Arc::get_mut(&mut self.graphql_context).unwrap(),
293294
);
294295
let on_connect_result = handler.handle(state);
295296
if let Err(_err) = on_connect_result {
@@ -312,7 +313,7 @@ where
312313
}
313314
GraphQLOverWebSocketMessage::Start if has_started_value => {
314315
let coordinator = self.coordinator.clone();
315-
let mut context = self.graphql_context.clone();
316+
316317
let payload = request
317318
.graphql_payload::<S>()
318319
.expect("Could not deserialize payload");
@@ -323,9 +324,12 @@ where
323324
payload.variables,
324325
);
325326
if let Some(handler) = &self.handler {
326-
let state = SubscriptionState::OnOperation(&mut context);
327+
let state = SubscriptionState::OnOperation(
328+
self.graphql_context.deref(),
329+
);
327330
handler.handle(state).unwrap();
328331
}
332+
let context = self.graphql_context.clone();
329333
{
330334
use std::collections::hash_map::Entry;
331335
let req_id = request_id.clone();
@@ -347,8 +351,8 @@ where
347351
GraphQLOverWebSocketMessage::Stop if has_started_value => {
348352
let request_id = request.id.unwrap_or("1".to_owned());
349353
if let Some(handler) = &self.handler {
350-
let state =
351-
SubscriptionState::OnOperationComplete(&self.graphql_context);
354+
let context = self.graphql_context.deref();
355+
let state = SubscriptionState::OnOperationComplete(context);
352356
handler.handle(state).unwrap();
353357
}
354358
match self.map_req_id_to_spawn_handle.remove(&request_id) {
@@ -366,7 +370,8 @@ where
366370
}
367371
GraphQLOverWebSocketMessage::ConnectionTerminate => {
368372
if let Some(handler) = &self.handler {
369-
let state = SubscriptionState::OnDisconnect(&self.graphql_context);
373+
let context = self.graphql_context.deref();
374+
let state = SubscriptionState::OnDisconnect(context);
370375
handler.handle(state).unwrap();
371376
}
372377
ctx.stop();
@@ -376,7 +381,8 @@ where
376381
}
377382
ws::Message::Close(_) => {
378383
if let Some(handler) = &self.handler {
379-
let state = SubscriptionState::OnDisconnect(&self.graphql_context);
384+
let context = self.graphql_context.deref();
385+
let state = SubscriptionState::OnDisconnect(context);
380386
handler.handle(state).unwrap();
381387
}
382388
ctx.stop();
@@ -483,6 +489,9 @@ mod tests {
483489
String::from(
484490
r#"{"id":"1","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#,
485491
),
492+
String::from(
493+
r#"{"id":"2","type":"start","payload":{"variables":{},"extensions":{},"operationName":"hello","query":"subscription hello { helloWorld}"}}"#,
494+
),
486495
String::from(r#"{"id":"1","type":"stop"}"#),
487496
String::from(r#"{"type":"connection_terminate"}"#),
488497
];
@@ -496,6 +505,9 @@ mod tests {
496505
vec![Some(bytes::Bytes::from(
497506
r#"{"type":"data","id":"1","payload":{"data":{"helloWorld":"Hello"}} }"#,
498507
))],
508+
vec![Some(bytes::Bytes::from(
509+
r#"{"type":"data","id":"2","payload":{"data":{"helloWorld":"Hello"}} }"#,
510+
))],
499511
vec![Some(bytes::Bytes::from(
500512
r#"{"type":"complete","id":"1","payload":null}"#,
501513
))],

juniper_subscriptions/src/ws_util.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ where
7070
OnConnection(Option<Value>, &'a mut Context),
7171
/// The Subscription is at the start of a operation after the GQL_START message is
7272
/// is received.
73-
OnOperation(&'a mut Context),
73+
OnOperation(&'a Context),
7474
/// The subscription is on the end of a operation before sending the GQL_COMPLETE
7575
/// message to the client.
7676
OnOperationComplete(&'a Context),
@@ -160,15 +160,15 @@ where
160160
pub mod tests {
161161
use super::*;
162162
use juniper::DefaultScalarValue;
163-
use std::sync::Mutex;
163+
use std::sync::atomic::{AtomicBool, Ordering};
164164

165165
#[derive(Default)]
166166
struct Context {
167167
pub user_id: Option<String>,
168168
pub has_connected: bool,
169-
pub has_operated: bool,
170-
pub has_completed_operation: Mutex<bool>,
171-
pub has_disconnected: Mutex<bool>,
169+
pub has_operated: AtomicBool,
170+
pub has_completed_operation: AtomicBool,
171+
pub has_disconnected: AtomicBool,
172172
}
173173

174174
#[derive(Deserialize)]
@@ -192,15 +192,13 @@ pub mod tests {
192192
ctx.has_connected = true;
193193
}
194194
SubscriptionState::OnOperation(ctx) => {
195-
ctx.has_operated = true;
195+
ctx.has_operated.store(true, Ordering::Relaxed);
196196
}
197197
SubscriptionState::OnOperationComplete(ctx) => {
198-
let mut has_completed = ctx.has_completed_operation.lock().unwrap();
199-
*has_completed = true;
198+
ctx.has_completed_operation.store(true, Ordering::Relaxed);
200199
}
201200
SubscriptionState::OnDisconnect(ctx) => {
202-
let mut has_disconnected = ctx.has_disconnected.lock().unwrap();
203-
*has_disconnected = true;
201+
ctx.has_disconnected.store(true, Ordering::Relaxed);
204202
}
205203
};
206204
Ok(())
@@ -261,7 +259,7 @@ pub mod tests {
261259
let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Start).unwrap();
262260
let msg = format!(r#"{{"type":{}, "payload": {{}}, "id": "1" }}"#, type_value);
263261
assert!(implementation_example(&msg, &mut ctx));
264-
assert!(ctx.has_operated);
262+
assert!(ctx.has_operated.load(Ordering::Relaxed));
265263
}
266264

267265
#[test]
@@ -270,8 +268,8 @@ pub mod tests {
270268
let type_value = serde_json::to_string(&GraphQLOverWebSocketMessage::Stop).unwrap();
271269
let msg = format!(r#"{{"type":{}, "payload": null, "id": "1" }}"#, type_value);
272270
assert!(implementation_example(&msg, &mut ctx));
273-
let has_completed = ctx.has_completed_operation.lock().unwrap();
274-
assert!(*has_completed);
271+
let has_completed = ctx.has_completed_operation.load(Ordering::Relaxed);
272+
assert!(has_completed);
275273
}
276274

277275
#[test]
@@ -281,7 +279,7 @@ pub mod tests {
281279
serde_json::to_string(&GraphQLOverWebSocketMessage::ConnectionTerminate).unwrap();
282280
let msg = format!(r#"{{"type":{}, "payload": null, "id": "1" }}"#, type_value);
283281
assert!(implementation_example(&msg, &mut ctx));
284-
let has_disconnected = ctx.has_disconnected.lock().unwrap();
285-
assert!(*has_disconnected);
282+
let has_disconnected = ctx.has_disconnected.load(Ordering::Relaxed);
283+
assert!(has_disconnected);
286284
}
287285
}

0 commit comments

Comments
 (0)