11from dataclasses import dataclass
2- from typing import Dict , Optional , Tuple , Union , Callable
2+ from typing import Dict , List , Optional , Union , Callable
33
44import torch
55import torch .nn as nn
66from torch import Tensor
77
8- from .modules import T5Stack , T5LayerNorm
8+ from .modules import T5Encoder , T5Decoder , T5LayerNorm
99
1010
1111@dataclass
@@ -69,14 +69,15 @@ def __init__(
6969 self ,
7070 config : T5Conf ,
7171 freeze : bool = False ,
72- device = None ,
72+ device : Optional [ torch . device ] = None ,
7373 dtype = None ,
7474 ) -> None :
7575 super ().__init__ ()
7676
7777 assert isinstance (config , T5Conf )
7878
7979 self .config = config
80+ self .embedding_dim = config .embedding_dim
8081 self .encoder_only = config .encoder_only
8182 self .linear_head = config .linear_head
8283 self .padding_idx = config .padding_idx
@@ -86,8 +87,7 @@ def __init__(
8687 self .dtype = dtype
8788
8889 self .token_embeddings = nn .Embedding (config .vocab_size , config .embedding_dim , config .padding_idx )
89- self .encoder = T5Stack (
90- is_decoder = False ,
90+ self .encoder = T5Encoder (
9191 d_model = config .embedding_dim ,
9292 nhead = config .num_attention_heads ,
9393 num_layers = config .num_encoder_layers ,
@@ -105,8 +105,7 @@ def __init__(
105105 self .dropout2 = nn .Dropout (self .dropout )
106106
107107 if not config .encoder_only :
108- self .decoder = T5Stack (
109- is_decoder = True ,
108+ self .decoder = T5Decoder (
110109 d_model = config .embedding_dim ,
111110 nhead = config .num_attention_heads ,
112111 num_layers = config .num_decoder_layers ,
@@ -122,9 +121,13 @@ def __init__(
122121 self .norm2 = T5LayerNorm (config .embedding_dim )
123122 self .dropout3 = nn .Dropout (self .dropout )
124123 self .dropout4 = nn .Dropout (self .dropout )
124+ else :
125+ self .decoder = None
125126
126127 if config .linear_head :
127128 self .lm_head = nn .Linear (config .embedding_dim , config .vocab_size , bias = False )
129+ else :
130+ self .lm_head = None
128131
129132 if freeze :
130133 for p in self .parameters ():
@@ -133,10 +136,10 @@ def __init__(
133136 def forward (
134137 self ,
135138 encoder_tokens : Tensor ,
136- decoder_tokens : Tensor = None ,
139+ decoder_tokens : Optional [ Tensor ] = None ,
137140 encoder_mask : Optional [Tensor ] = None ,
138141 decoder_mask : Optional [Tensor ] = None ,
139- ) -> Dict [str , Union [Tensor , Tuple [Tensor ]]]:
142+ ) -> Dict [str , Union [Tensor , List [Tensor ], Optional [ Tensor ], List [ Optional [ Tensor ] ]]]:
140143 r"""Pass the inputs (and mask) through the decoder layer in turn.
141144 Args:
142145 encoder_tokens: Tokenized input sequence to the encoder.
@@ -163,23 +166,27 @@ def forward(
163166 """
164167 encoder_padding_mask = encoder_tokens .eq (self .padding_idx )
165168 encoder_embeddings = self .dropout1 (self .token_embeddings (encoder_tokens ))
166- encoder_output , encoder_hidden_states , encoder_position_bias , encoder_sa , _ = self .encoder (
169+ encoder_output , encoder_hidden_states , encoder_position_bias , encoder_sa = self .encoder (
167170 encoder_embeddings , tgt_mask = encoder_mask , tgt_key_padding_mask = encoder_padding_mask
168171 )
169172
170173 encoder_output = self .norm1 (encoder_output )
171174 encoder_output = self .dropout2 (encoder_output )
172- encoder_hidden_states = encoder_hidden_states + (encoder_output , )
175+ encoder_hidden_states . append (encoder_output )
173176
174177 if not self .encoder_only :
175178
179+ assert self .decoder is not None
180+
176181 # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
177182 if decoder_tokens is None :
178183 decoder_tokens = torch .ones ((encoder_tokens .size (0 ), 1 ), dtype = torch .long ) * self .padding_idx
179184
180185 if decoder_mask is None :
186+ assert decoder_tokens is not None and decoder_tokens .dim () == 2
181187 tgt_len = decoder_tokens .shape [1 ]
182- decoder_mask = torch .triu (torch .ones ((tgt_len , tgt_len ), dtype = torch .float64 ), diagonal = 1 ).bool ()
188+ decoder_mask = torch .triu (torch .ones ((tgt_len , tgt_len ), dtype = torch .float64 ), diagonal = 1 )
189+ decoder_mask = decoder_mask .to (torch .bool )
183190
184191 decoder_padding_mask = decoder_tokens .eq (self .padding_idx )
185192 # T5 implemention uses padding idx to start sequence. Want to ignore this when masking
@@ -197,13 +204,14 @@ def forward(
197204
198205 decoder_output = self .norm2 (decoder_output )
199206 decoder_output = self .dropout4 (decoder_output )
200- decoder_hidden_states = decoder_hidden_states + (decoder_output , )
207+ decoder_hidden_states . append (decoder_output )
201208
202209 if self .linear_head :
210+ assert self .lm_head is not None
203211 # Rescale output before projecting on vocab. This happens when the encoder and decoder share the
204212 # same word embeddings, which is always the case in our t5 implementation.
205213 # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
206- decoder_output = decoder_output * (self .config . embedding_dim ** - 0.5 )
214+ decoder_output = decoder_output * (self .embedding_dim ** - 0.5 )
207215 decoder_output = self .lm_head (decoder_output )
208216
209217 t5_output = {
@@ -225,4 +233,8 @@ def forward(
225233 "encoder_sa_scores" : encoder_sa ,
226234 }
227235
236+ assert torch .jit .isinstance (
237+ t5_output , Dict [str , Union [Tensor , List [Tensor ], Optional [Tensor ], List [Optional [Tensor ]]]]
238+ )
239+
228240 return t5_output
0 commit comments