diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 590502e77b..cd76b97909 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -120,7 +120,12 @@ def __init__( batch_first=True, norm_first=normalize_before, ) - self.layers = torch.nn.TransformerEncoder(encoder_layer=layer, num_layers=num_encoder_layers) + self.layers = torch.nn.TransformerEncoder( + encoder_layer=layer, + num_layers=num_encoder_layers, + enable_nested_tensor=True, + mask_check=False, + ) self.positional_embedding = PositionalEmbedding(max_seq_len, embedding_dim, padding_idx) self.embedding_layer_norm = nn.LayerNorm(embedding_dim) self.dropout = nn.Dropout(dropout)