Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
e46e2f8
checkpoint
Oct 9, 2020
3129036
checkpoint
Oct 9, 2020
441f392
checkpoint
Oct 9, 2020
bc87fb0
checkpoint
Oct 9, 2020
8ba6b21
checkpoint
Oct 9, 2020
6b9f015
update tests
Oct 9, 2020
3d21a18
clang
Oct 9, 2020
0efe7b2
flake8
Oct 9, 2020
dce7080
checkpoint
Oct 9, 2020
660e051
checkpoint
Oct 9, 2020
6b9368b
CI
Oct 9, 2020
ffeb7ab
checkpoint
Oct 9, 2020
ca1dbbb
checkpoint
Oct 9, 2020
40d6c06
checkpoint
Oct 9, 2020
4246ba2
checkpoint
Oct 9, 2020
6f6cfad
update save/load in vocab
Oct 9, 2020
a0d5fc2
checkpooint
Oct 9, 2020
a49c10d
checkpoint
Oct 9, 2020
f76d9b1
checkpoint
Oct 9, 2020
7bec937
skip test for windows
Oct 9, 2020
ba7e561
update unk_index with insert_token
Oct 12, 2020
ed7be7d
checkpoint
Oct 12, 2020
e25660d
Merge branch 'master' into remove_unk
Oct 12, 2020
6627847
Merge branch 'master' into remove_unk
Oct 12, 2020
e4e1e05
change unk_index to fallback_index
Oct 12, 2020
279ba95
checkpoint
Oct 12, 2020
1cdb82e
checkpoint
Oct 12, 2020
6f339c0
Merge branch 'master' into remove_unk
Oct 13, 2020
ccd3166
switch to default
Oct 18, 2020
98c508b
Merge remote-tracking branch 'upstream/master' into remove_unk
Oct 19, 2020
c61c127
checkpoint
Oct 19, 2020
3549c23
Merge remote-tracking branch 'upstream/master' into remove_unk
Oct 19, 2020
6b7e5ad
update test
Oct 20, 2020
6858227
add one more test for inserting existing token
Oct 20, 2020
b40d8dd
use c10::optional for default index
Oct 21, 2020
4ec66e6
checkpoint
Oct 22, 2020
1556074
sync with master branch
Oct 27, 2020
c5e3773
Update docs
Nov 2, 2020
f3ed767
checkpoint
Nov 2, 2020
a9b27de
set_default_index if the saved vocab has a default index
Nov 3, 2020
53a353f
checkpoint
Nov 3, 2020
7631d72
Merge branch 'master' into remove_unk
Nov 7, 2020
2cab04b
sync with master
Nov 7, 2020
a1cfea9
sync with master
Dec 23, 2020
aeb9995
checkpoint
Dec 23, 2020
588cce4
checkpoint
Dec 23, 2020
67ec466
checkpoint
Dec 23, 2020
4567d7c
add setitem func for vocab
Dec 23, 2020
15f29be
checkpoint
Dec 23, 2020
33f31d9
add delete_token func
Dec 28, 2020
6beae11
implement setitem func and add a test
Dec 28, 2020
efc820b
Merge branch 'master' into vocab_setitem
Jan 5, 2021
82fce2b
add __delete__ func to Vocab and remind users not to reassign an exis…
Jan 6, 2021
0c2dab7
checkpoint
zhangguanheng66 Jan 8, 2021
2f9ac0e
checkpoint
zhangguanheng66 Jan 8, 2021
fe5ce92
sync with master branch
Feb 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 83 additions & 22 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@ def tearDown(self):
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()

def test_has_unk(self):
# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_has_no_unk(self):
c = OrderedDict()
v = vocab(c)
with self.assertRaisesRegex(RuntimeError, 'bad optional access'):
v.get_default_index()

# check if unk is mapped to the first index
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['<unk>'], 0)

def test_new_unk(self):
c = OrderedDict()
v = vocab(c, unk_token="<new_unk>")
with self.assertRaises(RuntimeError):
v['not_in_it']
with self.assertRaises(RuntimeError):
v['<unk>']

# check if new_unk is mapped to the first index
self.assertEqual(v['<new_unk>'], 0)
v.insert_token('not_in_it', 0)
v.set_default_index(0)
self.assertEqual(v.get_default_index(), 0)
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['<unk>'], 0)

def test_vocab_get_item(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
Expand All @@ -43,35 +48,81 @@ def test_vocab_get_item(self):
self.assertEqual(v['a'], 1)
self.assertEqual(v['b'], 2)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_vocab_set_item(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=2)

v.set_default_index(0)
with self.assertRaises(RuntimeError):
v['b'] = 1
del v['b']
self.assertEqual(v['<unk>'], 0)
self.assertEqual(v['a'], 1)
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['b'], 0)

v['b'] = 1
self.assertEqual(v['<unk>'], 0)
self.assertEqual(v['b'], 1)
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['a'], 0)

def test_vocab_insert_token(self):
c = OrderedDict({'<unk>': 2, 'a': 2})

# add item to end
v = vocab(c)
v.set_default_index(0)
v.insert_token('b', 2)

expected_itos = ['<unk>', 'a', 'b']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_default_index(), 0)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# add item to middle
v = vocab(c)
v.set_default_index(0)
v.insert_token('b', 0)

expected_itos = ['b', '<unk>', 'a']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_default_index(), 1)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_insert_existing_token(self):
c = OrderedDict({'a': 2, 'b': 2, 'c': 2})

