Skip to content

Commit a92c165

Browse files
committed
Add helper functions: _plot_images and _sample_image
1 parent 7f43913 commit a92c165

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

torchvision/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import torch
21
import math
2+
import numpy as np
3+
import torch
4+
35
irange = range
46

57

@@ -107,3 +109,25 @@ def save_image(tensor, fp, nrow=8, padding=2,
107109
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
108110
im = Image.fromarray(ndarr)
109111
im.save(fp, format=format)
112+
113+
114+
def _sample_image():
115+
from PIL import Image
116+
# TODO properly load the image
117+
return Image.open("/tmp/grace_hopper_517x606.jpg")
118+
119+
120+
def _plot_images(*imgs):
121+
import matplotlib.pyplot as plt
122+
import matplotlib
123+
124+
n = len(imgs)
125+
fig, axes = plt.subplots(1, n, figsize=(n * 2, 2))
126+
if isinstance(axes, matplotlib.axes.Axes):
127+
axes = np.array(axes)
128+
129+
for img, ax in zip(imgs, axes.flat):
130+
ax.imshow(img)
131+
ax.axis("off")
132+
133+
return fig

0 commit comments

Comments
 (0)