Skip to content
2 changes: 2 additions & 0 deletions captum/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from captum.optim._utils.image import atlas # noqa: F401
from captum.optim._utils.image.common import ( # noqa: F401
hue_to_rgb,
make_grid_image,
nchannels_to_rgb,
save_tensor_as_image,
show,
Expand All @@ -25,6 +26,7 @@
"circuits",
"models",
"reducer",
"make_grid_image",
"atlas",
"hue_to_rgb",
"nchannels_to_rgb",
Expand Down
66 changes: 58 additions & 8 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ def open(cls, path: str, scale: float = 255.0, mode: str = "RGB") -> "ImageTenso
path (str): A URL or filepath to an image.
scale (float, optional): The image scale to use.
Default: 255.0
mode (str, optional): The image loading mode to use.
mode (str, optional): The image loading mode / colorspace to use.
Default: "RGB"

Returns:
x (ImageTensor): An `ImageTensor` instance.
"""
if path.startswith("https://") or path.startswith("http://"):
response = requests.get(path, stream=True)
headers = {"User-Agent": "Captum"}
response = requests.get(path, stream=True, headers=headers)
img = Image.open(response.raw)
else:
img = Image.open(path)
Expand Down Expand Up @@ -95,7 +96,12 @@ def __torch_function__(
return super().__torch_function__(func, types, args, kwargs)

def show(
self, figsize: Optional[Tuple[int, int]] = None, scale: float = 255.0
self,
figsize: Optional[Tuple[int, int]] = None,
scale: float = 255.0,
images_per_row: Optional[int] = None,
padding: int = 2,
pad_value: float = 0.0,
) -> None:
"""
Display an `ImageTensor`.
Expand All @@ -107,10 +113,34 @@ def show(
scale (float, optional): Value to multiply the `ImageTensor` by so that
it's value range is [0-255] for display.
Default: 255.0
"""
show(self, figsize=figsize, scale=scale)
images_per_row (int, optional): The number of images per row to use for the
grid image. Default is set to None for no grid image creation.
Default: None
padding (int, optional): The amount of padding between images in the grid
images. This parameter only has an effect if `nrow` is not None.
Default: 2
pad_value (float, optional): The value to use for the padding. This
parameter only has an effect if `nrow` is not None.
Default: 0.0
"""
show(
self,
figsize=figsize,
scale=scale,
images_per_row=images_per_row,
padding=padding,
pad_value=pad_value,
)

def export(self, filename: str, scale: float = 255.0) -> None:
def export(
self,
filename: str,
scale: float = 255.0,
mode: Optional[str] = None,
images_per_row: Optional[int] = None,
padding: int = 2,
pad_value: float = 0.0,
) -> None:
"""
Save an `ImageTensor` as an image file.

Expand All @@ -121,8 +151,28 @@ def export(self, filename: str, scale: float = 255.0) -> None:
scale (float, optional): Value to multiply the `ImageTensor` by so that
it's value range is [0-255] for saving.
Default: 255.0
"""
save_tensor_as_image(self, filename=filename, scale=scale)
mode (str, optional): A PIL / Pillow supported colorspace. Default is
set to None for automatic RGB / RGBA detection and usage.
Default: None
images_per_row (int, optional): The number of images per row to use for the
grid image. Default is set to None for no grid image creation.
Default: None
padding (int, optional): The amount of padding between images in the grid
images. This parameter only has an effect if `nrow` is not None.
Default: 2
pad_value (float, optional): The value to use for the padding. This
parameter only has an effect if `nrow` is not None.
Default: 0.0
"""
save_tensor_as_image(
self,
filename=filename,
scale=scale,
mode=mode,
images_per_row=images_per_row,
padding=padding,
pad_value=pad_value,
)


class InputParameterization(torch.nn.Module):
Expand Down
116 changes: 109 additions & 7 deletions captum/optim/_utils/image/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -13,25 +13,100 @@
print("The Pillow/PIL library is required to use Captum's Optim library")


def make_grid_image(
tiles: Union[torch.Tensor, List[torch.Tensor]],
images_per_row: int = 4,
padding: int = 2,
pad_value: float = 0.0,
) -> torch.Tensor:
"""
Make grids from NCHW Image tensors in a way similar to torchvision.utils.make_grid,
but without any channel duplication or creation behaviour.

Args:

tiles (torch.Tensor or list of torch.Tensor): A stack of NCHW image tensors or
a list of NCHW image tensors to create a grid from.
nrow (int, optional): The number of rows to use for the grid image.
Default: 4
padding (int, optional): The amount of padding between images in the grid
images.
padding: 2
pad_value (float, optional): The value to use for the padding.
Default: 0.0

Returns:
grid_img (torch.Tensor): The full NCHW grid image.
"""
assert padding >= 0 and images_per_row >= 1
if isinstance(tiles, (list, tuple)):
assert all([t.device == tiles[0].device for t in tiles])
assert all([t.dim() == 4 for t in tiles])
tiles = torch.cat(tiles, 0)
assert tiles.dim() == 4

B, C, H, W = tiles.shape

x_rows = min(images_per_row, B)
y_rows = int(math.ceil(float(B) / x_rows))

base_height = ((H + padding) * y_rows) + padding
base_width = ((W + padding) * x_rows) + padding

grid_img = torch.ones(1, C, base_height, base_width, device=tiles.device)
grid_img = grid_img * pad_value

n = 0
for y in range(y_rows):
for x in range(x_rows):
if n >= B:
break
y_idx = ((H + padding) * y) + padding
x_idx = ((W + padding) * x) + padding
grid_img[..., y_idx : y_idx + H, x_idx : x_idx + W] = tiles[n : n + 1]
n += 1
return grid_img


def show(
x: torch.Tensor, figsize: Optional[Tuple[int, int]] = None, scale: float = 255.0
x: torch.Tensor,
figsize: Optional[Tuple[int, int]] = None,
scale: float = 255.0,
images_per_row: Optional[int] = None,
padding: int = 2,
pad_value: float = 0.0,
) -> None:
"""
Show CHW & NCHW tensors as an image.

Args:

x (torch.Tensor): The tensor you want to display as an image.
figsize (Tuple[int, int], optional): height & width to use
for displaying the image figure.
scale (float): Value to multiply the input tensor by so that
it's value range is [0-255] for display.
images_per_row (int, optional): The number of images per row to use for the
grid image. Default is set to None for no grid image creation.
Default: None
padding (int, optional): The amount of padding between images in the grid
images. This parameter only has an effect if nrow is not None.
Default: 2
pad_value (float, optional): The value to use for the padding. This parameter
only has an effect if nrow is not None.
Default: 0.0
"""

if x.dim() not in [3, 4]:
raise ValueError(
f"Incompatible number of dimensions. x.dim() = {x.dim()}; should be 3 or 4."
)
x = torch.cat([t[0] for t in x.split(1)], dim=2) if x.dim() == 4 else x
if images_per_row is not None:
x = make_grid_image(
x, images_per_row=images_per_row, padding=padding, pad_value=pad_value
)[0, ...]
else:
x = torch.cat([t[0] for t in x.split(1)], dim=2) if x.dim() == 4 else x
x = x.clone().cpu().detach().permute(1, 2, 0) * scale
if figsize is not None:
plt.figure(figsize=figsize)
Expand All @@ -40,25 +115,52 @@ def show(
plt.show()


def save_tensor_as_image(x: torch.Tensor, filename: str, scale: float = 255.0) -> None:
def save_tensor_as_image(
x: torch.Tensor,
filename: str,
scale: float = 255.0,
mode: Optional[str] = None,
images_per_row: Optional[int] = None,
padding: int = 2,
pad_value: float = 0.0,
) -> None:
"""
Save RGB & RGBA image tensors with a shape of CHW or NCHW as images.

Args:

x (torch.Tensor): The tensor you want to save as an image.
filename (str): The filename to use when saving the image.
scale (float, optional): Value to multiply the input tensor by so that
it's value range is [0-255] for saving.
mode (str, optional): A PIL / Pillow supported colorspace. Default is
set to None for automatic RGB / RGBA detection and usage.
Default: None
images_per_row (int, optional): The number of images per row to use for the
grid image. Default is set to None for no grid image creation.
Default: None
padding (int, optional): The amount of padding between images in the grid
images. This parameter only has an effect if `nrow` is not None.
Default: 2
pad_value (float, optional): The value to use for the padding. This parameter
only has an effect if `nrow` is not None.
Default: 0.0
"""

if x.dim() not in [3, 4]:
raise ValueError(
f"Incompatible number of dimensions. x.dim() = {x.dim()}; should be 3 or 4."
)
x = x[0] if x.dim() == 4 else x
if images_per_row is not None:
x = make_grid_image(
x, images_per_row=images_per_row, padding=padding, pad_value=pad_value
)[0, ...]
else:
x = torch.cat([t[0] for t in x.split(1)], dim=2) if x.dim() == 4 else x
x = x.clone().cpu().detach().permute(1, 2, 0) * scale
colorspace = "RGB" if x.shape[2] == 3 else "RGBA"
im = Image.fromarray(x.numpy().astype(np.uint8), colorspace)
if mode is None:
mode = "RGB" if x.shape[2] == 3 else "RGBA"
im = Image.fromarray(x.numpy().astype(np.uint8), mode=mode)
im.save(filename)


Expand Down
Loading