22import io
33from torchtext .utils import download_from_url , extract_archive
44from torchtext .experimental .datasets .raw .common import RawTextIterableDataset
5- from torchtext .experimental .datasets .raw .common import check_default_set
6- from torchtext .experimental .datasets .raw .common import wrap_datasets
5+ from torchtext .experimental .datasets .raw .common import wrap_split_argument
6+ from torchtext .experimental .datasets .raw .common import add_docstring_header
77
88URLS = {
99 'WikiText2' :
1919}
2020
2121
22- def _setup_datasets (dataset_name , root , split_ , year , language , offset ):
23- if dataset_name == 'WMTNewsCrawl' :
24- split = check_default_set (split_ , ('train' ,), dataset_name )
25- else :
26- split = check_default_set (split_ , ('train' , 'test' , 'valid' ), dataset_name )
27-
22+ def _setup_datasets (dataset_name , root , split , year , language , offset ):
2823 if dataset_name == 'PennTreebank' :
2924 extracted_files = [download_from_url (URLS ['PennTreebank' ][key ],
3025 root = root , hash_value = MD5 ['PennTreebank' ][key ],
@@ -49,23 +44,13 @@ def _setup_datasets(dataset_name, root, split_, year, language, offset):
4944 datasets .append (RawTextIterableDataset (dataset_name ,
5045 NUM_LINES [dataset_name ][item ], iter (io .open (path [item ], encoding = "utf8" )), offset = offset ))
5146
52- return wrap_datasets ( tuple ( datasets ), split_ )
47+ return datasets
5348
5449
50+ @wrap_split_argument
51+ @add_docstring_header
5552def WikiText2 (root = '.data' , split = ('train' , 'valid' , 'test' ), offset = 0 ):
56- """ Defines WikiText2 datasets.
57-
58- Create language modeling dataset: WikiText2
59- Separately returns the train/test/valid set
60-
61- Args:
62- root: Directory where the datasets are saved. Default: ".data"
63- split: a string or tuple for the returned datasets. Default: ('train', 'valid, 'test')
64- By default, all the three datasets (train, test, valid) are generated. Users
65- could also choose any one or two of them, for example ('train', 'test') or
66- just a string 'train'.
67- offset: the number of the starting line. Default: 0
68-
53+ """
6954 Examples:
7055 >>> from torchtext.experimental.raw.datasets import WikiText2
7156 >>> train_dataset, valid_dataset, test_dataset = WikiText2()
@@ -76,19 +61,10 @@ def WikiText2(root='.data', split=('train', 'valid', 'test'), offset=0):
7661 return _setup_datasets ("WikiText2" , root , split , None , None , offset )
7762
7863
64+ @wrap_split_argument
65+ @add_docstring_header
7966def WikiText103 (root = '.data' , split = ('train' , 'valid' , 'test' ), offset = 0 ):
80- """ Defines WikiText103 datasets.
81-
82- Create language modeling dataset: WikiText103
83- Separately returns the train/test/valid set
84-
85- Args:
86- root: Directory where the datasets are saved. Default: ".data"
87- split: the returned datasets. Default: ('train', 'valid','test')
88- By default, all the three datasets (train, test, valid) are generated. Users
89- could also choose any one or two of them, for example ('train', 'test').
90- offset: the number of the starting line. Default: 0
91-
67+ """
9268 Examples:
9369 >>> from torchtext.experimental.datasets.raw import WikiText103
9470 >>> train_dataset, valid_dataset, test_dataset = WikiText103()
@@ -98,21 +74,10 @@ def WikiText103(root='.data', split=('train', 'valid', 'test'), offset=0):
9874 return _setup_datasets ("WikiText103" , root , split , None , None , offset )
9975
10076
77+ @wrap_split_argument
78+ @add_docstring_header
10179def PennTreebank (root = '.data' , split = ('train' , 'valid' , 'test' ), offset = 0 ):
102- """ Defines PennTreebank datasets.
103-
104- Create language modeling dataset: PennTreebank
105- Separately returns the train/test/valid set
106-
107- Args:
108- root: Directory where the datasets are saved. Default: ".data"
109- split: a string or tuple for the returned datasets
110- (Default: ('train', 'test','valid'))
111- By default, all the three datasets ('train', 'valid', 'test') are generated. Users
112- could also choose any one or two of them, for example ('train', 'test') or
113- just a string 'train'.
114- offset: the number of the starting line. Default: 0
115-
80+ """
11681 Examples:
11782 >>> from torchtext.experimental.datasets.raw import PennTreebank
11883 >>> train_dataset, valid_dataset, test_dataset = PennTreebank()
@@ -123,18 +88,11 @@ def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
12388 return _setup_datasets ("PennTreebank" , root , split , None , None , offset )
12489
12590
126- def WMTNewsCrawl (root = '.data' , split = ('train' ), year = 2010 , language = 'en' , offset = 0 ):
127- """ Defines WMT News Crawl.
128-
129- Create language modeling dataset: WMTNewsCrawl
130-
131- Args:
132- root: Directory where the datasets are saved. Default: ".data"
133- split: a string or tuple for the returned datasets.
134- (Default: 'train')
135- year: the year of the dataset (Default: 2010)
91+ @wrap_split_argument
92+ @add_docstring_header
93+ def WMTNewsCrawl (root = '.data' , split = 'train' , offset = 0 , year = 2010 , language = 'en' ):
94+ """ year: the year of the dataset (Default: 2010)
13695 language: the language of the dataset (Default: 'en')
137- offset: the number of the starting line. Default: 0
13896
13997 Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
14098 """
@@ -148,12 +106,14 @@ def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en', offset
148106 'PennTreebank' : PennTreebank ,
149107 'WMTNewsCrawl' : WMTNewsCrawl
150108}
109+
151110NUM_LINES = {
152111 'WikiText2' : {'train' : 36718 , 'valid' : 3760 , 'test' : 4358 },
153112 'WikiText103' : {'train' : 1801350 , 'valid' : 3760 , 'test' : 4358 },
154113 'PennTreebank' : {'train' : 42068 , 'valid' : 3370 , 'test' : 3761 },
155114 'WMTNewsCrawl' : {'train' : 17676013 }
156115}
116+
157117MD5 = {
158118 'WikiText2' : '542ccefacc6c27f945fb54453812b3cd' ,
159119 'WikiText103' : '9ddaacaf6af0710eda8c456decff7832' ,
0 commit comments