From f883417c3bb43f77fcb3ea76aa3d6d2ef1a61fca Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 30 Sep 2022 17:12:05 -0400 Subject: [PATCH 1/7] Add ability to load HF checkpoints into T5 model --- .../prototype/test_models.py | 129 +++++++++++++++++- torchtext/prototype/models/t5/bundler.py | 125 +++++++++++++++++ 2 files changed, 253 insertions(+), 1 deletion(-) diff --git a/test/integration_tests/prototype/test_models.py b/test/integration_tests/prototype/test_models.py index 7743807031..3bbd32c886 100644 --- a/test/integration_tests/prototype/test_models.py +++ b/test/integration_tests/prototype/test_models.py @@ -1,3 +1,5 @@ +import tempfile + import pytest # noqa: F401 import torch from parameterized import parameterized, parameterized_class @@ -14,11 +16,12 @@ T5Conf, T5Transform, ) +from torchtext.prototype.models.t5.bundler import T5Bundle from torchtext.prototype.models.t5.wrapper import T5Wrapper from torchtext_unittest.common.assets import get_asset_path from torchtext_unittest.common.parameterized_utils import nested_params from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase - +from transformers import T5Model, T5EncoderModel, T5ForConditionalGeneration BUNDLERS = { "base_model": T5_BASE, @@ -135,3 +138,127 @@ def test_t5_wrapper_checkpoint(self, name) -> None: output_text = model(test_text, beam_size, max_seq_len) self.assertEqual(output_text, expected_text) + + +class TestLoadFromHFCheckpoints(TorchtextTestCase): + def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: + if encoder_only: + # check that encoder layers match + for i in range(config.num_encoder_layers + 1): + if i < config.num_encoder_layers: + # self-attention scores + assert torch.equal( + our_output["encoder_sa_scores"][i], hf_output.attentions[i] + ), f"Mismatched self-attention scores for encoder layer {i}" + # encoder hidden states + assert torch.equal( + our_output["encoder_hidden_states"][i], hf_output.hidden_states[i] + ), f"Mismatched hidden states for encoder layer {i}" + + else: + # check that encoder layers match + for i in range(config.num_encoder_layers + 1): + if i < config.num_encoder_layers: + # self-attention scores + assert torch.equal( + our_output["encoder_sa_scores"][i], hf_output.encoder_attentions[i] + ), f"Mismatched self-attention scores for encoder layer {i}" + # encoder hidden states + assert torch.equal( + our_output["encoder_hidden_states"][i], hf_output.encoder_hidden_states[i] + ), f"Mismatched hidden states for encoder layer {i}" + + # check that decoder layers match + for i in range(config.num_decoder_layers + 1): + if i < config.num_encoder_layers: + # self-attention scores + assert torch.equal( + our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i] + ), f"Mismatched self-attention scores for decoder layer {i}" + # cross-attention scores + assert torch.equal( + our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i] + ), f"Mismatched cross-attention scores for decoder layer {i}" + # decoder hidden states + assert torch.equal( + our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i] + ), f"Mismatched hidden states for decoder layer {i}" + + def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = f"{tmp_dir}/hf_t5_small_enc" + + t5_small_enc = T5EncoderModel.from_pretrained("t5-small") + t5_small_enc.save_pretrained(model_path) + + our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path) + + encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) + encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + + hf_output = t5_small_enc( + input_ids=encoder_input_ids, + attention_mask=encoder_padding_mask, + output_hidden_states=True, + output_attentions=True, + ) + + our_output = our_encoder(encoder_input_ids) + + self.check_outputs_of_models(our_output, hf_output, our_encoder.config, encoder_only=True) + + def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = f"{tmp_dir}/hf_t5_small" + + t5_small = T5Model.from_pretrained("t5-small") + t5_small.save_pretrained(model_path) + + encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) + encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + + decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) + decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) + + our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) + + hf_output = t5_small( + input_ids=encoder_input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=encoder_padding_mask, + decoder_attention_mask=decoder_padding_mask, + output_hidden_states=True, + output_attentions=True, + ) + + our_output = our_t5(encoder_input_ids, decoder_input_ids) + + self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False) + + def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = f"{tmp_dir}/hf_t5_small_gen" + + t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") + t5_small_gen.save_pretrained(model_path) + + encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) + encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + + decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) + decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) + + our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) + + hf_output = t5_small_gen( + input_ids=encoder_input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=encoder_padding_mask, + decoder_attention_mask=decoder_padding_mask, + output_hidden_states=True, + output_attentions=True, + ) + + our_output = our_t5(encoder_input_ids, decoder_input_ids) + + self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 76b0d9bba9..c69aa60ecf 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -1,4 +1,6 @@ +import json import logging +import os from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union from urllib.parse import urljoin @@ -133,6 +135,129 @@ def build_model( return model + @staticmethod + def build_model_from_huggingface_ckpt( + ckpt_path: Union[str, os.PathLike], + *, + freeze_model: bool = False, + strict: bool = True, + ) -> T5Model: + config_path = f"{ckpt_path}/config.json" + model_path = f"{ckpt_path}/pytorch_model.bin" + + assert os.path.exists(model_path), f"No PyTorch model found at {model_path}" + + with open(config_path, "r") as handle: + config_json = json.load(handle) + + hf_weights = torch.load(model_path) + + # TODO(joecummings): find better way to determine `encoder_only` and `linear_head` + config = T5Conf( + encoder_only="decoder.final_layer_norm.weight" not in hf_weights.keys(), + linear_head="lm_head.weight" in hf_weights.keys(), + embedding_dim=config_json["d_model"], + num_attention_heads=config_json["num_heads"], + num_encoder_layers=config_json["num_layers"], + num_decoder_layers=config_json["num_decoder_layers"], + ffn_dimension=config_json["d_ff"], + ) + + t5_model = T5Model(config, freeze_model) + + t5_model_state_dict = { + "token_embeddings.weight": hf_weights["shared.weight"], + "norm1.weight": hf_weights["encoder.final_layer_norm.weight"], + "encoder.layers.0.self_attn.relative_attention_bias.weight": hf_weights[ + "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" + ], + } + # Convert encoder layers + for i in range(config.num_encoder_layers): + t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[ + f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.linear2.weight"] = hf_weights[ + f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.norm1.weight"] = hf_weights[ + f"encoder.block.{i}.layer.0.layer_norm.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.norm2.weight"] = hf_weights[ + f"encoder.block.{i}.layer.1.layer_norm.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[ + f"encoder.block.{i}.layer.0.SelfAttention.o.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[ + f"encoder.block.{i}.layer.0.SelfAttention.q.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[ + f"encoder.block.{i}.layer.0.SelfAttention.k.weight" + ] + t5_model_state_dict[f"encoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[ + f"encoder.block.{i}.layer.0.SelfAttention.v.weight" + ] + + # Convert decoder layers if model is encoder-decoder + if not config.encoder_only: + t5_model_state_dict["norm2.weight"] = hf_weights["decoder.final_layer_norm.weight"] + t5_model_state_dict["decoder.layers.0.self_attn.relative_attention_bias.weight"] = hf_weights[ + "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" + ] + + for i in range(config.num_decoder_layers): + t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[ + f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.linear2.weight"] = hf_weights[ + f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.norm1.weight"] = hf_weights[ + f"decoder.block.{i}.layer.0.layer_norm.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.norm2.weight"] = hf_weights[ + f"decoder.block.{i}.layer.2.layer_norm.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.norm3.weight"] = hf_weights[ + f"decoder.block.{i}.layer.1.layer_norm.weight" + ] + + t5_model_state_dict[f"decoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[ + f"decoder.block.{i}.layer.0.SelfAttention.o.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[ + f"decoder.block.{i}.layer.0.SelfAttention.q.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[ + f"decoder.block.{i}.layer.0.SelfAttention.k.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[ + f"decoder.block.{i}.layer.0.SelfAttention.v.weight" + ] + + t5_model_state_dict[f"decoder.layers.{i}.cross_attn.out_proj.weight"] = hf_weights[ + f"decoder.block.{i}.layer.1.EncDecAttention.o.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.cross_attn.q_proj_weight"] = hf_weights[ + f"decoder.block.{i}.layer.1.EncDecAttention.q.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.cross_attn.k_proj_weight"] = hf_weights[ + f"decoder.block.{i}.layer.1.EncDecAttention.k.weight" + ] + t5_model_state_dict[f"decoder.layers.{i}.cross_attn.v_proj_weight"] = hf_weights[ + f"decoder.block.{i}.layer.1.EncDecAttention.v.weight" + ] + + # Convert language modeling head if there is one + if config.linear_head: + t5_model_state_dict["lm_head.weight"] = hf_weights["lm_head.weight"] + + # Load state dict into our model + t5_model.load_state_dict(t5_model_state_dict, strict) + + return t5_model + @property def config(self) -> T5Conf: return self._config From 0713bfa10ad8830d93f673562869afe2395491fa Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 30 Sep 2022 17:52:08 -0400 Subject: [PATCH 2/7] Add HuggingFace to integrations tests --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 0c005a12a1..61e802d848 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -24,7 +24,7 @@ jobs: run: | python -m pip install --quiet --upgrade pip python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest + python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest transformers python setup.py install - name: Run integration test run: | From 4bc64ed47b78b8ec17d14405c14e0b8efe8103ad Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 30 Sep 2022 18:02:03 -0400 Subject: [PATCH 3/7] Remove duplicate code --- .../prototype/test_models.py | 84 +++++++------------ 1 file changed, 31 insertions(+), 53 deletions(-) diff --git a/test/integration_tests/prototype/test_models.py b/test/integration_tests/prototype/test_models.py index 3bbd32c886..330863d4d8 100644 --- a/test/integration_tests/prototype/test_models.py +++ b/test/integration_tests/prototype/test_models.py @@ -141,33 +141,26 @@ def test_t5_wrapper_checkpoint(self, name) -> None: class TestLoadFromHFCheckpoints(TorchtextTestCase): - def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: - if encoder_only: - # check that encoder layers match - for i in range(config.num_encoder_layers + 1): - if i < config.num_encoder_layers: - # self-attention scores - assert torch.equal( - our_output["encoder_sa_scores"][i], hf_output.attentions[i] - ), f"Mismatched self-attention scores for encoder layer {i}" - # encoder hidden states - assert torch.equal( - our_output["encoder_hidden_states"][i], hf_output.hidden_states[i] - ), f"Mismatched hidden states for encoder layer {i}" + def setUp(self) -> None: + self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) + self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) + self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) - else: - # check that encoder layers match - for i in range(config.num_encoder_layers + 1): - if i < config.num_encoder_layers: - # self-attention scores - assert torch.equal( - our_output["encoder_sa_scores"][i], hf_output.encoder_attentions[i] - ), f"Mismatched self-attention scores for encoder layer {i}" - # encoder hidden states + def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: + # check that encoder layers match + for i in range(config.num_encoder_layers + 1): + if i < config.num_encoder_layers: + # self-attention scores assert torch.equal( - our_output["encoder_hidden_states"][i], hf_output.encoder_hidden_states[i] - ), f"Mismatched hidden states for encoder layer {i}" - + our_output["encoder_sa_scores"][i], hf_output.encoder_attentions[i] + ), f"Mismatched self-attention scores for encoder layer {i}" + # encoder hidden states + assert torch.equal( + our_output["encoder_hidden_states"][i], hf_output.encoder_hidden_states[i] + ), f"Mismatched hidden states for encoder layer {i}" + + if not encoder_only: # check that decoder layers match for i in range(config.num_decoder_layers + 1): if i < config.num_encoder_layers: @@ -193,17 +186,14 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path) - encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) - encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) - hf_output = t5_small_enc( - input_ids=encoder_input_ids, - attention_mask=encoder_padding_mask, + input_ids=self.encoder_input_ids, + attention_mask=self.encoder_padding_mask, output_hidden_states=True, output_attentions=True, ) - our_output = our_encoder(encoder_input_ids) + our_output = our_encoder(self.encoder_input_ids) self.check_outputs_of_models(our_output, hf_output, our_encoder.config, encoder_only=True) @@ -214,24 +204,18 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: t5_small = T5Model.from_pretrained("t5-small") t5_small.save_pretrained(model_path) - encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) - encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) - - decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) - decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) hf_output = t5_small( - input_ids=encoder_input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=encoder_padding_mask, - decoder_attention_mask=decoder_padding_mask, + input_ids=self.encoder_input_ids, + decoder_input_ids=self.decoder_input_ids, + attention_mask=self.encoder_padding_mask, + decoder_attention_mask=self.decoder_padding_mask, output_hidden_states=True, output_attentions=True, ) - our_output = our_t5(encoder_input_ids, decoder_input_ids) + our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False) @@ -242,23 +226,17 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> No t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") t5_small_gen.save_pretrained(model_path) - encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) - encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) - - decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) - decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) hf_output = t5_small_gen( - input_ids=encoder_input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=encoder_padding_mask, - decoder_attention_mask=decoder_padding_mask, + input_ids=self.encoder_input_ids, + decoder_input_ids=self.decoder_input_ids, + attention_mask=self.encoder_padding_mask, + decoder_attention_mask=self.decoder_padding_mask, output_hidden_states=True, output_attentions=True, ) - our_output = our_t5(encoder_input_ids, decoder_input_ids) + our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False) From e4ba94dc2fc187b4c2ce6de5bfcde2dc8b817125 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 30 Sep 2022 18:25:29 -0400 Subject: [PATCH 4/7] Revert fix --- test/integration_tests/prototype/test_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/integration_tests/prototype/test_models.py b/test/integration_tests/prototype/test_models.py index 330863d4d8..e0d2d4261b 100644 --- a/test/integration_tests/prototype/test_models.py +++ b/test/integration_tests/prototype/test_models.py @@ -151,13 +151,15 @@ def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) - # check that encoder layers match for i in range(config.num_encoder_layers + 1): if i < config.num_encoder_layers: + hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i] # self-attention scores assert torch.equal( - our_output["encoder_sa_scores"][i], hf_output.encoder_attentions[i] + our_output["encoder_sa_scores"][i], hf_output_sa ), f"Mismatched self-attention scores for encoder layer {i}" + hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i] # encoder hidden states assert torch.equal( - our_output["encoder_hidden_states"][i], hf_output.encoder_hidden_states[i] + our_output["encoder_hidden_states"][i], hf_output_hs ), f"Mismatched hidden states for encoder layer {i}" if not encoder_only: From 36108b9ac5db80d3fa10e3c712e352b12deba220 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 30 Sep 2022 18:41:35 -0400 Subject: [PATCH 5/7] Add setup --- test/integration_tests/prototype/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/integration_tests/prototype/test_models.py b/test/integration_tests/prototype/test_models.py index e0d2d4261b..3a76f52cd6 100644 --- a/test/integration_tests/prototype/test_models.py +++ b/test/integration_tests/prototype/test_models.py @@ -142,6 +142,7 @@ def test_t5_wrapper_checkpoint(self, name) -> None: class TestLoadFromHFCheckpoints(TorchtextTestCase): def setUp(self) -> None: + super().setUp() self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]]) From a624031648ed74613b2261dfc9bf2ceb37a730e3 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Mon, 3 Oct 2022 14:13:02 -0400 Subject: [PATCH 6/7] Remove ability to download from remote URL --- torchtext/prototype/models/t5/bundler.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index c69aa60ecf..23d611c5a8 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -142,14 +142,25 @@ def build_model_from_huggingface_ckpt( freeze_model: bool = False, strict: bool = True, ) -> T5Model: + """Build T5Model model from a HuggingFace checkpoint. + + Note: Only works with Huggingface models saved in the PyTorch format. Will not work \ + with TensorFlow or JAX. + + Args: + ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file \ + is local. + freeze_model (bool): Freeze the model upon loading. (Default: `False`) + strict (bool): Load model in strict mode. (Default: `True`) + + Returns: + T5Model loaded with the weights of the HuggingFace checkpoint provided + """ config_path = f"{ckpt_path}/config.json" model_path = f"{ckpt_path}/pytorch_model.bin" - assert os.path.exists(model_path), f"No PyTorch model found at {model_path}" - with open(config_path, "r") as handle: config_json = json.load(handle) - hf_weights = torch.load(model_path) # TODO(joecummings): find better way to determine `encoder_only` and `linear_head` From 1cce3f077710a3d18aceee86ae9ab158aa1d094a Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 4 Oct 2022 18:54:17 -0400 Subject: [PATCH 7/7] Remove line break from docstring --- torchtext/prototype/models/t5/bundler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 23d611c5a8..65c94dc63e 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -144,12 +144,10 @@ def build_model_from_huggingface_ckpt( ) -> T5Model: """Build T5Model model from a HuggingFace checkpoint. - Note: Only works with Huggingface models saved in the PyTorch format. Will not work \ - with TensorFlow or JAX. + Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX. Args: - ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file \ - is local. + ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local. freeze_model (bool): Freeze the model upon loading. (Default: `False`) strict (bool): Load model in strict mode. (Default: `True`)