# add item to end
v = vocab(c)
v.insert_token('<unk>', 2)
v.set_default_index(2)

with self.assertRaises(RuntimeError):
# Test proper error raised when setting a token out of bounds
v.insert_token('<unk>', 1)

v.insert_token('d', 1)
self.assertEqual(v['not_in_it'], 3)

def test_vocab_append_token(self):
c = OrderedDict({'a': 2})
v = vocab(c)
v.append_token('b')

expected_itos = ['<unk>', 'a', 'b']
expected_itos = ['a', 'b']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
Expand All @@ -83,7 +134,7 @@ def test_vocab_len(self):
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)

self.assertEqual(len(v), 4)
self.assertEqual(len(v), 3)

def test_vocab_basic(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
Expand All @@ -92,12 +143,15 @@ def test_vocab_basic(self):
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=3)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_vocab_jit(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand All @@ -106,7 +160,7 @@ def test_vocab_jit(self):
v = vocab(c, min_freq=3)
jit_v = torch.jit.script(v)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

assert not v.is_jitable
Expand All @@ -117,6 +171,9 @@ def test_vocab_jit(self):
self.assertEqual(jit_v.get_itos(), expected_itos)
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_vocab_forward(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand All @@ -126,7 +183,7 @@ def test_vocab_forward(self):
jit_v = torch.jit.script(v)

tokens = ['b', 'a', 'c']
expected_indices = [2, 1, 3]
expected_indices = [1, 0, 2]

self.assertEqual(v(tokens), expected_indices)
self.assertEqual(jit_v(tokens), expected_indices)
Expand All @@ -137,15 +194,15 @@ def test_vocab_lookup_token(self):
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)

self.assertEqual(v.lookup_token(1), 'a')
self.assertEqual(v.lookup_token(0), 'a')

def test_vocab_lookup_tokens(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)

indices = [2, 1, 3]
indices = [1, 0, 2]
expected_tokens = ['b', 'a', 'c']

self.assertEqual(v.lookup_tokens(indices), expected_tokens)
Expand All @@ -157,7 +214,7 @@ def test_vocab_lookup_indices(self):
v = vocab(c)

tokens = ['b', 'a', 'c']
expected_indices = [2, 1, 3]
expected_indices = [1, 0, 2]

self.assertEqual(v.lookup_indices(tokens), expected_indices)

Expand All @@ -179,23 +236,27 @@ def test_errors_vocab_cpp(self):
v = vocab(c)
v.lookup_token(100)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_errors_vocab_python(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)

with self.assertRaises(ValueError):
with self.assertRaises(RuntimeError):
# Test proper error raised when setting unk token to None
vocab(c, unk_token=None)
v(['not_in_vocab'])

def test_vocab_load_and_save(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=3)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
v.set_default_index(1)
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
Expand All @@ -221,7 +282,7 @@ def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']]
v = build_vocab_from_iterator(iterator)
expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low']
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
14 changes: 9 additions & 5 deletions test/experimental/test_with_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
load_vocab_from_file,
build_vocab_from_text_file,
)
import unittest
import platform
import shutil
import tempfile
import os
import unittest
import platform
from torchtext.experimental.vectors import (
GloVe,
build_vectors,
Expand Down Expand Up @@ -75,6 +75,9 @@ def test_wikitext103(self):


class TestTransformsWithAsset(TorchtextTestCase):
# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_vocab_transform(self):
asset_name = 'vocab_test2.txt'
asset_path = get_asset_path(asset_name)
Expand Down Expand Up @@ -180,7 +183,8 @@ def test_vocab_from_file(self):
asset_name = 'vocab_test.txt'
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
v = load_vocab_from_file(f, unk_token='<new_unk>')
v = load_vocab_from_file(f)
v.insert_token('<new_unk>', 0)
expected_itos = ['<new_unk>', 'b', 'a', 'c']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
Expand All @@ -192,8 +196,8 @@ def test_vocab_from_raw_text_file(self):
with open(asset_path, 'r') as f:
tokenizer = basic_english_normalize()
jit_tokenizer = torch.jit.script(tokenizer)
v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='<new_unk>')
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
v = build_vocab_from_text_file(f, jit_tokenizer)
expected_itos = ["'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner',
'unions', 'with', 'workers']
Expand Down
16 changes: 11 additions & 5 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ namespace py = pybind11;

namespace {
Vocab build_vocab_from_text_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq,
const int64_t num_cpus,
py::object fn) {
torch::jit::script::Module module(*torch::jit::as_module(fn));
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module);
return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module);
}
} // namespace

Expand Down Expand Up @@ -100,12 +99,15 @@ PYBIND11_MODULE(_torchtext, m) {
}));

py::class_<Vocab, c10::intrusive_ptr<Vocab>>(m, "Vocab")
.def(py::init<std::vector<std::string>, std::string>())
.def(py::init<std::vector<std::string>>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
.def("__getitem__", &Vocab::__getitem__)
.def("__setitem__", &Vocab::__setitem__)
.def("__delitem__", &Vocab::__delitem__)
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
Expand Down Expand Up @@ -202,10 +204,14 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, std::string>())
.def(torch::init<StringList>())
.def("__getitem__", &Vocab::__getitem__)
.def("__setitem__", &Vocab::__setitem__)
.def("__delitem__", &Vocab::__delitem__)
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
Expand Down
Loading