diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index aaf8a2e4bf..30e5edf112 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -99,7 +99,6 @@ def sgd_train_linear_model( This will return the final training loss (averaged with `running_loss_window`) """ - loss_window: List[torch.Tensor] = [] min_avg_loss = None convergence_counter = 0 @@ -145,77 +144,77 @@ def get_point(datapoint): if model.linear.bias is not None: model.linear.bias.zero_() - optim = torch.optim.SGD(model.parameters(), lr=initial_lr) - if reduce_lr: - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optim, factor=0.5, patience=patience, threshold=threshold - ) - - t1 = time.time() - epoch = 0 - i = 0 - while epoch < max_epoch: - while True: # for x, y, w in dataloader - if running_loss_window is None: - running_loss_window = x.shape[0] * len(dataloader) - - y = y.view(x.shape[0], -1) - if w is not None: - w = w.view(x.shape[0], -1) - - i += 1 - - out = model(x) - - loss = loss_fn(y, out, w) - if reg_term is not None: - reg = torch.norm(model.linear.weight, p=reg_term) - loss += reg.sum() * alpha - - if len(loss_window) >= running_loss_window: - loss_window = loss_window[1:] - loss_window.append(loss.clone().detach()) - assert len(loss_window) <= running_loss_window - - average_loss = torch.mean(torch.stack(loss_window)) - if min_avg_loss is not None: - # if we haven't improved by at least `threshold` - if average_loss > min_avg_loss or torch.isclose( - min_avg_loss, average_loss, atol=threshold - ): - convergence_counter += 1 - if convergence_counter >= patience: - converged = True - break - else: - convergence_counter = 0 - if min_avg_loss is None or min_avg_loss >= average_loss: - min_avg_loss = average_loss.clone() - - if debug: - print( - f"lr={optim.param_groups[0]['lr']}, Loss={loss}," - + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" - ) - - loss.backward() - - optim.step() - model.zero_grad() - if scheduler: - scheduler.step(average_loss) - - temp = next(data_iter, None) - if temp is None: + with torch.enable_grad(): + optim = torch.optim.SGD(model.parameters(), lr=initial_lr) + if reduce_lr: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, factor=0.5, patience=patience, threshold=threshold + ) + + t1 = time.time() + epoch = 0 + i = 0 + while epoch < max_epoch: + while True: # for x, y, w in dataloader + if running_loss_window is None: + running_loss_window = x.shape[0] * len(dataloader) + + y = y.view(x.shape[0], -1) + if w is not None: + w = w.view(x.shape[0], -1) + + i += 1 + + out = model(x) + + loss = loss_fn(y, out, w) + if reg_term is not None: + reg = torch.norm(model.linear.weight, p=reg_term) + loss += reg.sum() * alpha + + if len(loss_window) >= running_loss_window: + loss_window = loss_window[1:] + loss_window.append(loss.clone().detach()) + assert len(loss_window) <= running_loss_window + + average_loss = torch.mean(torch.stack(loss_window)) + if min_avg_loss is not None: + # if we haven't improved by at least `threshold` + if average_loss > min_avg_loss or torch.isclose( + min_avg_loss, average_loss, atol=threshold + ): + convergence_counter += 1 + if convergence_counter >= patience: + converged = True + break + else: + convergence_counter = 0 + if min_avg_loss is None or min_avg_loss >= average_loss: + min_avg_loss = average_loss.clone() + + if debug: + print( + f"lr={optim.param_groups[0]['lr']}, Loss={loss}," + + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" + ) + + loss.backward() + optim.step() + model.zero_grad() + if scheduler: + scheduler.step(average_loss) + + temp = next(data_iter, None) + if temp is None: + break + x, y, w = get_point(temp) + + if converged: break - x, y, w = get_point(temp) - - if converged: - break - epoch += 1 - data_iter = iter(dataloader) - x, y, w = get_point(next(data_iter)) + epoch += 1 + data_iter = iter(dataloader) + x, y, w = get_point(next(data_iter)) t2 = time.time() return { diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 76f3f4ca71..520251ce53 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -512,17 +512,17 @@ def attribute( if show_progress: attr_progress.close() - combined_interp_inps = torch.cat(interpretable_inps).double() + combined_interp_inps = torch.cat(interpretable_inps).float() combined_outputs = ( torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs) - ).double() + ).float() combined_sim = ( torch.cat(similarities) if len(similarities[0].shape) > 0 else torch.stack(similarities) - ).double() + ).float() dataset = TensorDataset( combined_interp_inps, combined_outputs, combined_sim ) diff --git a/tests/attr/test_lime.py b/tests/attr/test_lime.py index 4287aa05ba..45646c47d7 100644 --- a/tests/attr/test_lime.py +++ b/tests/attr/test_lime.py @@ -3,10 +3,12 @@ import io import unittest import unittest.mock -from typing import Any, Callable, Generator, List, Tuple, Union +from functools import partial +from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch -from captum._utils.models.linear_model import SkLearnLasso +from captum._utils.models.linear_model import SGDLasso, SkLearnLasso +from captum._utils.models.model import Model from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.lime import get_exp_kernel_similarity_function, Lime, LimeBase from captum.attr._utils.batching import _batch_example_iterator @@ -120,6 +122,22 @@ def test_simple_lime(self) -> None: test_generator=True, ) + def test_simple_lime_sgd_model(self) -> None: + net = BasicModel_MultiLayer() + inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + interpretable_model = SGDLasso() + interpretable_model.fit = partial( # type: ignore + interpretable_model.fit, initial_lr=0.1, max_epoch=500 + ) + self._lime_test_assert( + net, + inp, + [[73.3716, 193.3349, 113.3349]], + n_samples=1000, + expected_coefs_only=[[73.3716, 193.3349, 113.3349]], + interpretable_model=interpretable_model, + ) + def test_simple_lime_with_mask(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) @@ -487,12 +505,15 @@ def _lime_test_assert( batch_attr: bool = False, test_generator: bool = False, show_progress: bool = False, + interpretable_model: Optional[Model] = None, ) -> None: for batch_size in perturbations_per_eval: lime = Lime( model, similarity_func=get_exp_kernel_similarity_function("cosine", 10.0), - interpretable_model=SkLearnLasso(alpha=1.0), + interpretable_model=interpretable_model + if interpretable_model + else SkLearnLasso(alpha=1.0), ) attributions = lime.attribute( test_input, @@ -526,7 +547,9 @@ def _lime_test_assert( lime_alt = LimeBase( model, - SkLearnLasso(alpha=1.0), + interpretable_model + if interpretable_model + else SkLearnLasso(alpha=1.0), get_exp_kernel_similarity_function("euclidean", 1000.0), alt_perturb_generator if test_generator else alt_perturb_func, False,