Skip to content

Different bitrates computed using the saved file size and computed using the likelihoods from the entropy model #236

@Reza-Asiyabi

Description

@Reza-Asiyabi

I'm using the EntropyBottleneck module in a factorized prior network. While training, the aux_loss and bpp_loss are decreasing but when I use the trained network to compress and save the compressed images, the bitrate computed using the saved file size is much larger than the bitrate computed using the likelihoods from the entropy model (about 2.7 bpp in comparison to 1.1 bpp).
The network architecture is below:

import torch
import torch.nn as nn
from compressai.compressai.entropy_models import EntropyBottleneck
from compressai.compressai.models import CompressionModel
from compressai.compressai.layers import GDN
    
class FactorizedPrior_Net1(CompressionModel):
    def __init__(self, in_channels=1, out_channels=1, layer_num=3, intermediate_channels1=128, intermediate_channels2=192):
        super().__init__()
        self.layer_num = layer_num
        N = intermediate_channels1
        M = intermediate_channels2

        Encoder_layers = []
        for i in range(self.layer_num - 2):
            Encoder_layers.append( nn.Conv2d(in_channels=N, out_channels=N, kernel_size=5, stride=2, padding=2))
            Encoder_layers.append(GDN(N))

        self.Encoder = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=N, kernel_size=5, stride=2, padding=2),
            GDN(N),
            *Encoder_layers,
            nn.Conv2d(in_channels=N, out_channels=M, kernel_size=5, stride=2, padding=2),
        )

        self.entropy_bottleneck = EntropyBottleneck(channels=M)

        Decoder_layers = []
        for i in range(self.layer_num - 2):
            Decoder_layers.append(nn.ConvTranspose2d(in_channels=N, out_channels=N, kernel_size=5, stride=2, output_padding=1, padding=2))
            Decoder_layers.append(GDN(N, inverse=True))

        self.Decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=M, out_channels=N, kernel_size=5, stride=2, output_padding=1, padding=2),
            GDN(N, inverse=True),
            *Decoder_layers,
            nn.ConvTranspose2d(in_channels=N, out_channels=out_channels, kernel_size=5, stride=2, output_padding=1, padding=2),
        )

    def forward(self, x):
        y = self.Encoder(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.Decoder(y_hat)

        return {"x_hat": x_hat, "likelihoods": {'y': y_likelihoods}}

    def compress(self, x):
        y = self.Encoder(x)
        y_strings = self.entropy_bottleneck.compress(y)
        return {"EB_strings": y_strings, "EB_shape": y.size()[-2:]}

    def decompress(self, strings, shape):
        assert isinstance(strings, list) and len(strings) == 1
        y_hat = self.entropy_bottleneck.decompress(strings[0], shape)
        x_hat = self.Decoder(y_hat)
        return {"rec_image": x_hat}

This is the code for training the network:

import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
from datetime import datetime
from tqdm import tqdm
import logging
import torch.optim.lr_scheduler as lr_scheduler
from torchvision.utils import make_grid
from pytorch_msssim import ms_ssim
from FFT_Exercise_Paper.ComplexValuedNN.SAR_Data_Compression.RV_CAE.Models import FactorizedPrior_Net1
from compressai.compressai.losses import RateDistortionLoss
from compressai.compressai.optimizers import net_aux_optimizer
from compressai.compressai.models.utils import update_registered_buffers


##### Parameters
in_channels = 1
out_channels = 1
intermediate_channels1 = 64
intermediate_channels2 = 128
net1_layers_num = 4 #should be more than 2

test_freq = 50
save_freq = 1
total_step = 0
epoch_step = 0
best_testSQNR = 1

total_epochs = 5
train_batch_size = 8
test_batch_size = 1
lr = 0.0001
aux_lr = 0.001
use_lr_scheduler = True
if use_lr_scheduler:
    lr_scheduler_step = 1
clip_max_norm = 1
loss_lmbda = 2000
distortion_metric = "mse"
pretrained = False
if pretrained:
    checkpoint_path = 'logs/2023-05-08_10_36_39/Models/Model_epoch_31.pth.tar'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

description = 'Net1_Normal'
save_path = os.path.join('D:/Reza/My Python Projects/New Project/FFT_Exercise_Paper/ComplexValuedNN/SAR_Data_Compression/RV_CAE/Logs/', '{}_Layers {}_{}_{}'.format(datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),intermediate_channels1, intermediate_channels2, description))
os.makedirs(os.path.join(save_path, 'Models/'), exist_ok=True)

