2727# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
30- from typing import Any , Dict , List , Optional , Tuple
30+ from typing import Any , Dict , List , Optional , Tuple , Union
3131
3232import torch .nn as nn
3333
@@ -81,6 +81,7 @@ class GaussianConditionalLatentCodec(LatentCodec):
8181
8282 def __init__ (
8383 self ,
84+ scale_table : Optional [Union [List , Tuple ]] = None ,
8485 gaussian_conditional : Optional [GaussianConditional ] = None ,
8586 entropy_parameters : Optional [nn .Module ] = None ,
8687 quantizer : str = "noise" ,
@@ -89,7 +90,7 @@ def __init__(
8990 super ().__init__ ()
9091 self .quantizer = quantizer
9192 self .gaussian_conditional = gaussian_conditional or GaussianConditional (
92- ** kwargs
93+ scale_table , ** kwargs
9394 )
9495 self .entropy_parameters = entropy_parameters or nn .Identity ()
9596
@@ -109,7 +110,7 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
109110 y_hat = self .gaussian_conditional .decompress (
110111 y_strings , indexes , means = means_hat
111112 )
112- return {"strings" : [y_strings ], "y_hat" : y_hat }
113+ return {"strings" : [y_strings ], "shape" : y . shape [ 2 : 4 ], " y_hat" : y_hat }
113114
114115 def decompress (
115116 self , strings : List [List [bytes ]], shape : Tuple [int , int ], ctx_params : Tensor
0 commit comments