1515
1616use alloc:: sync:: Arc ;
1717use core:: mem;
18- use crate :: sync:: { Condvar , Mutex } ;
18+ use crate :: sync:: { Condvar , Mutex , MutexGuard } ;
1919
2020use crate :: prelude:: * ;
2121
@@ -41,9 +41,22 @@ impl Notifier {
4141 }
4242 }
4343
44+ fn propagate_future_state_to_notify_flag ( & self ) -> MutexGuard < ( bool , Option < Arc < Mutex < FutureState > > > ) > {
45+ let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
46+ if let Some ( existing_state) = & lock. 1 {
47+ if existing_state. lock ( ) . unwrap ( ) . callbacks_made {
48+ // If the existing futurestate has completed and actually made callbacks, consider
49+ // the notification flag to have been cleared and reset the future state.
50+ lock. 1 . take ( ) ;
51+ lock. 0 = false ;
52+ }
53+ }
54+ lock
55+ }
56+
4457 pub ( crate ) fn wait ( & self ) {
4558 loop {
46- let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
59+ let mut guard = self . propagate_future_state_to_notify_flag ( ) ;
4760 if guard. 0 {
4861 guard. 0 = false ;
4962 return ;
@@ -61,7 +74,7 @@ impl Notifier {
6174 pub ( crate ) fn wait_timeout ( & self , max_wait : Duration ) -> bool {
6275 let current_time = Instant :: now ( ) ;
6376 loop {
64- let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
77+ let mut guard = self . propagate_future_state_to_notify_flag ( ) ;
6578 if guard. 0 {
6679 guard. 0 = false ;
6780 return true ;
@@ -88,17 +101,8 @@ impl Notifier {
88101 /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
89102 pub ( crate ) fn notify ( & self ) {
90103 let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
91- let mut future_probably_generated_calls = false ;
92- if let Some ( future_state) = lock. 1 . take ( ) {
93- future_probably_generated_calls |= future_state. lock ( ) . unwrap ( ) . complete ( ) ;
94- future_probably_generated_calls |= Arc :: strong_count ( & future_state) > 1 ;
95- }
96- if future_probably_generated_calls {
97- // If a future made some callbacks or has not yet been drop'd (i.e. the state has more
98- // than the one reference we hold), assume the user was notified and skip setting the
99- // notification-required flag. This will not cause the `wait` functions above to return
100- // and avoid any future `Future`s starting in a completed state.
101- return ;
104+ if let Some ( future_state) = & lock. 1 {
105+ future_state. lock ( ) . unwrap ( ) . complete ( ) ;
102106 }
103107 lock. 0 = true ;
104108 mem:: drop ( lock) ;
@@ -107,20 +111,14 @@ impl Notifier {
107111
108112 /// Gets a [`Future`] that will get woken up with any waiters
109113 pub ( crate ) fn get_future ( & self ) -> Future {
110- let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
111- if lock. 0 {
112- Future {
113- state : Arc :: new ( Mutex :: new ( FutureState {
114- callbacks : Vec :: new ( ) ,
115- complete : true ,
116- } ) )
117- }
118- } else if let Some ( existing_state) = & lock. 1 {
114+ let mut lock = self . propagate_future_state_to_notify_flag ( ) ;
115+ if let Some ( existing_state) = & lock. 1 {
119116 Future { state : Arc :: clone ( & existing_state) }
120117 } else {
121118 let state = Arc :: new ( Mutex :: new ( FutureState {
122119 callbacks : Vec :: new ( ) ,
123- complete : false ,
120+ complete : lock. 0 ,
121+ callbacks_made : false ,
124122 } ) ) ;
125123 lock. 1 = Some ( Arc :: clone ( & state) ) ;
126124 Future { state }
@@ -153,17 +151,16 @@ impl<F: Fn() + Send> FutureCallback for F {
153151pub ( crate ) struct FutureState {
154152 callbacks : Vec < Box < dyn FutureCallback > > ,
155153 complete : bool ,
154+ callbacks_made : bool ,
156155}
157156
158157impl FutureState {
159- fn complete ( & mut self ) -> bool {
160- let mut made_calls = false ;
158+ fn complete ( & mut self ) {
161159 for callback in self . callbacks . drain ( ..) {
162160 callback. call ( ) ;
163- made_calls = true ;
161+ self . callbacks_made = true ;
164162 }
165163 self . complete = true ;
166- made_calls
167164 }
168165}
169166
@@ -180,6 +177,7 @@ impl Future {
180177 pub fn register_callback ( & self , callback : Box < dyn FutureCallback > ) {
181178 let mut state = self . state . lock ( ) . unwrap ( ) ;
182179 if state. complete {
180+ state. callbacks_made = true ;
183181 mem:: drop ( state) ;
184182 callback. call ( ) ;
185183 } else {
@@ -283,6 +281,28 @@ mod tests {
283281 assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
284282 }
285283
284+ #[ test]
285+ fn new_future_wipes_notify_bit ( ) {
286+ // Previously, if we were only using the `Future` interface to learn when a `Notifier` has
287+ // been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
288+ // fetched after the notify bit has been set.
289+ let notifier = Notifier :: new ( ) ;
290+ notifier. notify ( ) ;
291+
292+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
293+ let callback_ref = Arc :: clone ( & callback) ;
294+ notifier. get_future ( ) . register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
295+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
296+
297+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
298+ let callback_ref = Arc :: clone ( & callback) ;
299+ notifier. get_future ( ) . register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
300+ assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
301+
302+ notifier. notify ( ) ;
303+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
304+ }
305+
286306 #[ cfg( feature = "std" ) ]
287307 #[ test]
288308 fn test_wait_timeout ( ) {
@@ -334,6 +354,7 @@ mod tests {
334354 state : Arc :: new ( Mutex :: new ( FutureState {
335355 callbacks : Vec :: new ( ) ,
336356 complete : false ,
357+ callbacks_made : false ,
337358 } ) )
338359 } ;
339360 let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
@@ -352,6 +373,7 @@ mod tests {
352373 state : Arc :: new ( Mutex :: new ( FutureState {
353374 callbacks : Vec :: new ( ) ,
354375 complete : false ,
376+ callbacks_made : false ,
355377 } ) )
356378 } ;
357379 future. state . lock ( ) . unwrap ( ) . complete ( ) ;
@@ -389,6 +411,7 @@ mod tests {
389411 state : Arc :: new ( Mutex :: new ( FutureState {
390412 callbacks : Vec :: new ( ) ,
391413 complete : false ,
414+ callbacks_made : false ,
392415 } ) )
393416 } ;
394417 let mut second_future = Future { state : Arc :: clone ( & future. state ) } ;
0 commit comments