From 3340265471c1b0015fd061cc7b86d8e815b7f13e Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 7 May 2021 17:38:56 -0400 Subject: [PATCH] unit testing for experimental vocab --- test/experimental/test_vocab.py | 9 +++++++++ torchtext/csrc/vocab.cpp | 18 ++++++++++++++++-- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 3 +++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index d091d7796b..83b19a8fa3 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -76,6 +76,10 @@ def test_vocab_insert_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + with self.assertRaises(RuntimeError) as context: + v.insert_token('b', 0) + + self.assertTrue("Token b already exists in the Vocab with index: 0" in str(context.exception)) def test_vocab_append_token(self): c = OrderedDict({'a': 2}) @@ -88,6 +92,11 @@ def test_vocab_append_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + with self.assertRaises(RuntimeError) as context: + v.append_token('b') + + self.assertTrue("Token b already exists in the Vocab with index: 2" in str(context.exception)) + def test_vocab_len(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1831d46f39..fec59b094d 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -38,7 +38,6 @@ bool Vocab::__contains__(const c10::string_view &token) const { return false; } - int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); if (stoi_[id] != -1) { @@ -47,7 +46,22 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const { return unk_index_; } -void Vocab::append_token(const std::string &token) { _add(token); } +void Vocab::append_token(const std::string &token) { + // if item already in stoi we throw an error + auto token_position = _find(c10::string_view{token.data(), token.size()}); + if (stoi_[token_position] != -1) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Token " << token + << " already exists in the Vocab with index: " + << stoi_[token_position] << std::endl; +#endif + throw std::runtime_error("Token " + token + + " already exists in the Vocab with index: " + + std::to_string(stoi_[token_position]) + "."); + } + + _add(token); +} void Vocab::insert_token(const std::string &token, const int64_t &index) { if (index < 0 || index > itos_.size()) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index d915c7de27..40fc38e0e2 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -44,7 +44,7 @@ struct Vocab : torch::CustomClassHolder { uint32_t _find(const c10::string_view &w) const { uint32_t stoi_size = stoi_.size(); uint32_t id = _hash(w) % stoi_size; - while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) { + while (stoi_[id] != -1 && itos_[stoi_[id]] != w) { id = (id + 1) % stoi_size; } return id; diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 26f393ce36..acc6d96a4d 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -214,6 +214,9 @@ def append_token(self, token: str) -> None: r""" Args: token (str): the token used to lookup the corresponding index. + + Raises: + RuntimeError: if token already exists in the vocab """ self.vocab.append_token(token)