class AverageMeter:
    """Compute running average."""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def configure_optimizers(net, lr, aux_lr):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""
    conf = {
        "net": {"type": "Adam", "lr": lr},
        "aux": {"type": "Adam", "lr": aux_lr},
    }
    optimizer = net_aux_optimizer(net, conf)
    return optimizer["net"], optimizer["aux"]

def test(net1, test_dataloader, device, criterion, distortion_metric, epoch, total_step, best_testSQNR):
    net1.eval()

    loss = AverageMeter()
    SQNR = AverageMeter()
    bpp_loss = AverageMeter()
    distortion_loss = AverageMeter()
    aux_loss = AverageMeter()

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_dataloader):

            input = batch[:, :, 0, :, :].to(device)
            output = net1(input)

            out_criterion = criterion(output, input)
            SQNR_temp = torch.sum(torch.square(abs(input)))/torch.sum(torch.square(abs(input - output['x_hat'])))

            loss.update(out_criterion['loss'])
            SQNR.update(10*torch.log10(SQNR_temp))
            bpp_loss.update(out_criterion['bpp_loss'])
            if distortion_metric == 'mse':
                distortion_loss.update(out_criterion['mse_loss'])
            elif distortion_metric == 'ms_ssim':
                distortion_loss.update(out_criterion['ms_ssim_loss'])
            aux_loss.update(net1.entropy_bottleneck.loss())

            if batch_idx%1==0:
                try:
                    grid_imgs = torch.concat((grid_imgs, torch.abs(input), torch.abs(output['x_hat'])), dim=0)
                except:
                    grid_imgs = torch.concat((torch.abs(input), torch.abs(output['x_hat'])), dim=0)

        grid = make_grid(grid_imgs, nrow=4)
        logger.info("Test on test dataset: epoch-{}, step-{}".format(epoch, total_step))

        if event_writer !=None:
            logger.info("Add tensorboard for test dataset---epoch:{}-Step:{}".format(epoch, total_step))
            if distortion_metric == "mse":
                event_writer.add_scalar("Test MSE_avg", distortion_loss.avg.item(), total_step)
            elif distortion_metric == "ms_ssim":
                event_writer.add_scalar("Test MSE_avg", distortion_loss.avg.item(), total_step)
            event_writer.add_scalar("Test SQNR_avg (dB)", SQNR.avg.item(), total_step)
            event_writer.add_scalar("Test bpp_loss_avg (dB)", bpp_loss.avg.item(), total_step)
            event_writer.add_scalar("Test loss_avg (dB)", loss.avg.item(), total_step)
            event_writer.add_scalar("Test aux_loss_avg (dB)", aux_loss.avg.item(), total_step)
            event_writer.add_image('test example', grid, global_step=total_step)

    if SQNR.avg.item() > best_testSQNR:
        best_testSQNR = SQNR.avg.item()
        net1.entropy_bottleneck.update()
        ceckpoint_dict = {
            "epoch": epoch,
            "state_dict": net1.state_dict(),
            "optimizer": optimizer.state_dict(),
            "aux_optimizer": aux_optimizer.state_dict(),
        }
        torch.save(ceckpoint_dict, os.path.join(save_path, 'Models/Model_best_testSQNR.pth.tar'))
    net1.train()
    return best_testSQNR


