From ced0a2b0b456474b7494a000ea2d17b333ce9afe Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Tue, 27 Sep 2022 12:58:59 -0400 Subject: [PATCH] Resolve inconsistency in IMDB label output --- test/torchtext_unittest/datasets/test_imdb.py | 4 ++-- torchtext/datasets/imdb.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/torchtext_unittest/datasets/test_imdb.py b/test/torchtext_unittest/datasets/test_imdb.py index cb9ab3b62d..c71f53560e 100644 --- a/test/torchtext_unittest/datasets/test_imdb.py +++ b/test/torchtext_unittest/datasets/test_imdb.py @@ -29,8 +29,8 @@ def _get_mock_dataset(root_dir): for i in range(5): # all negative labels are read first before positive labels in the # IMDB dataset implementation - label = "neg" if i < 2 else "pos" - cur_dir = pos_dir if label == "pos" else neg_dir + label = 1 if i < 2 else 2 + cur_dir = pos_dir if label == 2 else neg_dir txt_file = os.path.join(cur_dir, f"{i}{i}_{i}.txt") with open(txt_file, "w", encoding="utf-8") as f: rand_string = get_random_unicode(seed) diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index 1e11ad95ab..debb5c06f3 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -20,6 +20,8 @@ "test": 25000, } +MAP_LABELS = {"neg": 1, "pos": 2} + _PATH = "aclImdb_v1.tar.gz" DATASET_NAME = "IMDB" @@ -50,7 +52,7 @@ def _cache_filepath_fn(root, decompressed_folder, split, x): def _modify_res(t): - return Path(t[0]).parts[-1], t[1] + return MAP_LABELS[Path(t[0]).parts[-1]], t[1] def filter_imdb_data(key, fname):