1818import inspect
1919from copy import deepcopy
2020from functools import partial
21- from typing import Any , Callable , List , Optional , Tuple , Union
21+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
2222
2323import torch
2424import torch .nn .utils .prune as pytorch_prune
2727from pytorch_lightning import _logger as log
2828from pytorch_lightning .callbacks .base import Callback
2929from pytorch_lightning .core .lightning import LightningModule
30- from pytorch_lightning .utilities import rank_zero_only
30+ from pytorch_lightning .utilities . distributed import rank_zero_debug , rank_zero_only
3131from pytorch_lightning .utilities .exceptions import MisconfigurationException
3232
3333_PYTORCH_PRUNING_FUNCTIONS = {
@@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
246246 def _wrap_pruning_fn (pruning_fn , ** kwargs ):
247247 return partial (pruning_fn , ** kwargs )
248248
249- def make_pruning_permanent (self ):
250- """ Makes ``parameters_to_prune`` current pruning permanent. """
251- for module , param_name in self ._parameters_to_prune :
252- try :
253- pytorch_prune .remove (module , param_name )
254- except ValueError :
255- # pruning already made permanent
256- pass
249+ def make_pruning_permanent (self , pl_module : LightningModule ):
250+ """
251+ Removes pruning buffers from any pruned modules
252+
253+ Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
254+ """
255+ for _ , module in pl_module .named_modules ():
256+ for k in list (module ._forward_pre_hooks ):
257+ hook = module ._forward_pre_hooks [k ]
258+ if isinstance (hook , pytorch_prune .BasePruningMethod ):
259+ hook .remove (module )
260+ del module ._forward_pre_hooks [k ]
257261
258262 def _restore_original_weights (self , module : nn .Module , orig_module : nn .Module , tensor_name : str ):
259263 trained = getattr (module , tensor_name )
@@ -351,7 +355,7 @@ def _log_sparsity_stats(
351355 f" { curr_mask_zeros } ({ curr_mask_zeros / curr_mask_size :.2%} )"
352356 )
353357
354- def on_before_accelerator_backend_setup (self , trainer , pl_module ):
358+ def on_before_accelerator_backend_setup (self , trainer , pl_module : LightningModule ):
355359 parameters_to_prune = self .sanitize_parameters_to_prune (
356360 pl_module , self ._parameters_to_prune , parameter_names = self ._parameter_names
357361 )
@@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
367371 self ._original_layers .setdefault (id_ , {"data" : deepcopy (module ), "names" : []})
368372 self ._original_layers [id_ ]["names" ].append ((i , name ))
369373
370- def on_train_epoch_end (self , trainer , pl_module , * args ):
374+ def on_train_epoch_end (self , trainer , pl_module : LightningModule , outputs ):
371375 current_epoch = trainer .current_epoch
372376 prune = self ._apply_pruning (current_epoch ) if isinstance (self ._apply_pruning , Callable ) else self ._apply_pruning
373377 amount = self .amount (current_epoch ) if isinstance (self .amount , Callable ) else self .amount
@@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
381385 ):
382386 self .apply_lottery_ticket_hypothesis ()
383387
384- def on_train_end (self , * args ):
388+ def on_train_end (self , trainer , pl_module : LightningModule ):
385389 if self ._make_pruning_permanent :
386- self .make_pruning_permanent ()
390+ rank_zero_debug ("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint." )
391+ self .make_pruning_permanent (pl_module )
387392
388- def on_save_checkpoint (self , * args ):
393+ def on_save_checkpoint (self , trainer , pl_module : LightningModule , checkpoint : Dict [ str , Any ] ):
389394 if self ._make_pruning_permanent :
390- self .make_pruning_permanent ()
395+ rank_zero_debug ("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint." )
396+ prev_device = pl_module .device
397+ # prune a copy so training can continue with the same buffers
398+ copy = deepcopy (pl_module .to ("cpu" ))
399+ self .make_pruning_permanent (copy )
400+ checkpoint ["state_dict" ] = copy .state_dict ()
401+ pl_module .to (prev_device )
391402
392403 @staticmethod
393404 def sanitize_parameters_to_prune (
0 commit comments