def train(net1, train_dataloader, test_dataloader, epoch, device, criterion, optimizer, aux_optimizer, lr_scheduler, distortion_metric, train_batch_size=8, test_freq=10, total_step=0, total_epochs=50, best_testSQNR=1):
    epoch_step = 0
    net1.train()

    with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{total_epochs}', unit='batch') as pbar:
        for batch_idx, batch in enumerate(train_dataloader):

            input = batch[:, :, 0, :, :].to(device)
            output = net1(input)

            out_criterion = criterion(output, input)

            pbar.set_postfix(**{'train loss (batch)': out_criterion['loss'].item()})

            optimizer.zero_grad()
            aux_optimizer.zero_grad()
            out_criterion['loss'].backward()
            if clip_max_norm > 0:
                torch.nn.utils.clip_grad_norm_(net1.parameters(), clip_max_norm)
            optimizer.step()

            aux_loss = net1.entropy_bottleneck.loss()
            aux_loss.backward()
            aux_optimizer.step()

            if distortion_metric == "mse":
                train_batch_sumDistorion = torch.sum(out_criterion['mse_loss'])
            elif distortion_metric == "ms_ssim":
                train_batch_sumDistorion = torch.sum(out_criterion['ms_ssim_loss'])
            train_batch_sumbpp = torch.sum(out_criterion['bpp_loss'])

            if event_writer !=None:
                logger.info("Add tensorboard for Train batch---epoch:{}-Step:{}".format(epoch, total_step))
                if distortion_metric == "mse":
                    event_writer.add_scalar("Train MSE_avg", train_batch_sumDistorion/train_batch_size, total_step)
                elif distortion_metric == "ms_ssim":
                    event_writer.add_scalar("Train MS-SSIM_avg", train_batch_sumDistorion/train_batch_size, total_step)
                event_writer.add_scalar("Train bpp_loss_avg", train_batch_sumbpp/train_batch_size, total_step)
                event_writer.add_scalar("Train lr", optimizer.param_groups[0]["lr"], total_step)

            if (total_step % test_freq) == 0:
                best_testSQNR = test(net1=net1,
                                     test_dataloader=test_dataloader,
                                     device=device,
                                     criterion=criterion,
                                     distortion_metric=distortion_metric,
                                     epoch=epoch,
                                     total_step=total_step,
                                     best_testSQNR=best_testSQNR)

            epoch_step += 1
            total_step += 1
            pbar.update(1)
        try:
            lr_scheduler.step()
        except:
            pass
    if (epoch % save_freq) == 0:
        net1.entropy_bottleneck.update()
        ceckpoint_dict = {
            "epoch": epoch,
            "state_dict": net1.state_dict(),
            "optimizer": optimizer.state_dict(),
            "aux_optimizer": aux_optimizer.state_dict(),
        }
        torch.save(ceckpoint_dict, os.path.join(save_path, 'Models/Model_epoch_{}.pth.tar'.format(epoch)))

    return total_step, best_testSQNR


logger = logging.getLogger("Real-Valued autoencoder to compress and decompress the Imaginary component SAR data with Entropy Modeling")
event_writer = SummaryWriter(os.path.join(save_path, 'tb_logs/'))
summary_text = "Compression of the Imaginary component of the SLC data\n" \
               "Parameters:\n" \
               "in_channels = {}\n" \
               "out_channels = {}\n" \
               "intermediate_channels1 = {}\n" \
               "intermediate_channels2 = {}\n" \
               "test_freq = {}\n" \
               "save_freq = {}\n" \
               "total_epochs = {}\n" \
               "train_batch_size = {}\n" \
               "test_batch_size = {}\n" \
               "lr = {}\n" \
               "aux_lr = {}\n" \
               "use_lr_scheduler = {}\n" \
               "clip_max_norm = {}\n" \
               "loss_lmbda = {}\n" \
               "distortion_metric = {}\n" \
               "pretrained = {}\n"\
    .format(in_channels, out_channels, intermediate_channels1, intermediate_channels2, test_freq, save_freq,
            total_epochs, train_batch_size, test_batch_size, lr, aux_lr, use_lr_scheduler, clip_max_norm, loss_lmbda, distortion_metric, pretrained)
event_writer.add_text("summary", summary_text)

net1 = FactorizedPrior_Net1(in_channels=1, out_channels=1, layer_num=net1_layers_num, intermediate_channels1=intermediate_channels1, intermediate_channels2=intermediate_channels2).to(device)
optimizer, aux_optimizer = configure_optimizers(net=net1, lr=lr, aux_lr=aux_lr)
if use_lr_scheduler:
    lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_scheduler_step, gamma=0.5)
else:
    lr_scheduler = None
criterion = RateDistortionLoss(lmbda=loss_lmbda, metric=distortion_metric)


if pretrained:  # load from previous checkpoint
    print("Loading pretrained states from checkpoint")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    update_registered_buffers(net1.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], checkpoint["state_dict"])
    epoch = checkpoint["epoch"] + 1
    net1.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])


