1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Optional
14+ from typing import Any , Dict , Optional , Tuple
1515
1616import numpy as np
1717import torch
18- from torch import nn
18+ from torch import FloatTensor , nn
1919
2020from .attention import AdaGroupNorm , AttentionBlock
2121from .attention_processor import Attention , AttnAddedKVProcessor
2222from .dual_transformer_2d import DualTransformer2DModel
2323from .resnet import Downsample2D , FirDownsample2D , FirUpsample2D , KDownsample2D , KUpsample2D , ResnetBlock2D , Upsample2D
24- from .transformer_2d import Transformer2DModel
24+ from .transformer_2d import Transformer2DModel , Transformer2DModelOutput
2525
2626
2727def get_down_block (
@@ -533,15 +533,24 @@ def __init__(
533533 self .resnets = nn .ModuleList (resnets )
534534
535535 def forward (
536- self , hidden_states , temb = None , encoder_hidden_states = None , attention_mask = None , cross_attention_kwargs = None
537- ):
536+ self ,
537+ hidden_states : FloatTensor ,
538+ temb : Optional [FloatTensor ] = None ,
539+ encoder_hidden_states : Optional [FloatTensor ] = None ,
540+ encoder_attention_mask : Optional [FloatTensor ] = None ,
541+ attention_mask : Optional [FloatTensor ] = None ,
542+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
543+ ) -> FloatTensor :
538544 hidden_states = self .resnets [0 ](hidden_states , temb )
539545 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
540- hidden_states = attn (
546+ output : Transformer2DModelOutput = attn (
541547 hidden_states ,
548+ attention_mask = attention_mask ,
542549 encoder_hidden_states = encoder_hidden_states ,
550+ encoder_attention_mask = encoder_attention_mask ,
543551 cross_attention_kwargs = cross_attention_kwargs ,
544- ).sample
552+ )
553+ hidden_states = output .sample
545554 hidden_states = resnet (hidden_states , temb )
546555
547556 return hidden_states
@@ -808,9 +817,14 @@ def __init__(
808817 self .gradient_checkpointing = False
809818
810819 def forward (
811- self , hidden_states , temb = None , encoder_hidden_states = None , attention_mask = None , cross_attention_kwargs = None
820+ self ,
821+ hidden_states : FloatTensor ,
822+ temb : Optional [FloatTensor ] = None ,
823+ encoder_hidden_states : Optional [FloatTensor ] = None ,
824+ encoder_attention_mask : Optional [FloatTensor ] = None ,
825+ attention_mask : Optional [FloatTensor ] = None ,
826+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
812827 ):
813- # TODO(Patrick, William) - attention mask is not used
814828 output_states = ()
815829
816830 for resnet , attn in zip (self .resnets , self .attentions ):
@@ -829,14 +843,18 @@ def custom_forward(*inputs):
829843 hidden_states = torch .utils .checkpoint .checkpoint (
830844 create_custom_forward (attn , return_dict = False ),
831845 hidden_states ,
846+ attention_mask ,
832847 encoder_hidden_states ,
848+ encoder_attention_mask ,
833849 cross_attention_kwargs ,
834850 )[0 ]
835851 else :
836852 hidden_states = resnet (hidden_states , temb )
837853 hidden_states = attn (
838854 hidden_states ,
855+ attention_mask = attention_mask ,
839856 encoder_hidden_states = encoder_hidden_states ,
857+ encoder_attention_mask = encoder_attention_mask ,
840858 cross_attention_kwargs = cross_attention_kwargs ,
841859 ).sample
842860
@@ -1775,15 +1793,15 @@ def __init__(
17751793
17761794 def forward (
17771795 self ,
1778- hidden_states ,
1779- res_hidden_states_tuple ,
1780- temb = None ,
1781- encoder_hidden_states = None ,
1782- cross_attention_kwargs = None ,
1783- upsample_size = None ,
1784- attention_mask = None ,
1796+ hidden_states : FloatTensor ,
1797+ res_hidden_states_tuple : Tuple [FloatTensor , ...],
1798+ temb : Optional [FloatTensor ] = None ,
1799+ encoder_hidden_states : Optional [FloatTensor ] = None ,
1800+ encoder_attention_mask : Optional [FloatTensor ] = None ,
1801+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1802+ upsample_size : Optional [int ] = None ,
1803+ attention_mask : Optional [FloatTensor ] = None ,
17851804 ):
1786- # TODO(Patrick, William) - attention mask is not used
17871805 for resnet , attn in zip (self .resnets , self .attentions ):
17881806 # pop res hidden states
17891807 res_hidden_states = res_hidden_states_tuple [- 1 ]
@@ -1805,14 +1823,18 @@ def custom_forward(*inputs):
18051823 hidden_states = torch .utils .checkpoint .checkpoint (
18061824 create_custom_forward (attn , return_dict = False ),
18071825 hidden_states ,
1826+ attention_mask ,
18081827 encoder_hidden_states ,
1828+ encoder_attention_mask ,
18091829 cross_attention_kwargs ,
18101830 )[0 ]
18111831 else :
18121832 hidden_states = resnet (hidden_states , temb )
18131833 hidden_states = attn (
18141834 hidden_states ,
1835+ attention_mask = attention_mask ,
18151836 encoder_hidden_states = encoder_hidden_states ,
1837+ encoder_attention_mask = encoder_attention_mask ,
18161838 cross_attention_kwargs = cross_attention_kwargs ,
18171839 ).sample
18181840
0 commit comments