From 70489182901375ee81fc6bcaa89fa9a947dc943a Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 07:43:28 -0700 Subject: [PATCH 1/8] Linear fix --- captum/_utils/models/linear_model/model.py | 5 +++-- captum/_utils/models/linear_model/train.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/captum/_utils/models/linear_model/model.py b/captum/_utils/models/linear_model/model.py index bfffdbf38a..26e27a18bd 100644 --- a/captum/_utils/models/linear_model/model.py +++ b/captum/_utils/models/linear_model/model.py @@ -2,7 +2,7 @@ import torch.nn as nn from captum._utils.models.model import Model -from torch import Tensor +from torch import Tensor, dtype from torch.utils.data import DataLoader @@ -47,6 +47,7 @@ def _construct_model_params( weight_values: Optional[Tensor] = None, bias_value: Optional[Tensor] = None, classes: Optional[Tensor] = None, + dtype: Optional[dtype] = None, ): r""" Lazily initializes a linear model. This will be called for you in a @@ -102,7 +103,7 @@ def _construct_model_params( else: self.norm = None - self.linear = nn.Linear(in_features, out_features, bias=bias) + self.linear = nn.Linear(in_features, out_features, bias=bias, dtype=dtype) if weight_values is not None: self.linear.weight.data = weight_values diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index aaf8a2e4bf..43e58ca8ae 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -127,6 +127,7 @@ def get_point(datapoint): model._construct_model_params( in_features=x.shape[1], out_features=y.shape[1] if len(y.shape) == 2 else 1, + dtype=x.dtype, **construct_kwargs, ) model.train() From 71cabc46bd870441d4ecb56d68416f5f312f5bff Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 10:58:07 -0700 Subject: [PATCH 2/8] Fixes --- captum/_utils/models/linear_model/train.py | 229 ++++++++++----------- tests/attr/test_lime.py | 26 ++- 2 files changed, 134 insertions(+), 121 deletions(-) diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 43e58ca8ae..15737cdbc3 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -99,132 +99,129 @@ def sgd_train_linear_model( This will return the final training loss (averaged with `running_loss_window`) """ - - loss_window: List[torch.Tensor] = [] - min_avg_loss = None - convergence_counter = 0 - converged = False - - def get_point(datapoint): - if len(datapoint) == 2: - x, y = datapoint - w = None - else: - x, y, w = datapoint - - if device is not None: - x = x.to(device) - y = y.to(device) - if w is not None: - w = w.to(device) - - return x, y, w - - # get a point and construct the model - data_iter = iter(dataloader) - x, y, w = get_point(next(data_iter)) - - model._construct_model_params( - in_features=x.shape[1], - out_features=y.shape[1] if len(y.shape) == 2 else 1, - dtype=x.dtype, - **construct_kwargs, - ) - model.train() - - assert model.linear is not None - - if init_scheme is not None: - assert init_scheme in ["xavier", "zeros"] - - with torch.no_grad(): - if init_scheme == "xavier": - torch.nn.init.xavier_uniform_(model.linear.weight) + with torch.enable_grad(): + loss_window: List[torch.Tensor] = [] + min_avg_loss = None + convergence_counter = 0 + converged = False + + def get_point(datapoint): + if len(datapoint) == 2: + x, y = datapoint + w = None else: - model.linear.weight.zero_() + x, y, w = datapoint - if model.linear.bias is not None: - model.linear.bias.zero_() + if device is not None: + x = x.to(device) + y = y.to(device) + if w is not None: + w = w.to(device) - optim = torch.optim.SGD(model.parameters(), lr=initial_lr) - if reduce_lr: - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optim, factor=0.5, patience=patience, threshold=threshold - ) + return x, y, w - t1 = time.time() - epoch = 0 - i = 0 - while epoch < max_epoch: - while True: # for x, y, w in dataloader - if running_loss_window is None: - running_loss_window = x.shape[0] * len(dataloader) - - y = y.view(x.shape[0], -1) - if w is not None: - w = w.view(x.shape[0], -1) - - i += 1 - - out = model(x) - - loss = loss_fn(y, out, w) - if reg_term is not None: - reg = torch.norm(model.linear.weight, p=reg_term) - loss += reg.sum() * alpha - - if len(loss_window) >= running_loss_window: - loss_window = loss_window[1:] - loss_window.append(loss.clone().detach()) - assert len(loss_window) <= running_loss_window - - average_loss = torch.mean(torch.stack(loss_window)) - if min_avg_loss is not None: - # if we haven't improved by at least `threshold` - if average_loss > min_avg_loss or torch.isclose( - min_avg_loss, average_loss, atol=threshold - ): - convergence_counter += 1 - if convergence_counter >= patience: - converged = True - break - else: - convergence_counter = 0 - if min_avg_loss is None or min_avg_loss >= average_loss: - min_avg_loss = average_loss.clone() + # get a point and construct the model + data_iter = iter(dataloader) + x, y, w = get_point(next(data_iter)) - if debug: - print( - f"lr={optim.param_groups[0]['lr']}, Loss={loss}," - + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" - ) + model._construct_model_params( + in_features=x.shape[1], + out_features=y.shape[1] if len(y.shape) == 2 else 1, + dtype=x.dtype, + **construct_kwargs, + ) + model.train() - loss.backward() + assert model.linear is not None - optim.step() - model.zero_grad() - if scheduler: - scheduler.step(average_loss) + if init_scheme is not None: + assert init_scheme in ["xavier", "zeros"] - temp = next(data_iter, None) - if temp is None: + with torch.no_grad(): + if init_scheme == "xavier": + torch.nn.init.xavier_uniform_(model.linear.weight) + else: + model.linear.weight.zero_() + + if model.linear.bias is not None: + model.linear.bias.zero_() + + optim = torch.optim.SGD(model.parameters(), lr=initial_lr) + if reduce_lr: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, factor=0.5, patience=patience, threshold=threshold + ) + + t1 = time.time() + epoch = 0 + i = 0 + while epoch < max_epoch: + while True: # for x, y, w in dataloader + if running_loss_window is None: + running_loss_window = x.shape[0] * len(dataloader) + + y = y.view(x.shape[0], -1) + if w is not None: + w = w.view(x.shape[0], -1) + + i += 1 + out = model(x) + loss = loss_fn(y, out, w) + if reg_term is not None: + reg = torch.norm(model.linear.weight, p=reg_term) + loss += reg.sum() * alpha + + if len(loss_window) >= running_loss_window: + loss_window = loss_window[1:] + loss_window.append(loss.clone().detach()) + assert len(loss_window) <= running_loss_window + + average_loss = torch.mean(torch.stack(loss_window)) + if min_avg_loss is not None: + # if we haven't improved by at least `threshold` + if average_loss > min_avg_loss or torch.isclose( + min_avg_loss, average_loss, atol=threshold + ): + convergence_counter += 1 + if convergence_counter >= patience: + converged = True + break + else: + convergence_counter = 0 + if min_avg_loss is None or min_avg_loss >= average_loss: + min_avg_loss = average_loss.clone() + + if debug: + print( + f"lr={optim.param_groups[0]['lr']}, Loss={loss}," + + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" + ) + + loss.backward() + optim.step() + model.zero_grad() + if scheduler: + scheduler.step(average_loss) + + temp = next(data_iter, None) + if temp is None: + break + x, y, w = get_point(temp) + + if converged: break - x, y, w = get_point(temp) - - if converged: - break - - epoch += 1 - data_iter = iter(dataloader) - x, y, w = get_point(next(data_iter)) - t2 = time.time() - return { - "train_time": t2 - t1, - "train_loss": torch.mean(torch.stack(loss_window)).item(), - "train_iter": i, - "train_epoch": epoch, - } + epoch += 1 + data_iter = iter(dataloader) + x, y, w = get_point(next(data_iter)) + + t2 = time.time() + return { + "train_time": t2 - t1, + "train_loss": torch.mean(torch.stack(loss_window)).item(), + "train_iter": i, + "train_epoch": epoch, + } class NormLayer(nn.Module): diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 4287aa05ba..37acf59785 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -3,10 +3,11 @@ import io import unittest import unittest.mock -from typing import Any, Callable, Generator, List, Tuple, Union +from typing import Any, Callable, Generator, List, Tuple, Optional, Union import torch -from captum._utils.models.linear_model import SkLearnLasso +from captum._utils.models.linear_model import SkLearnLasso, SGDLasso +from captum._utils.models.model import Model from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.lime import get_exp_kernel_similarity_function, Lime, LimeBase from captum.attr._utils.batching import _batch_example_iterator @@ -27,7 +28,7 @@ BasicModelBoolInput, ) from torch import Tensor - +from functools import partial def alt_perturb_func( original_inp: TensorOrTupleOfTensorsGeneric, **kwargs @@ -120,6 +121,20 @@ def test_simple_lime(self) -> None: test_generator=True, ) + def test_simple_lime_sgd_model(self) -> None: + net = BasicModel_MultiLayer() + inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + interpretable_model = SGDLasso() + interpretable_model.fit = partial(interpretable_model.fit, initial_lr=0.1, max_epoch=500) + self._lime_test_assert( + net, + inp, + [[73.3716, 193.3349, 113.3349]], + n_samples=1000, + expected_coefs_only=[[73.3716, 193.3349, 113.3349]], + interpretable_model=interpretable_model + ) + def test_simple_lime_with_mask(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) @@ -487,12 +502,13 @@ def _lime_test_assert( batch_attr: bool = False, test_generator: bool = False, show_progress: bool = False, + interpretable_model: Optional[Model] = None, ) -> None: for batch_size in perturbations_per_eval: lime = Lime( model, similarity_func=get_exp_kernel_similarity_function("cosine", 10.0), - interpretable_model=SkLearnLasso(alpha=1.0), + interpretable_model=interpretable_model if interpretable_model else SkLearnLasso(alpha=1.0), ) attributions = lime.attribute( test_input, @@ -526,7 +542,7 @@ def _lime_test_assert( lime_alt = LimeBase( model, - SkLearnLasso(alpha=1.0), + interpretable_model if interpretable_model else SkLearnLasso(alpha=1.0), get_exp_kernel_similarity_function("euclidean", 1000.0), alt_perturb_generator if test_generator else alt_perturb_func, False, From 71f2d7a7597273b0dad036c1c8cf0ac391550475 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 10:58:44 -0700 Subject: [PATCH 3/8] Lint --- tests/attr/test_lime.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 37acf59785..51b3414f17 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -3,6 +3,7 @@ import io import unittest import unittest.mock +from functools import partial from typing import Any, Callable, Generator, List, Tuple, Optional, Union import torch @@ -28,7 +29,7 @@ BasicModelBoolInput, ) from torch import Tensor -from functools import partial + def alt_perturb_func( original_inp: TensorOrTupleOfTensorsGeneric, **kwargs @@ -125,14 +126,16 @@ def test_simple_lime_sgd_model(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) interpretable_model = SGDLasso() - interpretable_model.fit = partial(interpretable_model.fit, initial_lr=0.1, max_epoch=500) + interpretable_model.fit = partial( + interpretable_model.fit, initial_lr=0.1, max_epoch=500 + ) self._lime_test_assert( net, inp, [[73.3716, 193.3349, 113.3349]], n_samples=1000, expected_coefs_only=[[73.3716, 193.3349, 113.3349]], - interpretable_model=interpretable_model + interpretable_model=interpretable_model, ) def test_simple_lime_with_mask(self) -> None: @@ -508,7 +511,9 @@ def _lime_test_assert( lime = Lime( model, similarity_func=get_exp_kernel_similarity_function("cosine", 10.0), - interpretable_model=interpretable_model if interpretable_model else SkLearnLasso(alpha=1.0), + interpretable_model=interpretable_model + if interpretable_model + else SkLearnLasso(alpha=1.0), ) attributions = lime.attribute( test_input, @@ -542,7 +547,9 @@ def _lime_test_assert( lime_alt = LimeBase( model, - interpretable_model if interpretable_model else SkLearnLasso(alpha=1.0), + interpretable_model + if interpretable_model + else SkLearnLasso(alpha=1.0), get_exp_kernel_similarity_function("euclidean", 1000.0), alt_perturb_generator if test_generator else alt_perturb_func, False, From 18a2246bd11682fa83b3f3d793e54d31c6faec3e Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 11:37:20 -0700 Subject: [PATCH 4/8] Move enable grad --- captum/_utils/models/linear_model/train.py | 90 +++++++++++----------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 15737cdbc3..48c7c6bddc 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -99,53 +99,53 @@ def sgd_train_linear_model( This will return the final training loss (averaged with `running_loss_window`) """ - with torch.enable_grad(): - loss_window: List[torch.Tensor] = [] - min_avg_loss = None - convergence_counter = 0 - converged = False - - def get_point(datapoint): - if len(datapoint) == 2: - x, y = datapoint - w = None - else: - x, y, w = datapoint + loss_window: List[torch.Tensor] = [] + min_avg_loss = None + convergence_counter = 0 + converged = False + + def get_point(datapoint): + if len(datapoint) == 2: + x, y = datapoint + w = None + else: + x, y, w = datapoint - if device is not None: - x = x.to(device) - y = y.to(device) - if w is not None: - w = w.to(device) + if device is not None: + x = x.to(device) + y = y.to(device) + if w is not None: + w = w.to(device) - return x, y, w + return x, y, w - # get a point and construct the model - data_iter = iter(dataloader) - x, y, w = get_point(next(data_iter)) + # get a point and construct the model + data_iter = iter(dataloader) + x, y, w = get_point(next(data_iter)) - model._construct_model_params( - in_features=x.shape[1], - out_features=y.shape[1] if len(y.shape) == 2 else 1, - dtype=x.dtype, - **construct_kwargs, - ) - model.train() + model._construct_model_params( + in_features=x.shape[1], + out_features=y.shape[1] if len(y.shape) == 2 else 1, + dtype=x.dtype, + **construct_kwargs, + ) + model.train() - assert model.linear is not None + assert model.linear is not None - if init_scheme is not None: - assert init_scheme in ["xavier", "zeros"] + if init_scheme is not None: + assert init_scheme in ["xavier", "zeros"] - with torch.no_grad(): - if init_scheme == "xavier": - torch.nn.init.xavier_uniform_(model.linear.weight) - else: - model.linear.weight.zero_() + with torch.no_grad(): + if init_scheme == "xavier": + torch.nn.init.xavier_uniform_(model.linear.weight) + else: + model.linear.weight.zero_() - if model.linear.bias is not None: - model.linear.bias.zero_() + if model.linear.bias is not None: + model.linear.bias.zero_() + with torch.enable_grad(): optim = torch.optim.SGD(model.parameters(), lr=initial_lr) if reduce_lr: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -215,13 +215,13 @@ def get_point(datapoint): data_iter = iter(dataloader) x, y, w = get_point(next(data_iter)) - t2 = time.time() - return { - "train_time": t2 - t1, - "train_loss": torch.mean(torch.stack(loss_window)).item(), - "train_iter": i, - "train_epoch": epoch, - } + t2 = time.time() + return { + "train_time": t2 - t1, + "train_loss": torch.mean(torch.stack(loss_window)).item(), + "train_iter": i, + "train_epoch": epoch, + } class NormLayer(nn.Module): From 20019617af2f56712c1f1191c3d7402cb6482ad4 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 11:40:51 -0700 Subject: [PATCH 5/8] Lint fix --- captum/_utils/models/linear_model/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 48c7c6bddc..f697162b25 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -165,7 +165,9 @@ def get_point(datapoint): w = w.view(x.shape[0], -1) i += 1 + out = model(x) + loss = loss_fn(y, out, w) if reg_term is not None: reg = torch.norm(model.linear.weight, p=reg_term) From 3bcbaa618c72f1463c7a25551ad2e767baa29829 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 14:33:07 -0700 Subject: [PATCH 6/8] Switch from double to float --- captum/_utils/models/linear_model/model.py | 5 ++--- captum/_utils/models/linear_model/train.py | 1 - captum/attr/_core/lime.py | 6 +++--- tests/attr/test_lime.py | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/captum/_utils/models/linear_model/model.py b/captum/_utils/models/linear_model/model.py index 26e27a18bd..bfffdbf38a 100644 --- a/captum/_utils/models/linear_model/model.py +++ b/captum/_utils/models/linear_model/model.py @@ -2,7 +2,7 @@ import torch.nn as nn from captum._utils.models.model import Model -from torch import Tensor, dtype +from torch import Tensor from torch.utils.data import DataLoader @@ -47,7 +47,6 @@ def _construct_model_params( weight_values: Optional[Tensor] = None, bias_value: Optional[Tensor] = None, classes: Optional[Tensor] = None, - dtype: Optional[dtype] = None, ): r""" Lazily initializes a linear model. This will be called for you in a @@ -103,7 +102,7 @@ def _construct_model_params( else: self.norm = None - self.linear = nn.Linear(in_features, out_features, bias=bias, dtype=dtype) + self.linear = nn.Linear(in_features, out_features, bias=bias) if weight_values is not None: self.linear.weight.data = weight_values diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index f697162b25..30e5edf112 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -126,7 +126,6 @@ def get_point(datapoint): model._construct_model_params( in_features=x.shape[1], out_features=y.shape[1] if len(y.shape) == 2 else 1, - dtype=x.dtype, **construct_kwargs, ) model.train() diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 289a4e51b6..04325e5f00 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -512,17 +512,17 @@ def attribute( if show_progress: attr_progress.close() - combined_interp_inps = torch.cat(interpretable_inps).double() + combined_interp_inps = torch.cat(interpretable_inps).float() combined_outputs = ( torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs) - ).double() + ).float() combined_sim = ( torch.cat(similarities) if len(similarities[0].shape) > 0 else torch.stack(similarities) - ).double() + ).float() dataset = TensorDataset( combined_interp_inps, combined_outputs, combined_sim ) diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 51b3414f17..6d90792e24 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -126,7 +126,7 @@ def test_simple_lime_sgd_model(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) interpretable_model = SGDLasso() - interpretable_model.fit = partial( + interpretable_model.fit = partial( # type: ignore interpretable_model.fit, initial_lr=0.1, max_epoch=500 ) self._lime_test_assert( From 3d4a9e46fd759f7992af62a9d3657d8fba83bc06 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Wed, 11 May 2022 14:33:31 -0700 Subject: [PATCH 7/8] Lint fix --- tests/attr/test_lime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 6d90792e24..9f4de46a85 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -126,7 +126,7 @@ def test_simple_lime_sgd_model(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) interpretable_model = SGDLasso() - interpretable_model.fit = partial( # type: ignore + interpretable_model.fit = partial( # type: ignore interpretable_model.fit, initial_lr=0.1, max_epoch=500 ) self._lime_test_assert( From 5372bee5bcb8dcc25eba82e544f35990dd165ecd Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 17 May 2022 22:40:49 -0700 Subject: [PATCH 8/8] Fix format --- tests/attr/test_lime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 9f4de46a85..45646c47d7 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -4,10 +4,10 @@ import unittest import unittest.mock from functools import partial -from typing import Any, Callable, Generator, List, Tuple, Optional, Union +from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch -from captum._utils.models.linear_model import SkLearnLasso, SGDLasso +from captum._utils.models.linear_model import SGDLasso, SkLearnLasso from captum._utils.models.model import Model from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.lime import get_exp_kernel_similarity_function, Lime, LimeBase