dataset = np.load("D:/Reza/Data/dataset.npy")
trainset = dataset[200:]
testset = dataset[:200]
trainset = np.expand_dims(trainset, axis=1)
testset = np.expand_dims(testset, axis=1)
train_dataloader = DataLoader(dataset=trainset,
                              batch_size=train_batch_size,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=0)
test_dataloader = DataLoader(dataset=testset,
                             batch_size=test_batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=0)

for epoch in range(total_epochs):
    total_step, best_testSQNR = train(net1=net1,
                                      train_dataloader=train_dataloader,
                                      test_dataloader=test_dataloader,
                                      epoch=epoch,
                                      device=device,
                                      criterion=criterion,
                                      optimizer=optimizer,
                                      aux_optimizer=aux_optimizer,
                                      lr_scheduler=lr_scheduler,
                                      distortion_metric=distortion_metric,
                                      train_batch_size=train_batch_size,
                                      test_freq=test_freq,
                                      total_step=total_step,
                                      total_epochs=total_epochs,
                                      best_testSQNR=best_testSQNR)
net1.entropy_bottleneck.update()
ceckpoint_dict = {
    "epoch": epoch,
    "state_dict": net1.state_dict(),
    "optimizer": optimizer.state_dict(),
    "aux_optimizer": aux_optimizer.state_dict(),
}
torch.save(ceckpoint_dict, os.path.join(save_path, 'Models/Model_Final.pth.tar'))

The following is the code for compressing and decompressing using the trained network:

import torch
import numpy as np
import os
from FFT_Exercise_Paper.ComplexValuedNN.SAR_Data_Compression.RV_CAE.Models import FactorizedPrior_Net1
from compressai.compressai.models.utils import update_registered_buffers
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'
in_channels = 1
out_channels = 1
intermediate_channels1 = 64
intermediate_channels2 = 128
net1_layers_num = 4
bottleneck_size = (16, 16)
Checkpoint_path = "D:/Reza/My Python Projects/New Project/FFT_Exercise_Paper/ComplexValuedNN/SAR_Data_Compression/RV_CAE/Logs/2023-06-16_10_31_16_Layers 64_128_Net1_Natural_Normal/Models/Model_best_testSQNR.pth.tar"

############################################################ Compress
patch = np.real(np.load("D:/Reza/Data/test_data"))[0] #select only one patch

net1 = FactorizedPrior_Net1(in_channels=in_channels, out_channels=out_channels, layer_num=net1_layers_num, intermediate_channels1=intermediate_channels1, intermediate_channels2=intermediate_channels2).to(device)

print("Loading pretrained states from checkpoint")
checkpoint = torch.load(Checkpoint_path, map_location=device)
update_registered_buffers(net1.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], checkpoint["state_dict"])
epoch = checkpoint["epoch"] + 1
net1.load_state_dict(checkpoint["state_dict"])
net1.eval()

###### Compress and save
compressed_outp = net1.compress(torch.tensor(patch).to(device))
with open('Compressed_strings/compressed_patch.bin', 'wb') as f:
    f.write(compressed_outp["EB_strings"][0])

############################################################ Decompress
string_path = "Compressed_strings/compressed_patch.bin"

net1 = FactorizedPrior_Net1(in_channels=in_channels, out_channels=out_channels, layer_num=net1_layers_num, intermediate_channels1=intermediate_channels1, intermediate_channels2=intermediate_channels2).to(device)

print("Loading pretrained states from checkpoint")
checkpoint = torch.load(Checkpoint_path, map_location=device)
update_registered_buffers(net1.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], checkpoint["state_dict"])
epoch = checkpoint["epoch"] + 1
net1.load_state_dict(checkpoint["state_dict"])
net1.eval()


###### Load and Decompress
with open(string_path, 'rb') as f:
    string_temp = f.read()
    # Decompress the data
    decompressed_outp = net1.decompress(strings=[[string_temp]], shape=bottleneck_size)
    decompressed_patch = np.array(np.squeeze(decompressed_outp['rec_image'].cpu().detach().numpy()))

########################################################## Metrics
#compute bitrate using the saved file size
avg_bpp = (os.path.getsize('Compressed_strings/compressed_patch.bin')*8) /(256*256)

#compute bitrate using the likelihoods from the entropy model
def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()
outp = net1.forward(torch.tensor(patch).to(device))
print(f'Bit-rate: {compute_bpp(outp):.3f} bpp')

Is there any mistake that I'm making while training or saving the network?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions