1313//!
1414//! [`ChannelManager`]: crate::ln::channelmanager::ChannelManager
1515
16+ use alloc:: sync:: Arc ;
1617use core:: mem;
1718use core:: time:: Duration ;
1819use sync:: { Condvar , Mutex } ;
1920
2021#[ cfg( any( test, feature = "std" ) ) ]
2122use std:: time:: Instant ;
2223
24+ use core:: future:: Future as StdFuture ;
25+ use core:: task:: { Context , Poll } ;
26+ use core:: pin:: Pin ;
27+
28+ use prelude:: * ;
29+
2330/// Used to signal to one of many waiters that the condition they're waiting on has happened.
2431pub ( crate ) struct Notifier {
25- /// Users won't access the lock directly, but rather wait on its bool using
26- /// `wait_timeout` and `wait`.
27- lock : ( Mutex < bool > , Condvar ) ,
32+ notify_pending : Mutex < ( bool , Option < Arc < Mutex < FutureState > > > ) > ,
33+ condvar : Condvar ,
2834}
2935
3036impl Notifier {
3137 pub ( crate ) fn new ( ) -> Self {
3238 Self {
33- lock : ( Mutex :: new ( false ) , Condvar :: new ( ) ) ,
39+ notify_pending : Mutex :: new ( ( false , None ) ) ,
40+ condvar : Condvar :: new ( ) ,
3441 }
3542 }
3643
3744 pub ( crate ) fn wait ( & self ) {
3845 loop {
39- let & ( ref mtx, ref cvar) = & self . lock ;
40- let mut guard = mtx. lock ( ) . unwrap ( ) ;
41- if * guard {
42- * guard = false ;
46+ let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
47+ if guard. 0 {
48+ guard. 0 = false ;
4349 return ;
4450 }
45- guard = cvar . wait ( guard) . unwrap ( ) ;
46- let result = * guard;
51+ guard = self . condvar . wait ( guard) . unwrap ( ) ;
52+ let result = guard. 0 ;
4753 if result {
48- * guard = false ;
54+ guard. 0 = false ;
4955 return
5056 }
5157 }
@@ -55,22 +61,21 @@ impl Notifier {
5561 pub ( crate ) fn wait_timeout ( & self , max_wait : Duration ) -> bool {
5662 let current_time = Instant :: now ( ) ;
5763 loop {
58- let & ( ref mtx, ref cvar) = & self . lock ;
59- let mut guard = mtx. lock ( ) . unwrap ( ) ;
60- if * guard {
61- * guard = false ;
64+ let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
65+ if guard. 0 {
66+ guard. 0 = false ;
6267 return true ;
6368 }
64- guard = cvar . wait_timeout ( guard, max_wait) . unwrap ( ) . 0 ;
69+ guard = self . condvar . wait_timeout ( guard, max_wait) . unwrap ( ) . 0 ;
6570 // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
6671 // desired wait time has actually passed, and if not then restart the loop with a reduced wait
6772 // time. Note that this logic can be highly simplified through the use of
6873 // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
6974 // 1.42.0.
7075 let elapsed = current_time. elapsed ( ) ;
71- let result = * guard;
76+ let result = guard. 0 ;
7277 if result || elapsed >= max_wait {
73- * guard = false ;
78+ guard. 0 = false ;
7479 return result;
7580 }
7681 match max_wait. checked_sub ( elapsed) {
@@ -82,29 +87,128 @@ impl Notifier {
8287
8388 /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
8489 pub ( crate ) fn notify ( & self ) {
85- let & ( ref persist_mtx, ref cnd) = & self . lock ;
86- let mut lock = persist_mtx. lock ( ) . unwrap ( ) ;
87- * lock = true ;
90+ let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
91+ lock. 0 = true ;
92+ if let Some ( future_state) = lock. 1 . take ( ) {
93+ future_state. lock ( ) . unwrap ( ) . complete ( ) ;
94+ }
8895 mem:: drop ( lock) ;
89- cnd. notify_all ( ) ;
96+ self . condvar . notify_all ( ) ;
97+ }
98+
99+ /// Gets a [`Future`] that will get woken up with any waiters
100+ pub ( crate ) fn get_future ( & self ) -> Future {
101+ let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
102+ if lock. 0 {
103+ Future {
104+ state : Arc :: new ( Mutex :: new ( FutureState {
105+ callbacks : Vec :: new ( ) ,
106+ complete : false ,
107+ } ) )
108+ }
109+ } else if let Some ( existing_state) = & lock. 1 {
110+ Future { state : Arc :: clone ( & existing_state) }
111+ } else {
112+ let state = Arc :: new ( Mutex :: new ( FutureState {
113+ callbacks : Vec :: new ( ) ,
114+ complete : false ,
115+ } ) ) ;
116+ lock. 1 = Some ( Arc :: clone ( & state) ) ;
117+ Future { state }
118+ }
90119 }
91120
92121 #[ cfg( any( test, feature = "_test_utils" ) ) ]
93122 pub fn notify_pending ( & self ) -> bool {
94- let & ( ref mtx, _) = & self . lock ;
95- let guard = mtx. lock ( ) . unwrap ( ) ;
96- * guard
123+ self . notify_pending . lock ( ) . unwrap ( ) . 0
124+ }
125+ }
126+
127+ /// A callback which is called when a [`Future`] completes.
128+ ///
129+ /// Note that this MUST NOT call back into LDK directly, it must instead schedule actions to be
130+ /// taken later. Rust users should use the [`std::future::Future`] implementation for [`Future`]
131+ /// instead.
132+ ///
133+ /// Note that the [`std::future::Future`] implementation may only work for runtimes which schedule
134+ /// futures when they receive a wake, rather than immediately executing them.
135+ pub trait FutureCallback : Send {
136+ /// The method which is called.
137+ fn call ( & self ) ;
138+ }
139+
140+ impl < F : Fn ( ) + Send > FutureCallback for F {
141+ fn call ( & self ) { ( self ) ( ) ; }
142+ }
143+
144+ pub ( crate ) struct FutureState {
145+ callbacks : Vec < Box < dyn FutureCallback > > ,
146+ complete : bool ,
147+ }
148+
149+ impl FutureState {
150+ fn complete ( & mut self ) {
151+ for callback in self . callbacks . drain ( ..) {
152+ callback. call ( ) ;
153+ }
154+ self . complete = true ;
155+ }
156+ }
157+
158+ /// A simple future which can complete once, and calls some callback(s) when it does so.
159+ pub struct Future {
160+ state : Arc < Mutex < FutureState > > ,
161+ }
162+
163+ impl Future {
164+ /// Registers a callback to be called upon completion of this future. If the future has already
165+ /// completed, the callback will be called immediately.
166+ pub fn register_callback ( & self , callback : Box < dyn FutureCallback > ) {
167+ let mut state = self . state . lock ( ) . unwrap ( ) ;
168+ if state. complete {
169+ mem:: drop ( state) ;
170+ callback. call ( ) ;
171+ } else {
172+ state. callbacks . push ( callback) ;
173+ }
174+ }
175+ }
176+
177+ mod std_future {
178+ use core:: task:: Waker ;
179+ pub struct StdWaker ( pub Waker ) ;
180+ impl super :: FutureCallback for StdWaker {
181+ fn call ( & self ) { self . 0 . wake_by_ref ( ) }
182+ }
183+ }
184+
185+ /// (C-not exported) as Rust Futures aren't usable in language bindings.
186+ impl < ' a > StdFuture for Future {
187+ type Output = ( ) ;
188+
189+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
190+ let mut state = self . state . lock ( ) . unwrap ( ) ;
191+ if state. complete {
192+ Poll :: Ready ( ( ) )
193+ } else {
194+ let waker = cx. waker ( ) . clone ( ) ;
195+ state. callbacks . push ( Box :: new ( std_future:: StdWaker ( waker) ) ) ;
196+ Poll :: Pending
197+ }
97198 }
98199}
99200
100201#[ cfg( test) ]
101202mod tests {
203+ use super :: * ;
204+ use core:: sync:: atomic:: { AtomicBool , Ordering } ;
205+ use core:: future:: Future as FutureTrait ;
206+ use core:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
207+
102208 #[ cfg( feature = "std" ) ]
103209 #[ test]
104210 fn test_wait_timeout ( ) {
105- use super :: * ;
106211 use sync:: Arc ;
107- use core:: sync:: atomic:: { AtomicBool , Ordering } ;
108212 use std:: thread;
109213
110214 let persistence_notifier = Arc :: new ( Notifier :: new ( ) ) ;
@@ -114,10 +218,9 @@ mod tests {
114218 let exit_thread_clone = exit_thread. clone ( ) ;
115219 thread:: spawn ( move || {
116220 loop {
117- let & ( ref persist_mtx, ref cnd) = & thread_notifier. lock ;
118- let mut lock = persist_mtx. lock ( ) . unwrap ( ) ;
119- * lock = true ;
120- cnd. notify_all ( ) ;
221+ let mut lock = thread_notifier. notify_pending . lock ( ) . unwrap ( ) ;
222+ lock. 0 = true ;
223+ thread_notifier. condvar . notify_all ( ) ;
121224
122225 if exit_thread_clone. load ( Ordering :: SeqCst ) {
123226 break
@@ -146,4 +249,84 @@ mod tests {
146249 }
147250 }
148251 }
252+
253+ #[ test]
254+ fn test_future_callbacks ( ) {
255+ let future = Future {
256+ state : Arc :: new ( Mutex :: new ( FutureState {
257+ callbacks : Vec :: new ( ) ,
258+ complete : false ,
259+ } ) )
260+ } ;
261+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
262+ let callback_ref = Arc :: clone ( & callback) ;
263+ future. register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
264+
265+ assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
266+ future. state . lock ( ) . unwrap ( ) . complete ( ) ;
267+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
268+ future. state . lock ( ) . unwrap ( ) . complete ( ) ;
269+ }
270+
271+ #[ test]
272+ fn test_pre_completed_future_callbacks ( ) {
273+ let future = Future {
274+ state : Arc :: new ( Mutex :: new ( FutureState {
275+ callbacks : Vec :: new ( ) ,
276+ complete : false ,
277+ } ) )
278+ } ;
279+ future. state . lock ( ) . unwrap ( ) . complete ( ) ;
280+
281+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
282+ let callback_ref = Arc :: clone ( & callback) ;
283+ future. register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
284+
285+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
286+ assert ! ( future. state. lock( ) . unwrap( ) . callbacks. is_empty( ) ) ;
287+ }
288+
289+ // Rather annoyingly, there's no safe way in Rust std to construct a Waker despite it being
290+ // totally possible to construct from a trait implementation (though somewhat less effecient
291+ // compared to a raw VTable). Instead, we have to write out a lot of boilerplate to build a
292+ // waker, which we do here with a trivial Arc<AtomicBool> data element to track woke-ness.
293+ const WAKER_V_TABLE : RawWakerVTable = RawWakerVTable :: new ( waker_clone, wake, wake_by_ref, drop) ;
294+ unsafe fn wake_by_ref ( ptr : * const ( ) ) { let p = ptr as * const Arc < AtomicBool > ; assert ! ( !( * p) . fetch_or( true , Ordering :: SeqCst ) ) ; }
295+ unsafe fn drop ( ptr : * const ( ) ) { let p = ptr as * mut Arc < AtomicBool > ; Box :: from_raw ( p) ; }
296+ unsafe fn wake ( ptr : * const ( ) ) { wake_by_ref ( ptr) ; drop ( ptr) ; }
297+ unsafe fn waker_clone ( ptr : * const ( ) ) -> RawWaker {
298+ let p = ptr as * const Arc < AtomicBool > ;
299+ RawWaker :: new ( Box :: into_raw ( Box :: new ( Arc :: clone ( & * p) ) ) as * const ( ) , & WAKER_V_TABLE )
300+ }
301+
302+ fn create_waker ( ) -> ( Arc < AtomicBool > , Waker ) {
303+ let a = Arc :: new ( AtomicBool :: new ( false ) ) ;
304+ let waker = unsafe { Waker :: from_raw ( waker_clone ( ( & a as * const Arc < AtomicBool > ) as * const ( ) ) ) } ;
305+ ( a, waker)
306+ }
307+
308+ #[ test]
309+ fn test_future ( ) {
310+ let mut future = Future {
311+ state : Arc :: new ( Mutex :: new ( FutureState {
312+ callbacks : Vec :: new ( ) ,
313+ complete : false ,
314+ } ) )
315+ } ;
316+ let mut second_future = Future { state : Arc :: clone ( & future. state ) } ;
317+
318+ let ( woken, waker) = create_waker ( ) ;
319+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Pending ) ;
320+ assert ! ( !woken. load( Ordering :: SeqCst ) ) ;
321+
322+ let ( second_woken, second_waker) = create_waker ( ) ;
323+ assert_eq ! ( Pin :: new( & mut second_future) . poll( & mut Context :: from_waker( & second_waker) ) , Poll :: Pending ) ;
324+ assert ! ( !second_woken. load( Ordering :: SeqCst ) ) ;
325+
326+ future. state . lock ( ) . unwrap ( ) . complete ( ) ;
327+ assert ! ( woken. load( Ordering :: SeqCst ) ) ;
328+ assert ! ( second_woken. load( Ordering :: SeqCst ) ) ;
329+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Ready ( ( ) ) ) ;
330+ assert_eq ! ( Pin :: new( & mut second_future) . poll( & mut Context :: from_waker( & second_waker) ) , Poll :: Ready ( ( ) ) ) ;
331+ }
149332}
0 commit comments