|
1 | | -from typing import Optional, Tuple |
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional, Tuple, Union |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | import torch |
5 | 6 | import torch.nn as nn |
6 | 7 |
|
7 | 8 | from ..configuration_utils import ConfigMixin, register_to_config |
8 | 9 | from ..modeling_utils import ModelMixin |
| 10 | +from ..utils import BaseOutput |
9 | 11 | from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block |
10 | 12 |
|
11 | 13 |
|
| 14 | +@dataclass |
| 15 | +class DecoderOutput(BaseOutput): |
| 16 | + """ |
| 17 | + Output of decoding method. |
| 18 | +
|
| 19 | + Args: |
| 20 | + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
| 21 | + Decoded output sample of the model. Output of the last layer of the model. |
| 22 | + """ |
| 23 | + |
| 24 | + sample: torch.FloatTensor |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class VQEncoderOutput(BaseOutput): |
| 29 | + """ |
| 30 | + Output of VQModel encoding method. |
| 31 | +
|
| 32 | + Args: |
| 33 | + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
| 34 | + Encoded output sample of the model. Output of the last layer of the model. |
| 35 | + """ |
| 36 | + |
| 37 | + latents: torch.FloatTensor |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class AutoencoderKLOutput(BaseOutput): |
| 42 | + """ |
| 43 | + Output of AutoencoderKL encoding method. |
| 44 | +
|
| 45 | + Args: |
| 46 | + latent_dist (`DiagonalGaussianDistribution`): |
| 47 | + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. |
| 48 | + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. |
| 49 | + """ |
| 50 | + |
| 51 | + latent_dist: "DiagonalGaussianDistribution" |
| 52 | + |
| 53 | + |
12 | 54 | class Encoder(nn.Module): |
13 | 55 | def __init__( |
14 | 56 | self, |
@@ -369,26 +411,40 @@ def __init__( |
369 | 411 | act_fn=act_fn, |
370 | 412 | ) |
371 | 413 |
|
372 | | - def encode(self, x): |
| 414 | + def encode(self, x, return_dict: bool = True): |
373 | 415 | h = self.encoder(x) |
374 | 416 | h = self.quant_conv(h) |
375 | | - return h |
376 | 417 |
|
377 | | - def decode(self, h, force_not_quantize=False): |
| 418 | + if not return_dict: |
| 419 | + return (h,) |
| 420 | + |
| 421 | + return VQEncoderOutput(latents=h) |
| 422 | + |
| 423 | + def decode( |
| 424 | + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True |
| 425 | + ) -> Union[DecoderOutput, torch.FloatTensor]: |
378 | 426 | # also go through quantization layer |
379 | 427 | if not force_not_quantize: |
380 | 428 | quant, emb_loss, info = self.quantize(h) |
381 | 429 | else: |
382 | 430 | quant = h |
383 | 431 | quant = self.post_quant_conv(quant) |
384 | 432 | dec = self.decoder(quant) |
385 | | - return dec |
386 | 433 |
|
387 | | - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: |
| 434 | + if not return_dict: |
| 435 | + return (dec,) |
| 436 | + |
| 437 | + return DecoderOutput(sample=dec) |
| 438 | + |
| 439 | + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: |
388 | 440 | x = sample |
389 | | - h = self.encode(x) |
390 | | - dec = self.decode(h) |
391 | | - return dec |
| 441 | + h = self.encode(x).latents |
| 442 | + dec = self.decode(h).sample |
| 443 | + |
| 444 | + if not return_dict: |
| 445 | + return (dec,) |
| 446 | + |
| 447 | + return DecoderOutput(sample=dec) |
392 | 448 |
|
393 | 449 |
|
394 | 450 | class AutoencoderKL(ModelMixin, ConfigMixin): |
@@ -431,23 +487,37 @@ def __init__( |
431 | 487 | self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) |
432 | 488 | self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) |
433 | 489 |
|
434 | | - def encode(self, x): |
| 490 | + def encode(self, x, return_dict: bool = True): |
435 | 491 | h = self.encoder(x) |
436 | 492 | moments = self.quant_conv(h) |
437 | 493 | posterior = DiagonalGaussianDistribution(moments) |
438 | | - return posterior |
439 | 494 |
|
440 | | - def decode(self, z): |
| 495 | + if not return_dict: |
| 496 | + return (posterior,) |
| 497 | + |
| 498 | + return AutoencoderKLOutput(latent_dist=posterior) |
| 499 | + |
| 500 | + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: |
441 | 501 | z = self.post_quant_conv(z) |
442 | 502 | dec = self.decoder(z) |
443 | | - return dec |
444 | 503 |
|
445 | | - def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor: |
| 504 | + if not return_dict: |
| 505 | + return (dec,) |
| 506 | + |
| 507 | + return DecoderOutput(sample=dec) |
| 508 | + |
| 509 | + def forward( |
| 510 | + self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True |
| 511 | + ) -> Union[DecoderOutput, torch.FloatTensor]: |
446 | 512 | x = sample |
447 | | - posterior = self.encode(x) |
| 513 | + posterior = self.encode(x).latent_dist |
448 | 514 | if sample_posterior: |
449 | 515 | z = posterior.sample() |
450 | 516 | else: |
451 | 517 | z = posterior.mode() |
452 | | - dec = self.decode(z) |
453 | | - return dec |
| 518 | + dec = self.decode(z).sample |
| 519 | + |
| 520 | + if not return_dict: |
| 521 | + return (dec,) |
| 522 | + |
| 523 | + return DecoderOutput(sample=dec) |
0 commit comments