diff --git a/benchmark/benchmark_experimental_vocab.py b/benchmark/benchmark_experimental_vocab.py index e63904934d..3cd4e49a1d 100644 --- a/benchmark/benchmark_experimental_vocab.py +++ b/benchmark/benchmark_experimental_vocab.py @@ -3,7 +3,7 @@ import time import torch -from torchtext.experimental.datasets import AG_NEWS +from torchtext.experimental.datasets import DATASETS from torchtext.experimental.vocab import ( vocab as VocabExperimental, load_vocab_from_file, @@ -76,7 +76,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True, print("Construction time:", time.monotonic() - t0) -def benchmark_experimental_vocab_lookup(vocab_file_path=None): +def benchmark_experimental_vocab_lookup(vocab_file_path=None, dataset = 'AG_NEWS'): def _run_benchmark_lookup(tokens, vocab): t0 = time.monotonic() # list lookup @@ -94,7 +94,7 @@ def _run_benchmark_lookup(tokens, vocab): tokens = [] tokens_lists = [] - train = AG_NEWS(split='train') + train = DATASETS[dataset](split='train') vocab = train.get_vocab() for (_, text) in train: cur_tokens = [] @@ -124,7 +124,7 @@ def token_iterator(file_path): v_experimental = load_vocab_from_file(f) print("Construction time:", time.monotonic() - t0) else: - print("Loading Vocab from AG News") + print("Loading Vocab from {}".format(dataset)) counter = Counter(tokens) sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) @@ -174,6 +174,8 @@ def token_iterator(file_path): help='The name of vocab file used for construction') parser.add_argument('--vocab-filename-lookup', type=str, default=None, help='The name of vocab file used for lookup') + parser.add_argument('--dataset', type=str, default='AG_NEWS', + help='The name of vocab file used for lookup') args = parser.parse_args() if args.run_construction_benchmark: @@ -181,4 +183,4 @@ def token_iterator(file_path): benchmark_experimental_vocab_construction(args.vocab_filename_construction, is_raw_text=args.is_raw_text, is_legacy=args.is_legacy) else: - benchmark_experimental_vocab_lookup(args.vocab_filename_lookup) + benchmark_experimental_vocab_lookup(args.vocab_filename_lookup, args.dataset) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 4c3ef76399..e41783cef1 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -1,14 +1,14 @@ +#include #include #include #include -#include // @manual -#include // @manual +#include // @manual +#include // @manual #include // @manual -#include // @manual +#include // @manual #include #include // @manual #include // @manual - namespace torchtext { namespace py = pybind11; @@ -16,11 +16,11 @@ namespace py = pybind11; namespace { Vocab build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, - const int64_t min_freq, - const int64_t num_cpus, + const int64_t min_freq, const int64_t num_cpus, py::object fn) { torch::jit::script::Module module(*torch::jit::as_module(fn)); - return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module); + return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, + module); } } // namespace @@ -40,7 +40,8 @@ PYBIND11_MODULE(_torchtext, m) { return _deserialize_regex(std::move(state)); })); - py::class_>(m, "RegexTokenizer") + py::class_>( + m, "RegexTokenizer") .def_readonly("patterns_", &RegexTokenizer::patterns_) .def_readonly("replacements_", &RegexTokenizer::replacements_) .def_readonly("to_lower_", &RegexTokenizer::to_lower_) @@ -48,15 +49,18 @@ PYBIND11_MODULE(_torchtext, m) { .def("forward", &RegexTokenizer::forward) .def(py::pickle( // __getstate__ - [](const c10::intrusive_ptr &self) -> RegexTokenizerStates { + [](const c10::intrusive_ptr &self) + -> RegexTokenizerStates { return _serialize_regex_tokenizer(self); }, // __setstate__ - [](RegexTokenizerStates states) -> c10::intrusive_ptr { + [](RegexTokenizerStates states) + -> c10::intrusive_ptr { return _deserialize_regex_tokenizer(std::move(states)); })); - py::class_>(m, "SentencePiece") + py::class_>(m, + "SentencePiece") .def(py::init()) .def("_return_content", [](const SentencePiece &self) { return py::bytes(self.content_); }) @@ -70,14 +74,14 @@ PYBIND11_MODULE(_torchtext, m) { .def("PieceToId", &SentencePiece::PieceToId) .def("IdToPiece", &SentencePiece::IdToPiece) .def(py::pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> py::bytes{ - return py::bytes(self->content_); - }, - // __setstate__ - [](py::bytes state) -> c10::intrusive_ptr { - return c10::make_intrusive(std::string(state)); - })); + // __getstate__ + [](const c10::intrusive_ptr &self) -> py::bytes { + return py::bytes(self->content_); + }, + // __setstate__ + [](py::bytes state) -> c10::intrusive_ptr { + return c10::make_intrusive(std::string(state)); + })); py::class_>(m, "Vectors") .def(py::init, std::vector, @@ -103,13 +107,30 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) - .def("__getitem__", &Vocab::__getitem__) + .def("__getitem__", + [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + return self->__getitem__(c10::string_view{buffer, (size_t)length}); + }) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", &Vocab::lookup_indices) + .def("lookup_indices", + [](const c10::intrusive_ptr &self, const py::list &items) { + std::vector indices(items.size()); + int64_t counter = 0; + for (const auto &item : items) { + ssize_t length; + const char *buffer = + PyUnicode_AsUTF8AndSize(item.ptr(), &length); + indices[counter++] = + self->__getitem__(c10::string_view{buffer, (size_t)length}); + } + return indices; + }) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def(py::pickle( @@ -131,96 +152,112 @@ PYBIND11_MODULE(_torchtext, m) { TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Regex") - .def(torch::init()) - .def("Sub", &Regex::Sub) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> std::string { - return _serialize_regex(self); - }, - // __setstate__ - [](std::string state) -> c10::intrusive_ptr { - return _deserialize_regex(std::move(state)); - }); + .def(torch::init()) + .def("Sub", &Regex::Sub) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> std::string { + return _serialize_regex(self); + }, + // __setstate__ + [](std::string state) -> c10::intrusive_ptr { + return _deserialize_regex(std::move(state)); + }); m.class_("RegexTokenizer") - .def(torch::init, std::vector, bool>()) - .def("forward", &RegexTokenizer::forward) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> RegexTokenizerStates { - return _serialize_regex_tokenizer(self); - }, - // __setstate__ - [](RegexTokenizerStates states) -> c10::intrusive_ptr { - return _deserialize_regex_tokenizer(std::move(states)); - }); + .def(torch::init, std::vector, + bool>()) + .def("forward", &RegexTokenizer::forward) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) + -> RegexTokenizerStates { + return _serialize_regex_tokenizer(self); + }, + // __setstate__ + [](RegexTokenizerStates states) + -> c10::intrusive_ptr { + return _deserialize_regex_tokenizer(std::move(states)); + }); m.class_("SentencePiece") - .def(torch::init()) - .def("Encode", &SentencePiece::Encode) - .def("EncodeAsIds", &SentencePiece::EncodeAsIds) - .def("DecodeIds", &SentencePiece::DecodeIds) - .def("EncodeAsPieces", &SentencePiece::EncodeAsPieces) - .def("DecodePieces", &SentencePiece::DecodePieces) - .def("GetPieceSize", &SentencePiece::GetPieceSize) - .def("unk_id", &SentencePiece::unk_id) - .def("PieceToId", &SentencePiece::PieceToId) - .def("IdToPiece", &SentencePiece::IdToPiece) - .def_pickle( - // The underlying content of SentencePiece contains byte string, - // and returing it as std::string cause UTF8 decoding error. - // Since TorchScript does not support byte string, we use byte Tensor to - // pass around the data. - // __getstate__ - [](const c10::intrusive_ptr &self) -> torch::Tensor { - auto *data = static_cast(const_cast(self->content_.data())); - auto numel = static_cast(self->content_.size()); - return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); - }, - // __setstate__ - [](torch::Tensor state) -> c10::intrusive_ptr { - auto *data = static_cast(state.data_ptr()); - auto numel = state.size(0); - return c10::make_intrusive(std::string(data, numel)); - }); + .def(torch::init()) + .def("Encode", &SentencePiece::Encode) + .def("EncodeAsIds", &SentencePiece::EncodeAsIds) + .def("DecodeIds", &SentencePiece::DecodeIds) + .def("EncodeAsPieces", &SentencePiece::EncodeAsPieces) + .def("DecodePieces", &SentencePiece::DecodePieces) + .def("GetPieceSize", &SentencePiece::GetPieceSize) + .def("unk_id", &SentencePiece::unk_id) + .def("PieceToId", &SentencePiece::PieceToId) + .def("IdToPiece", &SentencePiece::IdToPiece) + .def_pickle( + // The underlying content of SentencePiece contains byte string, + // and returing it as std::string cause UTF8 decoding error. + // Since TorchScript does not support byte string, we use byte Tensor + // to pass around the data. + // __getstate__ + [](const c10::intrusive_ptr &self) -> torch::Tensor { + auto *data = + static_cast(const_cast(self->content_.data())); + auto numel = static_cast(self->content_.size()); + return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); + }, + // __setstate__ + [](torch::Tensor state) -> c10::intrusive_ptr { + auto *data = static_cast(state.data_ptr()); + auto numel = state.size(0); + return c10::make_intrusive(std::string(data, numel)); + }); m.class_("Vectors") - .def(torch::init, std::vector, torch::Tensor, torch::Tensor>()) - .def("__getitem__", &Vectors::__getitem__) - .def("lookup_vectors", &Vectors::lookup_vectors) - .def("__setitem__", &Vectors::__setitem__) - .def("__len__", &Vectors::__len__) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> VectorsStates { - return _serialize_vectors(self); - }, - // __setstate__ - [](VectorsStates states) -> c10::intrusive_ptr { - return _deserialize_vectors(states); - }); + .def(torch::init, std::vector, + torch::Tensor, torch::Tensor>()) + .def("__getitem__", &Vectors::__getitem__) + .def("lookup_vectors", &Vectors::lookup_vectors) + .def("__setitem__", &Vectors::__setitem__) + .def("__len__", &Vectors::__len__) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VectorsStates { + return _serialize_vectors(self); + }, + // __setstate__ + [](VectorsStates states) -> c10::intrusive_ptr { + return _deserialize_vectors(states); + }); m.class_("Vocab") - .def(torch::init()) - .def("__getitem__", &Vocab::__getitem__) - .def("__len__", &Vocab::__len__) - .def("insert_token", &Vocab::insert_token) - .def("append_token", &Vocab::append_token) - .def("lookup_token", &Vocab::lookup_token) - .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", &Vocab::lookup_indices) - .def("get_stoi", &Vocab::get_stoi) - .def("get_itos", &Vocab::get_itos) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> VocabStates { - return _serialize_vocab(self); - }, - // __setstate__ - [](VocabStates states) -> c10::intrusive_ptr { - return _deserialize_vocab(states); - }); + .def(torch::init()) + .def("__getitem__", + [](const c10::intrusive_ptr &self, const std::string &item) + -> int64_t { return self->__getitem__(c10::string_view{item}); }) + .def("__len__", &Vocab::__len__) + .def("insert_token", &Vocab::insert_token) + .def("append_token", &Vocab::append_token) + .def("lookup_token", &Vocab::lookup_token) + .def("lookup_tokens", &Vocab::lookup_tokens) + .def("lookup_indices", + [](const c10::intrusive_ptr &self, + const std::vector &items) { + std::vector indices(items.size()); + int64_t counter = 0; + for (const auto &item : items) { + indices[counter++] = self->__getitem__(c10::string_view{item}); + } + return indices; + }) + .def("get_stoi", &Vocab::get_stoi) + .def("get_itos", &Vocab::get_itos) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VocabStates { + return _serialize_vocab(self); + }, + // __setstate__ + [](VocabStates states) -> c10::intrusive_ptr { + return _deserialize_vocab(states); + }); m.def("torchtext::generate_sp_model", &generate_sp_model); m.def("torchtext::load_sp_model", &load_sp_model); diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 0e324dbbf5..b6fa2099d7 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -1,23 +1,19 @@ #include // @manual #include +#include #include #include -#include // @manual -#include // @manual - +#include // @manual +#include // @manual namespace torchtext { -Vocab::Vocab(const StringList &tokens, const IndexDict &stoi, - const std::string &unk_token, const int64_t unk_index) - : unk_index_(std::move(unk_index)), stoi_(std::move(stoi)), - itos_(std::move(tokens)), unk_token_(std::move(unk_token)) {} - Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : itos_(std::move(tokens)), unk_token_(std::move(unk_token)) { - stoi_.reserve(tokens.size()); + : stoi_(MAX_VOCAB_SIZE, -1), unk_token_(std::move(unk_token)) { for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates - if (stoi_.find(tokens[i]) != stoi_.end()) { + auto token_position = + _find(c10::string_view{tokens[i].data(), tokens[i].size()}); + if (stoi_[token_position] != -1) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Duplicate token found in tokens list: " << tokens[i] << std::endl; @@ -25,35 +21,27 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } - stoi_[std::move(tokens[i])] = i; + _add(tokens[i]); } - unk_index_ = stoi_.find(unk_token)->second; + + unk_index_ = + stoi_[_find(c10::string_view{unk_token.data(), unk_token.size()})]; } -int64_t Vocab::__len__() const { return stoi_.size(); } +int64_t Vocab::__len__() const { return itos_.size(); } -int64_t Vocab::__getitem__(const std::string &token) const { - const auto &item = stoi_.find(token); - if (item != stoi_.end()) { - return item->second; +int64_t Vocab::__getitem__(const c10::string_view &token) const { + int64_t id = _find(token); + if (stoi_[id] != -1) { + return stoi_[id]; } return unk_index_; } -void Vocab::append_token(const std::string &token) { - if (stoi_.find(token) == stoi_.end()) { - // Note: we can't do `stoi_[token] = stoi_.size()` because of a bug - // on Windows where the size gets updated before the assign occurs. - // For example if the size of `stoi_` is 2, doing - // `stoi_["test"] = stoi_.size()` will set `stoi_["test"]` to a - // value of 3 instead of 2 on Windows stoi_[token] = itos_.size(); - stoi_[token] = itos_.size(); - itos_.push_back(token); - } -} +void Vocab::append_token(const std::string &token) { _add(token); } void Vocab::insert_token(const std::string &token, const int64_t &index) { - if (index < 0 || index > static_cast(stoi_.size())) { + if (index < 0 || index > itos_.size()) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Specified index " << index << " is out of bounds of the size of stoi dictionary: " @@ -65,30 +53,31 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { std::to_string(stoi_.size()) + "."); } - const auto &item = stoi_.find(token); // if item already in stoi we throw an error - if (item != stoi_.end()) { + auto token_position = _find(c10::string_view{token.data(), token.size()}); + if (stoi_[token_position] != -1) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Token " << token - << " already exists in the Vocab with index: " << item->second - << std::endl; + << " already exists in the Vocab with index: " + << stoi_[token_position] << std::endl; #endif throw std::runtime_error("Token " + token + " already exists in the Vocab with index: " + - std::to_string(item->second) + "."); + std::to_string(stoi_[token_position]) + "."); } // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < itos_.size(); i++) { - stoi_[itos_[i]] = i + 1; + stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; } - stoi_[token] = index; itos_.insert(itos_.begin() + index, token); + stoi_[_find(c10::string_view{token.data(), token.size()})] = index; // need to update unk_index in case token equals unk_token or token // inserted before unk_token - unk_index_ = stoi_.find(unk_token_)->second; + unk_index_ = + stoi_[_find(c10::string_view{unk_token_.data(), unk_token_.size()})]; } std::string Vocab::lookup_token(const int64_t &index) { @@ -96,7 +85,7 @@ std::string Vocab::lookup_token(const int64_t &index) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Specified index " << index << " is out of bounds of the size of itos dictionary: " - << stoi_.size() << std::endl; + << itos_.size() << std::endl; #endif throw std::runtime_error( "Specified index " + std::to_string(index) + @@ -115,7 +104,8 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { return tokens; } -std::vector Vocab::lookup_indices(const StringList &tokens) { +std::vector +Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); for (int64_t i = 0; i < static_cast(tokens.size()); i++) { indices[i] = __getitem__(tokens[i]); @@ -125,11 +115,9 @@ std::vector Vocab::lookup_indices(const StringList &tokens) { std::unordered_map Vocab::get_stoi() const { std::unordered_map stoi; - stoi.reserve(stoi_.size()); - // construct tokens and index list - for (const auto &item : stoi_) { - stoi[item.first] = item.second; + for (const auto &item : itos_) { + stoi[item] = __getitem__(c10::string_view{item}); } return stoi; } @@ -150,8 +138,11 @@ int64_t _infer_lines(const std::string &file_path) { void parse_vocab_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, std::shared_ptr counter) { - std::ifstream fin; - fin.open(file_path, std::ios::in); + std::ifstream fin(file_path, std::ios::in); + if (!fin.is_open()) { + throw std::runtime_error("Cannot open input file " + file_path + "\n"); + } + fin.seekg(offset); for (int64_t i = start_line; i < end_line; i++) { @@ -171,8 +162,11 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, std::shared_ptr counter, torch::jit::script::Module &module) { - std::ifstream fin; - fin.open(file_path, std::ios::in); + std::ifstream fin(file_path, std::ios::in); + if (!fin.is_open()) { + throw std::runtime_error("Cannot open input file " + file_path + "\n"); + } + fin.seekg(offset); std::string line; @@ -205,7 +199,7 @@ struct CompareTokens { } }; -std::tuple +StringList _concat_tokens(std::vector> chunk_counters, const std::string &unk_token, const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { @@ -264,17 +258,7 @@ _concat_tokens(std::vector> chunk_counters, unique_tokens.insert(unique_tokens.begin(), unk_token); } - // create stoi - IndexDict stoi; - stoi.reserve(num_lines); - int64_t index = 0; - - for (const auto &token : unique_tokens) { - stoi[token] = index; - index++; - } - - return std::make_tuple(std::move(stoi), std::move(unique_tokens)); + return unique_tokens; } constexpr int64_t GRAIN_SIZE = 13107; @@ -319,14 +303,10 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::unique_lock lock(m); cv.wait(lock, [&thread_count] { return thread_count == 0; }); - IndexDict stoi; - StringList tokens; - std::tie(stoi, tokens) = + StringList tokens = _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, false); - int64_t unk_index = stoi.find(unk_token)->second; - - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), unk_token); } Vocab _build_vocab_from_text_file(const std::string &file_path, @@ -370,13 +350,10 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::unique_lock lock(m); cv.wait(lock, [&thread_count] { return thread_count == 0; }); - IndexDict stoi; - StringList tokens; - std::tie(stoi, tokens) = + StringList tokens = _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, true); - int64_t unk_index = stoi.find(unk_token)->second; - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), unk_token); } VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 0da660a633..660f6145d4 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -1,5 +1,5 @@ +#include #include - namespace torchtext { typedef std::vector StringList; @@ -10,29 +10,52 @@ typedef std::tuple, std::vector, VocabStates; struct Vocab : torch::CustomClassHolder { -private: + static const int32_t MAX_VOCAB_SIZE = 30000000; int64_t unk_index_; - IndexDict stoi_; - -public: + std::vector stoi_; const std::string version_str_ = "0.0.1"; StringList itos_; std::string unk_token_; explicit Vocab(const std::vector &tokens, const std::string &unk_token); - explicit Vocab(const StringList &tokens, const IndexDict &stoi, - - const std::string &unk_token, const int64_t unk_index); int64_t __len__() const; - int64_t __getitem__(const std::string &token) const; + int64_t __getitem__(const c10::string_view &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); - std::vector lookup_indices(const std::vector &tokens); + std::vector + lookup_indices(const std::vector &tokens); std::unordered_map get_stoi() const; std::vector get_itos() const; + +protected: + uint32_t _hash(const c10::string_view &str) const { + uint32_t h = 2166136261; + for (size_t i = 0; i < str.size(); i++) { + h = h ^ uint32_t(uint8_t(str[i])); + h = h * 16777619; + } + return h; + } + + uint32_t _find(const c10::string_view &w) const { + uint32_t stoi_size = stoi_.size(); + uint32_t id = _hash(w) % stoi_size; + while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) { + id = (id + 1) % stoi_size; + } + return id; + } + + void _add(const std::string &w) { + uint32_t h = _find(c10::string_view{w.data(), w.size()}); + if (stoi_[h] == -1) { + itos_.push_back(w); + stoi_[h] = itos_.size() - 1; + } + } }; VocabStates _serialize_vocab(const c10::intrusive_ptr &self);