diff --git a/test/asset/t5.base.output.pt b/test/asset/t5.base.model.output.pt similarity index 100% rename from test/asset/t5.base.output.pt rename to test/asset/t5.base.model.output.pt diff --git a/test/asset/t5.large.encoder.output.pt b/test/asset/t5.large.encoder.output.pt new file mode 100644 index 0000000000..174bee7566 Binary files /dev/null and b/test/asset/t5.large.encoder.output.pt differ diff --git a/test/asset/t5.large.generation.output.pt b/test/asset/t5.large.generation.output.pt new file mode 100644 index 0000000000..762f2f26a0 Binary files /dev/null and b/test/asset/t5.large.generation.output.pt differ diff --git a/test/asset/t5.large.model.output.pt b/test/asset/t5.large.model.output.pt new file mode 100644 index 0000000000..13b0c6ecad Binary files /dev/null and b/test/asset/t5.large.model.output.pt differ diff --git a/test/asset/t5.small.encoder.output.pt b/test/asset/t5.small.encoder.output.pt new file mode 100644 index 0000000000..97c2922103 Binary files /dev/null and b/test/asset/t5.small.encoder.output.pt differ diff --git a/test/asset/t5.small.generation.output.pt b/test/asset/t5.small.generation.output.pt new file mode 100644 index 0000000000..9553d88463 Binary files /dev/null and b/test/asset/t5.small.generation.output.pt differ diff --git a/test/asset/t5.small.model.output.pt b/test/asset/t5.small.model.output.pt new file mode 100644 index 0000000000..d0933cdb05 Binary files /dev/null and b/test/asset/t5.small.model.output.pt differ diff --git a/test/prototype/integration_tests/test_models.py b/test/prototype/integration_tests/test_models.py index 7c486d8f3d..378a95711c 100644 --- a/test/prototype/integration_tests/test_models.py +++ b/test/prototype/integration_tests/test_models.py @@ -1,11 +1,37 @@ import torch from parameterized import parameterized from test.common.assets import get_asset_path +from test.common.parameterized_utils import nested_params from test.common.torchtext_test_case import TorchtextTestCase -from torchtext.prototype.models import T5_BASE_ENCODER, T5_BASE, T5_BASE_GENERATION, T5Conf, T5Transform +from torchtext.prototype.models import ( + T5_BASE_ENCODER, + T5_BASE, + T5_BASE_GENERATION, + T5_SMALL_ENCODER, + T5_SMALL, + T5_SMALL_GENERATION, + T5_LARGE_ENCODER, + T5_LARGE, + T5_LARGE_GENERATION, + T5Conf, + T5Transform, +) from torchtext.prototype.models.t5.wrapper import T5Wrapper +BUNDLERS = { + "base_model": T5_BASE, + "base_encoder": T5_BASE_ENCODER, + "base_generation": T5_BASE_GENERATION, + "small_model": T5_SMALL, + "small_encoder": T5_SMALL_ENCODER, + "small_generation": T5_SMALL_GENERATION, + "large_model": T5_LARGE, + "large_encoder": T5_LARGE_ENCODER, + "large_generation": T5_LARGE_GENERATION, +} + + class TestT5(TorchtextTestCase): def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): """Verify that pre-trained T5 models in torchtext produce @@ -29,43 +55,32 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06) - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_t5_base_encoder_model(self, name, is_jit) -> None: - expected_asset_name = "t5.base.encoder.output.pt" - test_text = ["Hello world", "Attention rocks!"] - self._t5_model( - is_jit=is_jit, t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text - ) - - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_t5_base_model(self, name, is_jit) -> None: - expected_asset_name = "t5.base.output.pt" - test_text = ["Hello world", "Attention rocks!"] - self._t5_model(is_jit=is_jit, t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text) - - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_t5_base_generation_model(self, name, is_jit) -> None: - expected_asset_name = "t5.base.generation.output.pt" + @nested_params(["base", "small", "large"], ["encoder", "model", "generation"], ["jit", "not_jit"]) + def test_t5_encoder_model(self, configuration, type, name) -> None: + expected_asset_name = f"t5.{configuration}.{type}.output.pt" test_text = ["Hello world", "Attention rocks!"] - self._t5_model( - is_jit=is_jit, t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text - ) + is_jit = name == "jit" + t5_model = BUNDLERS[configuration + "_" + type] + self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text) - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_t5_wrapper(self, name, is_jit) -> None: + @nested_params(["base", "small", "large"], ["jit", "not_jit"]) + def test_t5_wrapper(self, configuration, name) -> None: test_text = ["translate English to French: I want to eat pizza for dinner."] - expected_text = ["Je veux manger de la pizza pour le dîner."] + if configuration == "small": + expected_text = ["Je veux manger la pizza pour le dîner."] + else: + expected_text = ["Je veux manger de la pizza pour le dîner."] beam_size = 3 max_seq_len = 512 - model = T5Wrapper(configuration="base") - if is_jit: + model = T5Wrapper(configuration=configuration) + if name == "jit": model = torch.jit.script(model) output_text = model(test_text, beam_size, max_seq_len) self.assertEqual(output_text, expected_text) - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_t5_wrapper_checkpoing(self, name, is_jit) -> None: + @parameterized.expand(["jit", "not_jit"]) + def test_t5_wrapper_checkpoint(self, name) -> None: test_text = ["translate English to French: I want to eat pizza for dinner."] expected_text = ["Je veux manger de la pizza pour le dîner."] beam_size = 3 @@ -84,7 +99,7 @@ def test_t5_wrapper_checkpoing(self, name, is_jit) -> None: freeze_model=True, strict=True, ) - if is_jit: + if name == "jit": model = torch.jit.script(model) output_text = model(test_text, beam_size, max_seq_len) diff --git a/torchtext/prototype/models/t5/__init__.py b/torchtext/prototype/models/t5/__init__.py index 69bb75aef6..d23b5b8308 100644 --- a/torchtext/prototype/models/t5/__init__.py +++ b/torchtext/prototype/models/t5/__init__.py @@ -2,6 +2,18 @@ T5_BASE_ENCODER, T5_BASE, T5_BASE_GENERATION, + T5_SMALL_ENCODER, + T5_SMALL, + T5_SMALL_GENERATION, + T5_LARGE_ENCODER, + T5_LARGE, + T5_LARGE_GENERATION, + T5_3B_ENCODER, + T5_3B, + T5_3B_GENERATION, + T5_11B_ENCODER, + T5_11B, + T5_11B_GENERATION, T5Bundle, ) from .model import T5Conf, T5Model @@ -14,5 +26,17 @@ "T5_BASE_ENCODER", "T5_BASE", "T5_BASE_GENERATION", + "T5_SMALL_ENCODER", + "T5_SMALL", + "T5_SMALL_GENERATION", + "T5_LARGE_ENCODER", + "T5_LARGE", + "T5_LARGE_GENERATION", + "T5_3B_ENCODER", + "T5_3B", + "T5_3B_GENERATION", + "T5_11B_ENCODER", + "T5_11B", + "T5_11B_GENERATION", "T5Transform", ] diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index c4c1d78fb6..76b0d9bba9 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -68,7 +68,7 @@ def get_model( *, load_weights: bool = True, freeze_model: bool = False, - dl_kwargs: Dict[str, Any] = None, + dl_kwargs: Optional[Dict[str, Any]] = None, ) -> T5Model: r"""get_model(load_weights: bool = True, freeze_model: bool = False, *, dl_kwargs=None) -> torctext.prototype.models.T5Model @@ -104,8 +104,8 @@ def build_model( *, freeze_model: bool = False, checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, - strict=False, - dl_kwargs: Dict[str, Any] = None, + strict: bool = False, + dl_kwargs: Optional[Dict[str, Any]] = None, ) -> T5Model: """Class builder method @@ -138,19 +138,8 @@ def config(self) -> T5Conf: return self._config -T5_BASE_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"), - _config=T5Conf(encoder_only=True), - transform=lambda: T5Transform( - urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), - max_seq_len=512, - eos_idx=1, - padding_idx=0, - ), -) - -T5_BASE_ENCODER.__doc__ = """ - T5_BASE_ENCODER is an encoder-only model from a pre-trained T5 model with the base configuration.. +ENCODER_DOC = """ + T5_{}_ENCODER is an encoder-only model from a pre-trained T5 model with the {} configuration. It returns the normalized output from the final layer of the encoder. The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer @@ -166,22 +155,10 @@ def config(self) -> T5Conf: `Source `__] Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. - """ - - -T5_BASE = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.base.pt"), - _config=T5Conf(encoder_only=False), - transform=lambda: T5Transform( - urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), - max_seq_len=512, - eos_idx=1, - padding_idx=0, - ), -) +""" -T5_BASE.__doc__ = """ - T5_BASE is an encoder-decoder model from a pre-trained T5 model with the base configuration. +MODEL_DOC = """ + T5_{} is an encoder-decoder model from a pre-trained T5 model with the {} configuration. It returns the normalized output from the final layer of the decoder. The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer @@ -199,19 +176,8 @@ def config(self) -> T5Conf: Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. """ -T5_BASE_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"), - _config=T5Conf(encoder_only=False, linear_head=True), - transform=lambda: T5Transform( - urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), - max_seq_len=512, - eos_idx=1, - padding_idx=0, - ), -) - -T5_BASE_GENERATION.__doc__ = """ - T5_BASE_GENERATION is an encoder-decoder model from a pre-trained T5 model with the base configuration. +GENERATION_DOC = """ + T5_{}_GENERATION is an encoder-decoder model from a pre-trained T5 model with the {} configuration. It returns the output of the final layer of the decoder after passing through a linear layer to project the hidden states to the model vocabulary. This output can then be used for language generation. @@ -229,3 +195,293 @@ def config(self) -> T5Conf: Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. """ + +T5_BASE_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"), + _config=T5Conf(encoder_only=True), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_BASE_ENCODER.__doc__ = ENCODER_DOC.format("BASE", "base") + +T5_BASE = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.pt"), + _config=T5Conf(encoder_only=False), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_BASE.__doc__ = MODEL_DOC.format("BASE", "base") + +T5_BASE_GENERATION = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"), + _config=T5Conf(encoder_only=False, linear_head=True), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_BASE_GENERATION.__doc__ = GENERATION_DOC.format("BASE", "base") + +T5_SMALL_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.small.encoder.pt"), + _config=T5Conf( + encoder_only=True, + embedding_dim=512, + num_attention_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + ffn_dimension=2048, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_SMALL_ENCODER.__doc__ = ENCODER_DOC.format("SMALL", "small") + + +T5_SMALL = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.small.pt"), + _config=T5Conf( + encoder_only=False, + embedding_dim=512, + num_attention_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + ffn_dimension=2048, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_SMALL.__doc__ = MODEL_DOC.format("SMALL", "small") + +T5_SMALL_GENERATION = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.small.generation.pt"), + _config=T5Conf( + encoder_only=False, + linear_head=True, + embedding_dim=512, + num_attention_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + ffn_dimension=2048, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_SMALL_GENERATION.__doc__ = GENERATION_DOC.format("SMALL", "small") + +T5_LARGE_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.large.encoder.pt"), + _config=T5Conf( + encoder_only=True, + embedding_dim=1024, + num_attention_heads=16, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=4096, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_LARGE_ENCODER.__doc__ = ENCODER_DOC.format("LARGE", "large") + +T5_LARGE = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.large.pt"), + _config=T5Conf( + encoder_only=False, + embedding_dim=1024, + num_attention_heads=16, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=4096, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_LARGE.__doc__ = MODEL_DOC.format("LARGE", "large") + +T5_LARGE_GENERATION = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.large.generation.pt"), + _config=T5Conf( + encoder_only=False, + linear_head=True, + embedding_dim=1024, + num_attention_heads=16, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=4096, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_LARGE_GENERATION.__doc__ = GENERATION_DOC.format("LARGE", "large") + +T5_3B_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.3b.encoder.pt"), + _config=T5Conf( + encoder_only=True, + embedding_dim=1024, + qkv_dim=128, + num_attention_heads=32, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=16384, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_3B_ENCODER.__doc__ = ENCODER_DOC.format("3B", "3B") + +T5_3B = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.3b.pt"), + _config=T5Conf( + encoder_only=False, + embedding_dim=1024, + qkv_dim=128, + num_attention_heads=32, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=16384, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_3B.__doc__ = MODEL_DOC.format("3B", "3B") + +T5_3B_GENERATION = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.3b.generation.pt"), + _config=T5Conf( + encoder_only=False, + linear_head=True, + embedding_dim=1024, + qkv_dim=128, + num_attention_heads=32, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=16384, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_3B_GENERATION.__doc__ = GENERATION_DOC.format("3B", "3B") + +T5_11B_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.11b.encoder.pt"), + _config=T5Conf( + encoder_only=True, + embedding_dim=1024, + qkv_dim=128, + num_attention_heads=128, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=65536, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_11B_ENCODER.__doc__ = ENCODER_DOC.format("11B", "11B") + +T5_11B = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.11b.pt"), + _config=T5Conf( + encoder_only=False, + embedding_dim=1024, + qkv_dim=128, + num_attention_heads=128, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=65536, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_11B.__doc__ = MODEL_DOC.format("11B", "11B") + +T5_11B_GENERATION = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.11b.generation.pt"), + _config=T5Conf( + encoder_only=False, + linear_head=True, + embedding_dim=1024, + qkv_dim=128, + num_attention_heads=128, + num_encoder_layers=24, + num_decoder_layers=24, + ffn_dimension=65536, + ), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_11B_GENERATION.__doc__ = GENERATION_DOC.format("11B", "11B") diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 7113dfd9d1..2812af3c74 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -13,6 +13,7 @@ class T5Conf: encoder_only: bool = False linear_head: bool = False embedding_dim: int = 768 + qkv_dim: int = 64 num_attention_heads: int = 12 num_encoder_layers: int = 12 num_decoder_layers: int = 12 @@ -39,6 +40,7 @@ class T5Model(nn.Module): config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (default=False). config.linear_head: Whether or not a linear layer should be used to project the output of the decoder's last layer to the vocab (default=False). config.embedding_dim: Number of expected features in the encoder/decoder inputs (default=768). + config.qkv_dim: Projection dimension (per head) for query, keys, and values. (defualt=64). config.num_attention_heads: Number of heads in the multiheadattention models (default=12). config.num_encoder_layers: Number of encoder layers in the encoder (default=12). config.num_decoder_layers: Number of decoder layers in the decoder (default=12). @@ -92,6 +94,7 @@ def __init__( nhead=config.num_attention_heads, num_layers=config.num_encoder_layers, dim_feedforward=config.ffn_dimension, + qkv_dim=config.qkv_dim, dropout=self.dropout, activation=config.activation, layer_norm_eps=config.layer_norm_eps, @@ -110,6 +113,7 @@ def __init__( nhead=config.num_attention_heads, num_layers=config.num_decoder_layers, dim_feedforward=config.ffn_dimension, + qkv_dim=config.qkv_dim, dropout=self.dropout, activation=config.activation, layer_norm_eps=config.layer_norm_eps, diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 806099f4ea..f66036b96e 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -21,6 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear class T5MultiheadAttention(nn.MultiheadAttention): @@ -31,8 +32,7 @@ def __init__( is_decoder: bool = False, dropout: float = 0.0, bias: bool = False, - kdim: Optional[int] = None, - vdim: Optional[int] = None, + qkv_dim: int = 64, compute_relative_attention_bias: bool = False, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, @@ -46,8 +46,7 @@ def __init__( is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False` dropout: Probability of an element to be zeroed. Default: 0.0 bias: If specified, adds bias to input / output projection layers. Default: `False`. - kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`). - vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`). + qkv_dim: Projection dimension (per head) for query, keys, and values. Defualt: 64. compute_relative_attention_bias: Whether or not the relative position embeddings need to be computed. Wypically occurs in the first layer of the encoder/decoder and the resulting position embeddings are returned to be passed up to higher layers. (defualt: False) @@ -55,12 +54,15 @@ def __init__( relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket. Default: `128` """ - super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype) + super().__init__(embed_dim, num_heads, dropout, bias, False, False, qkv_dim, qkv_dim, True, device, dtype) factory_kwargs = {"device": device, "dtype": dtype} self.is_decoder = is_decoder - self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) - self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) - self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) + self.inner_dim = qkv_dim * num_heads + self.q_proj_weight = nn.Parameter(torch.empty((self.inner_dim, embed_dim), **factory_kwargs)) + self.k_proj_weight = nn.Parameter(torch.empty((self.inner_dim, embed_dim), **factory_kwargs)) + self.v_proj_weight = nn.Parameter(torch.empty((self.inner_dim, embed_dim), **factory_kwargs)) + self.out_proj = NonDynamicallyQuantizableLinear(self.inner_dim, embed_dim, bias=bias, **factory_kwargs) + self.register_parameter("in_proj_weight", None) self.compute_relative_attention_bias = compute_relative_attention_bias @@ -171,14 +173,7 @@ def _t5_multi_head_attention_forward( assert ( embed_dim == self.embed_dim ), f"was expecting embedding dimension of {self.embed_dim}, but got {embed_dim}" - if isinstance(embed_dim, Tensor): - # Embed_dim can be a tensor when JIT tracing - head_dim = embed_dim.div(self.num_heads, rounding_mode="trunc") - else: - head_dim = embed_dim // self.num_heads - assert ( - head_dim * self.num_heads == embed_dim - ), f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}" + head_dim = self.inner_dim // self.num_heads # Allow MHA to have different embedding dimensions when separate projection weights are used assert ( key.shape[:2] == value.shape[:2] @@ -192,7 +187,7 @@ def _t5_multi_head_attention_forward( b_q = b_k = b_v = None else: b_q, b_k, b_v = self.in_proj_bias.chunk(3) - q, k, v = F._in_projection( + q, k, v = self._t5_in_projection( query, key, value, self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, b_q, b_k, b_v ) @@ -221,7 +216,7 @@ def _t5_multi_head_attention_forward( warnings.warn("Byte tensor for key_padding_mask is not supported. Using bool tensor instead.") key_padding_mask = key_padding_mask.to(torch.bool) - # Reshape q, k, v for multihead attention and make em batch first + # Reshape q, k, v for multihead attention and make them batch first q = q.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2) k = k.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2) v = v.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2) @@ -289,6 +284,72 @@ def _t5_multi_head_attention_forward( return attn_output, position_bias, None + # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4761 + def _t5_in_projection( + self, + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Performs the in-projection step of the attention operation. This is simply + a triple of linear projections, with shape constraints on the weights which + ensure embedding dimension uniformity in the projected outputs. + Output is a triple containing projection tensors for query, key and value. + Args: + q, k, v: query, key and value tensors to be projected. + w_q, w_k, w_v: weights for q, k and v, respectively. + b_q, b_k, b_v: optional biases for q, k and v, respectively. + Shape: + Inputs: + - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any + number of leading dimensions. + - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any + number of leading dimensions. + - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any + number of leading dimensions. + - w_q: :math:`(Ei, Eq)` where Ei is the dimension to which the query, key, and value + emebeddings are to be projected + - w_k: :math:`(Ei, Ek)` + - w_v: :math:`(Ei, Ev)` + - b_q: :math:`(Ei)` + - b_k: :math:`(Ei)` + - b_v: :math:`(Ei)` + Output: in output triple :math:`(q', k', v')`, + - q': :math:`[Qdims..., Ei]` + - k': :math:`[Kdims..., Ei]` + - v': :math:`[Vdims..., Ei]` + """ + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == ( + self.inner_dim, + Eq, + ), f"expecting query weights shape of {(self.inner_dim, Eq)}, but got {w_q.shape}" + assert w_k.shape == ( + self.inner_dim, + Ek, + ), f"expecting key weights shape of {(self.inner_dim, Ek)}, but got {w_k.shape}" + assert w_v.shape == ( + self.inner_dim, + Ev, + ), f"expecting value weights shape of {(self.inner_dim, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == ( + self.inner_dim, + ), f"expecting query bias shape of {(self.inner_dim,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + self.inner_dim, + ), f"expecting key bias shape of {(self.inner_dim,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + self.inner_dim, + ), f"expecting value bias shape of {(self.inner_dim,)}, but got {b_v.shape}" + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4814 def _t5_dot_product_attention( self, @@ -459,6 +520,7 @@ class T5EncoderLayer(nn.Module): d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). dim_feedforward: Dimension of the feedforward network model (default=3072). + qkv_dim: Projection dimension (per head) for query, keys, and values. (defualt=64). dropout: Dropout value (default=0.1). activation: Activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. (default: relu) @@ -481,6 +543,7 @@ def __init__( d_model: int, nhead: int, dim_feedforward: int = 3072, + qkv_dim: int = 64, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, @@ -501,6 +564,7 @@ def __init__( nhead, is_decoder=False, dropout=dropout, + qkv_dim=qkv_dim, compute_relative_attention_bias=compute_relative_attention_bias, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, @@ -599,6 +663,7 @@ class T5DecoderLayer(T5EncoderLayer): d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). dim_feedforward: Dimension of the feedforward network model (default=3072). + qkv_dim: Projection dimension (per head) for query, keys, and values. (defualt=64). dropout: Dropout value (default=0.1). activation: Activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. (default: relu) @@ -622,6 +687,7 @@ def __init__( d_model: int, nhead: int, dim_feedforward: int = 3072, + qkv_dim: int = 64, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, @@ -635,6 +701,7 @@ def __init__( d_model, nhead, dim_feedforward, + qkv_dim, dropout, activation, layer_norm_eps, @@ -646,7 +713,7 @@ def __init__( ) self.cross_attn = T5MultiheadAttention( - d_model, nhead, is_decoder=True, dropout=dropout, device=device, dtype=dtype + d_model, nhead, is_decoder=True, dropout=dropout, qkv_dim=qkv_dim, device=device, dtype=dtype ) self.norm3 = T5LayerNorm(d_model, eps=layer_norm_eps) self.dropout4 = nn.Dropout(dropout) @@ -721,6 +788,7 @@ class T5Encoder(nn.Module): nhead: Number of heads in the multihead attention models (required). num_layers: Number of encoder layers in the stack (required) dim_feedforward: Dimension of the feedforward network model (default=3072). + qkv_dim: Projection dimension (per head) for query, keys, and values. (defualt=64). dropout: Dropout value (default=0.1). activation: Activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. (default: relu) @@ -740,6 +808,7 @@ def __init__( nhead: int, num_layers: int, dim_feedforward: int = 3072, + qkv_dim: int = 64, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, @@ -756,6 +825,7 @@ def __init__( d_model, nhead, dim_feedforward, + qkv_dim, dropout, activation, layer_norm_eps, @@ -811,6 +881,7 @@ class T5Decoder(nn.Module): nhead: Number of heads in the multihead attention models (required). num_layers: Number of decoder layers in the stack (required) dim_feedforward: Dimension of the feedforward network model (default=3072). + qkv_dim: Projection dimension (per head) for query, keys, and values. (defualt=64). dropout: Dropout value (default=0.1). activation: Activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. (default: relu) @@ -831,6 +902,7 @@ def __init__( nhead: int, num_layers: int, dim_feedforward: int = 3072, + qkv_dim: int = 64, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, @@ -847,6 +919,7 @@ def __init__( d_model, nhead, dim_feedforward, + qkv_dim, dropout, activation, layer_norm_eps, diff --git a/torchtext/prototype/models/t5/wrapper.py b/torchtext/prototype/models/t5/wrapper.py index 1aab3d047c..3eafb727a4 100644 --- a/torchtext/prototype/models/t5/wrapper.py +++ b/torchtext/prototype/models/t5/wrapper.py @@ -4,7 +4,25 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torchtext.prototype.models import T5_BASE_GENERATION, T5Conf, T5Transform, T5Bundle +from torchtext.prototype.models import ( + T5_BASE_GENERATION, + T5_SMALL_GENERATION, + T5_LARGE_GENERATION, + T5_3B_GENERATION, + T5_11B_GENERATION, + T5Conf, + T5Transform, + T5Bundle, +) + + +BUNDLERS = { + "base": T5_BASE_GENERATION, + "small": T5_SMALL_GENERATION, + "large": T5_LARGE_GENERATION, + "3b": T5_3B_GENERATION, + "11b": T5_11B_GENERATION, +} class T5Wrapper(nn.Module): @@ -21,7 +39,7 @@ def __init__( ) -> None: """ Args: - configuration (str or None): The model configuration. Currently only support 'base'. Must be `None` if checkpoint is not `None`. (Default: `None`) + configuration (str or None): The model configuration. Only support 'base', 'small', 'large', '3b', and '11b' . Must be `None` if checkpoint is not `None`. (Default: `None`) checkpoint (str, Dict[str, torch.Tensor], or None): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. Must be `None` if configuration is not `None`.(Default: ``None``) t5_config (T5Conf or None): An instance of T5Conf that defined the model configuration (i.e. number of layer, attention heads, etc). Must be provided if configuration is `None`. (Default: `None`) transform (T5Transfrom or None): An instance of T5Transform that defines the text processing pipeline. Must be provided if configuration is `None`. (Default: `None`) @@ -42,7 +60,9 @@ def __init__( else: assert checkpoint is None, "configuration and checkpoint were both provided. Can only provide one." - assert configuration in ("base"), "Invalid configuration provided. Only support 'base' configuration." + assert ( + configuration in BUNDLERS + ), f"Invalid configuration provided. Only support the following configurations: {[key for key in BUNDLERS.keys()]}" if configuration is None and checkpoint is not None: self.bundler = T5Bundle(_path=checkpoint, _config=t5_config, transform=lambda: transform) @@ -50,7 +70,7 @@ def __init__( config=t5_config, freeze_model=freeze_model, checkpoint=checkpoint, strict=strict, dl_kwargs=dl_kwargs ) else: - self.bundler = T5_BASE_GENERATION + self.bundler = BUNDLERS[configuration] self.model = self.bundler.get_model() self.transform = self.bundler.transform()