From e5039c9aa67b2eac8c70ca3d570887a56601a1dc Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 12 Mar 2024 23:04:15 -0700 Subject: [PATCH 01/16] Lint and import fixes --- captum/insights/attr_vis/features.py | 2 +- setup.py | 2 +- tests/attr/helpers/conductance_reference.py | 2 +- tests/attr/layer/test_layer_lrp.py | 2 +- tests/attr/test_guided_grad_cam.py | 2 +- tests/attr/test_interpretable_input.py | 2 +- tests/utils/test_linear_model.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/captum/insights/attr_vis/features.py b/captum/insights/attr_vis/features.py index 8a1104c08d..fac17f8e80 100644 --- a/captum/insights/attr_vis/features.py +++ b/captum/insights/attr_vis/features.py @@ -8,7 +8,7 @@ from captum._utils.common import safe_div from captum.attr._utils import visualization as viz from captum.insights.attr_vis._utils.transforms import format_transforms -from torch._tensor import Tensor +from torch import Tensor FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution") diff --git a/setup.py b/setup.py index fd9d83730d..b4968da4b6 100755 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def report(*args): INSIGHTS_REQUIRES + TEST_REQUIRES + [ - "black==22.3.0", + "black==24.2.0", "flake8", "sphinx", "sphinx-autodoc-typehints", diff --git a/tests/attr/helpers/conductance_reference.py b/tests/attr/helpers/conductance_reference.py index c09ab99f41..6706d431b6 100644 --- a/tests/attr/helpers/conductance_reference.py +++ b/tests/attr/helpers/conductance_reference.py @@ -10,7 +10,7 @@ from captum.attr._utils.approximation_methods import approximation_parameters from captum.attr._utils.attribution import LayerAttribution from captum.attr._utils.common import _reshape_and_sum -from torch._tensor import Tensor +from torch import Tensor """ Note: This implementation of conductance follows the procedure described in the original diff --git a/tests/attr/layer/test_layer_lrp.py b/tests/attr/layer/test_layer_lrp.py index b17b8cbcf6..d7877e6225 100644 --- a/tests/attr/layer/test_layer_lrp.py +++ b/tests/attr/layer/test_layer_lrp.py @@ -9,7 +9,7 @@ from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel -from torch._tensor import Tensor +from torch import Tensor def _get_basic_config() -> Tuple[BasicModel_ConvNet_One_Conv, Tensor]: diff --git a/tests/attr/test_guided_grad_cam.py b/tests/attr/test_guided_grad_cam.py index b2affe6611..fa1f1ff0a0 100644 --- a/tests/attr/test_guided_grad_cam.py +++ b/tests/attr/test_guided_grad_cam.py @@ -8,7 +8,7 @@ from captum.attr._core.guided_grad_cam import GuidedGradCam from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv -from torch._tensor import Tensor +from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index ec0cbb90a7..8dc256cb36 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -4,7 +4,7 @@ from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized from tests.helpers.basic import assertTensorAlmostEqual, BaseTest -from torch._tensor import Tensor +from torch import Tensor class DummyTokenizer: diff --git a/tests/utils/test_linear_model.py b/tests/utils/test_linear_model.py index a72262a716..9b3f494a1a 100644 --- a/tests/utils/test_linear_model.py +++ b/tests/utils/test_linear_model.py @@ -9,7 +9,7 @@ SGDRidge, ) from tests.helpers.basic import assertTensorAlmostEqual, BaseTest -from torch._tensor import Tensor +from torch import Tensor def _evaluate(test_data, classifier) -> Dict[str, float]: From 975ddea85d67491b0ddca87a6732179a0131d253 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 00:18:17 -0700 Subject: [PATCH 02/16] Typing fixes --- tests/attr/models/test_pytext.py | 2 +- tests/concept/test_tcav.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/attr/models/test_pytext.py b/tests/attr/models/test_pytext.py index e8b0da52d2..b790826e9e 100644 --- a/tests/attr/models/test_pytext.py +++ b/tests/attr/models/test_pytext.py @@ -8,7 +8,6 @@ from typing import Dict, List, NoReturn, Optional import torch -from pytext.data.data_handler import CommonMetadata HAS_PYTEXT = True try: @@ -29,6 +28,7 @@ from pytext.models.doc_model import DocModel_Deprecated # @manual=//pytext:main_lib from pytext.models.embeddings.word_embedding import WordEmbedding from pytext.models.representations.bilstm_doc_attention import BiLSTMDocAttention + from pytext.data.data_handler import CommonMetadata except ImportError: HAS_PYTEXT = False diff --git a/tests/concept/test_tcav.py b/tests/concept/test_tcav.py index 17dec34abc..247f365a5c 100644 --- a/tests/concept/test_tcav.py +++ b/tests/concept/test_tcav.py @@ -14,7 +14,6 @@ Iterator, List, Set, - SupportsIndex, Tuple, Union, ) @@ -174,7 +173,7 @@ def __init__( self, get_tensor_from_filename_func: Callable, path: str, - num_samples: SupportsIndex = 100, + num_samples: int = 100, ) -> None: r""" Args: From 89dd9d1c0a249f1dc0a2ada6e83bc6f7dd086d16 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 00:47:08 -0700 Subject: [PATCH 03/16] Fix --- scripts/install_via_pip.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/install_via_pip.sh b/scripts/install_via_pip.sh index 27c8cab625..d1b70a34d0 100755 --- a/scripts/install_via_pip.sh +++ b/scripts/install_via_pip.sh @@ -5,6 +5,7 @@ set -e PYTORCH_NIGHTLY=false DEPLOY=false CHOSEN_TORCH_VERSION=-1 +INSTALL_MODE='test' while getopts 'ndfv:' flag; do case "${flag}" in @@ -12,6 +13,7 @@ while getopts 'ndfv:' flag; do d) DEPLOY=true ;; f) FRAMEWORKS=true ;; v) CHOSEN_TORCH_VERSION=${OPTARG};; + m) INSTALL_MODE=${OPTARG};; *) echo "usage: $0 [-n] [-d] [-f] [-v version]" >&2 exit 1 ;; esac @@ -38,7 +40,7 @@ export TERM=xterm pip install --upgrade pip --progress-bar off # install captum with dev deps -pip install -e .[dev] --progress-bar off +pip install -e .[INSTALL_MODE] --progress-bar off BUILD_INSIGHTS=1 python setup.py develop # install other frameworks if asked for and make sure this is before pytorch From 89ab7a53abd8aa983bc61e4ed26043757742490b Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 01:22:35 -0700 Subject: [PATCH 04/16] Fix pytext --- tests/attr/models/test_pytext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/attr/models/test_pytext.py b/tests/attr/models/test_pytext.py index b790826e9e..ba76611106 100644 --- a/tests/attr/models/test_pytext.py +++ b/tests/attr/models/test_pytext.py @@ -143,7 +143,7 @@ def _create_dummy_model(self): self._create_dummy_meta_data(), ) - def _create_dummy_meta_data(self) -> CommonMetadata: + def _create_dummy_meta_data(self): text_field_meta = FieldMeta() text_field_meta.vocab = VocabStub() text_field_meta.vocab_size = 4 From 7271b98733fa21f606e18dfccec3149a643e9081 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 01:43:01 -0700 Subject: [PATCH 05/16] Fix --- scripts/install_via_pip.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/install_via_pip.sh b/scripts/install_via_pip.sh index d1b70a34d0..d1aaf0894e 100755 --- a/scripts/install_via_pip.sh +++ b/scripts/install_via_pip.sh @@ -5,7 +5,7 @@ set -e PYTORCH_NIGHTLY=false DEPLOY=false CHOSEN_TORCH_VERSION=-1 -INSTALL_MODE='test' +INSTALL_MODE=test while getopts 'ndfv:' flag; do case "${flag}" in @@ -14,7 +14,7 @@ while getopts 'ndfv:' flag; do f) FRAMEWORKS=true ;; v) CHOSEN_TORCH_VERSION=${OPTARG};; m) INSTALL_MODE=${OPTARG};; - *) echo "usage: $0 [-n] [-d] [-f] [-v version]" >&2 + *) echo "usage: $0 [-n] [-d] [-f] [-v version] [-m install_mode]" >&2 exit 1 ;; esac done From b4499bed3e078f7c17b65140731eedb5178e2533 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 02:34:51 -0700 Subject: [PATCH 06/16] Fix --- scripts/install_via_pip.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/install_via_pip.sh b/scripts/install_via_pip.sh index d1aaf0894e..4f69eb984f 100755 --- a/scripts/install_via_pip.sh +++ b/scripts/install_via_pip.sh @@ -40,7 +40,7 @@ export TERM=xterm pip install --upgrade pip --progress-bar off # install captum with dev deps -pip install -e .[INSTALL_MODE] --progress-bar off +pip install -e .[test] --progress-bar off BUILD_INSIGHTS=1 python setup.py develop # install other frameworks if asked for and make sure this is before pytorch From e3679278031f7ad277a615103b161b68523265e9 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 02:46:10 -0700 Subject: [PATCH 07/16] Fix --- scripts/install_via_pip.sh | 4 +--- setup.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/install_via_pip.sh b/scripts/install_via_pip.sh index 4f69eb984f..2978de2456 100755 --- a/scripts/install_via_pip.sh +++ b/scripts/install_via_pip.sh @@ -5,7 +5,6 @@ set -e PYTORCH_NIGHTLY=false DEPLOY=false CHOSEN_TORCH_VERSION=-1 -INSTALL_MODE=test while getopts 'ndfv:' flag; do case "${flag}" in @@ -13,7 +12,6 @@ while getopts 'ndfv:' flag; do d) DEPLOY=true ;; f) FRAMEWORKS=true ;; v) CHOSEN_TORCH_VERSION=${OPTARG};; - m) INSTALL_MODE=${OPTARG};; *) echo "usage: $0 [-n] [-d] [-f] [-v version] [-m install_mode]" >&2 exit 1 ;; esac @@ -40,7 +38,7 @@ export TERM=xterm pip install --upgrade pip --progress-bar off # install captum with dev deps -pip install -e .[test] --progress-bar off +pip install -e .[dev] --progress-bar off BUILD_INSIGHTS=1 python setup.py develop # install other frameworks if asked for and make sure this is before pytorch diff --git a/setup.py b/setup.py index b4968da4b6..31b55c485f 100755 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def report(*args): INSIGHTS_REQUIRES + TEST_REQUIRES + [ - "black==24.2.0", + "black", "flake8", "sphinx", "sphinx-autodoc-typehints", From 4a149906c7e153bdc9a8bb4226997ae41b2e0d4c Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 03:04:18 -0700 Subject: [PATCH 08/16] Fix --- tests/helpers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/basic.py b/tests/helpers/basic.py index 3d212f6f15..ba317458d2 100644 --- a/tests/helpers/basic.py +++ b/tests/helpers/basic.py @@ -7,7 +7,7 @@ import numpy as np import torch from captum.log import patch_methods -from torch._tensor import Tensor +from torch import Tensor def deep_copy_args(func: Callable): From 26c5309fa271ac662e708233d68a803024217ac5 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 05:13:32 -0700 Subject: [PATCH 09/16] Fixes --- .github/workflows/lint.yml | 2 +- captum/influence/_core/tracincp_fast_rand_proj.py | 8 +++----- captum/insights/attr_vis/attribution_calculation.py | 2 +- captum/insights/attr_vis/server.py | 8 ++++---- tests/attr/models/test_pytext.py | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0aa7c68b4c..daa676d5c1 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.12xlarge - docker-image: cimg/python:3.6 + docker-image: cimg/python:3.9 repository: pytorch/captum script: | sudo chmod -R 777 . diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index 0d970f1882..ed74901f8c 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -82,7 +82,7 @@ class TracInCPFast(TracInCPBase): def __init__( self, model: Module, - final_fc_layer: Union[Module, str], + final_fc_layer: Module, train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, @@ -96,11 +96,9 @@ def __init__( model (torch.nn.Module): An instance of pytorch model. This model should define all of its layers as attributes of the model. - final_fc_layer (torch.nn.Module or str): The last fully connected layer in + final_fc_layer (torch.nn.Module): The last fully connected layer in the network for which gradients will be approximated via fast random - projection method. Can be either the layer module itself, or the - fully qualified name of the layer if it is a defined attribute of - the passed `model`. + projection method. train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): In the `influence` method, we compute the influence score of training examples on examples in a test batch. diff --git a/captum/insights/attr_vis/attribution_calculation.py b/captum/insights/attr_vis/attribution_calculation.py index 3f695b1807..ca41b74396 100644 --- a/captum/insights/attr_vis/attribution_calculation.py +++ b/captum/insights/attr_vis/attribution_calculation.py @@ -131,7 +131,7 @@ def calculate_attribution( ) if "baselines" in inspect.signature(attribution_method.attribute).parameters: attribution_arguments["baselines"] = baseline - attr = attribution_method.attribute.__wrapped__( + attr = attribution_method.attribute.__wrapped__( # type: ignore attribution_method, # self data, additional_forward_args=additional_forward_args, diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index 5b19a94514..98122f781f 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -44,7 +44,7 @@ def attribute() -> Response: r = request.get_json(force=True) return jsonify( namedtuple_to_dict( - visualizer._calculate_attribution_from_cache( + visualizer._calculate_attribution_from_cache( # type: ignore r["inputIndex"], r["modelIndex"], r["labelIndex"] ) ) @@ -54,15 +54,15 @@ def attribute() -> Response: @app.route("/fetch", methods=["POST"]) def fetch() -> Response: # force=True needed, see comment for "/attribute" route above - visualizer._update_config(request.get_json(force=True)) - visualizer_output = visualizer.visualize() + visualizer._update_config(request.get_json(force=True)) # type: ignore + visualizer_output = visualizer.visualize() # type: ignore clean_output = namedtuple_to_dict(visualizer_output) return jsonify(clean_output) @app.route("/init") def init() -> Response: - return jsonify(visualizer.get_insights_config()) + return jsonify(visualizer.get_insights_config()) # type: ignore @app.route("/") diff --git a/tests/attr/models/test_pytext.py b/tests/attr/models/test_pytext.py index ba76611106..ab4ed1a588 100644 --- a/tests/attr/models/test_pytext.py +++ b/tests/attr/models/test_pytext.py @@ -19,6 +19,7 @@ from pytext.config.component import create_featurizer, create_model from pytext.config.doc_classification import ModelInputConfig, TargetConfig from pytext.config.field_config import FeatureConfig, WordFeatConfig + from pytext.data.data_handler import CommonMetadata from pytext.data.doc_classification_data_handler import ( # @manual=//pytext:main_lib # noqa DocClassificationDataHandler, ) @@ -28,7 +29,6 @@ from pytext.models.doc_model import DocModel_Deprecated # @manual=//pytext:main_lib from pytext.models.embeddings.word_embedding import WordEmbedding from pytext.models.representations.bilstm_doc_attention import BiLSTMDocAttention - from pytext.data.data_handler import CommonMetadata except ImportError: HAS_PYTEXT = False From 885ed76bc9c8a67c3cbfb2129b33b223380ce06d Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 23:12:13 -0700 Subject: [PATCH 10/16] Fixes --- captum/_utils/common.py | 4 ++-- captum/attr/_core/deep_lift.py | 8 ++------ captum/influence/_core/tracincp_fast_rand_proj.py | 8 +++----- captum/insights/attr_vis/server.py | 4 ++-- setup.cfg | 2 +- tests/attr/helpers/conductance_reference.py | 11 ++++++----- tests/attr/layer/test_layer_lrp.py | 1 + tests/attr/models/test_pytext.py | 2 +- tests/attr/test_class_summarizer.py | 4 +++- tests/attr/test_guided_grad_cam.py | 4 ++-- tests/attr/test_input_layer_wrapper.py | 3 ++- tests/attr/test_lime.py | 2 +- tests/attr/test_stat.py | 3 ++- tests/helpers/basic.py | 10 +++------- tests/helpers/basic_models.py | 2 +- tests/influence/_core/test_tracin_validation.py | 2 +- tests/robust/test_FGSM.py | 2 +- tests/robust/test_attack_comparator.py | 2 +- .../models/linear_models/_test_linear_classifier.py | 8 ++++---- tests/utils/test_helpers.py | 2 +- tests/utils/test_linear_model.py | 8 ++++---- tests/utils/test_sample_gradient.py | 12 ++++++++---- 22 files changed, 52 insertions(+), 52 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 942205a5a7..ddc2bbe692 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -3,7 +3,7 @@ from enum import Enum from functools import reduce from inspect import signature -from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union +from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union import numpy as np import torch @@ -683,7 +683,7 @@ def _extract_device( def _reduce_list( - val_list: List[TupleOrTensorOrBoolGeneric], + val_list: Sequence[TupleOrTensorOrBoolGeneric], red_func: Callable[[List], Any] = torch.cat, ) -> TupleOrTensorOrBoolGeneric: """ diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index eea8234eef..e717fc12bc 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -582,9 +582,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - baselines: Union[ - TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] - ], + baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]], target: TargetType = None, additional_forward_args: Any = None, return_convergence_delta: Literal[False] = False, @@ -595,9 +593,7 @@ def attribute( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - baselines: Union[ - TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] - ], + baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]], target: TargetType = None, additional_forward_args: Any = None, *, diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index ed74901f8c..ccc3bf061f 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -867,7 +867,7 @@ class TracInCPFastRandProj(TracInCPFast): def __init__( self, model: Module, - final_fc_layer: Union[Module, str], + final_fc_layer: Module, train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, @@ -884,11 +884,9 @@ def __init__( model (torch.nn.Module): An instance of pytorch model. This model should define all of its layers as attributes of the model. - final_fc_layer (torch.nn.Module or str): The last fully connected layer in + final_fc_layer (torch.nn.Module): The last fully connected layer in the network for which gradients will be approximated via fast random - projection method. Can be either the layer module itself, or the - fully qualified name of the layer if it is a defined attribute of - the passed `model`. + projection method. train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): In the `influence` method, we compute the influence score of training examples on examples in a test batch. diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index 98122f781f..d7f7384201 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -4,7 +4,7 @@ import socket import threading from time import sleep -from typing import Optional +from typing import cast, Dict, Optional from captum.log import log_usage from flask import Flask, jsonify, render_template, request @@ -41,7 +41,7 @@ def namedtuple_to_dict(obj): def attribute() -> Response: # force=True needed for Colab notebooks, which doesn't use the correct # Content-Type header when forwarding requests through the Colab proxy - r = request.get_json(force=True) + r = cast(Dict, request.get_json(force=True)) return jsonify( namedtuple_to_dict( visualizer._calculate_attribution_from_cache( # type: ignore diff --git a/setup.cfg b/setup.cfg index 9ead7322fc..d6702edfb4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] # E203: black and flake8 disagree on whitespace before ':' # W503: black and flake8 disagree on how to place operators -ignore = E203, W503 +ignore = E203, W503, E704 max-line-length = 88 exclude = build, dist, tutorials, website diff --git a/tests/attr/helpers/conductance_reference.py b/tests/attr/helpers/conductance_reference.py index 6706d431b6..fee05673a6 100644 --- a/tests/attr/helpers/conductance_reference.py +++ b/tests/attr/helpers/conductance_reference.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Optional, Tuple +from typing import cast, Optional, Tuple, Union import numpy as np import torch @@ -11,6 +11,7 @@ from captum.attr._utils.attribution import LayerAttribution from captum.attr._utils.common import _reshape_and_sum from torch import Tensor +from torch.utils.hooks import RemovableHandle """ Note: This implementation of conductance follows the procedure described in the original @@ -55,7 +56,7 @@ def forward_hook(module, inp, out): # The hidden layer tensor is assumed to have dimension (num_hidden, ...) # where the product of the dimensions >= 1 correspond to the total # number of hidden neurons in the layer. - layer_size = tuple(saved_tensor.size())[1:] + layer_size = tuple(cast(Tensor, saved_tensor).size())[1:] layer_units = int(np.prod(layer_size)) # Remove unnecessary forward hook. @@ -101,12 +102,12 @@ def forward_hook_register_back(module, inp, out): input_grads = torch.autograd.grad(torch.unbind(output), expanded_input) # Remove backwards hook - back_hook.remove() + cast(RemovableHandle, back_hook).remove() # Remove duplicates in gradient with respect to hidden layer, # choose one for each layer_units indices. output_mid_grads = torch.index_select( - saved_grads, + cast(Tensor, saved_grads), 0, torch.tensor(range(0, input_grads[0].shape[0], layer_units)), ) @@ -115,7 +116,7 @@ def forward_hook_register_back(module, inp, out): def attribute( self, inputs, - baselines: Optional[int] = None, + baselines: Union[None, int, Tensor] = None, target=None, n_steps: int = 500, method: str = "riemann_trapezoid", diff --git a/tests/attr/layer/test_layer_lrp.py b/tests/attr/layer/test_layer_lrp.py index d7877e6225..b76e186250 100644 --- a/tests/attr/layer/test_layer_lrp.py +++ b/tests/attr/layer/test_layer_lrp.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: ignore-errors from typing import Any, Tuple diff --git a/tests/attr/models/test_pytext.py b/tests/attr/models/test_pytext.py index ab4ed1a588..926aadbe0a 100644 --- a/tests/attr/models/test_pytext.py +++ b/tests/attr/models/test_pytext.py @@ -43,7 +43,7 @@ def __init__(self) -> None: class TestWordEmbeddings(unittest.TestCase): - def setUp(self) -> Optional[NoReturn]: + def setUp(self) -> None: if not HAS_PYTEXT: return self.skipTest("Skip the test since PyText is not installed") diff --git a/tests/attr/test_class_summarizer.py b/tests/attr/test_class_summarizer.py index 80bf74b0e4..78403ece11 100644 --- a/tests/attr/test_class_summarizer.py +++ b/tests/attr/test_class_summarizer.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from typing import List + import torch from captum.attr import ClassSummarizer, CommonStats from tests.helpers.basic import BaseTest @@ -45,7 +47,7 @@ def test_classes(self) -> None: ((3, 2, 10, 3), (1,)), # ((20,),), ] - list_of_classes = [ + list_of_classes: List[List] = [ list(range(100)), ["%d" % i for i in range(100)], list(range(300, 400)), diff --git a/tests/attr/test_guided_grad_cam.py b/tests/attr/test_guided_grad_cam.py index fa1f1ff0a0..8b33e583b6 100644 --- a/tests/attr/test_guided_grad_cam.py +++ b/tests/attr/test_guided_grad_cam.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest -from typing import Any +from typing import Any, List, Tuple, Union import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric @@ -107,7 +107,7 @@ def _guided_grad_cam_test_assert( model: Module, target_layer: Module, test_input: TensorOrTupleOfTensorsGeneric, - expected: Tensor, + expected: Union[Tensor, List, Tuple], additional_input: Any = None, interpolate_mode: str = "nearest", attribute_to_layer_input: bool = False, diff --git a/tests/attr/test_input_layer_wrapper.py b/tests/attr/test_input_layer_wrapper.py index 8629146fdc..d82f8603ef 100644 --- a/tests/attr/test_input_layer_wrapper.py +++ b/tests/attr/test_input_layer_wrapper.py @@ -27,6 +27,7 @@ BasicModel_MultiLayer_TrueMultiInput, MixedKwargsAndArgsModule, ) +from torch.nn import Module layer_methods_to_test_with_equiv = [ # layer_method, equiv_method, whether or not to use multiple layers @@ -115,7 +116,7 @@ def layer_method_with_input_layer_patches( assertTensorTuplesAlmostEqual(self, a1, real_attributions) def forward_eval_layer_with_inputs_helper( - self, model: ModelInputWrapper, inputs_to_test + self, model: Module, inputs_to_test ) -> None: # hard coding for simplicity # 0 if using args, 1 if using kwargs diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 404ed20f92..b6ec6135df 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -494,7 +494,7 @@ def _lime_test_assert( model: Callable, test_input: TensorOrTupleOfTensorsGeneric, expected_attr, - expected_coefs_only: Optional[Tensor] = None, + expected_coefs_only: Union[None, List, Tensor] = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, additional_input: Any = None, perturbations_per_eval: Tuple[int, ...] = (1,), diff --git a/tests/attr/test_stat.py b/tests/attr/test_stat.py index 8fcefb6f22..30c5e336b4 100644 --- a/tests/attr/test_stat.py +++ b/tests/attr/test_stat.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import random +from typing import Callable, List import torch from captum.attr import Max, Mean, Min, MSE, StdDev, Sum, Summarizer, Var @@ -140,7 +141,7 @@ def test_stats_random_data(self) -> None: "sum", "mse", ] - gt_fns = [ + gt_fns: List[Callable] = [ torch.mean, lambda x: torch.var(x, unbiased=False), lambda x: torch.var(x, unbiased=True), diff --git a/tests/helpers/basic.py b/tests/helpers/basic.py index ba317458d2..dccb726f39 100644 --- a/tests/helpers/basic.py +++ b/tests/helpers/basic.py @@ -2,7 +2,7 @@ import copy import random import unittest -from typing import Callable +from typing import Callable, List, Tuple, Union import numpy as np import torch @@ -20,9 +20,7 @@ def copy_args(*args, **kwargs): return copy_args -def assertTensorAlmostEqual( - test, actual: Tensor, expected: Tensor, delta: float = 0.0001, mode: str = "sum" -) -> None: +def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): assert isinstance(actual, torch.Tensor), ( "Actual parameter given for " "comparison must be a tensor." ) @@ -60,9 +58,7 @@ def assertTensorAlmostEqual( raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.") -def assertTensorTuplesAlmostEqual( - test, actual, expected, delta: float = 0.0001, mode: str = "sum" -) -> None: +def assertTensorTuplesAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): if isinstance(expected, tuple): assert len(actual) == len( expected diff --git a/tests/helpers/basic_models.py b/tests/helpers/basic_models.py index 5371af03f4..e5ea54c64b 100644 --- a/tests/helpers/basic_models.py +++ b/tests/helpers/basic_models.py @@ -44,7 +44,7 @@ def __init__(self) -> None: super().__init__() def forward(self, input: int): - input = 1 - F.relu(1 - input) + input = 1 - F.relu(torch.tensor(1 - input)) return input diff --git a/tests/influence/_core/test_tracin_validation.py b/tests/influence/_core/test_tracin_validation.py index d5fa654234..888a47142a 100644 --- a/tests/influence/_core/test_tracin_validation.py +++ b/tests/influence/_core/test_tracin_validation.py @@ -77,7 +77,7 @@ def test_tracincp_fast_rand_proj_inputs(self) -> None: ): TracInCPFast( net, - "invalid_layer", + "invalid_layer", # type: ignore train_dataset, tmpdir, loss_fn=nn.MSELoss(), diff --git a/tests/robust/test_FGSM.py b/tests/robust/test_FGSM.py index b39cc04b11..b0686f2cf6 100644 --- a/tests/robust/test_FGSM.py +++ b/tests/robust/test_FGSM.py @@ -188,7 +188,7 @@ def _FGSM_assert( inputs: TensorOrTupleOfTensorsGeneric, target: Any, epsilon: float, - answer: Union[TensorLikeList, Tuple[TensorLikeList, ...]], + answer: Union[List, Tuple[List, ...]], targeted: bool = False, additional_inputs: Any = None, lower_bound: float = float("-inf"), diff --git a/tests/robust/test_attack_comparator.py b/tests/robust/test_attack_comparator.py index 2239696fed..0c6d2cc707 100644 --- a/tests/robust/test_attack_comparator.py +++ b/tests/robust/test_attack_comparator.py @@ -202,7 +202,7 @@ def test_attack_comparator_with_additional_args(self) -> None: attack_comp.reset() self.assertEqual(len(attack_comp.summary()), 0) - def _compare_results(self, obtained: Tensor, expected) -> None: + def _compare_results(self, obtained, expected) -> None: if isinstance(expected, dict): self.assertIsInstance(obtained, dict) for key in expected: diff --git a/tests/utils/models/linear_models/_test_linear_classifier.py b/tests/utils/models/linear_models/_test_linear_classifier.py index 383786a1f3..37e3a6f2de 100644 --- a/tests/utils/models/linear_models/_test_linear_classifier.py +++ b/tests/utils/models/linear_models/_test_linear_classifier.py @@ -1,6 +1,6 @@ import argparse import random -from typing import Optional +from typing import cast, Optional import captum._utils.models.linear_model.model as pytorch_model_module import numpy as np @@ -98,9 +98,9 @@ def compare_to_sk_learn( o_pytorch["l1_reg"] = alpha * pytorch_h.norm(p=1, dim=-1) o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1) - rel_diff = (sum(o_sklearn.values()) - sum(o_pytorch.values())) / abs( - sum(o_sklearn.values()) - ) + rel_diff = cast( + np.ndarray, (sum(o_sklearn.values()) - sum(o_pytorch.values())) + ) / abs(sum(o_sklearn.values())) return ( { "objective_rel_diff": rel_diff.tolist(), diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 46af61b58a..4c2f5d1ffe 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -7,7 +7,7 @@ class HelpersTest(BaseTest): def test_assert_tensor_almost_equal(self) -> None: with self.assertRaises(AssertionError) as cm: - assertTensorAlmostEqual(self, [[1.0]], [[1.0]]) + assertTensorAlmostEqual(self, [[1.0]], [[1.0]]) # type: ignore self.assertEqual( cm.exception.args, ("Actual parameter given for comparison must be a tensor.",), diff --git a/tests/utils/test_linear_model.py b/tests/utils/test_linear_model.py index 9b3f494a1a..b2627ea00e 100644 --- a/tests/utils/test_linear_model.py +++ b/tests/utils/test_linear_model.py @@ -12,11 +12,11 @@ from torch import Tensor -def _evaluate(test_data, classifier) -> Dict[str, float]: +def _evaluate(test_data, classifier) -> Dict[str, Tensor]: classifier.eval() - l1_loss = 0.0 - l2_loss = 0.0 + l1_loss = torch.tensor(0.0) + l2_loss = torch.tensor(0.0) n = 0 l2_losses = [] with torch.no_grad(): @@ -67,7 +67,7 @@ def train_and_compare( model_type, xs, ys, - expected_loss: Tensor, + expected_loss: Union[int, float, Tensor], expected_reg: Union[float, Tensor] = 0.0, expected_hyperplane: Optional[Tensor] = None, norm_hyperplane: bool = True, diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 0f804f69fb..08ee0924a3 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest -from typing import Callable, Tuple +from typing import Callable, List, Tuple import torch from captum._utils.gradient import apply_gradient_requirements @@ -110,7 +110,11 @@ def test_sample_grads_layer_modules(self) -> None: # possible candidates for `layer_modules`, which are the modules whose # parameters we want to compute sample grads for - layer_moduless = [[model.conv1], [model.fc1], [model.conv1, model.fc1]] + layer_moduless: List[List[Module]] = [ + [model.conv1], + [model.fc1], + [model.conv1, model.fc1], + ] # hard coded all modules we want to check all_modules = [model.conv1, model.fc1] @@ -135,10 +139,10 @@ def test_sample_grads_layer_modules(self) -> None: # So, check that we did calculate sample grads for the desired # layers via the above checking approach. for parameter in module.parameters(): - assert not isinstance(parameter.sample_grad, int) + assert not isinstance(parameter.sample_grad, int) # type: ignore else: # For the layers we do not want sample grads for, their # `sample_grad` should still be 0, since they should not have been # over-written. for parameter in module.parameters(): - assert parameter.sample_grad == 0 + assert parameter.sample_grad == 0 # type: ignore From 3d4c1fd3014897538155034dc8c0a2cb5b8bbaf5 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 23:29:18 -0700 Subject: [PATCH 11/16] Fix type issues --- captum/insights/attr_vis/server.py | 10 ++++++---- scripts/install_via_pip.sh | 2 +- tests/helpers/basic_models.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index d7f7384201..f11371166e 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -6,6 +6,8 @@ from time import sleep from typing import cast, Dict, Optional +from captum.insights import AttributionVisualizer + from captum.log import log_usage from flask import Flask, jsonify, render_template, request from flask.wrappers import Response @@ -44,7 +46,7 @@ def attribute() -> Response: r = cast(Dict, request.get_json(force=True)) return jsonify( namedtuple_to_dict( - visualizer._calculate_attribution_from_cache( # type: ignore + cast(AttributionVisualizer, visualizer)._calculate_attribution_from_cache( r["inputIndex"], r["modelIndex"], r["labelIndex"] ) ) @@ -54,15 +56,15 @@ def attribute() -> Response: @app.route("/fetch", methods=["POST"]) def fetch() -> Response: # force=True needed, see comment for "/attribute" route above - visualizer._update_config(request.get_json(force=True)) # type: ignore - visualizer_output = visualizer.visualize() # type: ignore + cast(AttributionVisualizer, visualizer)._update_config(request.get_json(force=True)) + visualizer_output = cast(AttributionVisualizer, visualizer).visualize() clean_output = namedtuple_to_dict(visualizer_output) return jsonify(clean_output) @app.route("/init") def init() -> Response: - return jsonify(visualizer.get_insights_config()) # type: ignore + return jsonify(cast(AttributionVisualizer, visualizer).get_insights_config()) @app.route("/") diff --git a/scripts/install_via_pip.sh b/scripts/install_via_pip.sh index 2978de2456..27c8cab625 100755 --- a/scripts/install_via_pip.sh +++ b/scripts/install_via_pip.sh @@ -12,7 +12,7 @@ while getopts 'ndfv:' flag; do d) DEPLOY=true ;; f) FRAMEWORKS=true ;; v) CHOSEN_TORCH_VERSION=${OPTARG};; - *) echo "usage: $0 [-n] [-d] [-f] [-v version] [-m install_mode]" >&2 + *) echo "usage: $0 [-n] [-d] [-f] [-v version]" >&2 exit 1 ;; esac done diff --git a/tests/helpers/basic_models.py b/tests/helpers/basic_models.py index e5ea54c64b..52d3404ee5 100644 --- a/tests/helpers/basic_models.py +++ b/tests/helpers/basic_models.py @@ -43,8 +43,8 @@ class BasicModel(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, input: int): - input = 1 - F.relu(torch.tensor(1 - input)) + def forward(self, input: Tensor): + input = 1 - F.relu(1 - input) return input From d0e6690b93e6d339b72e484e54654cbe457ab0e9 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 13 Mar 2024 23:47:48 -0700 Subject: [PATCH 12/16] Fix insights --- captum/insights/attr_vis/server.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index f11371166e..d7f7384201 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -6,8 +6,6 @@ from time import sleep from typing import cast, Dict, Optional -from captum.insights import AttributionVisualizer - from captum.log import log_usage from flask import Flask, jsonify, render_template, request from flask.wrappers import Response @@ -46,7 +44,7 @@ def attribute() -> Response: r = cast(Dict, request.get_json(force=True)) return jsonify( namedtuple_to_dict( - cast(AttributionVisualizer, visualizer)._calculate_attribution_from_cache( + visualizer._calculate_attribution_from_cache( # type: ignore r["inputIndex"], r["modelIndex"], r["labelIndex"] ) ) @@ -56,15 +54,15 @@ def attribute() -> Response: @app.route("/fetch", methods=["POST"]) def fetch() -> Response: # force=True needed, see comment for "/attribute" route above - cast(AttributionVisualizer, visualizer)._update_config(request.get_json(force=True)) - visualizer_output = cast(AttributionVisualizer, visualizer).visualize() + visualizer._update_config(request.get_json(force=True)) # type: ignore + visualizer_output = visualizer.visualize() # type: ignore clean_output = namedtuple_to_dict(visualizer_output) return jsonify(clean_output) @app.route("/init") def init() -> Response: - return jsonify(cast(AttributionVisualizer, visualizer).get_insights_config()) + return jsonify(visualizer.get_insights_config()) # type: ignore @app.route("/") From 1a990172d3185ffa046c9d938ae3be1bd00f3f75 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Thu, 14 Mar 2024 00:31:46 -0700 Subject: [PATCH 13/16] Fix typing linear --- tests/utils/test_linear_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_linear_model.py b/tests/utils/test_linear_model.py index b2627ea00e..e937057690 100644 --- a/tests/utils/test_linear_model.py +++ b/tests/utils/test_linear_model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Dict, Optional, Union +from typing import cast, Dict, Optional, Union import torch from captum._utils.models.linear_model.model import ( @@ -15,8 +15,8 @@ def _evaluate(test_data, classifier) -> Dict[str, Tensor]: classifier.eval() - l1_loss = torch.tensor(0.0) - l2_loss = torch.tensor(0.0) + l1_loss = 0.0 + l2_loss = 0.0 n = 0 l2_losses = [] with torch.no_grad(): @@ -56,7 +56,7 @@ def _evaluate(test_data, classifier) -> Dict[str, Tensor]: assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all() classifier.train() - return {"l1": l1_loss / n, "l2": l2_loss / n} + return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)} class TestLinearModel(BaseTest): From e20a56625ecce5358229a5c2d9a7d7a4117dc261 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Thu, 14 Mar 2024 01:10:01 -0700 Subject: [PATCH 14/16] Fix lint --- tests/attr/helpers/conductance_reference.py | 2 +- tests/attr/models/test_pytext.py | 2 +- tests/helpers/basic.py | 3 +-- tests/robust/test_FGSM.py | 2 +- tests/utils/test_sample_gradient.py | 4 +++- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/attr/helpers/conductance_reference.py b/tests/attr/helpers/conductance_reference.py index fee05673a6..5a7b8906c8 100644 --- a/tests/attr/helpers/conductance_reference.py +++ b/tests/attr/helpers/conductance_reference.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import cast, Optional, Tuple, Union +from typing import cast, Tuple, Union import numpy as np import torch diff --git a/tests/attr/models/test_pytext.py b/tests/attr/models/test_pytext.py index 926aadbe0a..fa7394945b 100644 --- a/tests/attr/models/test_pytext.py +++ b/tests/attr/models/test_pytext.py @@ -5,7 +5,7 @@ import os import tempfile import unittest -from typing import Dict, List, NoReturn, Optional +from typing import Dict, List import torch diff --git a/tests/helpers/basic.py b/tests/helpers/basic.py index dccb726f39..8bf6a926c2 100644 --- a/tests/helpers/basic.py +++ b/tests/helpers/basic.py @@ -2,12 +2,11 @@ import copy import random import unittest -from typing import Callable, List, Tuple, Union +from typing import Callable import numpy as np import torch from captum.log import patch_methods -from torch import Tensor def deep_copy_args(func: Callable): diff --git a/tests/robust/test_FGSM.py b/tests/robust/test_FGSM.py index b0686f2cf6..19dffdacf1 100644 --- a/tests/robust/test_FGSM.py +++ b/tests/robust/test_FGSM.py @@ -2,7 +2,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch -from captum._utils.typing import TensorLikeList, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.robust import FGSM from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from tests.helpers.basic_models import BasicModel, BasicModel2, BasicModel_MultiLayer diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 08ee0924a3..2e4bdbf379 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -139,7 +139,9 @@ def test_sample_grads_layer_modules(self) -> None: # So, check that we did calculate sample grads for the desired # layers via the above checking approach. for parameter in module.parameters(): - assert not isinstance(parameter.sample_grad, int) # type: ignore + assert not isinstance( + parameter.sample_grad, int # type: ignore + ) else: # For the layers we do not want sample grads for, their # `sample_grad` should still be 0, since they should not have been From 6903fbe31b09b954f6f207f8e2ac451084f6cc9f Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Thu, 14 Mar 2024 07:26:23 -0700 Subject: [PATCH 15/16] Fixes --- captum/attr/_core/deep_lift.py | 8 ++++++-- tests/attr/test_deeplift_classification.py | 5 +++++ tests/helpers/basic.py | 8 ++++++-- tests/robust/test_attack_comparator.py | 6 ++++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index e717fc12bc..eea8234eef 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -582,7 +582,9 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], target: TargetType = None, additional_forward_args: Any = None, return_convergence_delta: Literal[False] = False, @@ -593,7 +595,9 @@ def attribute( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, - baselines: Union[BaselineType, Callable[..., TensorOrTupleOfTensorsGeneric]], + baselines: Union[ + TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] + ], target: TargetType = None, additional_forward_args: Any = None, *, diff --git a/tests/attr/test_deeplift_classification.py b/tests/attr/test_deeplift_classification.py index f3eb6e1aa0..87d5d40688 100644 --- a/tests/attr/test_deeplift_classification.py +++ b/tests/attr/test_deeplift_classification.py @@ -155,6 +155,11 @@ def softmax_classification( target: TargetType, ) -> None: # TODO add test cases for multiple different layers + if isinstance(attr_method, DeepLiftShap): + assert isinstance( + baselines, Tensor + ), "Non-tensor baseline not supported for DeepLiftShap" + model.zero_grad() attributions, delta = attr_method.attribute( input, baselines=baselines, target=target, return_convergence_delta=True diff --git a/tests/helpers/basic.py b/tests/helpers/basic.py index 8bf6a926c2..2cf8fd3de1 100644 --- a/tests/helpers/basic.py +++ b/tests/helpers/basic.py @@ -19,7 +19,9 @@ def copy_args(*args, **kwargs): return copy_args -def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): +def assertTensorAlmostEqual( + test, actual, expected, delta: float = 0.0001, mode: str = "sum" +): assert isinstance(actual, torch.Tensor), ( "Actual parameter given for " "comparison must be a tensor." ) @@ -57,7 +59,9 @@ def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.") -def assertTensorTuplesAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): +def assertTensorTuplesAlmostEqual( + test, actual, expected, delta: float = 0.0001, mode: str = "sum" +): if isinstance(expected, tuple): assert len(actual) == len( expected diff --git a/tests/robust/test_attack_comparator.py b/tests/robust/test_attack_comparator.py index 0c6d2cc707..7585ad8f9c 100644 --- a/tests/robust/test_attack_comparator.py +++ b/tests/robust/test_attack_comparator.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import collections -from typing import List +from typing import Dict, List, Tuple, Union import torch from captum.robust import AttackComparator, FGSM @@ -202,7 +202,9 @@ def test_attack_comparator_with_additional_args(self) -> None: attack_comp.reset() self.assertEqual(len(attack_comp.summary()), 0) - def _compare_results(self, obtained, expected) -> None: + def _compare_results( + self, obtained: Union[Dict, Tuple, Tensor], expected: Union[Dict, Tuple, Tensor] + ) -> None: if isinstance(expected, dict): self.assertIsInstance(obtained, dict) for key in expected: From fbe9ab2b6f7b0ad56038fa2178d48d968adf04d6 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Thu, 14 Mar 2024 07:30:06 -0700 Subject: [PATCH 16/16] Fix typing --- tests/helpers/basic.py | 4 ++-- tests/utils/test_helpers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/helpers/basic.py b/tests/helpers/basic.py index 2cf8fd3de1..047036fdbf 100644 --- a/tests/helpers/basic.py +++ b/tests/helpers/basic.py @@ -21,7 +21,7 @@ def copy_args(*args, **kwargs): def assertTensorAlmostEqual( test, actual, expected, delta: float = 0.0001, mode: str = "sum" -): +) -> None: assert isinstance(actual, torch.Tensor), ( "Actual parameter given for " "comparison must be a tensor." ) @@ -61,7 +61,7 @@ def assertTensorAlmostEqual( def assertTensorTuplesAlmostEqual( test, actual, expected, delta: float = 0.0001, mode: str = "sum" -): +) -> None: if isinstance(expected, tuple): assert len(actual) == len( expected diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 4c2f5d1ffe..46af61b58a 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -7,7 +7,7 @@ class HelpersTest(BaseTest): def test_assert_tensor_almost_equal(self) -> None: with self.assertRaises(AssertionError) as cm: - assertTensorAlmostEqual(self, [[1.0]], [[1.0]]) # type: ignore + assertTensorAlmostEqual(self, [[1.0]], [[1.0]]) self.assertEqual( cm.exception.args, ("Actual parameter given for comparison must be a tensor.",),