33import  re 
44from  collections  import  defaultdict 
55from  itertools  import  chain 
6- from  typing  import  Any , Callable , Dict , Iterator , Tuple , Type , Union 
6+ from  typing  import  Any , Callable , Dict , Iterator , Optional ,  Tuple , Type , Union 
77
88import  torch 
9+ import  torch .utils .checkpoint 
910from  torch  import  nn  as  nn 
10- from  torch .utils .checkpoint  import  checkpoint 
11+ 
12+ from  timm .layers  import  use_reentrant_ckpt 
13+ 
1114
1215__all__  =  ['model_parameters' , 'named_apply' , 'named_modules' , 'named_modules_with_params' , 'adapt_input_conv' ,
13-            'group_with_matcher' , 'group_modules' , 'group_parameters' , 'flatten_modules' , 'checkpoint_seq' ]
16+            'group_with_matcher' , 'group_modules' , 'group_parameters' , 'flatten_modules' , 'checkpoint_seq' ,  'checkpoint' ]
1417
1518
1619def  model_parameters (model : nn .Module , exclude_head : bool  =  False ):
@@ -183,13 +186,35 @@ def flatten_modules(
183186                yield  name , module 
184187
185188
189+ def  checkpoint (
190+     function ,
191+     * args ,
192+     use_reentrant : Optional [bool ] =  None ,
193+     ** kwargs ,
194+ ):
195+     """ checkpoint wrapper fn 
196+ 
197+     A thin wrapper around torch.utils.checkpoint.checkpoint to default 
198+     use_reentrant to False 
199+     """ 
200+     if  use_reentrant  is  None :
201+         use_reentrant  =  use_reentrant_ckpt ()
202+ 
203+     return  torch .utils .checkpoint .checkpoint (
204+         function ,
205+         * args ,
206+         use_reentrant = use_reentrant ,
207+         ** kwargs ,
208+     )
209+ 
210+ 
186211def  checkpoint_seq (
187212        functions ,
188213        x ,
189-         every = 1 ,
190-         flatten = False ,
191-         skip_last = False ,
192-         preserve_rng_state = True 
214+         every :  int   =   1 ,
215+         flatten :  bool   =   False ,
216+         skip_last :  bool   =   False ,
217+         use_reentrant :  Optional [ bool ]  =   None , 
193218):
194219    r"""A helper function for checkpointing sequential models. 
195220
@@ -215,10 +240,9 @@ def checkpoint_seq(
215240        functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. 
216241        x: A Tensor that is input to :attr:`functions` 
217242        every: checkpoint every-n functions (default: 1) 
218-         flatten (bool): flatten nn.Sequential of nn.Sequentials 
219-         skip_last (bool): skip checkpointing the last function in the sequence if True 
220-         preserve_rng_state (bool, optional, default=True):  Omit stashing and restoring 
221-             the RNG state during each checkpoint. 
243+         flatten: flatten nn.Sequential of nn.Sequentials 
244+         skip_last: skip checkpointing the last function in the sequence if True 
245+         use_reentrant: Use re-entrant checkpointing 
222246
223247    Returns: 
224248        Output of running :attr:`functions` sequentially on :attr:`*inputs` 
@@ -227,6 +251,9 @@ def checkpoint_seq(
227251        >>> model = nn.Sequential(...) 
228252        >>> input_var = checkpoint_seq(model, input_var, every=2) 
229253    """ 
254+     if  use_reentrant  is  None :
255+         use_reentrant  =  use_reentrant_ckpt ()
256+ 
230257    def  run_function (start , end , functions ):
231258        def  forward (_x ):
232259            for  j  in  range (start , end  +  1 ):
@@ -247,7 +274,11 @@ def forward(_x):
247274    end  =  - 1 
248275    for  start  in  range (0 , num_checkpointed , every ):
249276        end  =  min (start  +  every  -  1 , num_checkpointed  -  1 )
250-         x  =  checkpoint (run_function (start , end , functions ), x , preserve_rng_state = preserve_rng_state )
277+         x  =  torch .utils .checkpoint .checkpoint (
278+             run_function (start , end , functions ),
279+             x ,
280+             use_reentrant = use_reentrant ,
281+         )
251282    if  skip_last :
252283        return  run_function (end  +  1 , len (functions ) -  1 , functions )(x )
253284    return  x 
0 commit comments