Skip to content

Commit c755007

Browse files
committed
Merge branch 'master' into optim-wip
2 parents 5a5b285 + bb64aef commit c755007

File tree

140 files changed

+6831
-891
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+6831
-891
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
tutorials/* linguist-documentation
2+
*.pt binary

CONTRIBUTING.md

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
# Contributing to Captum
22

3-
We want to make contributing to Captum as easy and transparent as possible.
3+
Thank you for your interest in contributing to Captum! We want to make contributing to Captum as easy and transparent as possible.
4+
Before you begin writing code, it is important that you share your intention to contribute with the team, based on the type of contribution:
5+
6+
7+
1. You want to propose and implement a new algorithm, add a new feature or fix a bug. This can be both code and documentation proposals.
8+
1. For all non-outstanding features, bug-fixes and algorithms in the Captum issue list (https://github.com/pytorch/captum/issues) please create an issue first.
9+
2. If the implementation requires API or any other major code changes (new files, packages or algorithms), we will likely request a design document to review and discuss the design and implementation before making changes. An example design document for LIME can be found here (https://github.com/pytorch/captum/issues/467).
10+
3. Once we agree that the plan looks good or confirmed that the change is small enough to not require a detailed design discussion, go ahead and implement it!
11+
12+
2. You want to implement a feature or bug-fix for an outstanding issue.
13+
14+
1. Search for your issue in the Captum issue list (https://github.com/pytorch/captum/issues).
15+
2. Pick an issue and comment that you'd like to work on the feature or bug-fix.
16+
3. If you need more context on a particular issue, please ask and we’ll be happy to help.
17+
18+
Once you implement and test your feature or bug-fix, please submit a Pull Request to https://github.com/pytorch/captum (https://github.com/pytorch/pytorch).
19+
20+
This document covers some of the techical aspects of contributing to Captum. More details on what we are looking for in the contributions can be found in the [Contributing Guidelines](https://captum.ai/docs/contribution_guide).
421

522

623
## Development installation
@@ -29,7 +46,7 @@ from the repository root. No additional configuration should be needed (see the
2946
[black documentation](https://black.readthedocs.io/en/stable/installation_and_usage.html#usage)
3047
for advanced usage).
3148

32-
Captum also uses [isort](https://github.com/timothycrosley/isort) to sort imports
49+
Captum also uses [isort](https://github.com/timothycrosley/isort) to sort imports
3350
alphabetically and separate into sections. isort is installed easily via
3451
pip using `pip install isort`, and run locally by calling
3552
```bash
@@ -45,18 +62,18 @@ CircleCI will fail on your PR if it does not adhere to the black or flake8 forma
4562

4663
Captum is fully typed using python 3.6+
4764
[type hints](https://www.python.org/dev/peps/pep-0484/).
48-
We expect any contributions to also use proper type annotations, and we enforce
49-
consistency of these in our continuous integration tests.
65+
We expect any contributions to also use proper type annotations, and we enforce
66+
consistency of these in our continuous integration tests.
5067

51-
To type check your code locally, install [mypy](https://github.com/python/mypy),
68+
To type check your code locally, install [mypy](https://github.com/python/mypy),
5269
which can be done with pip using `pip install "mypy>=0.760"`
5370
Then run this script from the repository root:
5471
```bash
5572
./scripts/run_mypy.sh
5673
```
57-
Note that we expect mypy to have version 0.760 or higher, and when type checking, use PyTorch 1.4 or
58-
higher due to fixes to PyTorch type hints available in 1.4. We also use the Literal feature which is
59-
available only in Python 3.8 or above. If type-checking using a previous version of Python, you will
74+
Note that we expect mypy to have version 0.760 or higher, and when type checking, use PyTorch 1.4 or
75+
higher due to fixes to PyTorch type hints available in 1.4. We also use the Literal feature which is
76+
available only in Python 3.8 or above. If type-checking using a previous version of Python, you will
6077
need to install the typing-extension package which can be done with pip using `pip install typing-extensions`.
6178

6279
#### Unit Tests

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,12 @@ Convergence Delta: tensor([2.3842e-07, -4.7684e-07])
175175
The algorithm outputs an attribution score for each input element and a
176176
convergence delta. The lower the absolute value of the convergence delta the better
177177
is the approximation. If we choose not to return delta,
178-
we can simply not provide `return_convergence_delta` input
178+
we can simply not provide the `return_convergence_delta` input
179179
argument. The absolute value of the returned deltas can be interpreted as an
180180
approximation error for each input sample.
181181
It can also serve as a proxy of how accurate the integral approximation for given
182182
inputs and baselines is.
183-
If the approximation error is large, we can try larger number of integral
183+
If the approximation error is large, we can try a larger number of integral
184184
approximation steps by setting `n_steps` to a larger value. Not all algorithms
185185
return approximation error. Those which do, though, compute it based on the
186186
completeness property of the algorithms.
@@ -224,7 +224,7 @@ in order to get per example average delta.
224224

225225

226226
Below is an example of how we can apply `DeepLift` and `DeepLiftShap` on the
227-
`ToyModel` described above. Current implementation of DeepLift supports only
227+
`ToyModel` described above. The current implementation of DeepLift supports only the
228228
`Rescale` rule.
229229
For more details on alternative implementations, please see the [DeepLift paper](https://arxiv.org/abs/1704.02685).
230230

@@ -286,7 +286,7 @@ In order to smooth and improve the quality of the attributions we can run
286286
to smoothen the attributions by aggregating them for multiple noisy
287287
samples that were generated by adding gaussian noise.
288288

289-
Here is an example how we can use `NoiseTunnel` with `IntegratedGradients`.
289+
Here is an example of how we can use `NoiseTunnel` with `IntegratedGradients`.
290290

291291
```python
292292
ig = IntegratedGradients(model)
@@ -324,7 +324,7 @@ In this case, we choose to analyze the first neuron in the linear layer.
324324

325325
```python
326326
nc = NeuronConductance(model, model.lin1)
327-
attributions = nc.attribute(input, neuron_index=1, target=0)
327+
attributions = nc.attribute(input, neuron_selector=1, target=0)
328328
print('Neuron Attributions:', attributions)
329329
```
330330
Output
@@ -338,7 +338,7 @@ It is an extension of path integrated gradients for hidden layers and holds the
338338
completeness property as well.
339339

340340
It doesn't attribute the contribution scores to the input features
341-
but shows the importance of each neuron in selected layer.
341+
but shows the importance of each neuron in the selected layer.
342342
```python
343343
lc = LayerConductance(model, model.lin1)
344344
attributions, delta = lc.attribute(input, baselines=baseline, target=0, return_convergence_delta=True)
@@ -433,6 +433,7 @@ Image Classification Models and Saliency Maps, K. Simonyan, et. al. 2014](https:
433433
* `Occlusion`: [Visualizing and Understanding Convolutional Networks](https://arxiv.org/abs/1311.2901)
434434
* `Shapely Value`: [A value for n-person games. Contributions to the Theory of Games 2.28 (1953): 307-317](https://apps.dtic.mil/dtic/tr/fulltext/u2/604084.pdf)
435435
* `Shapely Value Sampling`: [Polynomial calculation of the Shapley value based on sampling](https://www.sciencedirect.com/science/article/pii/S0305054808000804)
436+
* `Infidelity and Sensitivity`: [On the (In)fidelity and Sensitivity for Explanations](https://arxiv.org/abs/1901.09392)
436437

437438
More details about the above mentioned [algorithms](https://captum.ai/docs/algorithms) and their pros and cons can be found on our [web-site](https://captum.ai/docs/algorithms_comparison_matrix).
438439

captum/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env python3
22

3-
__version__ = "0.3.0"
3+
__version__ = "0.4.0"

captum/_utils/common.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from torch import Tensor, device
1010
from torch.nn import Module
1111

12-
from .._utils.typing import (
12+
from captum._utils.typing import (
1313
BaselineType,
1414
Literal,
1515
TargetType,
16+
TensorOrTupleOfTensorsGeneric,
1617
TupleOrTensorOrBoolGeneric,
1718
)
1819

@@ -353,6 +354,43 @@ def _format_output(
353354
return output if is_inputs_tuple else output[0]
354355

355356

357+
@typing.overload
358+
def _format_outputs(
359+
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
360+
) -> Union[Tensor, Tuple[Tensor, ...]]:
361+
...
362+
363+
364+
@typing.overload
365+
def _format_outputs(
366+
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
367+
) -> List[Union[Tensor, Tuple[Tensor, ...]]]:
368+
...
369+
370+
371+
@typing.overload
372+
def _format_outputs(
373+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
374+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
375+
...
376+
377+
378+
def _format_outputs(
379+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
380+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
381+
assert isinstance(outputs, list), "Outputs must be a list"
382+
assert is_multiple_inputs or len(outputs) == 1, (
383+
"outputs should contain multiple inputs or have a single output"
384+
f"however the number of outputs is: {len(outputs)}"
385+
)
386+
387+
return (
388+
[_format_output(len(output) > 1, output) for output in outputs]
389+
if is_multiple_inputs
390+
else _format_output(len(outputs[0]) > 1, outputs[0])
391+
)
392+
393+
356394
def _run_forward(
357395
forward_func: Callable,
358396
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -417,6 +455,15 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
417455
raise AssertionError("Target type %r is not valid." % target)
418456

419457

458+
def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool:
459+
if isinstance(target, tuple):
460+
for index in target:
461+
if isinstance(index, slice):
462+
return True
463+
return False
464+
return isinstance(target, slice)
465+
466+
420467
def _verify_select_column(
421468
output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]]
422469
) -> Tensor:
@@ -427,6 +474,24 @@ def _verify_select_column(
427474
return output[(slice(None), *target)]
428475

429476

477+
def _verify_select_neuron(
478+
layer_output: Tuple[Tensor, ...],
479+
selector: Union[int, Tuple[Union[int, slice], ...], Callable],
480+
) -> Tensor:
481+
if callable(selector):
482+
return selector(layer_output if len(layer_output) > 1 else layer_output[0])
483+
484+
assert len(layer_output) == 1, (
485+
"Cannot select neuron index from layer with multiple tensors,"
486+
"consider providing a neuron selector function instead."
487+
)
488+
489+
selected_neurons = _verify_select_column(layer_output[0], selector)
490+
if _contains_slice(selector):
491+
return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1)
492+
return selected_neurons
493+
494+
430495
def _extract_device(
431496
module: Module,
432497
hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]],
@@ -520,3 +585,9 @@ def _sort_key_list(
520585
"devices with computed tensors."
521586

522587
return out_list
588+
589+
590+
def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
591+
if isinstance(inp, Tensor):
592+
return inp.flatten()
593+
return torch.cat([single_inp.flatten() for single_inp in inp])

0 commit comments

Comments
 (0)