diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index 038ba21d..e14827fb 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -386,12 +386,15 @@ def _get_medians(self) -> Tensor: medians = self.quantiles[:, :, 1:2] return medians - def update(self, force: bool = False) -> bool: + def update(self, force: bool = False, update_quantiles: bool = True) -> bool: # Check if we need to update the bottleneck parameters, the offsets are # only computed and stored when the conditonal model is update()'d. if self._offset.numel() > 0 and not force: return False + if update_quantiles: + self._update_quantiles() + medians = self.quantiles[:, 0, 1] minima = medians - self.quantiles[:, 0, 0] @@ -521,6 +524,35 @@ def _build_indexes(size): def _extend_ndims(tensor, n): return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1) + @torch.no_grad() + def _update_quantiles(self, search_radius=1e5, rtol=1e-4, atol=1e-3): + device = self.quantiles.device + shape = (self.channels, 1, 1) + low = torch.full(shape, -search_radius, device=device) + high = torch.full(shape, search_radius, device=device) + + def f(y, self=self): + return self._logits_cumulative(y, stop_gradient=True) + + for i in range(len(self.target)): + q_i = self._search_target(f, self.target[i], low, high, rtol, atol) + self.quantiles[:, :, i] = q_i[:, :, 0] + + @staticmethod + def _search_target(f, target, low, high, rtol=1e-4, atol=1e-3, strict=False): + assert (low <= high).all() + if strict: + assert ((f(low) <= target) & (target <= f(high))).all() + else: + low = torch.where(target <= f(high), low, high) + high = torch.where(f(low) <= target, high, low) + while not torch.isclose(low, high, rtol=rtol, atol=atol).all(): + mid = (low + high) / 2 + f_mid = f(mid) + low = torch.where(f_mid <= target, mid, low) + high = torch.where(f_mid >= target, mid, high) + return (low + high) / 2 + def compress(self, x): indexes = self._build_indexes(x.size()) medians = self._get_medians().detach() diff --git a/tests/test_entropy_models.py b/tests/test_entropy_models.py index 00c83287..2dbfc1c4 100644 --- a/tests/test_entropy_models.py +++ b/tests/test_entropy_models.py @@ -261,26 +261,30 @@ def test_compression_2D(self): eb.update() s = eb.compress(x) x2 = eb.decompress(s, x.size()[2:]) + means = eb._get_medians() - assert torch.allclose(torch.round(x), x2) + assert torch.allclose(torch.round(x - means) + means, x2) def test_compression_ND(self): eb = EntropyBottleneck(128) eb.update() + # Test 0D x = torch.rand(1, 128) s = eb.compress(x) x2 = eb.decompress(s, []) + means = eb._get_medians().reshape(128) - assert torch.allclose(torch.round(x), x2) + assert torch.allclose(torch.round(x - means) + means, x2) # Test from 1 to 5 dimensions for i in range(1, 6): x = torch.rand(1, 128, *([4] * i)) s = eb.compress(x) x2 = eb.decompress(s, x.size()[2:]) + means = eb._get_medians().reshape(128, *([1] * i)) - assert torch.allclose(torch.round(x), x2) + assert torch.allclose(torch.round(x - means) + means, x2) class TestGaussianConditional: