Commit 9b14f36
update tracin influence API
Summary:
This diff changes the API for implementations of `TracInCPBase` as discussed in https://fb.quip.com/JbpnAiWluZmI. In particular, the arguments representing test data of the `influence` method are changed from `inputs: Tuple, targets: Optional[Tensor]` to `inputs: Union[Tuple[Any], DataLoader]`, which is either a single batch, or a dataloader yielding batches. In both cases, `model(*batch)` is assumed to produce the predictions for a batch, and `batch[-1]` is assumed to be the labels for a batch. This is the same format assumed of the batches yielded by `train_dataloader`.
We make this change for 2 reasons
- it unifies the assumptions made of the test data and the assumptions made of the training data
- for some implementations, we want to allow the test data to be represented by a dataloader. with the old API, there was no clean way to allow both a single as well as a dataloader to be passed in, since a batch required 2 arguments, but a dataloader only requires 1.
For now, all implementations only allow `inputs` to be a tuple (and not a dataloader). This is okay due to inheritance rules. Later on, we will allow some implementations (i.e. `TracInCP`) to accept a dataloader as `inputs`.
Other changes:
- changes to make documentation. for example, documentation in `TracInCPBase.influence` now refers to the "test dataset" instead of test batch.
- the `unpack_inputs` argument is no longer needed for the `influence` methods, and is removed
- the usage of `influence` in all the tests is changed to match new API.
- signature of helper methods `_influence_batch_tracincp` and `_influence_batch_tracincp_fast` are changed to match new representation of batches.
Reviewed By: cyrjano
Differential Revision: D41324297
fbshipit-source-id: ee06ef12c79c7bc143b786ae2c5c15d5c8c116841 parent ada8c0d commit 9b14f36
File tree
10 files changed
+243
-281
lines changed- captum/influence
- _core
- _utils
- tests/influence
- _core
- _utils
10 files changed
+243
-281
lines changedLarge diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
187 | 187 | | |
188 | 188 | | |
189 | 189 | | |
190 | | - | |
191 | 190 | | |
192 | 191 | | |
193 | 192 | | |
| |||
202 | 201 | | |
203 | 202 | | |
204 | 203 | | |
205 | | - | |
| 204 | + | |
206 | 205 | | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
212 | 210 | | |
213 | 211 | | |
214 | 212 | | |
| |||
270 | 268 | | |
271 | 269 | | |
272 | 270 | | |
273 | | - | |
| 271 | + | |
274 | 272 | | |
275 | 273 | | |
276 | 274 | | |
| |||
Lines changed: 5 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
| 14 | + | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
| |||
162 | 163 | | |
163 | 164 | | |
164 | 165 | | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
165 | 169 | | |
166 | | - | |
| 170 | + | |
167 | 171 | | |
168 | 172 | | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | 173 | | |
185 | 174 | | |
186 | 175 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
107 | 108 | | |
108 | 109 | | |
109 | 110 | | |
110 | | - | |
| 111 | + | |
| 112 | + | |
111 | 113 | | |
112 | 114 | | |
113 | 115 | | |
114 | | - | |
115 | | - | |
| 116 | + | |
116 | 117 | | |
117 | 118 | | |
118 | | - | |
119 | 119 | | |
120 | 120 | | |
121 | 121 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
165 | 165 | | |
166 | 166 | | |
167 | 167 | | |
168 | | - | |
| 168 | + | |
169 | 169 | | |
170 | | - | |
| 170 | + | |
171 | 171 | | |
172 | 172 | | |
173 | 173 | | |
174 | 174 | | |
175 | 175 | | |
176 | 176 | | |
177 | 177 | | |
178 | | - | |
| 178 | + | |
179 | 179 | | |
180 | | - | |
| 180 | + | |
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
| |||
208 | 208 | | |
209 | 209 | | |
210 | 210 | | |
211 | | - | |
| 211 | + | |
212 | 212 | | |
213 | | - | |
| 213 | + | |
214 | 214 | | |
215 | 215 | | |
216 | 216 | | |
217 | 217 | | |
218 | 218 | | |
219 | | - | |
| 219 | + | |
220 | 220 | | |
221 | | - | |
| 221 | + | |
222 | 222 | | |
223 | 223 | | |
224 | 224 | | |
| |||
270 | 270 | | |
271 | 271 | | |
272 | 272 | | |
273 | | - | |
| 273 | + | |
274 | 274 | | |
275 | 275 | | |
276 | 276 | | |
| |||
364 | 364 | | |
365 | 365 | | |
366 | 366 | | |
367 | | - | |
| 367 | + | |
368 | 368 | | |
369 | | - | |
| 369 | + | |
370 | 370 | | |
371 | 371 | | |
372 | 372 | | |
| |||
397 | 397 | | |
398 | 398 | | |
399 | 399 | | |
400 | | - | |
| 400 | + | |
401 | 401 | | |
402 | | - | |
| 402 | + | |
403 | 403 | | |
404 | 404 | | |
405 | 405 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
108 | 109 | | |
109 | 110 | | |
110 | 111 | | |
111 | | - | |
112 | | - | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
113 | 115 | | |
114 | | - | |
115 | 116 | | |
116 | 117 | | |
117 | 118 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
181 | | - | |
182 | | - | |
| 181 | + | |
183 | 182 | | |
184 | 183 | | |
185 | 184 | | |
| |||
196 | 195 | | |
197 | 196 | | |
198 | 197 | | |
199 | | - | |
200 | | - | |
| 198 | + | |
201 | 199 | | |
202 | 200 | | |
203 | 201 | | |
| |||
218 | 216 | | |
219 | 217 | | |
220 | 218 | | |
221 | | - | |
222 | | - | |
| 219 | + | |
223 | 220 | | |
224 | 221 | | |
225 | 222 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
258 | 258 | | |
259 | 259 | | |
260 | 260 | | |
261 | | - | |
| 261 | + | |
262 | 262 | | |
263 | 263 | | |
264 | 264 | | |
| |||
288 | 288 | | |
289 | 289 | | |
290 | 290 | | |
291 | | - | |
| 291 | + | |
292 | 292 | | |
293 | | - | |
| 293 | + | |
294 | 294 | | |
295 | 295 | | |
296 | 296 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
| 5 | + | |
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
344 | 345 | | |
345 | 346 | | |
346 | 347 | | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
0 commit comments