|
| 1 | +import os.path |
| 2 | + |
| 3 | +from typing import Union, Tuple |
| 4 | + |
| 5 | +from torchtext._internal.module_utils import is_module_available |
| 6 | + |
| 7 | +if is_module_available("torchdata"): |
| 8 | + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper |
| 9 | + |
| 10 | +from torchtext.data.datasets_utils import ( |
| 11 | + _create_dataset_directory, |
| 12 | + _wrap_split_argument |
| 13 | +) |
| 14 | + |
| 15 | +URL = "http://data.statmt.org/cc-100/%s.txt.xz" |
| 16 | + |
| 17 | +VALID_CODES = { |
| 18 | + "am", "ar", "as", "az", "be", "bg", "bn", "bn_rom", "br", "bs", "ca", "cs", "cy", "da", "de", |
| 19 | + "el", "en", "eo", "es", "et", "eu", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gn", "gu", |
| 20 | + "ha", "he", "hi", "hi_rom", "hr", "ht", "hu", "hy", "id", "ig", "is", "it", "ja", "jv", "ka", |
| 21 | + "kk", "km", "kn", "ko", "ku", "ky", "la", "lg", "li", "ln", "lo", "lt", "lv", "mg", "mk", "ml", |
| 22 | + "mn", "mr", "ms", "my", "my_zaw", "ne", "nl", "no", "ns", "om", "or", "pa", "pl", "ps", "pt", |
| 23 | + "qu", "rm", "ro", "ru", "sa", "si", "sc", "sd", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", |
| 24 | + "sw", "ta", "ta_rom", "te", "te_rom", "th", "tl", "tn", "tr", "ug", "uk", "ur", "ur_rom", "uz", |
| 25 | + "vi", "wo", "xh", "yi", "yo", "zh-Hans", "zh-Hant", "zu", |
| 26 | +} |
| 27 | + |
| 28 | +NUM_LINES = None |
| 29 | +MD5 = None |
| 30 | + |
| 31 | +DATASET_NAME = "CC100" |
| 32 | + |
| 33 | + |
| 34 | +@_create_dataset_directory(dataset_name=DATASET_NAME) |
| 35 | +@_wrap_split_argument(("train",)) |
| 36 | +def CC100(root: str, split: Union[Tuple[str], str], language_code: str = "en"): |
| 37 | + if language_code not in VALID_CODES: |
| 38 | + raise ValueError(f"Invalid language code {language_code}") |
| 39 | + |
| 40 | + url = URL % language_code |
| 41 | + url_dp = IterableWrapper([url]) |
| 42 | + cache_compressed_dp = url_dp.on_disk_cache( |
| 43 | + filepath_fn=lambda x: os.path.join(root, os.path.basename(url)) |
| 44 | + ) |
| 45 | + |
| 46 | + cache_compressed_dp = HttpReader(cache_compressed_dp) |
| 47 | + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) |
| 48 | + |
| 49 | + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( |
| 50 | + filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz")) |
| 51 | + ) |
| 52 | + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_xz() |
| 53 | + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") |
| 54 | + |
| 55 | + data_dp = FileOpener(cache_decompressed_dp, mode="r").readlines(return_path=False) |
| 56 | + return data_dp.map(lambda x: (language_code, x)) |
0 commit comments