diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 85c58ea67d..d091d7796b 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -33,6 +33,17 @@ def test_new_unk(self): self.assertEqual(v[''], 0) self.assertEqual(v['not_in_it'], 0) + def test_vocab_membership(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + + self.assertTrue('' in v) + self.assertTrue('a' in v) + self.assertTrue('b' in v) + self.assertFalse('c' in v) + def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index e41783cef1..0b325bcda6 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -107,6 +107,12 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) + .def("__contains__", + [](c10::intrusive_ptr &self, const py::str &item) -> bool { + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + return self->__contains__(c10::string_view{buffer, (size_t)length}); + }) .def("__getitem__", [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; @@ -229,6 +235,9 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Vocab") .def(torch::init()) + .def("__contains__", + [](const c10::intrusive_ptr &self, const std::string &item) + -> bool { return self->__contains__(c10::string_view{item}); }) .def("__getitem__", [](const c10::intrusive_ptr &self, const std::string &item) -> int64_t { return self->__getitem__(c10::string_view{item}); }) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index b6fa2099d7..1831d46f39 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -30,6 +30,15 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return itos_.size(); } +bool Vocab::__contains__(const c10::string_view &token) const { + int64_t id = _find(token); + if (stoi_[id] != -1) { + return true; + } + return false; +} + + int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); if (stoi_[id] != -1) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 660f6145d4..d915c7de27 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -21,6 +21,7 @@ struct Vocab : torch::CustomClassHolder { const std::string &unk_token); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; + bool __contains__(const c10::string_view &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 7045ba65b0..26f393ce36 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -175,6 +175,17 @@ def __len__(self) -> int: """ return len(self.vocab) + @torch.jit.export + def __contains__(self, token: str) -> bool: + r""" + Args: + token (str): the token for which to check the membership + + Returns: + membership (bool): whether the token is member of vocab or not + """ + return self.vocab.__contains__(token) + @torch.jit.export def __getitem__(self, token: str) -> int: r""" diff --git a/version.txt b/version.txt index 7e4490fc70..37f1777fc3 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0a0 \ No newline at end of file +0.10.0a0