-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
I am trying to use Tensorboard and log images generated by a GAN at the end of each epoch. To do that, I set up a callback that automatically logs an image into tensorboard. Tensorboard is showing the image that was generated after the first epoch only. In addition, I am also trying to log some losses such as gradient penalty and minmax loss and it seems like tensorboard is showing only part of the losses that I am trying to log. I am not sure if I am doing something wrong or if there's a bug somewhere within lightning.
The image saving callback code:
class GenerateCollage(Callback):
def __init__(self, config):
super().__init__()
test_set = ds_fac[config['test_dataset_cfg']['type']](config['test_dataset_cfg'])
self.test_data_loader = DataLoader(test_set,
batch_size=config['test_dataloader_cfg']['batch_size'],
shuffle=config['test_dataloader_cfg']['shuffle'],
num_workers=config['test_dataloader_cfg']['num_workers'],
drop_last=config['test_dataloader_cfg']['drop_last'],
pin_memory=config['test_dataloader_cfg']['pin_memory'])
self.test_every = config['test_every']
def on_train_epoch_end(self, trainer, pl_module, outputs):
pl_module.gen.eval()
test_batch = next(iter(self.test_data_loader))
real = test_batch['real'].to(pl_module.device)
noisy = test_batch['noisy'].to(pl_module.device)
collage_size = 10
with torch.no_grad():
g_out = pl_module.gen(y=noisy, z=False, encoder_assistance=True)
if type(g_out) is list:
g_out = g_out[0]
g_out = g_out.to(torch.device("cpu"))
all_imgs_to_show = [utils.tensor2uint(elem.to(torch.device("cpu"))) for elem in real]
all_imgs_to_show.extend([utils.tensor2uint(elem.to(torch.device("cpu"))) for elem in noisy])
all_imgs_to_show.extend([utils.tensor2uint(elem) for elem in g_out])
for i in range(collage_size):
with torch.no_grad():
out = pl_module.gen(y=noisy, z=True, encoder_assistance=True)
if type(out) is list:
out = out[0]
out = out.to(torch.device("cpu"))
all_imgs_to_show.extend([utils.tensor2uint(elem) for elem in out])
titles = ["Clean", "Noisy", "Denoised"] + ["Generated"] * collage_size
num_cols = len(titles)
batch_size = noisy.shape[0]
fig = plt.figure(figsize=(num_cols * 2, batch_size * 2))
grid = ImageGrid(fig, 111, nrows_ncols=(batch_size, num_cols), axes_pad=0.02, direction="column")
for i, (ax, im) in enumerate(zip(grid, all_imgs_to_show)):
# Iterating over the grid returns the Axes.
ax.imshow(im)
ax.set_axis_off()
for ax, title in zip(grid.axes_row[0], titles):
ax.set_title(title, size='xx-large')
fig.set_tight_layout(True)
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
image = PIL.Image.open(buf)
image = ToTensor()(image)
buf.close()
pl_module.logger.experiment.add_image('generated_collage', image, trainer.batch_idx)
pl_module.gen.train()
The way I am returning the losses, for example from a function that calculates the loss of the discriminator (look at the call function):
`class DiscWGAN(pl.LightningModule):
def init(self, config, gen, disc):
super().init()
self.gp_reg = config['gp_reg']
self.gen = gen
self.disc = disc
def __call__(self, real, gen_input, step):
with torch.no_grad():
fake = self.gen(y=gen_input, z=True, encoder_assistance=True)
if type(fake) is list:
fake = fake[0]
batch_size = real.shape[0]
num_channels = real.shape[1]
patch_h = real.shape[2]
patch_w = real.shape[3]
# Gradient penalty calculation
alpha = torch.rand((batch_size, 1), device=self.device)
alpha = alpha.expand(batch_size, int(real.nelement() / batch_size)).contiguous()
alpha = alpha.view(batch_size, num_channels, patch_h, patch_w)
interpolates = alpha * real.detach() + (1 - alpha) * fake.detach()
interpolates.requires_grad_(True)
disc_interpolates = self.disc(x=interpolates)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size(), device=self.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradients_norm = gradients.norm(2, dim=1)
gp = ((gradients_norm - 1) ** 2).mean()
d_out_fake_mean = torch.mean(self.disc(x=fake))
d_out_real_mean = torch.mean(self.disc(x=real))
minmax_loss = d_out_fake_mean - d_out_real_mean
return OrderedDict({"log": {"disc_gp": gp, "disc_minmax": minmax_loss},
"loss": minmax_loss + self.gp_reg * gp})`
I tried to return a regular dictionary, to use self.log instead, to use self.logger.experiment.add_scalar instead, etc...
Also I noticed that lightning modules that are nested inside another lightning modules do not maintain the same logger, which might be a good feature to add (for example, in my code I build a loss function as a lightning module and initialize it within another module).
I am running CUDA 10.2 with tensorboard 2.3
Thanks!