Skip to content

Conversation

@YodaEmbedding
Copy link
Contributor

@YodaEmbedding YodaEmbedding commented Apr 29, 2023

This method completes in <1 second and reduces aux_loss to <0.01.

This makes the aux_loss optimization during training unnecessary.

Another alternative would be to run the following post-training:

while aux_loss > 0.1:
    aux_loss = model.aux_loss()
    aux_loss.backward()
    aux_optimizer.step()
    aux_optimizer.zero_grad()

...but since we do not manage aux_loss learning rates, the bisection search method might converge better.

Note: I haven't yet tested this extensively, so there might be some rough edges.

@YodaEmbedding
Copy link
Contributor Author

YodaEmbedding commented Sep 25, 2023

Why are the automated tests failing?

Previously, update() did not modify the default quantiles, which are initialized to [-10, 0, 10]. With this PR's modification to update(), it automatically updates the quantiles, too. But then, the initial quantiles become e.g. [-216, -0.4, 207], so that median=-0.4, and thus y = -0.6 is quantized to y_hat = -0.4. However, the tests were written with the assumption that median=0, and thus they expect y_hat = -1.

    def test_compression_2D(self):
        x = torch.rand(1, 128, 32, 32)
        eb = EntropyBottleneck(128)
        eb.update()
        s = eb.compress(x)
        x2 = eb.decompress(s, x.size()[2:])

        # THE torch.round(x) ASSERT ASSUMES THAT:
        # assert (eb.quantiles[..., 1] == 0).all()

        assert torch.allclose(torch.round(x), x2)

Possible fixes:

  1. Patch the tests to disable auto-updating the quantiles via eb.update(update_quantiles=False). (The "easy" way.)
  2. Fix the tests to use torch.round(x - means) + means instead. (The "correct" way.)

I went with option 2.

This method completes in <1 second and reduces aux_loss to <0.01.

This makes the aux_loss optimization during training unnecessary.

Another alternative would be to run the following post-training:
```python
while aux_loss > 0.1:
    aux_loss = model.aux_loss()
    aux_loss.backward()
    aux_optimizer.step()
    aux_optimizer.zero_grad()
```
...but since we do not manage aux_loss learning rates,
the bisection search method might converge better.
@fracape fracape merged commit 4ccc68e into InterDigitalInc:master Feb 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants