Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c60704a

Browse files
committed
Update base for Update on "create T5MultiheadAttention module"
# Description Add T5 architecture to torchtext # Process The T5 architecture is very similar to the architecture of a traditional transformer. The main differences are that rather than using positional embeddings, it computes a relative attention bias that encodes the relative position of a token within a sequence. This position bias is then passed into each layer and used to compute the attention scores. T5 also uses a simplified layer normalization (root mean square normalization) which occurs at the start of every attention and feed-forward block. Incorporating relative attention bias requires under the hood changes to the MultiHeadAttention module. We can use HF's implementation for computing relative attention bias and modify the source code for torch.nn.MultiHeadAttention to incorporate relative attention bias. We can also create our own layer normalization, similarly to HF. Given the above components, we can then define our own T5Layer, T5Stack, and T5Model. * The T5Layer can be used either as an encoder layer or decoder layer based on an input boolean parameter. The only difference between the decoder layer versus the encoder layer is that the decoder layer also performs cross-attention with the encoder output. * T5Stack can also be used as either an encoder or decoder based on an input boolean parameter. This dictates which type of layer composes the stack. * T5Model can be used either as an encoder-only or encoder-decoder model based on an input boolean parameter. If it is an encoder-decoder model, a causal mask is generated for the decoder input tokens. # Testing not yet implemented # Stack WIP PR where implementation details were discussed: #1812 [ghstack-poisoned]
2 parents f7c8046 + e964051 commit c60704a

File tree

11 files changed

+69
-41
lines changed

11 files changed

+69
-41
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ binary_common: &binary_common
5757
BUILD_VERSION: << parameters.build_version >>
5858
PYTORCH_VERSION: << parameters.pytorch_version >>
5959
CU_VERSION: cpu
60+
MACOSX_DEPLOYMENT_TARGET: 10.9
6061

6162
smoke_test_common: &smoke_test_common
6263
<<: *binary_common

.circleci/unittest/linux/scripts/environment.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ dependencies:
1313
- pytest-pythonpath
1414
- sacremoses
1515
- spacy
16-
- sphinx
17-
- sphinx-rtd-theme
1816
- tqdm
1917
- expecttest
2018
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0

.circleci/unittest/windows/scripts/environment.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@ dependencies:
1414
- pytest-pythonpath
1515
- sacremoses
1616
- spacy
17-
- sphinx
18-
- sphinx-rtd-theme
1917
- tqdm
2018
- certifi
21-
- future
2219
- expecttest
2320
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
2421
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0

benchmark/benchmark_torcharrow_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys, os
22

33
import torcharrow as ta
4+
import torcharrow.pytorch as tap
45
import torchtext.transforms as T
56
from benchmark.utils import Timer
67
from torcharrow import functional as ta_F
@@ -26,6 +27,7 @@ def run_torchtext_ops():
2627
add_eos_str = T.AddToken(token="<eros>", begin=False)
2728
add_bos_int = T.AddToken(token=0, begin=True)
2829
add_eos_int = T.AddToken(token=-1, begin=False)
30+
convert_to_tensor = T.ToTensor(padding_value=1)
2931

3032
# dataset
3133
train_dp = SST2(split="train")
@@ -45,6 +47,9 @@ def run_torchtext_ops():
4547
add_bos_int(token_ids)
4648
add_eos_int(token_ids)
4749

50+
with Timer("Running torchtext's to tensor conversion"):
51+
convert_to_tensor(token_ids)
52+
4853

4954
def run_torcharrow_ops():
5055
# tokenizer converting text into tokens
@@ -56,7 +61,8 @@ def run_torcharrow_ops():
5661
# dataset
5762
train_dp = SST2(split="train")
5863
text_list = list(train_dp.map(lambda x: x[0]))
59-
data_frame = ta.dataframe({"text": text_list})
64+
with Timer("Converting python data to TA data frame"):
65+
data_frame = ta.dataframe({"text": text_list})
6066

6167
with Timer("Running torcharrow's GPT2BPE tokenizer"):
6268
data_frame["tokens"] = ta_F.bpe_tokenize(tokenizer, data_frame["text"])
@@ -72,6 +78,9 @@ def run_torcharrow_ops():
7278
ta_F.add_tokens(data_frame["token_ids"], [0], begin=True)
7379
ta_F.add_tokens(data_frame["token_ids"], [-1], begin=False)
7480

81+
with Timer("Running torcharrow's to tensor conversion"):
82+
data_frame.to_tensor({"token_ids": tap.PadSequence(padding_value=1)})
83+
7584

7685
if __name__ == "__main__":
7786
run_torchtext_ops()

packaging/torchtext/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ build:
2929
string: py{{py}}
3030
script_env:
3131
- BUILD_VERSION
32+
- MACOSX_DEPLOYMENT_TARGET
3233

3334
test:
3435
imports:

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ git+https://github.com/jekbradbury/revtok.git
1212

1313
# Documentation
1414
Sphinx
15-
sphinx_rtd_theme
1615

1716
# Required for tests only:
1817

test/datasets/test_cnndm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _get_mock_dataset(root_dir):
4141
stories.append((txt_file, dataset_line))
4242
seed += 2
4343

44-
# append stories to correct dataset split, must be in legixographic order of filenames per dataset
44+
# append stories to correct dataset split, must be in lexicographic order of filenames per dataset
4545
stories.sort(key=lambda x: x[0])
4646
mocked_data[split] += [t[1] for t in stories]
4747

@@ -70,15 +70,14 @@ def tearDownClass(cls):
7070
cls.patcher.stop()
7171
super().tearDownClass()
7272

73-
def _mock_split_list(split):
73+
def _mock_split_list(source, split):
7474
story_fnames = []
75-
for source in ["cnn", "dailymail"]:
76-
for i in range(5):
77-
url = "_".join([source, split, str(i)])
78-
h = hashlib.sha1()
79-
h.update(url.encode())
80-
filename = h.hexdigest() + ".story"
81-
story_fnames.append(filename)
75+
for i in range(5):
76+
url = "_".join([source, split, str(i)])
77+
h = hashlib.sha1()
78+
h.update(url.encode())
79+
filename = h.hexdigest() + ".story"
80+
story_fnames.append(filename)
8281

8382
return story_fnames
8483

@@ -92,6 +91,7 @@ def test_cnndm(self, split):
9291
self.assertEqual(sample, expected_sample)
9392

9493
@parameterized.expand(["train", "val", "test"])
94+
@patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list)
9595
def test_cnndm_split_argument(self, split):
9696
dataset1 = CNNDM(root=self.root_dir, split=split)
9797
(dataset2,) = CNNDM(root=self.root_dir, split=(split,))

torchtext/datasets/cnndm.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
import os
3+
from collections import defaultdict
34
from functools import partial
45
from typing import Union, Tuple
56

@@ -20,9 +21,12 @@
2021
DATASET_NAME = "CNNDM"
2122

2223
URL_LIST = {
23-
"train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt",
24-
"val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt",
25-
"test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt",
24+
"cnn_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt",
25+
"cnn_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt",
26+
"cnn_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt",
27+
"dailymail_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt",
28+
"dailymail_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt",
29+
"dailymail_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt",
2630
}
2731

2832
STORIES_LIST = {
@@ -39,24 +43,34 @@
3943

4044
_EXTRACTED_FOLDERS = {
4145
"cnn": os.path.join("cnn", "stories"),
42-
"daily_mail": os.path.join("dailymail", "stories"),
46+
"dailymail": os.path.join("dailymail", "stories"),
4347
}
4448

49+
story_fnames = defaultdict(set)
50+
4551

4652
def _filepath_fn(root: str, source: str, _=None):
4753
return os.path.join(root, PATH_LIST[source])
4854

4955

50-
# this function will be used to cache the contents of the tar file
51-
def _extracted_filepath_fn(root: str, source: str):
52-
return os.path.join(root, _EXTRACTED_FOLDERS[source])
56+
# called once per tar file, therefore no duplicate processing
57+
def _extracted_folder_fn(root: str, source: str, split: str, _=None):
58+
global story_fnames
59+
key = source + "_" + split
60+
story_fnames[key] = set(_get_split_list(source, split))
61+
filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]]
62+
return filepaths
63+
5364

65+
def _extracted_filepath_fn(root: str, source: str, x: str):
66+
return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x))
5467

55-
def _filter_fn(story_fnames, x):
56-
return os.path.basename(x[0]) in story_fnames
5768

69+
def _filter_fn(source: str, split: str, x: tuple):
70+
return os.path.basename(x[0]) in story_fnames[source + "_" + split]
5871

59-
def _hash_urls(s):
72+
73+
def _hash_urls(s: tuple):
6074
"""
6175
Returns story filename as a heximal formated SHA1 hash of the input url string.
6276
Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py
@@ -69,23 +83,32 @@ def _hash_urls(s):
6983
return story_fname
7084

7185

72-
def _get_split_list(split: str):
73-
url_dp = IterableWrapper([URL_LIST[split]])
86+
def _get_split_list(source: str, split: str):
87+
url_dp = IterableWrapper([URL_LIST[source + "_" + split]])
7488
online_dp = OnlineReader(url_dp)
7589
return online_dp.readlines().map(fn=_hash_urls)
7690

7791

78-
def _load_stories(root: str, source: str):
92+
def _load_stories(root: str, source: str, split: str):
7993
story_dp = IterableWrapper([STORIES_LIST[source]])
8094
cache_compressed_dp = story_dp.on_disk_cache(
8195
filepath_fn=partial(_filepath_fn, root, source),
8296
hash_dict={_filepath_fn(root, source): STORIES_MD5[source]},
8397
hash_type="md5",
8498
)
8599
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
86-
# TODO: cache the contents of the extracted tar file
87-
cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar()
88-
return cache_decompressed_dp
100+
101+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
102+
filepath_fn=partial(_extracted_folder_fn, root, source, split)
103+
)
104+
cache_decompressed_dp = (
105+
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split))
106+
)
107+
cache_decompressed_dp = cache_decompressed_dp.end_caching(
108+
mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)
109+
)
110+
data_dp = FileOpener(cache_decompressed_dp, mode="b")
111+
return data_dp
89112

90113

91114
@_create_dataset_directory(dataset_name=DATASET_NAME)
@@ -119,11 +142,7 @@ def CNNDM(root: str, split: Union[Tuple[str], str]):
119142
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
120143
)
121144

122-
cnn_dp = _load_stories(root, "cnn")
123-
dailymail_dp = _load_stories(root, "dailymail")
145+
cnn_dp = _load_stories(root, "cnn", split)
146+
dailymail_dp = _load_stories(root, "dailymail", split)
124147
data_dp = cnn_dp.concat(dailymail_dp)
125-
# TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn
126-
# of the on_disk_cache_dp which caches the files extracted from the tar
127-
story_fnames = set(_get_split_list(split))
128-
data_dp = data_dp.filter(partial(_filter_fn, story_fnames))
129148
return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter()

torchtext/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,9 @@ def __init__(
572572
self, vocab_path: str, do_lower_case: bool = True, strip_accents: Optional[bool] = None, return_tokens=False
573573
) -> None:
574574
super().__init__()
575-
self.bert_model = BERTEncoderPyBind(get_asset_local_path(vocab_path), do_lower_case, strip_accents)
575+
self.bert_model = BERTEncoderPyBind(
576+
get_asset_local_path(vocab_path, overwite=True), do_lower_case, strip_accents
577+
)
576578
self._return_tokens = return_tokens
577579
self._vocab_path = vocab_path
578580
self._do_lower_case = do_lower_case

0 commit comments

Comments
 (0)