From 04bb26fe404cec99ce0ac2727f4473cdf8edee17 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Thu, 3 Feb 2022 17:17:59 -0800 Subject: [PATCH] Fix handling of end of file while reading vocab from file --- test/asset/vocab_test.txt | 2 +- torchtext/csrc/vocab.cpp | 35 +++++++++++++++++------------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/test/asset/vocab_test.txt b/test/asset/vocab_test.txt index 389d543a10..e041165aa0 100644 --- a/test/asset/vocab_test.txt +++ b/test/asset/vocab_test.txt @@ -4,4 +4,4 @@ c a b a -c \ No newline at end of file +c diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 6f7abc8db4..53a2fba6c2 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -1,14 +1,14 @@ -#include // @manual +#include // @manual #include +#include // @manual +#include // @manual + #include #include #include -#include // @manual -#include // @manual namespace torchtext { -Vocab::Vocab(StringList tokens, - const c10::optional &default_index) +Vocab::Vocab(StringList tokens, const c10::optional &default_index) : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { for (auto &token : tokens) { // throw error if duplicate token is found @@ -34,8 +34,7 @@ bool Vocab::__contains__(const c10::string_view &token) const { int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); - if (stoi_[id] != -1) - return stoi_[id]; + if (stoi_[id] != -1) return stoi_[id]; // throw error if default_index_ is not set TORCH_CHECK(default_index_.has_value(), @@ -110,8 +109,8 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { return tokens; } -std::vector -Vocab::lookup_indices(const std::vector &tokens) { +std::vector Vocab::lookup_indices( + const std::vector &tokens) { std::vector indices(tokens.size()); for (size_t i = 0; i < tokens.size(); i++) { indices[i] = __getitem__(tokens[i]); @@ -191,11 +190,9 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, } } -StringList -_concat_tokens(std::vector> chunk_counters, - const int64_t min_freq, const int64_t num_lines, - const bool sort_tokens) { - +StringList _concat_tokens( + std::vector> chunk_counters, + const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); @@ -214,8 +211,11 @@ _concat_tokens(std::vector> chunk_counters, tokens_freq[item.first] = cur_token_freq; } - // add to tokens list only if we exceed min_freq for the first time - if (tokens_freq[item.first] - cur_token_freq < min_freq && + // add to tokens list only if all of the conditions are met: + // 1. token is not empty + // 2. we exceed min_freq for the first time + if (item.first.length() && + tokens_freq[item.first] - cur_token_freq < min_freq && tokens_freq[item.first] >= min_freq) { unique_tokens.push_back(item.first); } @@ -248,7 +248,6 @@ _concat_tokens(std::vector> chunk_counters, constexpr int64_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, const int64_t min_freq, const int64_t num_cpus) { - int64_t num_lines = _infer_lines(file_path); int64_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. @@ -374,4 +373,4 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { return c10::make_intrusive(std::move(strings), default_index); } -} // namespace torchtext +} // namespace torchtext