From fac375248e77109cff4be1cf4a610a6662193d37 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 29 Apr 2021 17:46:09 -0400 Subject: [PATCH 1/2] adding __contains__ method to vocab --- test/experimental/test_vocab.py | 11 +++++++++++ torchtext/csrc/register_bindings.cpp | 9 +++++++++ torchtext/csrc/vocab.cpp | 9 +++++++++ torchtext/csrc/vocab.h | 1 + torchtext/experimental/vocab.py | 11 +++++++++++ 5 files changed, 41 insertions(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 85c58ea67d..d091d7796b 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -33,6 +33,17 @@ def test_new_unk(self): self.assertEqual(v[''], 0) self.assertEqual(v['not_in_it'], 0) + def test_vocab_membership(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + + self.assertTrue('' in v) + self.assertTrue('a' in v) + self.assertTrue('b' in v) + self.assertFalse('c' in v) + def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index e41783cef1..0b325bcda6 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -107,6 +107,12 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) + .def("__contains__", + [](c10::intrusive_ptr &self, const py::str &item) -> bool { + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + return self->__contains__(c10::string_view{buffer, (size_t)length}); + }) .def("__getitem__", [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; @@ -229,6 +235,9 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Vocab") .def(torch::init()) + .def("__contains__", + [](const c10::intrusive_ptr &self, const std::string &item) + -> bool { return self->__contains__(c10::string_view{item}); }) .def("__getitem__", [](const c10::intrusive_ptr &self, const std::string &item) -> int64_t { return self->__getitem__(c10::string_view{item}); }) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index b6fa2099d7..1831d46f39 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -30,6 +30,15 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return itos_.size(); } +bool Vocab::__contains__(const c10::string_view &token) const { + int64_t id = _find(token); + if (stoi_[id] != -1) { + return true; + } + return false; +} + + int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); if (stoi_[id] != -1) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 660f6145d4..d915c7de27 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -21,6 +21,7 @@ struct Vocab : torch::CustomClassHolder { const std::string &unk_token); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; + bool __contains__(const c10::string_view &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 7045ba65b0..26f393ce36 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -175,6 +175,17 @@ def __len__(self) -> int: """ return len(self.vocab) + @torch.jit.export + def __contains__(self, token: str) -> bool: + r""" + Args: + token (str): the token for which to check the membership + + Returns: + membership (bool): whether the token is member of vocab or not + """ + return self.vocab.__contains__(token) + @torch.jit.export def __getitem__(self, token: str) -> int: r""" From 6fbe99bfe28219243d2ba7e3719b39580e3e35ff Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 29 Apr 2021 18:25:42 -0400 Subject: [PATCH 2/2] adding new line to version .txt --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 7e4490fc70..37f1777fc3 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0a0 \ No newline at end of file +0.10.0a0