Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
from collections import OrderedDict
import os
import platform
import torch
import unittest
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.experimental.vocab import (
vocab,
Expand All @@ -20,18 +18,12 @@ def tearDown(self):
def test_has_unk(self):
c = OrderedDict()
v = vocab(c)

# check if unk is mapped to the first index
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['<unk>'], 0)

def test_new_unk(self):
c = OrderedDict()
v = vocab(c, unk_token="<new_unk>")

# check if new_unk is mapped to the first index
self.assertEqual(v['<new_unk>'], 0)
self.assertEqual(v['not_in_it'], 0)

def test_vocab_membership(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
Expand All @@ -54,6 +46,50 @@ def test_vocab_get_item(self):
self.assertEqual(v['a'], 1)
self.assertEqual(v['b'], 2)

def test_reassign_token(self):
token_to_freq = {'<unk>': 1, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=1)

self.assertEqual(v['<unk>'], 2)
self.assertEqual(v['a'], 0)
self.assertEqual(v['b'], 1)
v.reassign_token('<unk>', 0)
self.assertEqual(v['<unk>'], 0)
self.assertEqual(v['a'], 1)
self.assertEqual(v['b'], 2)

self.assertEqual(v.get_itos(), ['<unk>', 'a', 'b'])

with self.assertRaises(RuntimeError):
v.reassign_token('not in vocab', 0)

with self.assertRaises(RuntimeError):
v.reassign_token('<unk>', 3)

def test_default_index(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=2)

self.assertTrue(v.get_default_index() is None)
with self.assertRaises(RuntimeError):
v['not in vocab']

v.set_default_index(0)
self.assertEqual(v['not in vocab'], 0)

def test_default_index_jit(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=2)
v.set_default_index(0)
v_jit = torch.jit.script(v)
self.assertEqual(v_jit['not in vocab'], 0)

def test_vocab_insert_token(self):
c = OrderedDict({'<unk>': 2, 'a': 2})

Expand Down Expand Up @@ -88,6 +124,10 @@ def test_vocab_append_token(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# token must not exist to be appended
with self.assertRaises(RuntimeError):
v.append_token('b')

def test_vocab_len(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand Down Expand Up @@ -149,6 +189,8 @@ def test_vocab_lookup_token(self):
v = vocab(c)

self.assertEqual(v.lookup_token(1), 'a')
with self.assertRaises(RuntimeError):
v.lookup_token(100)

def test_vocab_lookup_tokens(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
Expand All @@ -172,24 +214,6 @@ def test_vocab_lookup_indices(self):

self.assertEqual(v.lookup_indices(tokens), expected_indices)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_errors_vocab_cpp(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)

with self.assertRaises(RuntimeError):
# Test proper error raised when setting a token out of bounds
v = vocab(c, min_freq=3)
v.insert_token('new_token', 100)

with self.assertRaises(RuntimeError):
# Test proper error raised when looking up a token out of bounds
v = vocab(c)
v.lookup_token(100)

def test_errors_vocab_python(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand All @@ -205,6 +229,7 @@ def test_vocab_load_and_save(self):

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=3)
v.set_default_index(0)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
Expand All @@ -218,6 +243,7 @@ def test_vocab_load_and_save(self):
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
self.assertEqual(v['not in vocab'], 0)

with self.subTest('torchscript'):
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
Expand All @@ -227,6 +253,7 @@ def test_vocab_load_and_save(self):
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
self.assertEqual(v['not in vocab'], 0)

def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
Expand Down
33 changes: 19 additions & 14 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ namespace py = pybind11;

namespace {
Vocab build_vocab_from_text_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq, const int64_t num_cpus,
py::object fn) {
torch::jit::script::Module module(*torch::jit::as_module(fn));
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus,
module);
return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module);
}
} // namespace

Expand Down Expand Up @@ -104,23 +102,27 @@ PYBIND11_MODULE(_torchtext, m) {
}));

py::class_<Vocab, c10::intrusive_ptr<Vocab>>(m, "Vocab")
.def(py::init<std::vector<std::string>, std::string>())
.def(py::init<StringList, c10::optional<int64_t>>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
.def("__contains__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> bool {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__contains__(c10::string_view{buffer, (size_t)length});
})
.def_readonly("default_index_", &Vocab::default_index_)
.def(
"__contains__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> bool {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__contains__(c10::string_view{buffer, (size_t)length});
})
.def("__getitem__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> int64_t {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__getitem__(c10::string_view{buffer, (size_t)length});
})
.def("__len__", &Vocab::__len__)
.def("reassign_token", &Vocab::reassign_token)
.def("insert_token", &Vocab::insert_token)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("__len__", &Vocab::__len__)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
Expand Down Expand Up @@ -234,15 +236,18 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, std::string>())
.def(torch::init<StringList, c10::optional<int64_t>>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("__len__", &Vocab::__len__)
.def("reassign_token", &Vocab::reassign_token)
.def("insert_token", &Vocab::insert_token)
.def("__len__", &Vocab::__len__)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
Expand Down
Loading