-
Notifications
You must be signed in to change notification settings - Fork 814
Test newly uploaded Flan-T5 weights #2074
Changes from all commits
2b64d8b
776aee6
791dcb2
4e57a3d
91f5686
b85d228
dbeb4d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,7 +122,7 @@ The library currently consist of following pre-trained models: | |
| * `DistilRoBERTa <https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/README.md>`_ | ||
| * XLM-RoBERTa: `Base and Large Architure <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`_ | ||
| * T5: `Small, Base, Large, 3B, and 11B Architecture <https://github.com/google-research/text-to-text-transfer-transformer>`_ | ||
| * Flan-T5: `Small, Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_ | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't actually support Flan-T5 small due to non divisible # of attention heads. |
||
| * Flan-T5: `Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_ | ||
|
|
||
| Tokenizers | ||
| ========== | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,15 @@ | ||
| import os | ||
| import tempfile | ||
|
|
||
| import pytest # noqa: F401 | ||
| import torch | ||
| from parameterized import parameterized_class | ||
| from torchtext.models import T5Bundle | ||
| from torchtext import _TEXT_BUCKET | ||
| from torchtext._download_hooks import _TEST_DOWNLOAD_MANAGER | ||
| from torchtext.models import ( | ||
| FLAN_T5_BASE, | ||
| FLAN_T5_BASE_ENCODER, | ||
| FLAN_T5_BASE_GENERATION, | ||
| T5_BASE, | ||
| T5_BASE_ENCODER, | ||
| T5_BASE_GENERATION, | ||
|
|
@@ -14,11 +19,11 @@ | |
| T5_SMALL, | ||
| T5_SMALL_ENCODER, | ||
| T5_SMALL_GENERATION, | ||
| T5Bundle, | ||
| ) | ||
| from torchtext_unittest.common.assets import get_asset_path | ||
| from torchtext_unittest.common.parameterized_utils import nested_params | ||
| from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase | ||
| from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model | ||
|
|
||
| BUNDLERS = { | ||
| "base_model": T5_BASE, | ||
|
|
@@ -30,6 +35,9 @@ | |
| "large_model": T5_LARGE, | ||
| "large_encoder": T5_LARGE_ENCODER, | ||
| "large_generation": T5_LARGE_GENERATION, | ||
| "flan_base_encoder": FLAN_T5_BASE_ENCODER, | ||
| "flan_base_model": FLAN_T5_BASE, | ||
| "flan_base_generation": FLAN_T5_BASE_GENERATION, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -45,6 +53,9 @@ | |
| ("large_model",), | ||
| ("large_encoder",), | ||
| ("large_generation",), | ||
| ("flan_base_encoder",), | ||
| ("flan_base_model",), | ||
| ("flan_base_generation",), | ||
| ], | ||
| ) | ||
| class TestT5Model(TorchtextTestCase): | ||
|
|
@@ -74,126 +85,81 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): | |
|
|
||
| def _t5_get_encoder(self, model, model_input, encoder_output): | ||
| encoder = model.get_encoder() | ||
| # Need to set the tgt_key_padding_mask to ensure the same results | ||
| # Need to set the key_padding_mask to ensure the same results | ||
| encoder_padding_mask = model_input.eq(model.padding_idx) | ||
| output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"] | ||
| assert torch.all(output_from_get_encoder.eq(encoder_output)) | ||
|
|
||
| @nested_params(["jit", "not_jit"]) | ||
| @nested_params(["not_jit", "jit"]) | ||
| def test_t5_model(self, name) -> None: | ||
| configuration, type = self.model_name.split("_") | ||
| names = self.model_name.split("_") | ||
|
|
||
| num_names = len(names) | ||
|
|
||
| if num_names == 3: | ||
| # Handled slightly differently for Flan-T5 model naming | ||
| configuration = names[1] | ||
| type = names[2] | ||
| expected_asset_name = f"t5.flan.{configuration}.{type}.output.pt" | ||
| t5_model = BUNDLERS["flan_" + configuration + "_" + type] | ||
| elif num_names == 2: | ||
| configuration = names[0] | ||
| type = names[1] | ||
| expected_asset_name = f"t5.{configuration}.{type}.output.pt" | ||
| t5_model = BUNDLERS[configuration + "_" + type] | ||
| else: | ||
| raise RuntimeError(f"Unknown model name: {self.model_name}") | ||
|
|
||
| expected_asset_name = f"t5.{configuration}.{type}.output.pt" | ||
| test_text = ["Hello world", "Attention rocks!"] | ||
| is_jit = name == "jit" | ||
| t5_model = BUNDLERS[configuration + "_" + type] | ||
| self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text) | ||
|
|
||
|
|
||
| @parameterized_class( | ||
| ("model",), | ||
| [ | ||
| ("hf_t5_small_encoder",), | ||
| ("hf_t5_small",), | ||
| ("hf_t5_small_generation",), | ||
| ("hf_flan_base_encoder",), | ||
| ("hf_flan_base",), | ||
| ("hf_flan_base_generation",), | ||
| ], | ||
| ) | ||
| class TestLoadFromHFCheckpoints(TorchtextTestCase): | ||
| def setUp(self) -> None: | ||
| super().setUp() | ||
| self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) | ||
| self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) | ||
| self.encoder_padding_mask = torch.tensor( | ||
| [[False, False, False, False, False, False], [False, False, False, True, True, True]] | ||
| ) | ||
| self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) | ||
| self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) | ||
|
|
||
| def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: | ||
| # check that encoder layers match | ||
| for i in range(config.num_encoder_layers + 1): | ||
| if i < config.num_encoder_layers: | ||
| hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i] | ||
| # self-attention scores | ||
| assert torch.equal( | ||
| our_output["encoder_sa_scores"][i], hf_output_sa | ||
| ), f"Mismatched self-attention scores for encoder layer {i}" | ||
| hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i] | ||
| # encoder hidden states | ||
| assert torch.equal( | ||
| our_output["encoder_hidden_states"][i], hf_output_hs | ||
| ), f"Mismatched hidden states for encoder layer {i}" | ||
|
|
||
| if not encoder_only: | ||
| # check that decoder layers match | ||
| for i in range(config.num_decoder_layers + 1): | ||
| if i < config.num_encoder_layers: | ||
| # self-attention scores | ||
| assert torch.equal( | ||
| our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i] | ||
| ), f"Mismatched self-attention scores for decoder layer {i}" | ||
| # cross-attention scores | ||
| assert torch.equal( | ||
| our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i] | ||
| ), f"Mismatched cross-attention scores for decoder layer {i}" | ||
| # decoder hidden states | ||
| assert torch.equal( | ||
| our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i] | ||
| ), f"Mismatched hidden states for decoder layer {i}" | ||
|
|
||
| def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we get rid of all of these tests?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do the actual testing of the weights in the above tests. Here we just want to make sure the code can be loaded and ran. |
||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| model_path = f"{tmp_dir}/hf_t5_small_enc" | ||
|
|
||
| t5_small_enc = T5EncoderModel.from_pretrained("t5-small") | ||
| t5_small_enc.save_pretrained(model_path) | ||
|
|
||
| our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=True) | ||
|
|
||
| hf_output = t5_small_enc( | ||
| input_ids=self.encoder_input_ids, | ||
| attention_mask=self.encoder_padding_mask, | ||
| output_hidden_states=True, | ||
| output_attentions=True, | ||
| ) | ||
|
|
||
| our_output = our_encoder(self.encoder_input_ids) | ||
|
|
||
| self.check_outputs_of_models(our_output, hf_output, our_encoder.config, True) | ||
|
|
||
| def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: | ||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| model_path = f"{tmp_dir}/hf_t5_small" | ||
|
|
||
| t5_small = T5Model.from_pretrained("t5-small") | ||
| t5_small.save_pretrained(model_path) | ||
|
|
||
| our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) | ||
|
|
||
| hf_output = t5_small( | ||
| input_ids=self.encoder_input_ids, | ||
| decoder_input_ids=self.decoder_input_ids, | ||
| attention_mask=self.encoder_padding_mask, | ||
| decoder_attention_mask=self.decoder_padding_mask, | ||
| output_hidden_states=True, | ||
| output_attentions=True, | ||
| ) | ||
|
|
||
| our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) | ||
|
|
||
| self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) | ||
|
|
||
| def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None: | ||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| model_path = f"{tmp_dir}/hf_t5_small_gen" | ||
|
|
||
| t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") | ||
| t5_small_gen.save_pretrained(model_path) | ||
|
|
||
| our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) | ||
|
|
||
| hf_output = t5_small_gen( | ||
| input_ids=self.encoder_input_ids, | ||
| decoder_input_ids=self.decoder_input_ids, | ||
| attention_mask=self.encoder_padding_mask, | ||
| decoder_attention_mask=self.decoder_padding_mask, | ||
| output_hidden_states=True, | ||
| output_attentions=True, | ||
| ) | ||
|
|
||
| our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) | ||
|
|
||
| self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) | ||
|
|
||
| def test_flan_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: | ||
| # TODO(joecummings): Download FLAN-T5 chkpts and test here | ||
| pass | ||
| self.decoder_padding_mask = torch.tensor( | ||
| [[False, False, False, True, True, True], [False, False, False, True, True, True]] | ||
| ) | ||
|
|
||
| def test_t5_bundler_load_hf_ckpt_pretrained(self) -> None: | ||
| with tempfile.TemporaryDirectory() as tmp: | ||
| local_path = f"{tmp}/{self.model}" | ||
| remote_bucket = f"{_TEXT_BUCKET}test_models" | ||
|
|
||
| os.mkdir(local_path) | ||
|
|
||
| for f in {"config.json", "pytorch_model.bin"}: | ||
| destination = f"{local_path}/{f}" | ||
| remote_path = f"{remote_bucket}/{self.model}/{f}" | ||
| _TEST_DOWNLOAD_MANAGER.get_local_path(url=remote_path, destination=destination) | ||
|
|
||
| names = self.model.split("_") | ||
| is_encoder_only = names[-1] == "encoder" | ||
|
|
||
| model = T5Bundle.build_model_from_huggingface_ckpt(local_path, encoder_only=is_encoder_only) | ||
| if is_encoder_only: | ||
| model(self.encoder_input_ids, encoder_padding_mask=self.encoder_padding_mask) | ||
| else: | ||
| model( | ||
| self.encoder_input_ids, | ||
| self.decoder_input_ids, | ||
| encoder_padding_mask=self.encoder_padding_mask, | ||
| decoder_padding_mask=self.decoder_padding_mask, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,6 +155,7 @@ def build_model_from_huggingface_ckpt( | |
| """Build T5Model model from a HuggingFace checkpoint. | ||
|
|
||
| Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX. | ||
| This also requires a fully saved model, sharded checkpoints are not supported. | ||
|
|
||
| Args: | ||
| ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local. | ||
|
|
@@ -238,12 +239,12 @@ def build_model_from_huggingface_ckpt( | |
|
|
||
| for i in range(config.num_decoder_layers): | ||
| if config.is_gated_act: | ||
| t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[ | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a bug.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to confirm that the weights I generated before are correct. |
||
| f"decoder.block.{i}.layer.1.DenseReluDense.wi_0.weight" | ||
| t5_model_state_dict[f"decoder.layers.{i}.linear1_0.weight"] = hf_weights[ | ||
| f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight" | ||
| ] | ||
|
|
||
| t5_model_state_dict[f"encoder.layers.{i}.linear1_1.weight"] = hf_weights[ | ||
| f"decoder.block.{i}.layer.1.DenseReluDense.wi_1.weight" | ||
| t5_model_state_dict[f"decoder.layers.{i}.linear1_1.weight"] = hf_weights[ | ||
| f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight" | ||
| ] | ||
| else: | ||
| t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[ | ||
|
|
@@ -650,56 +651,6 @@ def t5_transform() -> T5Transform: | |
|
|
||
| T5_11B_GENERATION.__doc__ = GENERATION_DOC.format("11B", "11B") | ||
|
|
||
|
|
||
| FLAN_T5_SMALL_ENCODER = T5Bundle( | ||
| _path=urljoin(_TEXT_BUCKET, "t5.flan.small.encoder.pt"), | ||
| _config=T5Conf( | ||
| encoder_only=True, | ||
| embedding_dim=512, | ||
| num_attention_heads=6, | ||
| num_encoder_layers=8, | ||
| num_decoder_layers=8, | ||
| ffn_dimension=1024, | ||
| feed_forward_proj="gated-gelu", | ||
| ), | ||
| transform=t5_transform, | ||
| ) | ||
|
|
||
| FLAN_T5_SMALL_ENCODER.__doc__ = FLAN_ENCODER_DOC.format("SMALL", "SMALL") | ||
|
|
||
| FLAN_T5_SMALL = T5Bundle( | ||
| _path=urljoin(_TEXT_BUCKET, "t5.flan.small.pt"), | ||
| _config=T5Conf( | ||
| encoder_only=False, | ||
| embedding_dim=512, | ||
| num_attention_heads=6, | ||
| num_encoder_layers=8, | ||
| num_decoder_layers=8, | ||
| ffn_dimension=1024, | ||
| feed_forward_proj="gated-gelu", | ||
| ), | ||
| transform=t5_transform, | ||
| ) | ||
|
|
||
| FLAN_T5_SMALL.__doc__ = FLAN_DOC.format("SMALL", "SMALL") | ||
|
|
||
| FLAN_T5_SMALL_GENERATION = T5Bundle( | ||
| _path=urljoin(_TEXT_BUCKET, "t5.flan.small.generation.pt"), | ||
| _config=T5Conf( | ||
| encoder_only=False, | ||
| linear_head=True, | ||
| embedding_dim=512, | ||
| num_attention_heads=6, | ||
| num_encoder_layers=8, | ||
| num_decoder_layers=8, | ||
| ffn_dimension=1024, | ||
| feed_forward_proj="gated-gelu", | ||
| ), | ||
| transform=t5_transform, | ||
| ) | ||
|
|
||
| FLAN_T5_SMALL_GENERATION.__doc__ = FLAN_GENERATION_DOC.format("SMALL", "SMALL") | ||
|
|
||
| FLAN_T5_BASE_ENCODER = T5Bundle( | ||
| _path=urljoin(_TEXT_BUCKET, "t5.flan.base.encoder.pt"), | ||
| _config=T5Conf(encoder_only=True, ffn_dimension=2048, feed_forward_proj="gated-gelu"), | ||
|
|
@@ -762,7 +713,7 @@ def t5_transform() -> T5Transform: | |
|
|
||
|
|
||
| FLAN_T5_LARGE_GENERATION = T5Bundle( | ||
| _path=urljoin(_TEXT_BUCKET, "t5.flan.large.encoder.pt"), | ||
| _path=urljoin(_TEXT_BUCKET, "t5.flan.large.generation.pt"), | ||
| _config=T5Conf( | ||
| encoder_only=False, | ||
| linear_head=True, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer need transformers install for integration tests.