|
| 1 | +import tempfile |
| 2 | + |
1 | 3 | import pytest # noqa: F401 |
2 | 4 | import torch |
3 | 5 | from parameterized import parameterized, parameterized_class |
|
14 | 16 | T5Conf, |
15 | 17 | T5Transform, |
16 | 18 | ) |
| 19 | +from torchtext.prototype.models.t5.bundler import T5Bundle |
17 | 20 | from torchtext.prototype.models.t5.wrapper import T5Wrapper |
18 | 21 | from torchtext_unittest.common.assets import get_asset_path |
19 | 22 | from torchtext_unittest.common.parameterized_utils import nested_params |
20 | 23 | from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase |
21 | | - |
| 24 | +from transformers import T5Model, T5EncoderModel, T5ForConditionalGeneration |
22 | 25 |
|
23 | 26 | BUNDLERS = { |
24 | 27 | "base_model": T5_BASE, |
@@ -135,3 +138,108 @@ def test_t5_wrapper_checkpoint(self, name) -> None: |
135 | 138 |
|
136 | 139 | output_text = model(test_text, beam_size, max_seq_len) |
137 | 140 | self.assertEqual(output_text, expected_text) |
| 141 | + |
| 142 | + |
| 143 | +class TestLoadFromHFCheckpoints(TorchtextTestCase): |
| 144 | + def setUp(self) -> None: |
| 145 | + super().setUp() |
| 146 | + self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) |
| 147 | + self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) |
| 148 | + self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) |
| 149 | + self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) |
| 150 | + |
| 151 | + def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: |
| 152 | + # check that encoder layers match |
| 153 | + for i in range(config.num_encoder_layers + 1): |
| 154 | + if i < config.num_encoder_layers: |
| 155 | + hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i] |
| 156 | + # self-attention scores |
| 157 | + assert torch.equal( |
| 158 | + our_output["encoder_sa_scores"][i], hf_output_sa |
| 159 | + ), f"Mismatched self-attention scores for encoder layer {i}" |
| 160 | + hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i] |
| 161 | + # encoder hidden states |
| 162 | + assert torch.equal( |
| 163 | + our_output["encoder_hidden_states"][i], hf_output_hs |
| 164 | + ), f"Mismatched hidden states for encoder layer {i}" |
| 165 | + |
| 166 | + if not encoder_only: |
| 167 | + # check that decoder layers match |
| 168 | + for i in range(config.num_decoder_layers + 1): |
| 169 | + if i < config.num_encoder_layers: |
| 170 | + # self-attention scores |
| 171 | + assert torch.equal( |
| 172 | + our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i] |
| 173 | + ), f"Mismatched self-attention scores for decoder layer {i}" |
| 174 | + # cross-attention scores |
| 175 | + assert torch.equal( |
| 176 | + our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i] |
| 177 | + ), f"Mismatched cross-attention scores for decoder layer {i}" |
| 178 | + # decoder hidden states |
| 179 | + assert torch.equal( |
| 180 | + our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i] |
| 181 | + ), f"Mismatched hidden states for decoder layer {i}" |
| 182 | + |
| 183 | + def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: |
| 184 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 185 | + model_path = f"{tmp_dir}/hf_t5_small_enc" |
| 186 | + |
| 187 | + t5_small_enc = T5EncoderModel.from_pretrained("t5-small") |
| 188 | + t5_small_enc.save_pretrained(model_path) |
| 189 | + |
| 190 | + our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
| 191 | + |
| 192 | + hf_output = t5_small_enc( |
| 193 | + input_ids=self.encoder_input_ids, |
| 194 | + attention_mask=self.encoder_padding_mask, |
| 195 | + output_hidden_states=True, |
| 196 | + output_attentions=True, |
| 197 | + ) |
| 198 | + |
| 199 | + our_output = our_encoder(self.encoder_input_ids) |
| 200 | + |
| 201 | + self.check_outputs_of_models(our_output, hf_output, our_encoder.config, encoder_only=True) |
| 202 | + |
| 203 | + def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: |
| 204 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 205 | + model_path = f"{tmp_dir}/hf_t5_small" |
| 206 | + |
| 207 | + t5_small = T5Model.from_pretrained("t5-small") |
| 208 | + t5_small.save_pretrained(model_path) |
| 209 | + |
| 210 | + our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
| 211 | + |
| 212 | + hf_output = t5_small( |
| 213 | + input_ids=self.encoder_input_ids, |
| 214 | + decoder_input_ids=self.decoder_input_ids, |
| 215 | + attention_mask=self.encoder_padding_mask, |
| 216 | + decoder_attention_mask=self.decoder_padding_mask, |
| 217 | + output_hidden_states=True, |
| 218 | + output_attentions=True, |
| 219 | + ) |
| 220 | + |
| 221 | + our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) |
| 222 | + |
| 223 | + self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False) |
| 224 | + |
| 225 | + def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None: |
| 226 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 227 | + model_path = f"{tmp_dir}/hf_t5_small_gen" |
| 228 | + |
| 229 | + t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") |
| 230 | + t5_small_gen.save_pretrained(model_path) |
| 231 | + |
| 232 | + our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
| 233 | + |
| 234 | + hf_output = t5_small_gen( |
| 235 | + input_ids=self.encoder_input_ids, |
| 236 | + decoder_input_ids=self.decoder_input_ids, |
| 237 | + attention_mask=self.encoder_padding_mask, |
| 238 | + decoder_attention_mask=self.decoder_padding_mask, |
| 239 | + output_hidden_states=True, |
| 240 | + output_attentions=True, |
| 241 | + ) |
| 242 | + |
| 243 | + our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) |
| 244 | + |
| 245 | + self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False) |
0 commit comments