Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def test_new_unk(self):
self.assertEqual(v['<new_unk>'], 0)
self.assertEqual(v['not_in_it'], 0)

def test_vocab_membership(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=2)

self.assertTrue('<unk>' in v)
self.assertTrue('a' in v)
self.assertTrue('b' in v)
self.assertFalse('c' in v)

def test_vocab_get_item(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand Down
9 changes: 9 additions & 0 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ PYBIND11_MODULE(_torchtext, m) {
.def(py::init<std::vector<std::string>, std::string>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
.def("__contains__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> bool {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__contains__(c10::string_view{buffer, (size_t)length});
})
.def("__getitem__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> int64_t {
ssize_t length;
Expand Down Expand Up @@ -229,6 +235,9 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, std::string>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
Expand Down
9 changes: 9 additions & 0 deletions torchtext/csrc/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token)

int64_t Vocab::__len__() const { return itos_.size(); }

bool Vocab::__contains__(const c10::string_view &token) const {
int64_t id = _find(token);
if (stoi_[id] != -1) {
return true;
}
return false;
}


int64_t Vocab::__getitem__(const c10::string_view &token) const {
int64_t id = _find(token);
if (stoi_[id] != -1) {
Expand Down
1 change: 1 addition & 0 deletions torchtext/csrc/vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct Vocab : torch::CustomClassHolder {
const std::string &unk_token);
int64_t __len__() const;
int64_t __getitem__(const c10::string_view &token) const;
bool __contains__(const c10::string_view &token) const;
void append_token(const std::string &token);
void insert_token(const std::string &token, const int64_t &index);
std::string lookup_token(const int64_t &index);
Expand Down
11 changes: 11 additions & 0 deletions torchtext/experimental/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@ def __len__(self) -> int:
"""
return len(self.vocab)

@torch.jit.export
def __contains__(self, token: str) -> bool:
r"""
Args:
token (str): the token for which to check the membership

Returns:
membership (bool): whether the token is member of vocab or not
"""
return self.vocab.__contains__(token)

@torch.jit.export
def __getitem__(self, token: str) -> int:
r"""
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.0a0
0.10.0a0