1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: future:: Future ;
18+ use std:: {
19+ future:: Future ,
20+ pin:: Pin ,
21+ task:: { Context , Poll } ,
22+ } ;
1923
20- use crate :: JoinSet ;
21- use tokio:: task:: JoinError ;
24+ use tokio:: task:: { JoinError , JoinHandle } ;
25+
26+ use crate :: trace_utils:: { trace_block, trace_future} ;
2227
2328/// Helper that provides a simple API to spawn a single task and join it.
2429/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
30+ /// Note that if the task was spawned with `spawn_blocking`, it will only be
31+ /// aborted if it hasn't started yet.
2532///
26- /// Technically, it's just a wrapper of `JoinSet` (with size=1) .
33+ /// Technically, it's just a wrapper of a `JoinHandle` overriding drop .
2734#[ derive( Debug ) ]
2835pub struct SpawnedTask < R > {
29- inner : JoinSet < R > ,
36+ inner : JoinHandle < R > ,
3037}
3138
3239impl < R : ' static > SpawnedTask < R > {
@@ -36,8 +43,9 @@ impl<R: 'static> SpawnedTask<R> {
3643 T : Send + ' static ,
3744 R : Send ,
3845 {
39- let mut inner = JoinSet :: new ( ) ;
40- inner. spawn ( task) ;
46+ // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop
47+ #[ allow( clippy:: disallowed_methods) ]
48+ let inner = tokio:: task:: spawn ( trace_future ( task) ) ;
4149 Self { inner }
4250 }
4351
@@ -47,22 +55,21 @@ impl<R: 'static> SpawnedTask<R> {
4755 T : Send + ' static ,
4856 R : Send ,
4957 {
50- let mut inner = JoinSet :: new ( ) ;
51- inner. spawn_blocking ( task) ;
58+ // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop
59+ #[ allow( clippy:: disallowed_methods) ]
60+ let inner = tokio:: task:: spawn_blocking ( trace_block ( task) ) ;
5261 Self { inner }
5362 }
5463
5564 /// Joins the task, returning the result of join (`Result<R, JoinError>`).
56- pub async fn join ( mut self ) -> Result < R , JoinError > {
57- self . inner
58- . join_next ( )
59- . await
60- . expect ( "`SpawnedTask` instance always contains exactly 1 task" )
65+ /// Same as awaiting the spawned task, but left for backwards compatibility.
66+ pub async fn join ( self ) -> Result < R , JoinError > {
67+ self . await
6168 }
6269
6370 /// Joins the task and unwinds the panic if it happens.
6471 pub async fn join_unwind ( self ) -> Result < R , JoinError > {
65- self . join ( ) . await . map_err ( |e| {
72+ self . await . map_err ( |e| {
6673 // `JoinError` can be caused either by panic or cancellation. We have to handle panics:
6774 if e. is_panic ( ) {
6875 std:: panic:: resume_unwind ( e. into_panic ( ) ) ;
@@ -77,17 +84,32 @@ impl<R: 'static> SpawnedTask<R> {
7784 }
7885}
7986
87+ impl < R > Future for SpawnedTask < R > {
88+ type Output = Result < R , JoinError > ;
89+
90+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
91+ Pin :: new ( & mut self . inner ) . poll ( cx)
92+ }
93+ }
94+
95+ impl < R > Drop for SpawnedTask < R > {
96+ fn drop ( & mut self ) {
97+ self . inner . abort ( ) ;
98+ }
99+ }
100+
80101#[ cfg( test) ]
81102mod tests {
82103 use super :: * ;
83104
84105 use std:: future:: { pending, Pending } ;
85106
86- use tokio:: runtime:: Runtime ;
107+ use tokio:: { runtime:: Runtime , sync :: oneshot } ;
87108
88109 #[ tokio:: test]
89110 async fn runtime_shutdown ( ) {
90111 let rt = Runtime :: new ( ) . unwrap ( ) ;
112+ #[ allow( clippy:: async_yields_async) ]
91113 let task = rt
92114 . spawn ( async {
93115 SpawnedTask :: spawn ( async {
@@ -119,4 +141,36 @@ mod tests {
119141 . await
120142 . ok ( ) ;
121143 }
144+
145+ #[ tokio:: test]
146+ async fn cancel_not_started_task ( ) {
147+ let ( sender, receiver) = oneshot:: channel :: < i32 > ( ) ;
148+ let task = SpawnedTask :: spawn ( async {
149+ // Shouldn't be reached.
150+ sender. send ( 42 ) . unwrap ( ) ;
151+ } ) ;
152+
153+ drop ( task) ;
154+
155+ // If the task was cancelled, the sender was also dropped,
156+ // and awaiting the receiver should result in an error.
157+ assert ! ( receiver. await . is_err( ) ) ;
158+ }
159+
160+ #[ tokio:: test]
161+ async fn cancel_ongoing_task ( ) {
162+ let ( sender, mut receiver) = tokio:: sync:: mpsc:: channel ( 1 ) ;
163+ let task = SpawnedTask :: spawn ( async move {
164+ sender. send ( 1 ) . await . unwrap ( ) ;
165+ // This line will never be reached because the channel has a buffer
166+ // of 1.
167+ sender. send ( 2 ) . await . unwrap ( ) ;
168+ } ) ;
169+ // Let the task start.
170+ assert_eq ! ( receiver. recv( ) . await . unwrap( ) , 1 ) ;
171+ drop ( task) ;
172+
173+ // The sender was dropped so we receive `None`.
174+ assert ! ( receiver. recv( ) . await . is_none( ) ) ;
175+ }
122176}
0 commit comments