From 7978975f399f1dcfdb9afe5b5550c540a04ec6b3 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Mon, 12 Sep 2022 14:53:17 -0700 Subject: [PATCH] Turn off mask checking for torchtext which is known to have a legal mask (#1896) Summary: Pull Request resolved: https://github.com/pytorch/text/pull/1896 Turn off mask checking for torchtext which is known to have a legal mask Reviewed By: zrphercule Differential Revision: D39445703 fbshipit-source-id: 8ac678ec6ad78607a5951e8604c1591613fbdcbb --- torchtext/models/roberta/modules.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 590502e77b..cd76b97909 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -120,7 +120,12 @@ def __init__( batch_first=True, norm_first=normalize_before, ) - self.layers = torch.nn.TransformerEncoder(encoder_layer=layer, num_layers=num_encoder_layers) + self.layers = torch.nn.TransformerEncoder( + encoder_layer=layer, + num_layers=num_encoder_layers, + enable_nested_tensor=True, + mask_check=False, + ) self.positional_embedding = PositionalEmbedding(max_seq_len, embedding_dim, padding_idx) self.embedding_layer_norm = nn.LayerNorm(embedding_dim) self.dropout = nn.Dropout(dropout)