Skip to content

Commit 3baa06a

Browse files
committed
feat: fast EntropyBottleneck aux_loss minimization via bisection search
This method completes in <1 second and reduces aux_loss to <0.01. This makes the aux_loss optimization during training unnecessary. Another alternative would be to run the following post-training: ```python while aux_loss > 0.1: aux_loss = model.aux_loss() aux_loss.backward() aux_optimizer.step() aux_optimizer.zero_grad() ``` ...but since we do not manage aux_loss learning rates, the bisection search method might converge better.
1 parent b10cc7c commit 3baa06a

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

compressai/entropy_models/entropy_models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ def update(self, force: bool = False) -> bool:
392392
if self._offset.numel() > 0 and not force:
393393
return False
394394

395+
self._update_quantiles()
396+
395397
medians = self.quantiles[:, 0, 1]
396398

397399
minima = medians - self.quantiles[:, 0, 0]
@@ -521,6 +523,31 @@ def _build_indexes(size):
521523
def _extend_ndims(tensor, n):
522524
return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1)
523525

526+
@torch.no_grad()
527+
def _update_quantiles(self):
528+
device = self.quantiles.device
529+
shape = (self.channels, 1, 1)
530+
low = torch.full(shape, -1e9, device=device)
531+
high = torch.full(shape, 1e9, device=device)
532+
533+
def f(y, self=self):
534+
return self._logits_cumulative(y, stop_gradient=True)
535+
536+
for i in range(len(self.target)):
537+
q_i = self._search_target(f, self.target[i], low, high)
538+
self.quantiles[:, :, i] = q_i[:, :, 0]
539+
540+
@staticmethod
541+
def _search_target(f, target, low, high):
542+
assert (low <= high).all()
543+
assert ((f(low) <= target) & (target <= f(high))).all()
544+
while not torch.isclose(low, high).all():
545+
mid = (low + high) / 2
546+
f_mid = f(mid)
547+
low = torch.where(f_mid <= target, mid, low)
548+
high = torch.where(f_mid >= target, mid, high)
549+
return (low + high) / 2
550+
524551
def compress(self, x):
525552
indexes = self._build_indexes(x.size())
526553
medians = self._get_medians().detach()

0 commit comments

Comments
 (0)