4343from torch .utils .hooks import RemovableHandle
4444
4545
46- # Check if module backward hook can safely be used for the module that produced
47- # this inputs / outputs mapping
48- def _check_valid_module (inputs_grad_fn , outputs ) -> bool :
49- def is_output_cloned (output_fn , input_grad_fn ) -> bool :
50- """
51- Checks if the output has been cloned. This happens especially in case of
52- layer deeplift.
53- """
54- return (
55- output_fn [0 ].next_functions is not None
56- and output_fn [0 ].next_functions [0 ][0 ] == input_grad_fn
57- )
58-
59- curr_fn = outputs .grad_fn
60- first_next = curr_fn .next_functions [0 ]
61- try :
62- # if `inputs` in the input to the network then the grad_fn is None and
63- # for that input backward_hook isn't computed. That's the reason why we
64- # need to check on `inputs_grad_fns[first_next[1]]` being None.
65- return (
66- inputs_grad_fn is None
67- or first_next [0 ] == inputs_grad_fn
68- or is_output_cloned (first_next , inputs_grad_fn )
69- )
70- except IndexError :
71- return False
72-
73-
7446class DeepLift (GradientAttribution ):
7547 r"""
7648 Implements DeepLIFT algorithm based on the following paper:
@@ -112,10 +84,7 @@ def __init__(
11284 r"""
11385 Args:
11486
115- model (nn.Module): The reference to PyTorch model instance. Model cannot
116- contain any in-place nonlinear submodules; these are not
117- supported by the register_full_backward_hook PyTorch API
118- starting from PyTorch v1.9.
87+ model (nn.Module): The reference to PyTorch model instance.
11988 multiply_by_inputs (bool, optional): Indicates whether to factor
12089 model inputs' multiplier in the final attribution scores.
12190 In the literature this is also known as local vs global
@@ -430,25 +399,6 @@ def _forward_pre_hook(
430399 """
431400 inputs = _format_tensor_into_tuples (inputs )
432401 module .input = inputs [0 ].clone ().detach ()
433- module .input_grad_fns = inputs [0 ].grad_fn # type: ignore
434-
435- def tensor_backward_hook (grad ):
436- if module .saved_grad is None :
437- raise RuntimeError (
438- """Module {} was detected as not supporting correctly module
439- backward hook. You should modify your hook to ignore the given
440- grad_inputs (recompute them by hand if needed) and save the
441- newly computed grad_inputs in module.saved_grad. See MaxPool1d
442- as an example.""" .format (
443- module
444- )
445- )
446- return module .saved_grad
447-
448- # the hook is set by default but it will be used only for
449- # failure cases and will be removed otherwise
450- handle = inputs [0 ].register_hook (tensor_backward_hook )
451- module .input_hook = handle
452402
453403 def _forward_hook (
454404 self ,
@@ -462,30 +412,13 @@ def _forward_hook(
462412 """
463413 outputs = _format_tensor_into_tuples (outputs )
464414 module .output = outputs [0 ].clone ().detach ()
465- if not _check_valid_module (module .input_grad_fns , outputs [0 ]):
466- warnings .warn (
467- """An invalid module {} is detected. Saved gradients will
468- be used as the gradients of the module's input tensor.
469- See MaxPool1d as an example.""" .format (
470- module
471- )
472- )
473- module .is_invalid = True # type: ignore
474- module .saved_grad = None # type: ignore
475- self .forward_handles .append (cast (RemovableHandle , module .input_hook ))
476- else :
477- module .is_invalid = False # type: ignore
478- # removing the hook if there is no failure case
479- cast (RemovableHandle , module .input_hook ).remove ()
480- del module .input_hook
481- del module .input_grad_fns
482415
483416 def _backward_hook (
484417 self ,
485418 module : Module ,
486- grad_input : Union [ Tensor , Tuple [ Tensor , ...]] ,
487- grad_output : Union [ Tensor , Tuple [ Tensor , ...]] ,
488- ):
419+ grad_input : Tensor ,
420+ grad_output : Tensor ,
421+ ) -> Tensor :
489422 r"""
490423 `grad_input` is the gradient of the neuron with respect to its input
491424 `grad_output` is the gradient of the neuron with respect to its output
@@ -506,15 +439,14 @@ def _backward_hook(
506439 "Please, ensure that module is being used only once in the "
507440 "network." .format (module )
508441 )
509- multipliers = tuple (
510- SUPPORTED_NON_LINEAR [type (module )](
511- module ,
512- module .input ,
513- module .output ,
514- grad_input ,
515- grad_output ,
516- eps = self .eps ,
517- )
442+
443+ multipliers = SUPPORTED_NON_LINEAR [type (module )](
444+ module ,
445+ module .input ,
446+ module .output ,
447+ grad_input ,
448+ grad_output ,
449+ eps = self .eps ,
518450 )
519451 # remove all the properies that we set for the inputs and output
520452 del module .input
@@ -545,10 +477,10 @@ def _register_hooks(
545477 # adds forward hook to leaf nodes that are non-linear
546478 forward_handle = module .register_forward_hook (self ._forward_hook )
547479 pre_forward_handle = module .register_forward_pre_hook (self ._forward_pre_hook )
548- backward_handle = _register_backward_hook (module , self ._backward_hook , self )
480+ backward_handles = _register_backward_hook (module , self ._backward_hook , self )
549481 self .forward_handles .append (forward_handle )
550482 self .forward_handles .append (pre_forward_handle )
551- self .backward_handles .append ( backward_handle )
483+ self .backward_handles .extend ( backward_handles )
552484
553485 def _remove_hooks (self , extra_hooks_to_remove : List [RemovableHandle ]) -> None :
554486 for handle in extra_hooks_to_remove :
@@ -627,9 +559,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
627559 r"""
628560 Args:
629561
630- model (nn.Module): The reference to PyTorch model instance. Model cannot
631- contain any in-place nonlinear submodules; these are not
632- supported by the register_full_backward_hook PyTorch API.
562+ model (nn.Module): The reference to PyTorch model instance.
633563 multiply_by_inputs (bool, optional): Indicates whether to factor
634564 model inputs' multiplier in the final attribution scores.
635565 In the literature this is also known as local vs global
@@ -941,26 +871,18 @@ def nonlinear(
941871 grad_input : Tensor ,
942872 grad_output : Tensor ,
943873 eps : float = 1e-10 ,
944- ):
874+ ) -> Tensor :
945875 r"""
946876 grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij)
947877 grad_output: (dLoss / dlayer_out)
948878 https://github.com/pytorch/pytorch/issues/12331
949879 """
950880 delta_in , delta_out = _compute_diffs (inputs , outputs )
951881
952- new_grad_inp = list (grad_input )
953-
954- # supported non-linear modules take only single tensor as input hence accessing
955- # only the first element in `grad_input` and `grad_output`
956- new_grad_inp [0 ] = torch .where (
957- abs (delta_in ) < eps , new_grad_inp [0 ], grad_output [0 ] * delta_out / delta_in
882+ new_grad_inp = torch .where (
883+ abs (delta_in ) < eps , grad_input , grad_output * delta_out / delta_in
958884 )
959885
960- # If the module is invalid, save the newly computed gradients
961- # The original_grad_input will be overridden later in the Tensor hook
962- if module .is_invalid :
963- module .saved_grad = new_grad_inp [0 ]
964886 return new_grad_inp
965887
966888
@@ -974,15 +896,14 @@ def softmax(
974896):
975897 delta_in , delta_out = _compute_diffs (inputs , outputs )
976898
977- new_grad_inp = list (grad_input )
978899 grad_input_unnorm = torch .where (
979- abs (delta_in ) < eps , new_grad_inp [ 0 ] , grad_output [ 0 ] * delta_out / delta_in
900+ abs (delta_in ) < eps , grad_input , grad_output * delta_out / delta_in
980901 )
981902 # normalizing
982- n = grad_input [ 0 ] .numel ()
903+ n = grad_input .numel ()
983904
984905 # updating only the first half
985- new_grad_inp [ 0 ] = grad_input_unnorm - grad_input_unnorm .sum () * 1 / n
906+ new_grad_inp = grad_input_unnorm - grad_input_unnorm .sum () * 1 / n
986907 return new_grad_inp
987908
988909
@@ -1073,7 +994,7 @@ def maxpool(
1073994 module .ceil_mode ,
1074995 True ,
1075996 )
1076- grad_output_updated = grad_output [ 0 ]
997+ grad_output_updated = grad_output
1077998 unpool_grad_out_delta , unpool_grad_out_ref_delta = torch .chunk (
1078999 unpool_func (
10791000 grad_output_updated * delta_out ,
@@ -1089,20 +1010,7 @@ def maxpool(
10891010 unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta
10901011 unpool_grad_out_delta = torch .cat (2 * [unpool_grad_out_delta ])
10911012
1092- # If the module is invalid, we need to recompute the grad_input
1093- if module .is_invalid :
1094- original_grad_input = grad_input
1095- grad_input = (
1096- unpool_func (
1097- grad_output_updated ,
1098- indices ,
1099- module .kernel_size ,
1100- module .stride ,
1101- module .padding ,
1102- list (cast (torch .Size , module .input .shape )),
1103- ),
1104- )
1105- if grad_input [0 ].shape != inputs .shape :
1013+ if grad_input .shape != inputs .shape :
11061014 raise AssertionError (
11071015 "A problem occurred during maxpool modul's backward pass. "
11081016 "The gradients with respect to inputs include only a "
@@ -1118,13 +1026,7 @@ def maxpool(
11181026 new_grad_inp = torch .where (
11191027 abs (delta_in ) < eps , grad_input [0 ], unpool_grad_out_delta / delta_in
11201028 )
1121- # If the module is invalid, save the newly computed gradients
1122- # The original_grad_input will be overridden later in the Tensor hook
1123- if module .is_invalid :
1124- module .saved_grad = new_grad_inp
1125- return original_grad_input
1126- else :
1127- return (new_grad_inp ,)
1029+ return new_grad_inp
11281030
11291031
11301032def _compute_diffs (inputs : Tensor , outputs : Tensor ) -> Tuple [Tensor , Tensor ]:
0 commit comments