Skip to content

Commit 50ddf91

Browse files
YodaEmbeddingfracape
authored andcommitted
fix: LatentCodec shapes, etc
1 parent 2758f12 commit 50ddf91

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

compressai/latent_codecs/gaussian_conditional.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Dict, List, Optional, Tuple
30+
from typing import Any, Dict, List, Optional, Tuple, Union
3131

3232
import torch.nn as nn
3333

@@ -81,6 +81,7 @@ class GaussianConditionalLatentCodec(LatentCodec):
8181

8282
def __init__(
8383
self,
84+
scale_table: Optional[Union[List, Tuple]] = None,
8485
gaussian_conditional: Optional[GaussianConditional] = None,
8586
entropy_parameters: Optional[nn.Module] = None,
8687
quantizer: str = "noise",
@@ -89,7 +90,7 @@ def __init__(
8990
super().__init__()
9091
self.quantizer = quantizer
9192
self.gaussian_conditional = gaussian_conditional or GaussianConditional(
92-
**kwargs
93+
scale_table, **kwargs
9394
)
9495
self.entropy_parameters = entropy_parameters or nn.Identity()
9596

@@ -109,7 +110,7 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
109110
y_hat = self.gaussian_conditional.decompress(
110111
y_strings, indexes, means=means_hat
111112
)
112-
return {"strings": [y_strings], "y_hat": y_hat}
113+
return {"strings": [y_strings], "shape": y.shape[2:4], "y_hat": y_hat}
113114

114115
def decompress(
115116
self, strings: List[List[bytes]], shape: Tuple[int, int], ctx_params: Tensor

compressai/latent_codecs/rasterscan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
126126
)
127127
for i in range(n)
128128
]
129-
return default_collate(ds)
129+
return {**default_collate(ds), "shape": y.shape[2:4]}
130130

131131
def _compress_single(self, **kwargs):
132132
encoder = BufferedRansEncoder()

0 commit comments

Comments
 (0)