Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions captum/optim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Captum "optim" module

This is a WIP PR to integrate existing feature visualization code from the authors of `tensorflow/lucid` into captum.
It is also an opportunity to review which parts of such interpretability tools still feel rough to implement in a system like PyTorch, and to make suggetsions to the core PyTorch team for how to improve these aspects.

## Roadmap

* unify API with Captum API: a single class that's callable per "technique"(? check for details before implementing)
* Consider if we need an abstraction around "an optimization process" (in terms of stopping criteria, reporting losses, etc) or if there are sufficiently strong conventions in PyTorch land for such tasks
* integrate Eli's FFT param changes (mostly for simplification)
* make a table of PyTorch interpretability tools for readme?
* do we need image viewing helpers and io helpers or throw those out?
* can we integrate paper references closer with the code?
27 changes: 27 additions & 0 deletions captum/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict, Optional, Union, Callable, Iterable
from typing_extensions import Protocol

import torch
import torch.nn as nn

ParametersForOptimizers = Iterable[Union[torch.Tensor, Dict[str, torch.tensor]]]


class HasLoss(Protocol):
def loss(self) -> torch.Tensor:
...


class Parameterized(Protocol):
parameters: ParametersForOptimizers


class Objective(Parameterized, HasLoss):
def cleanup(self):
pass


ModuleOutputMapping = Dict[nn.Module, Optional[torch.Tensor]]

StopCriteria = Callable[[int, Objective, torch.optim.Optimizer], bool]

139 changes: 139 additions & 0 deletions captum/optim/_scrap_and_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import requests
from PIL import Image
from IPython.display import display

from clarity.pytorch.inception_v1 import googlenet
from lucid.misc.io import show, load, save
from lucid.modelzoo.other_models import InceptionV1

# get a test image
img_url = (
"https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png"
)
img_tf = load(img_url)
img_pt = torch.as_tensor(img_tf.transpose(2, 0, 1))[None, ...]
img_pil = Image.open(requests.get(img_url, stream=True).raw)

# instantiate ported model
net = googlenet(pretrained=True)

# get predictions
out = net(img_pt)
logits = out.detach().numpy()[0]
top_k = np.argsort(-logits)[:5]

# load labels
labels = load(InceptionV1.labels_path, split=True)

# show predictions
for i, k in enumerate(top_k):
prediction = logits[k]
label = labels[k]
print(f"{i}: {label} ({prediction*100:.2f}%)")

# transforms


# def build_grid(source_size, target_size):
# k = float(target_size) / float(source_size)
# direct = (
# torch.linspace(0, k, target_size)
# .unsqueeze(0)
# .repeat(target_size, 1)
# .unsqueeze(-1)
# )
# full = torch.cat([direct, direct.transpose(1, 0)], dim=2).unsqueeze(0)
# return full.cuda()


# def random_crop_grid(x, grid):
# d = x.size(2) - grid.size(1)
# grid = grid.repeat(x.size(0), 1, 1, 1).cuda()
# # Add random shifts by x
# grid[:, :, :, 0] += torch.FloatTensor(x.size(0)).cuda().random_(0, d).unsqueeze(
# -1
# ).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) / x.size(2)
# # Add random shifts by y
# grid[:, :, :, 1] += torch.FloatTensor(x.size(0)).cuda().random_(0, d).unsqueeze(
# -1
# ).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) / x.size(2)
# return grid


# # We want to crop a 80x80 image randomly for our batch
# # Building central crop of 80 pixel size
# grid_source = build_grid(224, 80)
# # Make radom shift for each batch
# grid_shifted = random_crop_grid(batch, grid_source)
# # Sample using grid sample
# sampled_batch = F.grid_sample(batch, grid_shifted)


from clarity.pytorch.transform import RandomSpatialJitter, RandomUpsample

# crop = torchvision.transforms.RandomCrop(
# 224, padding=34, pad_if_needed=True, padding_mode="reflect"
# )
jitter = RandomSpatialJitter(16)
ups = RandomUpsample()
for i in range(10):
cropped = ups(img_pt)
show(cropped.numpy()[0].transpose(1, 2, 0))
# display(cropped)


# result = param().cpu().detach().numpy()[0].transpose(1, 2, 0)
# loss_curve = objective.history

# 2019-11-21 notes from Pytorch team
# Set up model
# net = googlenet(pretrained=True)
# parameterization = Image() # TODO: make size adjustable, currently hardcoded
# input_image = parameterization()

# writer = SummaryWriter()
# writer.add_graph(net, (input_image,))
# writer.close()

# Specify target module / "objective"
# target_module = net.mixed3b._pool_reduce[1]
# target_channel = 54
# hook = OutputHook(target_module) # TODO: investigate detach on rerun
# parameterization = Image() # TODO: make size adjustable, currently hardcoded
# optimizer = optim.Adam(parameterization.parameters, lr=0.025)

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# net = net.to(device)
# parameterization = parameterization.to(device)
# for i in range(1000):
# optimizer.zero_grad()

