diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index 5fda4a8451..d7d33298ad 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -3,6 +3,7 @@ from .ag_news import AG_NEWS from .amazonreviewfull import AmazonReviewFull from .amazonreviewpolarity import AmazonReviewPolarity +from .cc100 import CC100 from .conll2000chunking import CoNLL2000Chunking from .dbpedia import DBpedia from .enwik9 import EnWik9 @@ -26,6 +27,7 @@ "AG_NEWS": AG_NEWS, "AmazonReviewFull": AmazonReviewFull, "AmazonReviewPolarity": AmazonReviewPolarity, + "CC100": CC100, "CoNLL2000Chunking": CoNLL2000Chunking, "DBpedia": DBpedia, "EnWik9": EnWik9, diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py new file mode 100644 index 0000000000..0e8f55e589 --- /dev/null +++ b/torchtext/datasets/cc100.py @@ -0,0 +1,56 @@ +import os.path + +from typing import Union, Tuple + +from torchtext._internal.module_utils import is_module_available + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + +from torchtext.data.datasets_utils import ( + _create_dataset_directory, + _wrap_split_argument +) + +URL = "http://data.statmt.org/cc-100/%s.txt.xz" + +VALID_CODES = { + "am", "ar", "as", "az", "be", "bg", "bn", "bn_rom", "br", "bs", "ca", "cs", "cy", "da", "de", + "el", "en", "eo", "es", "et", "eu", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gn", "gu", + "ha", "he", "hi", "hi_rom", "hr", "ht", "hu", "hy", "id", "ig", "is", "it", "ja", "jv", "ka", + "kk", "km", "kn", "ko", "ku", "ky", "la", "lg", "li", "ln", "lo", "lt", "lv", "mg", "mk", "ml", + "mn", "mr", "ms", "my", "my_zaw", "ne", "nl", "no", "ns", "om", "or", "pa", "pl", "ps", "pt", + "qu", "rm", "ro", "ru", "sa", "si", "sc", "sd", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", + "sw", "ta", "ta_rom", "te", "te_rom", "th", "tl", "tn", "tr", "ug", "uk", "ur", "ur_rom", "uz", + "vi", "wo", "xh", "yi", "yo", "zh-Hans", "zh-Hant", "zu", +} + +NUM_LINES = None +MD5 = None + +DATASET_NAME = "CC100" + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train",)) +def CC100(root: str, split: Union[Tuple[str], str], language_code: str = "en"): + if language_code not in VALID_CODES: + raise ValueError(f"Invalid language code {language_code}") + + url = URL % language_code + url_dp = IterableWrapper([url]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(url)) + ) + + cache_compressed_dp = HttpReader(cache_compressed_dp) + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz")) + ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_xz() + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") + + data_dp = FileOpener(cache_decompressed_dp, mode="r").readlines(return_path=False) + return data_dp.map(lambda x: (language_code, x))