diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 39a4911d6f..1e507a407f 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -55,7 +55,6 @@ jobs: python3 -m pip --quiet install sentencepiece python3 -m pip --quiet install tqdm python3 -m pip --quiet install expecttest - python3 -m pip --quiet install transformers # Run Tests python3 -m torch.utils.collect_env cd test diff --git a/README.rst b/README.rst index 6b757af521..35399cb8db 100644 --- a/README.rst +++ b/README.rst @@ -122,7 +122,7 @@ The library currently consist of following pre-trained models: * `DistilRoBERTa `_ * XLM-RoBERTa: `Base and Large Architure `_ * T5: `Small, Base, Large, 3B, and 11B Architecture `_ -* Flan-T5: `Small, Base, Large, XL, and XXL Architecture `_ +* Flan-T5: `Base, Large, XL, and XXL Architecture `_ Tokenizers ========== diff --git a/test/integration_tests/test_t5_models.py b/test/integration_tests/test_t5_models.py index c7ea3b794f..0bacca430c 100644 --- a/test/integration_tests/test_t5_models.py +++ b/test/integration_tests/test_t5_models.py @@ -1,10 +1,15 @@ +import os import tempfile import pytest # noqa: F401 import torch from parameterized import parameterized_class -from torchtext.models import T5Bundle +from torchtext import _TEXT_BUCKET +from torchtext._download_hooks import _TEST_DOWNLOAD_MANAGER from torchtext.models import ( + FLAN_T5_BASE, + FLAN_T5_BASE_ENCODER, + FLAN_T5_BASE_GENERATION, T5_BASE, T5_BASE_ENCODER, T5_BASE_GENERATION, @@ -14,11 +19,11 @@ T5_SMALL, T5_SMALL_ENCODER, T5_SMALL_GENERATION, + T5Bundle, ) from torchtext_unittest.common.assets import get_asset_path from torchtext_unittest.common.parameterized_utils import nested_params from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase -from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model BUNDLERS = { "base_model": T5_BASE, @@ -30,6 +35,9 @@ "large_model": T5_LARGE, "large_encoder": T5_LARGE_ENCODER, "large_generation": T5_LARGE_GENERATION, + "flan_base_encoder": FLAN_T5_BASE_ENCODER, + "flan_base_model": FLAN_T5_BASE, + "flan_base_generation": FLAN_T5_BASE_GENERATION, } @@ -45,6 +53,9 @@ ("large_model",), ("large_encoder",), ("large_generation",), + ("flan_base_encoder",), + ("flan_base_model",), + ("flan_base_generation",), ], ) class TestT5Model(TorchtextTestCase): @@ -74,126 +85,81 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): def _t5_get_encoder(self, model, model_input, encoder_output): encoder = model.get_encoder() - # Need to set the tgt_key_padding_mask to ensure the same results + # Need to set the key_padding_mask to ensure the same results encoder_padding_mask = model_input.eq(model.padding_idx) output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"] assert torch.all(output_from_get_encoder.eq(encoder_output)) - @nested_params(["jit", "not_jit"]) + @nested_params(["not_jit", "jit"]) def test_t5_model(self, name) -> None: - configuration, type = self.model_name.split("_") + names = self.model_name.split("_") + + num_names = len(names) + + if num_names == 3: + # Handled slightly differently for Flan-T5 model naming + configuration = names[1] + type = names[2] + expected_asset_name = f"t5.flan.{configuration}.{type}.output.pt" + t5_model = BUNDLERS["flan_" + configuration + "_" + type] + elif num_names == 2: + configuration = names[0] + type = names[1] + expected_asset_name = f"t5.{configuration}.{type}.output.pt" + t5_model = BUNDLERS[configuration + "_" + type] + else: + raise RuntimeError(f"Unknown model name: {self.model_name}") - expected_asset_name = f"t5.{configuration}.{type}.output.pt" test_text = ["Hello world", "Attention rocks!"] is_jit = name == "jit" - t5_model = BUNDLERS[configuration + "_" + type] self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text) +@parameterized_class( + ("model",), + [ + ("hf_t5_small_encoder",), + ("hf_t5_small",), + ("hf_t5_small_generation",), + ("hf_flan_base_encoder",), + ("hf_flan_base",), + ("hf_flan_base_generation",), + ], +) class TestLoadFromHFCheckpoints(TorchtextTestCase): def setUp(self) -> None: super().setUp() self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) - self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + self.encoder_padding_mask = torch.tensor( + [[False, False, False, False, False, False], [False, False, False, True, True, True]] + ) self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) - self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) - - def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: - # check that encoder layers match - for i in range(config.num_encoder_layers + 1): - if i < config.num_encoder_layers: - hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i] - # self-attention scores - assert torch.equal( - our_output["encoder_sa_scores"][i], hf_output_sa - ), f"Mismatched self-attention scores for encoder layer {i}" - hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i] - # encoder hidden states - assert torch.equal( - our_output["encoder_hidden_states"][i], hf_output_hs - ), f"Mismatched hidden states for encoder layer {i}" - - if not encoder_only: - # check that decoder layers match - for i in range(config.num_decoder_layers + 1): - if i < config.num_encoder_layers: - # self-attention scores - assert torch.equal( - our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i] - ), f"Mismatched self-attention scores for decoder layer {i}" - # cross-attention scores - assert torch.equal( - our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i] - ), f"Mismatched cross-attention scores for decoder layer {i}" - # decoder hidden states - assert torch.equal( - our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i] - ), f"Mismatched hidden states for decoder layer {i}" - - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: - with tempfile.TemporaryDirectory() as tmp_dir: - model_path = f"{tmp_dir}/hf_t5_small_enc" - - t5_small_enc = T5EncoderModel.from_pretrained("t5-small") - t5_small_enc.save_pretrained(model_path) - - our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=True) - - hf_output = t5_small_enc( - input_ids=self.encoder_input_ids, - attention_mask=self.encoder_padding_mask, - output_hidden_states=True, - output_attentions=True, - ) - - our_output = our_encoder(self.encoder_input_ids) - - self.check_outputs_of_models(our_output, hf_output, our_encoder.config, True) - - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: - with tempfile.TemporaryDirectory() as tmp_dir: - model_path = f"{tmp_dir}/hf_t5_small" - - t5_small = T5Model.from_pretrained("t5-small") - t5_small.save_pretrained(model_path) - - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) - - hf_output = t5_small( - input_ids=self.encoder_input_ids, - decoder_input_ids=self.decoder_input_ids, - attention_mask=self.encoder_padding_mask, - decoder_attention_mask=self.decoder_padding_mask, - output_hidden_states=True, - output_attentions=True, - ) - - our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) - - self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) - - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None: - with tempfile.TemporaryDirectory() as tmp_dir: - model_path = f"{tmp_dir}/hf_t5_small_gen" - - t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") - t5_small_gen.save_pretrained(model_path) - - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) - - hf_output = t5_small_gen( - input_ids=self.encoder_input_ids, - decoder_input_ids=self.decoder_input_ids, - attention_mask=self.encoder_padding_mask, - decoder_attention_mask=self.decoder_padding_mask, - output_hidden_states=True, - output_attentions=True, - ) - - our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) - - self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) - - def test_flan_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: - # TODO(joecummings): Download FLAN-T5 chkpts and test here - pass + self.decoder_padding_mask = torch.tensor( + [[False, False, False, True, True, True], [False, False, False, True, True, True]] + ) + + def test_t5_bundler_load_hf_ckpt_pretrained(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + local_path = f"{tmp}/{self.model}" + remote_bucket = f"{_TEXT_BUCKET}test_models" + + os.mkdir(local_path) + + for f in {"config.json", "pytorch_model.bin"}: + destination = f"{local_path}/{f}" + remote_path = f"{remote_bucket}/{self.model}/{f}" + _TEST_DOWNLOAD_MANAGER.get_local_path(url=remote_path, destination=destination) + + names = self.model.split("_") + is_encoder_only = names[-1] == "encoder" + + model = T5Bundle.build_model_from_huggingface_ckpt(local_path, encoder_only=is_encoder_only) + if is_encoder_only: + model(self.encoder_input_ids, encoder_padding_mask=self.encoder_padding_mask) + else: + model( + self.encoder_input_ids, + self.decoder_input_ids, + encoder_padding_mask=self.encoder_padding_mask, + decoder_padding_mask=self.decoder_padding_mask, + ) diff --git a/test/torchtext_unittest/asset/t5.flan.base.encoder.output.pt b/test/torchtext_unittest/asset/t5.flan.base.encoder.output.pt new file mode 100644 index 0000000000..75c8e3e4f3 Binary files /dev/null and b/test/torchtext_unittest/asset/t5.flan.base.encoder.output.pt differ diff --git a/test/torchtext_unittest/asset/t5.flan.base.generation.output.pt b/test/torchtext_unittest/asset/t5.flan.base.generation.output.pt new file mode 100644 index 0000000000..74126636a5 Binary files /dev/null and b/test/torchtext_unittest/asset/t5.flan.base.generation.output.pt differ diff --git a/test/torchtext_unittest/asset/t5.flan.base.model.output.pt b/test/torchtext_unittest/asset/t5.flan.base.model.output.pt new file mode 100644 index 0000000000..bd2a5aa0b9 Binary files /dev/null and b/test/torchtext_unittest/asset/t5.flan.base.model.output.pt differ diff --git a/torchtext/_download_hooks.py b/torchtext/_download_hooks.py index 505320efae..89baafafa5 100644 --- a/torchtext/_download_hooks.py +++ b/torchtext/_download_hooks.py @@ -59,3 +59,4 @@ def get_local_path(self, url, destination): _DATASET_DOWNLOAD_MANAGER = DownloadManager() +_TEST_DOWNLOAD_MANAGER = DownloadManager() diff --git a/torchtext/models/t5/__init__.py b/torchtext/models/t5/__init__.py index 5f7b4a275a..9340649dfe 100644 --- a/torchtext/models/t5/__init__.py +++ b/torchtext/models/t5/__init__.py @@ -1,7 +1,4 @@ from .bundler import ( - FLAN_T5_SMALL_ENCODER, - FLAN_T5_SMALL, - FLAN_T5_SMALL_GENERATION, FLAN_T5_BASE_ENCODER, FLAN_T5_BASE, FLAN_T5_BASE_GENERATION, @@ -53,9 +50,6 @@ "T5_11B_ENCODER", "T5_11B", "T5_11B_GENERATION", - "FLAN_T5_SMALL_ENCODER", - "FLAN_T5_SMALL", - "FLAN_T5_SMALL_GENERATION", "FLAN_T5_BASE_ENCODER", "FLAN_T5_BASE", "FLAN_T5_BASE_GENERATION", diff --git a/torchtext/models/t5/bundler.py b/torchtext/models/t5/bundler.py index bf140191a3..f20c13d47b 100644 --- a/torchtext/models/t5/bundler.py +++ b/torchtext/models/t5/bundler.py @@ -155,6 +155,7 @@ def build_model_from_huggingface_ckpt( """Build T5Model model from a HuggingFace checkpoint. Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX. + This also requires a fully saved model, sharded checkpoints are not supported. Args: ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local. @@ -238,12 +239,12 @@ def build_model_from_huggingface_ckpt( for i in range(config.num_decoder_layers): if config.is_gated_act: - t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[ - f"decoder.block.{i}.layer.1.DenseReluDense.wi_0.weight" + t5_model_state_dict[f"decoder.layers.{i}.linear1_0.weight"] = hf_weights[ + f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight" ] - t5_model_state_dict[f"encoder.layers.{i}.linear1_1.weight"] = hf_weights[ - f"decoder.block.{i}.layer.1.DenseReluDense.wi_1.weight" + t5_model_state_dict[f"decoder.layers.{i}.linear1_1.weight"] = hf_weights[ + f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight" ] else: t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[ @@ -650,56 +651,6 @@ def t5_transform() -> T5Transform: T5_11B_GENERATION.__doc__ = GENERATION_DOC.format("11B", "11B") - -FLAN_T5_SMALL_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.flan.small.encoder.pt"), - _config=T5Conf( - encoder_only=True, - embedding_dim=512, - num_attention_heads=6, - num_encoder_layers=8, - num_decoder_layers=8, - ffn_dimension=1024, - feed_forward_proj="gated-gelu", - ), - transform=t5_transform, -) - -FLAN_T5_SMALL_ENCODER.__doc__ = FLAN_ENCODER_DOC.format("SMALL", "SMALL") - -FLAN_T5_SMALL = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.flan.small.pt"), - _config=T5Conf( - encoder_only=False, - embedding_dim=512, - num_attention_heads=6, - num_encoder_layers=8, - num_decoder_layers=8, - ffn_dimension=1024, - feed_forward_proj="gated-gelu", - ), - transform=t5_transform, -) - -FLAN_T5_SMALL.__doc__ = FLAN_DOC.format("SMALL", "SMALL") - -FLAN_T5_SMALL_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.flan.small.generation.pt"), - _config=T5Conf( - encoder_only=False, - linear_head=True, - embedding_dim=512, - num_attention_heads=6, - num_encoder_layers=8, - num_decoder_layers=8, - ffn_dimension=1024, - feed_forward_proj="gated-gelu", - ), - transform=t5_transform, -) - -FLAN_T5_SMALL_GENERATION.__doc__ = FLAN_GENERATION_DOC.format("SMALL", "SMALL") - FLAN_T5_BASE_ENCODER = T5Bundle( _path=urljoin(_TEXT_BUCKET, "t5.flan.base.encoder.pt"), _config=T5Conf(encoder_only=True, ffn_dimension=2048, feed_forward_proj="gated-gelu"), @@ -762,7 +713,7 @@ def t5_transform() -> T5Transform: FLAN_T5_LARGE_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.flan.large.encoder.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.flan.large.generation.pt"), _config=T5Conf( encoder_only=False, linear_head=True,