Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion compressai/entropy_models/entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,15 @@ def _get_medians(self) -> Tensor:
medians = self.quantiles[:, :, 1:2]
return medians

def update(self, force: bool = False) -> bool:
def update(self, force: bool = False, update_quantiles: bool = True) -> bool:
# Check if we need to update the bottleneck parameters, the offsets are
# only computed and stored when the conditonal model is update()'d.
if self._offset.numel() > 0 and not force:
return False

if update_quantiles:
self._update_quantiles()

medians = self.quantiles[:, 0, 1]

minima = medians - self.quantiles[:, 0, 0]
Expand Down Expand Up @@ -521,6 +524,35 @@ def _build_indexes(size):
def _extend_ndims(tensor, n):
return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1)

@torch.no_grad()
def _update_quantiles(self, search_radius=1e5, rtol=1e-4, atol=1e-3):
device = self.quantiles.device
shape = (self.channels, 1, 1)
low = torch.full(shape, -search_radius, device=device)
high = torch.full(shape, search_radius, device=device)

def f(y, self=self):
return self._logits_cumulative(y, stop_gradient=True)

for i in range(len(self.target)):
q_i = self._search_target(f, self.target[i], low, high, rtol, atol)
self.quantiles[:, :, i] = q_i[:, :, 0]

@staticmethod
def _search_target(f, target, low, high, rtol=1e-4, atol=1e-3, strict=False):
assert (low <= high).all()
if strict:
assert ((f(low) <= target) & (target <= f(high))).all()
else:
low = torch.where(target <= f(high), low, high)
high = torch.where(f(low) <= target, high, low)
while not torch.isclose(low, high, rtol=rtol, atol=atol).all():
mid = (low + high) / 2
f_mid = f(mid)
low = torch.where(f_mid <= target, mid, low)
high = torch.where(f_mid >= target, mid, high)
return (low + high) / 2

def compress(self, x):
indexes = self._build_indexes(x.size())
medians = self._get_medians().detach()
Expand Down
10 changes: 7 additions & 3 deletions tests/test_entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,26 +261,30 @@ def test_compression_2D(self):
eb.update()
s = eb.compress(x)
x2 = eb.decompress(s, x.size()[2:])
means = eb._get_medians()

assert torch.allclose(torch.round(x), x2)
assert torch.allclose(torch.round(x - means) + means, x2)

def test_compression_ND(self):
eb = EntropyBottleneck(128)
eb.update()

# Test 0D
x = torch.rand(1, 128)
s = eb.compress(x)
x2 = eb.decompress(s, [])
means = eb._get_medians().reshape(128)

assert torch.allclose(torch.round(x), x2)
assert torch.allclose(torch.round(x - means) + means, x2)

# Test from 1 to 5 dimensions
for i in range(1, 6):
x = torch.rand(1, 128, *([4] * i))
s = eb.compress(x)
x2 = eb.decompress(s, x.size()[2:])
means = eb._get_medians().reshape(128, *([1] * i))

assert torch.allclose(torch.round(x), x2)
assert torch.allclose(torch.round(x - means) + means, x2)


class TestGaussianConditional:
Expand Down