43
43
from torch .utils .hooks import RemovableHandle
44
44
45
45
46
- # Check if module backward hook can safely be used for the module that produced
47
- # this inputs / outputs mapping
48
- def _check_valid_module (inputs_grad_fn , outputs ) -> bool :
49
- def is_output_cloned (output_fn , input_grad_fn ) -> bool :
50
- """
51
- Checks if the output has been cloned. This happens especially in case of
52
- layer deeplift.
53
- """
54
- return (
55
- output_fn [0 ].next_functions is not None
56
- and output_fn [0 ].next_functions [0 ][0 ] == input_grad_fn
57
- )
58
-
59
- curr_fn = outputs .grad_fn
60
- first_next = curr_fn .next_functions [0 ]
61
- try :
62
- # if `inputs` in the input to the network then the grad_fn is None and
63
- # for that input backward_hook isn't computed. That's the reason why we
64
- # need to check on `inputs_grad_fns[first_next[1]]` being None.
65
- return (
66
- inputs_grad_fn is None
67
- or first_next [0 ] == inputs_grad_fn
68
- or is_output_cloned (first_next , inputs_grad_fn )
69
- )
70
- except IndexError :
71
- return False
72
-
73
-
74
46
class DeepLift (GradientAttribution ):
75
47
r"""
76
48
Implements DeepLIFT algorithm based on the following paper:
@@ -112,10 +84,7 @@ def __init__(
112
84
r"""
113
85
Args:
114
86
115
- model (nn.Module): The reference to PyTorch model instance. Model cannot
116
- contain any in-place nonlinear submodules; these are not
117
- supported by the register_full_backward_hook PyTorch API
118
- starting from PyTorch v1.9.
87
+ model (nn.Module): The reference to PyTorch model instance.
119
88
multiply_by_inputs (bool, optional): Indicates whether to factor
120
89
model inputs' multiplier in the final attribution scores.
121
90
In the literature this is also known as local vs global
@@ -430,25 +399,6 @@ def _forward_pre_hook(
430
399
"""
431
400
inputs = _format_tensor_into_tuples (inputs )
432
401
module .input = inputs [0 ].clone ().detach ()
433
- module .input_grad_fns = inputs [0 ].grad_fn # type: ignore
434
-
435
- def tensor_backward_hook (grad ):
436
- if module .saved_grad is None :
437
- raise RuntimeError (
438
- """Module {} was detected as not supporting correctly module
439
- backward hook. You should modify your hook to ignore the given
440
- grad_inputs (recompute them by hand if needed) and save the
441
- newly computed grad_inputs in module.saved_grad. See MaxPool1d
442
- as an example.""" .format (
443
- module
444
- )
445
- )
446
- return module .saved_grad
447
-
448
- # the hook is set by default but it will be used only for
449
- # failure cases and will be removed otherwise
450
- handle = inputs [0 ].register_hook (tensor_backward_hook )
451
- module .input_hook = handle
452
402
453
403
def _forward_hook (
454
404
self ,
@@ -462,30 +412,13 @@ def _forward_hook(
462
412
"""
463
413
outputs = _format_tensor_into_tuples (outputs )
464
414
module .output = outputs [0 ].clone ().detach ()
465
- if not _check_valid_module (module .input_grad_fns , outputs [0 ]):
466
- warnings .warn (
467
- """An invalid module {} is detected. Saved gradients will
468
- be used as the gradients of the module's input tensor.
469
- See MaxPool1d as an example.""" .format (
470
- module
471
- )
472
- )
473
- module .is_invalid = True # type: ignore
474
- module .saved_grad = None # type: ignore
475
- self .forward_handles .append (cast (RemovableHandle , module .input_hook ))
476
- else :
477
- module .is_invalid = False # type: ignore
478
- # removing the hook if there is no failure case
479
- cast (RemovableHandle , module .input_hook ).remove ()
480
- del module .input_hook
481
- del module .input_grad_fns
482
415
483
416
def _backward_hook (
484
417
self ,
485
418
module : Module ,
486
- grad_input : Union [ Tensor , Tuple [ Tensor , ...]] ,
487
- grad_output : Union [ Tensor , Tuple [ Tensor , ...]] ,
488
- ):
419
+ grad_input : Tensor ,
420
+ grad_output : Tensor ,
421
+ ) -> Tensor :
489
422
r"""
490
423
`grad_input` is the gradient of the neuron with respect to its input
491
424
`grad_output` is the gradient of the neuron with respect to its output
@@ -506,15 +439,14 @@ def _backward_hook(
506
439
"Please, ensure that module is being used only once in the "
507
440
"network." .format (module )
508
441
)
509
- multipliers = tuple (
510
- SUPPORTED_NON_LINEAR [type (module )](
511
- module ,
512
- module .input ,
513
- module .output ,
514
- grad_input ,
515
- grad_output ,
516
- eps = self .eps ,
517
- )
442
+
443
+ multipliers = SUPPORTED_NON_LINEAR [type (module )](
444
+ module ,
445
+ module .input ,
446
+ module .output ,
447
+ grad_input ,
448
+ grad_output ,
449
+ eps = self .eps ,
518
450
)
519
451
# remove all the properies that we set for the inputs and output
520
452
del module .input
@@ -545,10 +477,10 @@ def _register_hooks(
545
477
# adds forward hook to leaf nodes that are non-linear
546
478
forward_handle = module .register_forward_hook (self ._forward_hook )
547
479
pre_forward_handle = module .register_forward_pre_hook (self ._forward_pre_hook )
548
- backward_handle = _register_backward_hook (module , self ._backward_hook , self )
480
+ backward_handles = _register_backward_hook (module , self ._backward_hook , self )
549
481
self .forward_handles .append (forward_handle )
550
482
self .forward_handles .append (pre_forward_handle )
551
- self .backward_handles .append ( backward_handle )
483
+ self .backward_handles .extend ( backward_handles )
552
484
553
485
def _remove_hooks (self , extra_hooks_to_remove : List [RemovableHandle ]) -> None :
554
486
for handle in extra_hooks_to_remove :
@@ -627,9 +559,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
627
559
r"""
628
560
Args:
629
561
630
- model (nn.Module): The reference to PyTorch model instance. Model cannot
631
- contain any in-place nonlinear submodules; these are not
632
- supported by the register_full_backward_hook PyTorch API.
562
+ model (nn.Module): The reference to PyTorch model instance.
633
563
multiply_by_inputs (bool, optional): Indicates whether to factor
634
564
model inputs' multiplier in the final attribution scores.
635
565
In the literature this is also known as local vs global
@@ -941,26 +871,18 @@ def nonlinear(
941
871
grad_input : Tensor ,
942
872
grad_output : Tensor ,
943
873
eps : float = 1e-10 ,
944
- ):
874
+ ) -> Tensor :
945
875
r"""
946
876
grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij)
947
877
grad_output: (dLoss / dlayer_out)
948
878
https://github.com/pytorch/pytorch/issues/12331
949
879
"""
950
880
delta_in , delta_out = _compute_diffs (inputs , outputs )
951
881
952
- new_grad_inp = list (grad_input )
953
-
954
- # supported non-linear modules take only single tensor as input hence accessing
955
- # only the first element in `grad_input` and `grad_output`
956
- new_grad_inp [0 ] = torch .where (
957
- abs (delta_in ) < eps , new_grad_inp [0 ], grad_output [0 ] * delta_out / delta_in
882
+ new_grad_inp = torch .where (
883
+ abs (delta_in ) < eps , grad_input , grad_output * delta_out / delta_in
958
884
)
959
885
960
- # If the module is invalid, save the newly computed gradients
961
- # The original_grad_input will be overridden later in the Tensor hook
962
- if module .is_invalid :
963
- module .saved_grad = new_grad_inp [0 ]
964
886
return new_grad_inp
965
887
966
888
@@ -974,15 +896,14 @@ def softmax(
974
896
):
975
897
delta_in , delta_out = _compute_diffs (inputs , outputs )
976
898
977
- new_grad_inp = list (grad_input )
978
899
grad_input_unnorm = torch .where (
979
- abs (delta_in ) < eps , new_grad_inp [ 0 ] , grad_output [ 0 ] * delta_out / delta_in
900
+ abs (delta_in ) < eps , grad_input , grad_output * delta_out / delta_in
980
901
)
981
902
# normalizing
982
- n = grad_input [ 0 ] .numel ()
903
+ n = grad_input .numel ()
983
904
984
905
# updating only the first half
985
- new_grad_inp [ 0 ] = grad_input_unnorm - grad_input_unnorm .sum () * 1 / n
906
+ new_grad_inp = grad_input_unnorm - grad_input_unnorm .sum () * 1 / n
986
907
return new_grad_inp
987
908
988
909
@@ -1073,7 +994,7 @@ def maxpool(
1073
994
module .ceil_mode ,
1074
995
True ,
1075
996
)
1076
- grad_output_updated = grad_output [ 0 ]
997
+ grad_output_updated = grad_output
1077
998
unpool_grad_out_delta , unpool_grad_out_ref_delta = torch .chunk (
1078
999
unpool_func (
1079
1000
grad_output_updated * delta_out ,
@@ -1089,20 +1010,7 @@ def maxpool(
1089
1010
unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta
1090
1011
unpool_grad_out_delta = torch .cat (2 * [unpool_grad_out_delta ])
1091
1012
1092
- # If the module is invalid, we need to recompute the grad_input
1093
- if module .is_invalid :
1094
- original_grad_input = grad_input
1095
- grad_input = (
1096
- unpool_func (
1097
- grad_output_updated ,
1098
- indices ,
1099
- module .kernel_size ,
1100
- module .stride ,
1101
- module .padding ,
1102
- list (cast (torch .Size , module .input .shape )),
1103
- ),
1104
- )
1105
- if grad_input [0 ].shape != inputs .shape :
1013
+ if grad_input .shape != inputs .shape :
1106
1014
raise AssertionError (
1107
1015
"A problem occurred during maxpool modul's backward pass. "
1108
1016
"The gradients with respect to inputs include only a "
@@ -1118,13 +1026,7 @@ def maxpool(
1118
1026
new_grad_inp = torch .where (
1119
1027
abs (delta_in ) < eps , grad_input [0 ], unpool_grad_out_delta / delta_in
1120
1028
)
1121
- # If the module is invalid, save the newly computed gradients
1122
- # The original_grad_input will be overridden later in the Tensor hook
1123
- if module .is_invalid :
1124
- module .saved_grad = new_grad_inp
1125
- return original_grad_input
1126
- else :
1127
- return (new_grad_inp ,)
1029
+ return new_grad_inp
1128
1030
1129
1031
1130
1032
def _compute_diffs (inputs : Tensor , outputs : Tensor ) -> Tuple [Tensor , Tensor ]:
0 commit comments