# # forward pass through entire net
# input_image = parameterization()
# with suppress(AbortForwardException):
# _ = net(input_image.to(device))

# # activations were stored during forward pass
# assert hook.saved_output is not None
# loss = -hook.saved_output[:, target_channel, :, :].sum() # channel 13

# loss.backward()
# optimizer.step()

# if i % 100 == 0:
# print("Loss: ", -loss.cpu().detach().numpy())
# url = show(
# parameterization.raw_image.cpu()
# .detach()
# .numpy()[0]
# .transpose(1, 2, 0)
# )

# traced_net = torch.jit.trace(net, example_inputs=(input_image,))
# print(traced_net.graph)
1 change: 1 addition & 0 deletions captum/optim/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .io import show
13 changes: 13 additions & 0 deletions captum/optim/io/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

# TODO: use imageio to redo load and avoid TF dependency
from lucid.misc.io import load

DOG_CAT_URL = (
"https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png"
)


def image(url: str = DOG_CAT_URL):
img_np = load(url)
return torch.as_tensor(img_np.transpose(2, 0, 1))
22 changes: 22 additions & 0 deletions captum/optim/io/formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from io import BytesIO

import torch
from torchvision import transforms

from IPython import display, get_ipython


def tensor_jpeg(tensor: torch.Tensor):
if tensor.dim() == 3:
pil_image = transforms.ToPILImage()(tensor.cpu().detach()).convert("RGB")
buffer = BytesIO()
pil_image.save(buffer, format="jpeg")
data = buffer.getvalue()
return data
else:
return tensor


def register_formatters():
jpeg_formatter = get_ipython().display_formatter.formatters["image/jpeg"]
jpeg_formatter.for_type(torch.Tensor, tensor_jpeg)
11 changes: 11 additions & 0 deletions captum/optim/io/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# TODO: redo show using display or register handler for jupyter display directly
# maybe we could even have subtypes of tensors that are "ImageTensors" or "ActivationTensors" etc
from lucid.misc.io import show as lucid_show


def show(thing):
if len(thing.shape) == 3:
numpy_thing = thing.cpu().detach().numpy().transpose(1, 2, 0)
elif len(thing.shape) == 4:
numpy_thing = thing.cpu().detach().numpy()[0].transpose(1, 2, 0)
lucid_show(numpy_thing)
1 change: 1 addition & 0 deletions captum/optim/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .inception_v1 import googlenet
119 changes: 119 additions & 0 deletions captum/optim/models/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


def _get_padding(kernel_size, stride=1, dilation=1, **_):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding


def _calc_same_pad(i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)


def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split


class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias)

def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
return F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)


# def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
# padding = kwargs.pop('padding', '')
# kwargs.setdefault('bias', False)
# if isinstance(padding, str):
# # for any string padding, the padding will be calculated for you, one of three ways
# padding = padding.lower()
# if padding == 'same':
# # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
# if _is_static_pad(kernel_size, **kwargs):
# # static case, no extra overhead
# padding = _get_padding(kernel_size, **kwargs)
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
# else:
# # dynamic padding
# return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
# elif padding == 'valid':
# # 'VALID' padding, same as padding=0
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
# else:
# # Default to PyTorch style 'same'-ish symmetric padding
# padding = _get_padding(kernel_size, **kwargs)
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
# else:
# # padding was specified as a number or pair
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)


# class MixedConv2d(nn.Module):
# """ Mixed Grouped Convolution
# Based on MDConv and GroupedConv in MixNet impl:
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
# """

# def __init__(self, in_channels, out_channels, kernel_size=3,
# stride=1, padding='', dilated=False, depthwise=False, **kwargs):
# super(MixedConv2d, self).__init__()

# kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
# num_groups = len(kernel_size)
# in_splits = _split_channels(in_channels, num_groups)
# out_splits = _split_channels(out_channels, num_groups)
# for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
# d = 1
# # FIXME make compat with non-square kernel/dilations/strides
# if stride == 1 and dilated:
# d, k = (k - 1) // 2, 3
# conv_groups = out_ch if depthwise else 1
# # use add_module to keep key space clean
# self.add_module(
# str(idx),
# conv2d_pad(
# in_ch, out_ch, k, stride=stride,
# padding=padding, dilation=d, groups=conv_groups, **kwargs)
# )
# self.splits = in_splits

# def forward(self, x):
# x_split = torch.split(x, self.splits, 1)
# x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
# x = torch.cat(x_out, 1)
# return x


# # helper method
# def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
# assert 'groups' not in kwargs # only use 'depthwise' bool arg
# if isinstance(kernel_size, list):
# # We're going to use only lists for defining the MixedConv2d kernel groups,
# # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
# return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
# else:
# depthwise = kwargs.pop('depthwise', False)
# groups = out_chs if depthwise else 1
# return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
Loading