Skip to content

Commit a702728

Browse files
vivekmigfacebook-github-bot
authored andcommitted
SGD Linear Model Fixes for Lime (#938)
Summary: This updates SGD linear models to work appropriately with Lime, addressing #910 . Particularly, this switches Lime interpretable model inputs / outputs from double to float and enables gradients when necessary. Also adds a unit test to Lime for testing with SGD linear models. Pull Request resolved: #938 Reviewed By: NarineK Differential Revision: D36331146 Pulled By: vivekmig fbshipit-source-id: 84d7aecf293404f9ba0b14c48e8723e0e489b392
1 parent 33d2b75 commit a702728

File tree

3 files changed

+99
-77
lines changed

3 files changed

+99
-77
lines changed

captum/_utils/models/linear_model/train.py

Lines changed: 69 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def sgd_train_linear_model(
9999
This will return the final training loss (averaged with
100100
`running_loss_window`)
101101
"""
102-
103102
loss_window: List[torch.Tensor] = []
104103
min_avg_loss = None
105104
convergence_counter = 0
@@ -145,77 +144,77 @@ def get_point(datapoint):
145144
if model.linear.bias is not None:
146145
model.linear.bias.zero_()
147146

148-
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
149-
if reduce_lr:
150-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
151-
optim, factor=0.5, patience=patience, threshold=threshold
152-
)
153-
154-
t1 = time.time()
155-
epoch = 0
156-
i = 0
157-
while epoch < max_epoch:
158-
while True: # for x, y, w in dataloader
159-
if running_loss_window is None:
160-
running_loss_window = x.shape[0] * len(dataloader)
161-
162-
y = y.view(x.shape[0], -1)
163-
if w is not None:
164-
w = w.view(x.shape[0], -1)
165-
166-
i += 1
167-
168-
out = model(x)
169-
170-
loss = loss_fn(y, out, w)
171-
if reg_term is not None:
172-
reg = torch.norm(model.linear.weight, p=reg_term)
173-
loss += reg.sum() * alpha
174-
175-
if len(loss_window) >= running_loss_window:
176-
loss_window = loss_window[1:]
177-
loss_window.append(loss.clone().detach())
178-
assert len(loss_window) <= running_loss_window
179-
180-
average_loss = torch.mean(torch.stack(loss_window))
181-
if min_avg_loss is not None:
182-
# if we haven't improved by at least `threshold`
183-
if average_loss > min_avg_loss or torch.isclose(
184-
min_avg_loss, average_loss, atol=threshold
185-
):
186-
convergence_counter += 1
187-
if convergence_counter >= patience:
188-
converged = True
189-
break
190-
else:
191-
convergence_counter = 0
192-
if min_avg_loss is None or min_avg_loss >= average_loss:
193-
min_avg_loss = average_loss.clone()
194-
195-
if debug:
196-
print(
197-
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
198-
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199-
)
200-
201-
loss.backward()
202-
203-
optim.step()
204-
model.zero_grad()
205-
if scheduler:
206-
scheduler.step(average_loss)
207-
208-
temp = next(data_iter, None)
209-
if temp is None:
147+
with torch.enable_grad():
148+
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
149+
if reduce_lr:
150+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
151+
optim, factor=0.5, patience=patience, threshold=threshold
152+
)
153+
154+
t1 = time.time()
155+
epoch = 0
156+
i = 0
157+
while epoch < max_epoch:
158+
while True: # for x, y, w in dataloader
159+
if running_loss_window is None:
160+
running_loss_window = x.shape[0] * len(dataloader)
161+
162+
y = y.view(x.shape[0], -1)
163+
if w is not None:
164+
w = w.view(x.shape[0], -1)
165+
166+
i += 1
167+
168+
out = model(x)
169+
170+
loss = loss_fn(y, out, w)
171+
if reg_term is not None:
172+
reg = torch.norm(model.linear.weight, p=reg_term)
173+
loss += reg.sum() * alpha
174+
175+
if len(loss_window) >= running_loss_window:
176+
loss_window = loss_window[1:]
177+
loss_window.append(loss.clone().detach())
178+
assert len(loss_window) <= running_loss_window
179+
180+
average_loss = torch.mean(torch.stack(loss_window))
181+
if min_avg_loss is not None:
182+
# if we haven't improved by at least `threshold`
183+
if average_loss > min_avg_loss or torch.isclose(
184+
min_avg_loss, average_loss, atol=threshold
185+
):
186+
convergence_counter += 1
187+
if convergence_counter >= patience:
188+
converged = True
189+
break
190+
else:
191+
convergence_counter = 0
192+
if min_avg_loss is None or min_avg_loss >= average_loss:
193+
min_avg_loss = average_loss.clone()
194+
195+
if debug:
196+
print(
197+
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
198+
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199+
)
200+
201+
loss.backward()
202+
optim.step()
203+
model.zero_grad()
204+
if scheduler:
205+
scheduler.step(average_loss)
206+
207+
temp = next(data_iter, None)
208+
if temp is None:
209+
break
210+
x, y, w = get_point(temp)
211+
212+
if converged:
210213
break
211-
x, y, w = get_point(temp)
212-
213-
if converged:
214-
break
215214

216-
epoch += 1
217-
data_iter = iter(dataloader)
218-
x, y, w = get_point(next(data_iter))
215+
epoch += 1
216+
data_iter = iter(dataloader)
217+
x, y, w = get_point(next(data_iter))
219218

220219
t2 = time.time()
221220
return {

captum/attr/_core/lime.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,17 +512,17 @@ def attribute(
512512
if show_progress:
513513
attr_progress.close()
514514

515-
combined_interp_inps = torch.cat(interpretable_inps).double()
515+
combined_interp_inps = torch.cat(interpretable_inps).float()
516516
combined_outputs = (
517517
torch.cat(outputs)
518518
if len(outputs[0].shape) > 0
519519
else torch.stack(outputs)
520-
).double()
520+
).float()
521521
combined_sim = (
522522
torch.cat(similarities)
523523
if len(similarities[0].shape) > 0
524524
else torch.stack(similarities)
525-
).double()
525+
).float()
526526
dataset = TensorDataset(
527527
combined_interp_inps, combined_outputs, combined_sim
528528
)

tests/attr/test_lime.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import io
44
import unittest
55
import unittest.mock
6-
from typing import Any, Callable, Generator, List, Tuple, Union
6+
from functools import partial
7+
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
78

89
import torch
9-
from captum._utils.models.linear_model import SkLearnLasso
10+
from captum._utils.models.linear_model import SGDLasso, SkLearnLasso
11+
from captum._utils.models.model import Model
1012
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
1113
from captum.attr._core.lime import get_exp_kernel_similarity_function, Lime, LimeBase
1214
from captum.attr._utils.batching import _batch_example_iterator
@@ -120,6 +122,22 @@ def test_simple_lime(self) -> None:
120122
test_generator=True,
121123
)
122124

125+
def test_simple_lime_sgd_model(self) -> None:
126+
net = BasicModel_MultiLayer()
127+
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
128+
interpretable_model = SGDLasso()
129+
interpretable_model.fit = partial( # type: ignore
130+
interpretable_model.fit, initial_lr=0.1, max_epoch=500
131+
)
132+
self._lime_test_assert(
133+
net,
134+
inp,
135+
[[73.3716, 193.3349, 113.3349]],
136+
n_samples=1000,
137+
expected_coefs_only=[[73.3716, 193.3349, 113.3349]],
138+
interpretable_model=interpretable_model,
139+
)
140+
123141
def test_simple_lime_with_mask(self) -> None:
124142
net = BasicModel_MultiLayer()
125143
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
@@ -487,12 +505,15 @@ def _lime_test_assert(
487505
batch_attr: bool = False,
488506
test_generator: bool = False,
489507
show_progress: bool = False,
508+
interpretable_model: Optional[Model] = None,
490509
) -> None:
491510
for batch_size in perturbations_per_eval:
492511
lime = Lime(
493512
model,
494513
similarity_func=get_exp_kernel_similarity_function("cosine", 10.0),
495-
interpretable_model=SkLearnLasso(alpha=1.0),
514+
interpretable_model=interpretable_model
515+
if interpretable_model
516+
else SkLearnLasso(alpha=1.0),
496517
)
497518
attributions = lime.attribute(
498519
test_input,
@@ -526,7 +547,9 @@ def _lime_test_assert(
526547

527548
lime_alt = LimeBase(
528549
model,
529-
SkLearnLasso(alpha=1.0),
550+
interpretable_model
551+
if interpretable_model
552+
else SkLearnLasso(alpha=1.0),
530553
get_exp_kernel_similarity_function("euclidean", 1000.0),
531554
alt_perturb_generator if test_generator else alt_perturb_func,
532555
False,

0 commit comments

Comments
 (0)