From 4515a348a1978d45c925926f7713ef10f2c27a91 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 31 Jan 2022 15:51:28 -0500 Subject: [PATCH 1/5] WIP: add CC100 --- torchtext/datasets/__init__.py | 2 ++ torchtext/datasets/cc100.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 torchtext/datasets/cc100.py diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index 5fda4a8451..d7d33298ad 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -3,6 +3,7 @@ from .ag_news import AG_NEWS from .amazonreviewfull import AmazonReviewFull from .amazonreviewpolarity import AmazonReviewPolarity +from .cc100 import CC100 from .conll2000chunking import CoNLL2000Chunking from .dbpedia import DBpedia from .enwik9 import EnWik9 @@ -26,6 +27,7 @@ "AG_NEWS": AG_NEWS, "AmazonReviewFull": AmazonReviewFull, "AmazonReviewPolarity": AmazonReviewPolarity, + "CC100": CC100, "CoNLL2000Chunking": CoNLL2000Chunking, "DBpedia": DBpedia, "EnWik9": EnWik9, diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py new file mode 100644 index 0000000000..229aae6f97 --- /dev/null +++ b/torchtext/datasets/cc100.py @@ -0,0 +1,57 @@ +import os.path + +from torchtext._internal.module_utils import is_module_available + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + +from torchtext.data.datasets_utils import ( + _create_dataset_directory, +) + +URL = "http://data.statmt.org/cc-100/%s.txt.xz" + +VALID_CODES = { + "am", "ar", "as", "az", "be", "bg", "bn", "bn_rom", "br", "bs", "ca", "cs", "cy", "da", "de", + "el", "en", "eo", "es", "et", "eu", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gn", "gu", + "ha", "he", "hi", "hi_rom", "hr", "ht", "hu", "hy", "id", "ig", "is", "it", "ja", "jv", "ka", + "kk", "km", "kn", "ko", "ku", "ky", "la", "lg", "li", "ln", "lo", "lt", "lv", "mg", "mk", "ml", + "mn", "mr", "ms", "my", "my_zaw", "ne", "nl", "no", "ns", "om", "or", "pa", "pl", "ps", "pt", + "qu", "rm", "ro", "ru", "sa", "si", "sc", "sd", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", + "sw", "ta", "ta_rom", "te", "te_rom", "th", "tl", "tn", "tr", "ug", "uk", "ur", "ur_rom", "uz", + "vi", "wo", "xh", "yi", "yo", "zh-Hans", "zh-Hant", "zu", +} + +NUM_LINES = None +MD5 = None + +DATASET_NAME = "CC100" + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +def CC100(root: str, language_code: str): + if language_code not in VALID_CODES: + raise ValueError(f"Invalid language code {language_code}") + + url = URL % language_code + url_dp = IterableWrapper([url]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(url)) + ) + + cache_compressed_dp = HttpReader(cache_compressed_dp) + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b").map(lambda x: x[0]) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz")) + ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_xz() + cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: os.path.basename(x).rstrip(".xz") in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, mode="r") + + units_dp = data_dp.readlines().map(lambda x: (language_code, x[1])) + return units_dp From 1794b8282db24260a3254958352ee86b9c164bb1 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 31 Jan 2022 20:56:04 -0500 Subject: [PATCH 2/5] add split following EnWiki9's train-only split convention --- torchtext/datasets/cc100.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index 229aae6f97..820ddaa477 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -1,5 +1,7 @@ import os.path +from typing import Union, Tuple + from torchtext._internal.module_utils import is_module_available if is_module_available("torchdata"): @@ -7,6 +9,7 @@ from torchtext.data.datasets_utils import ( _create_dataset_directory, + _wrap_split_argument ) URL = "http://data.statmt.org/cc-100/%s.txt.xz" @@ -29,7 +32,8 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) -def CC100(root: str, language_code: str): +@_wrap_split_argument(("train",)) +def CC100(root: str, split: Union[Tuple[str], str], language_code: str): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") From f641e6df1246b185ed46d3f202a6442138120bf5 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 1 Feb 2022 09:49:18 -0500 Subject: [PATCH 3/5] add default args to satisfy _wrap_split_argument_with_fn --- torchtext/datasets/cc100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index 820ddaa477..7ca7d7b3fb 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -33,7 +33,7 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train",)) -def CC100(root: str, split: Union[Tuple[str], str], language_code: str): +def CC100(root: str, split: Union[Tuple[str], str] = ("train",), language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") From ee478dff2c994ea688f45f13582e02508af3f777 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Fri, 4 Feb 2022 08:16:52 -0500 Subject: [PATCH 4/5] incorporate feedback from review. --- torchtext/datasets/cc100.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index 7ca7d7b3fb..8de03bdcd0 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -33,7 +33,7 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train",)) -def CC100(root: str, split: Union[Tuple[str], str] = ("train",), language_code: str = "en"): +def CC100(root: str = ".data", split: Union[Tuple[str], str] = ("train",), language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") @@ -46,16 +46,11 @@ def CC100(root: str, split: Union[Tuple[str], str] = ("train",), language_code: cache_compressed_dp = HttpReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b").map(lambda x: x[0]) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz")) ) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_xz() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: os.path.basename(x).rstrip(".xz") in x[0]) - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - - data_dp = FileOpener(cache_decompressed_dp, mode="r") + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") - units_dp = data_dp.readlines().map(lambda x: (language_code, x[1])) - return units_dp + data_dp = FileOpener(cache_decompressed_dp, mode="r").readlines(return_path=False) + return data_dp.map(lambda x: (language_code, x)) From 0775e787353205fdaf2b26c840af1d898142cf47 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Fri, 4 Feb 2022 08:30:07 -0500 Subject: [PATCH 5/5] fix issue with consistent tuple return. --- torchtext/datasets/cc100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index 8de03bdcd0..0e8f55e589 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -33,7 +33,7 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train",)) -def CC100(root: str = ".data", split: Union[Tuple[str], str] = ("train",), language_code: str = "en"): +def CC100(root: str, split: Union[Tuple[str], str], language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}")