From c96cd7221a7af9ba3f0877ff621c8d828c733b5f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 20 Dec 2021 14:02:23 -0700 Subject: [PATCH 1/9] Composable loss improvements --- captum/optim/_core/loss.py | 75 +++++++++++++++++++++++-- tests/optim/core/test_loss.py | 31 ++++++++++ tests/optim/helpers/numpy_transforms.py | 3 +- 3 files changed, 104 insertions(+), 5 deletions(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index aa33793642..e718b352eb 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -16,6 +16,10 @@ def _make_arg_str(arg: Any) -> str: return arg[:15] + "..." if too_big else arg +# Reduction op for CompositeLoss loss composability size mismatch avoidance +REDUCTION_OP: Callable[[torch.Tensor], torch.Tensor] = torch.mean + + class Loss(ABC): """ Abstract Class to describe loss. @@ -40,6 +44,12 @@ def __repr__(self) -> str: def __neg__(self) -> "CompositeLoss": return module_op(self, None, operator.neg) + def __pos__(self) -> "CompositeLoss": + return module_op(self, None, operator.pos) + + def __abs__(self) -> "CompositeLoss": + return module_op(self, None, operator.abs) + def __add__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return module_op(self, other, operator.add) @@ -68,7 +78,7 @@ def __rtruediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": if isinstance(other, (int, float)): def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return operator.truediv(other, torch.mean(self(module))) + return operator.truediv(other, REDUCTION_OP(self(module))) name = self.__name__ target = self.target @@ -86,7 +96,7 @@ def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": if isinstance(other, (int, float)): def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return operator.pow(other, torch.mean(self(module))) + return operator.pow(other, REDUCTION_OP(self(module))) name = self.__name__ target = self.target @@ -100,6 +110,62 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: ) return CompositeLoss(loss_fn, name=name, target=target) + def mean(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": + """ + Composable torch.mean reduction operator. See torch.mean for more details: + https://pytorch.org/docs/stable/generated/torch.mean.html + + Args: + dim (int or tuple of int, optional): The dimension or dimensions to reduce. + Default: None for all dimension. + keepdim (bool, optional): Whether the output tensor has dim retained or + not. + Default: False + + Returns: + composite_loss (ComposableLoss): A composable loss instance. + """ + if dim is None: + # dim is equal to all dimensions unless specified + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return torch.mean(self(module)) + + else: + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return torch.mean(self(module), dim=dim, keepdim=keepdim) + + name = "mean(" + self.__name__ + ")" + return CompositeLoss(loss_fn, name=name, target=self.target) + + def sum(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": + """ + Composable torch.sum reduction operator. See torch.sum for more details: + https://pytorch.org/docs/stable/generated/torch.sum.html + + Args: + dim (int or tuple of int, optional): The dimension or dimensions to reduce. + Default: None for all dimension. + keepdim (bool, optional): Whether the output tensor has dim retained or + not. + Default: False + + Returns: + composite_loss (ComposableLoss): A composable loss instance. + """ + if dim is None: + # dim is equal to all dimensions unless specified + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return torch.sum(self(module)) + + else: + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return torch.sum(self(module), dim=dim, keepdim=keepdim) + + name = "sum(" + self.__name__ + ")" + return CompositeLoss(loss_fn, name=name, target=self.target) + def module_op( self: Loss, other: Union[None, int, float, Loss], math_op: Callable @@ -107,7 +173,7 @@ def module_op( """ This is a general function for applying math operations to Losses """ - if other is None and math_op == operator.neg: + if other is None and math_op in [operator.neg, operator.pos, operator.abs]: def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return math_op(self(module)) @@ -124,7 +190,7 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: elif isinstance(other, Loss): # We take the mean of the output tensor to resolve shape mismatches def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return math_op(torch.mean(self(module)), torch.mean(other(module))) + return math_op(REDUCTION_OP(self(module)), REDUCTION_OP(other(module))) name = f"Compose({', '.join([self.__name__, other.__name__])})" target = ( @@ -603,6 +669,7 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor: __all__ = [ "Loss", + "REDUCTION_OP", "loss_wrapper", "BaseLoss", "LayerActivation", diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index 566745de25..da49d75247 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -168,6 +168,20 @@ def test_negative(self) -> None: get_loss_value(model, loss), -CHANNEL_ACTIVATION_0_LOSS, places=6 ) + def test_positive(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = +opt_loss.ChannelActivation(model.layer, 0) + self.assertAlmostEqual( + get_loss_value(model, loss), CHANNEL_ACTIVATION_0_LOSS, places=6 + ) + + def test_abs(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = abs(-opt_loss.ChannelActivation(model.layer, 0)) + self.assertAlmostEqual( + get_loss_value(model, loss), CHANNEL_ACTIVATION_0_LOSS, places=6 + ) + def test_addition(self) -> None: model = BasicModel_ConvNet_Optim() loss = ( @@ -242,3 +256,20 @@ def test_pow(self) -> None: # opt_loss.ChannelActivation(model.layer, 0) ** opt_loss.ChannelActivation( # model.layer, 1 # ) + + def test_sum(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model).sum() + + self.assertAlmostEqual(get_loss_value(model, loss), 3.0, places=1) + + def test_mean(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model).mean() + + self.assertAlmostEqual(get_loss_value(model, loss), 1.0, places=1) + + +class TestCompositeLossReductionOP(BaseTest): + def test_reduction_op(self) -> None: + self.assertEqual(opt_loss.REDUCTION_OP, torch.mean) diff --git a/tests/optim/helpers/numpy_transforms.py b/tests/optim/helpers/numpy_transforms.py index eec0afebac..e52a1a3e60 100644 --- a/tests/optim/helpers/numpy_transforms.py +++ b/tests/optim/helpers/numpy_transforms.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Tuple, Union, cast +from typing import List, Optional, Tuple, Union, cast, no_type_check import numpy as np @@ -112,6 +112,7 @@ def forward(self, input: np.ndarray) -> np.ndarray: ) +@no_type_check def center_crop( input: np.ndarray, crop_vals: IntSeqOrIntType, From b7e2f35393a89448399553ed5910629dedd0894a Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 20 Dec 2021 14:16:19 -0700 Subject: [PATCH 2/9] Set minimum torch version for positive CompositeLoss test --- tests/optim/core/test_loss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index da49d75247..cbbc167c2c 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -169,6 +169,11 @@ def test_negative(self) -> None: ) def test_positive(self) -> None: + if torch.__version__ <= "1.3.0": + raise unittest.SkipTest( + "Skipping postive CompositeLoss test due to insufficient" + + " Torch version." + ) model = BasicModel_ConvNet_Optim() loss = +opt_loss.ChannelActivation(model.layer, 0) self.assertAlmostEqual( From da0c4725d9cd290e3db3613c1e2d3a81a677df64 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 20 Dec 2021 14:22:48 -0700 Subject: [PATCH 3/9] Fix import --- tests/optim/core/test_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index cbbc167c2c..0584e0716a 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import unittest from typing import List, Union import numpy as np From ac5bc7400054c9d161b3152298288813b0435370 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 22 Dec 2021 12:04:53 -0700 Subject: [PATCH 4/9] Improvements to CompositeLoss * Added `operator.floordiv` support. * Added the `basic_torch_module_op` function that should allow for the composability of many common torch operations. * Added `rmodule_op` function for handling the 3 "r" versions of math operations. --- captum/optim/_core/loss.py | 163 +++++++++++++++++++++++----------- tests/optim/core/test_loss.py | 25 +++++- 2 files changed, 132 insertions(+), 56 deletions(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index e718b352eb..c22e4e458a 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -1,7 +1,7 @@ import functools import operator from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -62,6 +62,9 @@ def __mul__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": def __truediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return module_op(self, other, operator.truediv) + def __floordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": + return module_op(self, other, operator.floordiv) + def __pow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return module_op(self, other, operator.pow) @@ -75,40 +78,13 @@ def __rmul__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": return self.__mul__(other) def __rtruediv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": - if isinstance(other, (int, float)): - - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return operator.truediv(other, REDUCTION_OP(self(module))) + rmodule_op(self, other, operator.truediv) - name = self.__name__ - target = self.target - elif isinstance(other, Loss): - # This should never get called because __div__ will be called instead - pass - else: - raise TypeError( - "Can only apply math operations with int, float or Loss. Received type " - + str(type(other)) - ) - return CompositeLoss(loss_fn, name=name, target=target) + def __rfloordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": + rmodule_op(self, other, operator.floordiv) def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": - if isinstance(other, (int, float)): - - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return operator.pow(other, REDUCTION_OP(self(module))) - - name = self.__name__ - target = self.target - elif isinstance(other, Loss): - # This should never get called because __pow__ will be called instead - pass - else: - raise TypeError( - "Can only apply math operations with int, float or Loss. Received type " - + str(type(other)) - ) - return CompositeLoss(loss_fn, name=name, target=target) + rmodule_op(self, other, operator.pow) def mean(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": """ @@ -123,20 +99,12 @@ def mean(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": Default: False Returns: - composite_loss (ComposableLoss): A composable loss instance. + composite_loss (CompositeLoss): A composable loss instance. """ if dim is None: - # dim is equal to all dimensions unless specified - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return torch.mean(self(module)) - + return basic_torch_module_op(self, torch.mean) else: - - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return torch.mean(self(module), dim=dim, keepdim=keepdim) - - name = "mean(" + self.__name__ + ")" - return CompositeLoss(loss_fn, name=name, target=self.target) + return basic_torch_module_op(self, torch.mean, dim=dim, keepdim=keepdim) def sum(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": """ @@ -144,6 +112,7 @@ def sum(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": https://pytorch.org/docs/stable/generated/torch.sum.html Args: + dim (int or tuple of int, optional): The dimension or dimensions to reduce. Default: None for all dimension. keepdim (bool, optional): Whether the output tensor has dim retained or @@ -151,20 +120,12 @@ def sum(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": Default: False Returns: - composite_loss (ComposableLoss): A composable loss instance. + composite_loss (CompositeLoss): A composable loss instance. """ if dim is None: - # dim is equal to all dimensions unless specified - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return torch.sum(self(module)) - + return basic_torch_module_op(self, torch.sum) else: - - def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return torch.sum(self(module), dim=dim, keepdim=keepdim) - - name = "sum(" + self.__name__ + ")" - return CompositeLoss(loss_fn, name=name, target=self.target) + return basic_torch_module_op(self, torch.sum, dim=dim, keepdim=keepdim) def module_op( @@ -204,6 +165,100 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return CompositeLoss(loss_fn, name=name, target=target) +def rmodule_op( + self: Loss, other: Union[int, float, "Loss"], math_op: Callable +) -> "CompositeLoss": + """ + This is a general function for applying the "r" versions of math operations to + Losses. + """ + if isinstance(other, (int, float)): + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return math_op(other, REDUCTION_OP(self(module))) + + name = self.__name__ + target = self.target + elif isinstance(other, Loss): + # This should never get called because __math_op__ will be called instead + pass + else: + raise TypeError( + "Can only apply math operations with int, float or Loss. Received type " + + str(type(other)) + ) + return CompositeLoss(loss_fn, name=name, target=target) + + +def basic_torch_module_op( + loss: Union[Loss, List[Loss]], + torch_op: Callable[[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor], + *args: Any, + **kwargs: Any, +) -> "CompositeLoss": + """ + Implement composability for PyTorch operation that take a single tensor or list + of tensors as it's first input variable. + See here for possible torch_op choices: https://pytorch.org/docs/stable/torch.html + Some built-in Python functions can also be used as well if supported by PyTorch. + + Args: + + loss (Loss or list of Loss): A loss objective or list of loss objectives. + torch_op (Callable): A PyTorch or supported Python function. Ex: torch.mean, + torch.sum,, torch.linalg.norm, torch.sin, torch.cat, torch.stack, max, min, + sum, math.ceil, and others. + Default: torch.mean + args (Any, optional): Any additional arguments. + kwargs (Any, optional): Any additional arguments. + to_scalar_fn (Callable, optional): A function for converting loss function + outputs to scalar values, in order to prevent size mismatches. This is + variable only used if more than one loss is given. + Default: A non-op function. + + Returns: + composite_loss (CompositeLoss): A composable loss instance. + """ + + if isinstance(loss, (tuple, list)): + + def identity(x: torch.Tensor) -> torch.Tensor: + return x + + if "to_scalar_fn" not in kwargs: + to_scalar_fn = identity + else: + to_scalar_fn = kwargs["to_scalar_fn"] + del kwargs["to_scalar_fn"] + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + loss_tensors = [to_scalar_fn(loss_obj(module)) for loss_obj in loss] + return torch_op(loss_tensors, *args, **kwargs) + + name_list = ", ".join([loss_obj.__name__ for loss_obj in loss]) + name = torch_op.__name__ + "(" + name_list + ")" + + # Collect targets from losses + target = [ + target + for targets in [ + [loss_obj.target] + if not hasattr(loss_obj.target, "__iter__") + else loss_obj.target + for loss_obj in loss + ] + for target in targets + ] + else: + + def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: + return torch_op(loss(module), *args, **kwargs) + + name = torch_op.__name__ + "(" + loss.__name__ + ")" + target = loss.target + return CompositeLoss(loss_fn, name=name, target=target) + + class BaseLoss(Loss): def __init__( self, target: nn.Module = [], batch_index: Optional[int] = None diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index 0584e0716a..48973ca8b4 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -245,6 +245,13 @@ def test_division(self) -> None: # model.layer, 1 # ) + def test_floor_division(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = opt_loss.ChannelActivation(model.layer, 0) // 10 + self.assertAlmostEqual( + get_loss_value(model, loss), CHANNEL_ACTIVATION_0_LOSS // 10 + ) + def test_pow(self) -> None: model = BasicModel_ConvNet_Optim() loss = opt_loss.ChannelActivation(model.layer, 0) ** 2 @@ -266,16 +273,30 @@ def test_pow(self) -> None: def test_sum(self) -> None: model = torch.nn.Identity() loss = opt_loss.LayerActivation(model).sum() - self.assertAlmostEqual(get_loss_value(model, loss), 3.0, places=1) def test_mean(self) -> None: model = torch.nn.Identity() loss = opt_loss.LayerActivation(model).mean() - self.assertAlmostEqual(get_loss_value(model, loss), 1.0, places=1) class TestCompositeLossReductionOP(BaseTest): def test_reduction_op(self) -> None: self.assertEqual(opt_loss.REDUCTION_OP, torch.mean) + + +def TestBasisTorchModuleOP(BaseTest): + def test_torch_sum(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model) + loss = opt_loss.basic_torch_module_op(loss, torch_op=torch.sum) + self.assertAlmostEqual(get_loss_value(model, loss), 3.0, places=1) + + def test_sum_list_with_scalar_fn(self) -> None: + model = torch.nn.Identity() + loss_list = [opt_loss.LayerActivation(model)] * 5 + loss = opt_loss.basic_torch_module_op( + loss_list, torch_op=sum, to_scalar_fn=torch.mean + ) + self.assertAlmostEqual(get_loss_value(model, loss), 5.0, places=1) From b25f927861ffcb13097f06fb1580e4ecfc5489d4 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 22 Dec 2021 12:13:47 -0700 Subject: [PATCH 5/9] Fix Mypy type hints --- captum/optim/_core/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index c22e4e458a..b6f80a0dd2 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -191,8 +191,8 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: def basic_torch_module_op( - loss: Union[Loss, List[Loss]], - torch_op: Callable[[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor], + loss, + torch_op: Callable, *args: Any, **kwargs: Any, ) -> "CompositeLoss": From 26d7007f1b0ca4c5fea602e5c77e3189c355de4c Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 22 Dec 2021 12:17:34 -0700 Subject: [PATCH 6/9] Fix flake8 error --- captum/optim/_core/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index b6f80a0dd2..036102372d 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -1,7 +1,7 @@ import functools import operator from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn as nn From 3a41426e6e5dcae104636f044643187806728ca4 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Fri, 7 Jan 2022 16:36:56 -0700 Subject: [PATCH 7/9] Add batch_index support for Diversity & Alignment --- captum/optim/_core/loss.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index 036102372d..df64c4071b 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -1,7 +1,7 @@ import functools import operator from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -261,14 +261,20 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: class BaseLoss(Loss): def __init__( - self, target: nn.Module = [], batch_index: Optional[int] = None + self, + target: nn.Module = [], + batch_index: Optional[Union[int, List[int]]] = None, ) -> None: super(BaseLoss, self).__init__() self._target = target if batch_index is None: self._batch_index = (None, None) + elif isinstance(batch_index, (list, tuple)): + assert len(batch_index) == 2 + self._batch_index = tuple(batch_index) else: self._batch_index = (batch_index, batch_index + 1) + assert all([isinstance(b, (int, type(None))) for b in self._batch_index]) @property def target(self) -> nn.Module: @@ -459,6 +465,7 @@ class Diversity(BaseLoss): def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] batch, channels = activations.shape[:2] flattened = activations.view(batch, channels, -1) grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2)) @@ -533,12 +540,18 @@ class Alignment(BaseLoss): https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons """ - def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: - BaseLoss.__init__(self, target) + def __init__( + self, + target: nn.Module, + decay_ratio: float = 2.0, + batch_index: Optional[List[int]] = None, + ) -> None: + BaseLoss.__init__(self, target, batch_index) self.decay_ratio = decay_ratio def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] B = activations.size(0) sum_tensor = torch.zeros(1, device=activations.device) From ce897d38259382d607f093e89b7b081073cefdc6 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 9 Jan 2022 13:10:26 -0700 Subject: [PATCH 8/9] Improvements to custom_composable_op * Improved documentation. * Renamed `basic_torch_module_op` to `custom_composable_op`. * Removed the reduction OP from 'r' module calls as it's not required. --- captum/optim/_core/loss.py | 58 +++++++++++++++++------------------ tests/optim/core/test_loss.py | 54 +++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 34 deletions(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index df64c4071b..70ad8b8aed 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -17,6 +17,7 @@ def _make_arg_str(arg: Any) -> str: # Reduction op for CompositeLoss loss composability size mismatch avoidance +# REDUCTION_OP is only used for binary math operations using two Loss instances REDUCTION_OP: Callable[[torch.Tensor], torch.Tensor] = torch.mean @@ -86,7 +87,9 @@ def __rfloordiv__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": def __rpow__(self, other: Union[int, float, "Loss"]) -> "CompositeLoss": rmodule_op(self, other, operator.pow) - def mean(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": + def mean( + self, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False + ) -> "CompositeLoss": """ Composable torch.mean reduction operator. See torch.mean for more details: https://pytorch.org/docs/stable/generated/torch.mean.html @@ -102,11 +105,13 @@ def mean(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": composite_loss (CompositeLoss): A composable loss instance. """ if dim is None: - return basic_torch_module_op(self, torch.mean) + return custom_composable_op(self, torch.mean) else: - return basic_torch_module_op(self, torch.mean, dim=dim, keepdim=keepdim) + return custom_composable_op(self, torch.mean, dim=dim, keepdim=keepdim) - def sum(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": + def sum( + self, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False + ) -> "CompositeLoss": """ Composable torch.sum reduction operator. See torch.sum for more details: https://pytorch.org/docs/stable/generated/torch.sum.html @@ -123,9 +128,9 @@ def sum(self, dim: Optional = None, keepdim: bool = False) -> "CompositeLoss": composite_loss (CompositeLoss): A composable loss instance. """ if dim is None: - return basic_torch_module_op(self, torch.sum) + return custom_composable_op(self, torch.sum) else: - return basic_torch_module_op(self, torch.sum, dim=dim, keepdim=keepdim) + return custom_composable_op(self, torch.sum, dim=dim, keepdim=keepdim) def module_op( @@ -175,7 +180,7 @@ def rmodule_op( if isinstance(other, (int, float)): def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return math_op(other, REDUCTION_OP(self(module))) + return math_op(other, self(module)) name = self.__name__ target = self.target @@ -190,53 +195,48 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: return CompositeLoss(loss_fn, name=name, target=target) -def basic_torch_module_op( +def custom_composable_op( loss, - torch_op: Callable, + loss_op_fn: Callable, *args: Any, **kwargs: Any, ) -> "CompositeLoss": """ - Implement composability for PyTorch operation that take a single tensor or list - of tensors as it's first input variable. - See here for possible torch_op choices: https://pytorch.org/docs/stable/torch.html - Some built-in Python functions can also be used as well if supported by PyTorch. + Implement composability for operations that take a single tensor or list of tensors + and then return a single tensor. Custom user defined functions can be used in + addition to some built-in Python functions and PyTorch operations. Args: loss (Loss or list of Loss): A loss objective or list of loss objectives. - torch_op (Callable): A PyTorch or supported Python function. Ex: torch.mean, - torch.sum,, torch.linalg.norm, torch.sin, torch.cat, torch.stack, max, min, - sum, math.ceil, and others. + loss_op_fn (Callable): A supported PyTorch, Python, or custom function. Default: torch.mean - args (Any, optional): Any additional arguments. - kwargs (Any, optional): Any additional arguments. + args (Any, optional): Any additional arguments to pass to loss_op_fn. + kwargs (Any, optional): Any additional arguments to pass to loss_op_fn. to_scalar_fn (Callable, optional): A function for converting loss function outputs to scalar values, in order to prevent size mismatches. This is variable only used if more than one loss is given. - Default: A non-op function. + Default: None Returns: composite_loss (CompositeLoss): A composable loss instance. """ if isinstance(loss, (tuple, list)): - - def identity(x: torch.Tensor) -> torch.Tensor: - return x - if "to_scalar_fn" not in kwargs: - to_scalar_fn = identity + to_scalar_fn = None else: to_scalar_fn = kwargs["to_scalar_fn"] del kwargs["to_scalar_fn"] def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - loss_tensors = [to_scalar_fn(loss_obj(module)) for loss_obj in loss] - return torch_op(loss_tensors, *args, **kwargs) + loss_tensors = [loss_obj(module) for loss_obj in loss] + if to_scalar_fn is not None: + loss_tensors = [to_scalar_fn(tensor) for tensor in loss_tensors] + return loss_op_fn(loss_tensors, *args, **kwargs) name_list = ", ".join([loss_obj.__name__ for loss_obj in loss]) - name = torch_op.__name__ + "(" + name_list + ")" + name = loss_op_fn.__name__ + "(" + name_list + ")" # Collect targets from losses target = [ @@ -252,9 +252,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: else: def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: - return torch_op(loss(module), *args, **kwargs) + return loss_op_fn(loss(module), *args, **kwargs) - name = torch_op.__name__ + "(" + loss.__name__ + ")" + name = loss_op_fn.__name__ + "(" + loss.__name__ + ")" target = loss.target return CompositeLoss(loss_fn, name=name, target=target) diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index 48973ca8b4..acdb79f47e 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -286,17 +286,61 @@ def test_reduction_op(self) -> None: self.assertEqual(opt_loss.REDUCTION_OP, torch.mean) -def TestBasisTorchModuleOP(BaseTest): +def TestCustomComposableOP(BaseTest): def test_torch_sum(self) -> None: model = torch.nn.Identity() loss = opt_loss.LayerActivation(model) - loss = opt_loss.basic_torch_module_op(loss, torch_op=torch.sum) + loss = opt_loss.custom_composable_op(loss, loss_op_fn=torch.sum) self.assertAlmostEqual(get_loss_value(model, loss), 3.0, places=1) def test_sum_list_with_scalar_fn(self) -> None: model = torch.nn.Identity() - loss_list = [opt_loss.LayerActivation(model)] * 5 - loss = opt_loss.basic_torch_module_op( - loss_list, torch_op=sum, to_scalar_fn=torch.mean + loss_list = [ + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + ] + loss = opt_loss.custom_composable_op( + loss_list, loss_op_fn=sum, to_scalar_fn=torch.mean ) self.assertAlmostEqual(get_loss_value(model, loss), 5.0, places=1) + + def test_custom_op(self) -> None: + def custom_op_fn( + losses: torch.Tensor, add_val: float = 1.0, mul_val: float = 1.0 + ) -> torch.Tensor: + return torch.sum(losses) + add_val * mul_val + + model = torch.nn.Identity() + loss = opt_loss.LayerActivation(model) + + loss = opt_loss.custom_composable_op( + loss, loss_op_fn=custom_op_fn, add_val=2.0, mul_val=2.0 + ) + self.assertAlmostEqual(get_loss_value(model, loss), 7.0, places=1) + + def test_custom_op_list(self) -> None: + def custom_op_list_fn( + losses: List[torch.Tensor], add_val: float = 1.0, mul_val: float = 1.0 + ) -> torch.Tensor: + return torch.cat( + [torch.sum(loss) + add_val * mul_val for loss in losses], 0 + ).sum() + + model = torch.nn.Identity() + loss_list = [ + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + opt_loss.LayerActivation(model), + ] + loss = opt_loss.custom_composable_op( + loss_list, + loss_op_fn=custom_op_list_fn, + add_val=2.0, + mul_val=2.0, + ) + self.assertAlmostEqual(get_loss_value(model, loss), 35.0, places=1) From 3438cf11a13a451220e20233fc7c2aae37d80f82 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 30 Jan 2022 10:43:05 -0700 Subject: [PATCH 9/9] Support any number of target batch dimensions * Custom loss objections can support any number of batch dimension values. --- captum/optim/_core/loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index da2be0a3be..3dfc54c58d 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -270,7 +270,6 @@ def __init__( if batch_index is None: self._batch_index = (None, None) elif isinstance(batch_index, (list, tuple)): - assert len(batch_index) == 2 self._batch_index = tuple(batch_index) else: self._batch_index = (batch_index, batch_index + 1)