|
| 1 | +import os |
1 | 2 | import tempfile |
2 | 3 |
|
3 | 4 | import pytest # noqa: F401 |
4 | 5 | import torch |
5 | 6 | from parameterized import parameterized_class |
6 | | -from torchtext.models import T5Bundle |
| 7 | +from torchtext import _TEXT_BUCKET |
| 8 | +from torchtext._download_hooks import _TEST_DOWNLOAD_MANAGER |
7 | 9 | from torchtext.models import ( |
| 10 | + FLAN_T5_BASE, |
| 11 | + FLAN_T5_BASE_ENCODER, |
| 12 | + FLAN_T5_BASE_GENERATION, |
8 | 13 | T5_BASE, |
9 | 14 | T5_BASE_ENCODER, |
10 | 15 | T5_BASE_GENERATION, |
|
14 | 19 | T5_SMALL, |
15 | 20 | T5_SMALL_ENCODER, |
16 | 21 | T5_SMALL_GENERATION, |
| 22 | + T5Bundle, |
17 | 23 | ) |
18 | 24 | from torchtext_unittest.common.assets import get_asset_path |
19 | 25 | from torchtext_unittest.common.parameterized_utils import nested_params |
20 | 26 | from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase |
21 | | -from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model |
22 | 27 |
|
23 | 28 | BUNDLERS = { |
24 | 29 | "base_model": T5_BASE, |
|
30 | 35 | "large_model": T5_LARGE, |
31 | 36 | "large_encoder": T5_LARGE_ENCODER, |
32 | 37 | "large_generation": T5_LARGE_GENERATION, |
| 38 | + "flan_base_encoder": FLAN_T5_BASE_ENCODER, |
| 39 | + "flan_base_model": FLAN_T5_BASE, |
| 40 | + "flan_base_generation": FLAN_T5_BASE_GENERATION, |
33 | 41 | } |
34 | 42 |
|
35 | 43 |
|
|
45 | 53 | ("large_model",), |
46 | 54 | ("large_encoder",), |
47 | 55 | ("large_generation",), |
| 56 | + ("flan_base_encoder",), |
| 57 | + ("flan_base_model",), |
| 58 | + ("flan_base_generation",), |
48 | 59 | ], |
49 | 60 | ) |
50 | 61 | class TestT5Model(TorchtextTestCase): |
@@ -74,126 +85,81 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): |
74 | 85 |
|
75 | 86 | def _t5_get_encoder(self, model, model_input, encoder_output): |
76 | 87 | encoder = model.get_encoder() |
77 | | - # Need to set the tgt_key_padding_mask to ensure the same results |
| 88 | + # Need to set the key_padding_mask to ensure the same results |
78 | 89 | encoder_padding_mask = model_input.eq(model.padding_idx) |
79 | 90 | output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"] |
80 | 91 | assert torch.all(output_from_get_encoder.eq(encoder_output)) |
81 | 92 |
|
82 | | - @nested_params(["jit", "not_jit"]) |
| 93 | + @nested_params(["not_jit", "jit"]) |
83 | 94 | def test_t5_model(self, name) -> None: |
84 | | - configuration, type = self.model_name.split("_") |
| 95 | + names = self.model_name.split("_") |
| 96 | + |
| 97 | + num_names = len(names) |
| 98 | + |
| 99 | + if num_names == 3: |
| 100 | + # Handled slightly differently for Flan-T5 model naming |
| 101 | + configuration = names[1] |
| 102 | + type = names[2] |
| 103 | + expected_asset_name = f"t5.flan.{configuration}.{type}.output.pt" |
| 104 | + t5_model = BUNDLERS["flan_" + configuration + "_" + type] |
| 105 | + elif num_names == 2: |
| 106 | + configuration = names[0] |
| 107 | + type = names[1] |
| 108 | + expected_asset_name = f"t5.{configuration}.{type}.output.pt" |
| 109 | + t5_model = BUNDLERS[configuration + "_" + type] |
| 110 | + else: |
| 111 | + raise RuntimeError(f"Unknown model name: {self.model_name}") |
85 | 112 |
|
86 | | - expected_asset_name = f"t5.{configuration}.{type}.output.pt" |
87 | 113 | test_text = ["Hello world", "Attention rocks!"] |
88 | 114 | is_jit = name == "jit" |
89 | | - t5_model = BUNDLERS[configuration + "_" + type] |
90 | 115 | self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text) |
91 | 116 |
|
92 | 117 |
|
| 118 | +@parameterized_class( |
| 119 | + ("model",), |
| 120 | + [ |
| 121 | + ("hf_t5_small_encoder",), |
| 122 | + ("hf_t5_small",), |
| 123 | + ("hf_t5_small_generation",), |
| 124 | + ("hf_flan_base_encoder",), |
| 125 | + ("hf_flan_base",), |
| 126 | + ("hf_flan_base_generation",), |
| 127 | + ], |
| 128 | +) |
93 | 129 | class TestLoadFromHFCheckpoints(TorchtextTestCase): |
94 | 130 | def setUp(self) -> None: |
95 | 131 | super().setUp() |
96 | 132 | self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) |
97 | | - self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) |
| 133 | + self.encoder_padding_mask = torch.tensor( |
| 134 | + [[False, False, False, False, False, False], [False, False, False, True, True, True]] |
| 135 | + ) |
98 | 136 | self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) |
99 | | - self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) |
100 | | - |
101 | | - def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: |
102 | | - # check that encoder layers match |
103 | | - for i in range(config.num_encoder_layers + 1): |
104 | | - if i < config.num_encoder_layers: |
105 | | - hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i] |
106 | | - # self-attention scores |
107 | | - assert torch.equal( |
108 | | - our_output["encoder_sa_scores"][i], hf_output_sa |
109 | | - ), f"Mismatched self-attention scores for encoder layer {i}" |
110 | | - hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i] |
111 | | - # encoder hidden states |
112 | | - assert torch.equal( |
113 | | - our_output["encoder_hidden_states"][i], hf_output_hs |
114 | | - ), f"Mismatched hidden states for encoder layer {i}" |
115 | | - |
116 | | - if not encoder_only: |
117 | | - # check that decoder layers match |
118 | | - for i in range(config.num_decoder_layers + 1): |
119 | | - if i < config.num_encoder_layers: |
120 | | - # self-attention scores |
121 | | - assert torch.equal( |
122 | | - our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i] |
123 | | - ), f"Mismatched self-attention scores for decoder layer {i}" |
124 | | - # cross-attention scores |
125 | | - assert torch.equal( |
126 | | - our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i] |
127 | | - ), f"Mismatched cross-attention scores for decoder layer {i}" |
128 | | - # decoder hidden states |
129 | | - assert torch.equal( |
130 | | - our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i] |
131 | | - ), f"Mismatched hidden states for decoder layer {i}" |
132 | | - |
133 | | - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: |
134 | | - with tempfile.TemporaryDirectory() as tmp_dir: |
135 | | - model_path = f"{tmp_dir}/hf_t5_small_enc" |
136 | | - |
137 | | - t5_small_enc = T5EncoderModel.from_pretrained("t5-small") |
138 | | - t5_small_enc.save_pretrained(model_path) |
139 | | - |
140 | | - our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=True) |
141 | | - |
142 | | - hf_output = t5_small_enc( |
143 | | - input_ids=self.encoder_input_ids, |
144 | | - attention_mask=self.encoder_padding_mask, |
145 | | - output_hidden_states=True, |
146 | | - output_attentions=True, |
147 | | - ) |
148 | | - |
149 | | - our_output = our_encoder(self.encoder_input_ids) |
150 | | - |
151 | | - self.check_outputs_of_models(our_output, hf_output, our_encoder.config, True) |
152 | | - |
153 | | - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: |
154 | | - with tempfile.TemporaryDirectory() as tmp_dir: |
155 | | - model_path = f"{tmp_dir}/hf_t5_small" |
156 | | - |
157 | | - t5_small = T5Model.from_pretrained("t5-small") |
158 | | - t5_small.save_pretrained(model_path) |
159 | | - |
160 | | - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
161 | | - |
162 | | - hf_output = t5_small( |
163 | | - input_ids=self.encoder_input_ids, |
164 | | - decoder_input_ids=self.decoder_input_ids, |
165 | | - attention_mask=self.encoder_padding_mask, |
166 | | - decoder_attention_mask=self.decoder_padding_mask, |
167 | | - output_hidden_states=True, |
168 | | - output_attentions=True, |
169 | | - ) |
170 | | - |
171 | | - our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) |
172 | | - |
173 | | - self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) |
174 | | - |
175 | | - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None: |
176 | | - with tempfile.TemporaryDirectory() as tmp_dir: |
177 | | - model_path = f"{tmp_dir}/hf_t5_small_gen" |
178 | | - |
179 | | - t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") |
180 | | - t5_small_gen.save_pretrained(model_path) |
181 | | - |
182 | | - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
183 | | - |
184 | | - hf_output = t5_small_gen( |
185 | | - input_ids=self.encoder_input_ids, |
186 | | - decoder_input_ids=self.decoder_input_ids, |
187 | | - attention_mask=self.encoder_padding_mask, |
188 | | - decoder_attention_mask=self.decoder_padding_mask, |
189 | | - output_hidden_states=True, |
190 | | - output_attentions=True, |
191 | | - ) |
192 | | - |
193 | | - our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) |
194 | | - |
195 | | - self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) |
196 | | - |
197 | | - def test_flan_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: |
198 | | - # TODO(joecummings): Download FLAN-T5 chkpts and test here |
199 | | - pass |
| 137 | + self.decoder_padding_mask = torch.tensor( |
| 138 | + [[False, False, False, True, True, True], [False, False, False, True, True, True]] |
| 139 | + ) |
| 140 | + |
| 141 | + def test_t5_bundler_load_hf_ckpt_pretrained(self) -> None: |
| 142 | + with tempfile.TemporaryDirectory() as tmp: |
| 143 | + local_path = f"{tmp}/{self.model}" |
| 144 | + remote_bucket = f"{_TEXT_BUCKET}test_models" |
| 145 | + |
| 146 | + os.mkdir(local_path) |
| 147 | + |
| 148 | + for f in {"config.json", "pytorch_model.bin"}: |
| 149 | + destination = f"{local_path}/{f}" |
| 150 | + remote_path = f"{remote_bucket}/{self.model}/{f}" |
| 151 | + _TEST_DOWNLOAD_MANAGER.get_local_path(url=remote_path, destination=destination) |
| 152 | + |
| 153 | + names = self.model.split("_") |
| 154 | + is_encoder_only = names[-1] == "encoder" |
| 155 | + |
| 156 | + model = T5Bundle.build_model_from_huggingface_ckpt(local_path, encoder_only=is_encoder_only) |
| 157 | + if is_encoder_only: |
| 158 | + model(self.encoder_input_ids, encoder_padding_mask=self.encoder_padding_mask) |
| 159 | + else: |
| 160 | + model( |
| 161 | + self.encoder_input_ids, |
| 162 | + self.decoder_input_ids, |
| 163 | + encoder_padding_mask=self.encoder_padding_mask, |
| 164 | + decoder_padding_mask=self.decoder_padding_mask, |
| 165 | + ) |
0 commit comments