Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f7c2985

Browse files
authored
Some simple c++ refactoring (#1327)
1 parent 0d7b9e2 commit f7c2985

File tree

6 files changed

+43
-42
lines changed

6 files changed

+43
-42
lines changed

torchtext/csrc/sentencepiece.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ c10::intrusive_ptr<SentencePiece> load_sp_model(const std::string &path) {
7575
}
7676

7777
c10::intrusive_ptr<SentencePiece>
78-
load_sp_model_string(const std::string &content) {
78+
load_sp_model_string(std::string content) {
7979
return c10::make_intrusive<SentencePiece>(std::move(content));
8080
}
8181

torchtext/csrc/sentencepiece.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ void generate_sp_model(const std::string &filename, const int64_t &vocab_size,
3333
const std::string &model_prefix);
3434
c10::intrusive_ptr<SentencePiece> load_sp_model(const std::string &path);
3535
c10::intrusive_ptr<SentencePiece>
36-
load_sp_model_string(const std::string &content);
36+
load_sp_model_string(std::string content);
3737

3838
} // namespace torchtext

torchtext/csrc/vectors.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616
namespace torchtext {
1717

18-
Vectors::Vectors(const IndexMap &stoi, const torch::Tensor vectors,
19-
const torch::Tensor &unk_tensor)
20-
: stoi_(stoi), vectors_(vectors), unk_tensor_(unk_tensor) {}
18+
Vectors::Vectors(const IndexMap &stoi, torch::Tensor vectors,
19+
torch::Tensor unk_tensor)
20+
: stoi_(stoi), vectors_(std::move(vectors)), unk_tensor_(std::move(unk_tensor)) {}
2121

2222
Vectors::Vectors(const std::vector<std::string> &tokens,
2323
const std::vector<std::int64_t> &indices,
24-
const torch::Tensor &vectors, const torch::Tensor &unk_tensor)
24+
torch::Tensor vectors, torch::Tensor unk_tensor)
2525
: vectors_(std::move(vectors)), unk_tensor_(std::move(unk_tensor)) {
2626
// guarding against size mismatch of tokens and indices
27-
if (static_cast<int>(tokens.size()) != indices.size()) {
27+
if (tokens.size() != indices.size()) {
2828
#ifdef _MSC_VER
2929
std::cerr << "[RuntimeError] Mismatching sizes for tokens and indices. "
3030
"Size of tokens: "
@@ -72,7 +72,7 @@ torch::Tensor Vectors::__getitem__(const std::string &token) {
7272
torch::Tensor Vectors::lookup_vectors(const std::vector<std::string> &tokens) {
7373
std::vector<torch::Tensor> vectors;
7474
for (const std::string &token : tokens) {
75-
vectors.push_back(__getitem__(token));
75+
vectors.emplace_back(__getitem__(token));
7676
}
7777
return torch::stack(vectors, 0);
7878
}
@@ -206,7 +206,7 @@ _concat_vectors(std::vector<std::shared_ptr<StringList>> chunk_tokens,
206206

207207
constexpr int64_t GRAIN_SIZE = 131072;
208208
std::tuple<Vectors, std::vector<std::string>> _load_token_and_vectors_from_file(
209-
const std::string &file_path, const std::string delimiter_str,
209+
const std::string &file_path, const std::string &delimiter_str,
210210
int64_t num_cpus, c10::optional<torch::Tensor> opt_unk_tensor) {
211211

212212
TORCH_CHECK(delimiter_str.size() == 1,
@@ -265,7 +265,7 @@ std::tuple<Vectors, std::vector<std::string>> _load_token_and_vectors_from_file(
265265

266266
torch::Tensor unk_tensor;
267267
if (opt_unk_tensor) {
268-
unk_tensor = *opt_unk_tensor;
268+
unk_tensor = std::move(*opt_unk_tensor);
269269
} else {
270270
unk_tensor = torch::zeros({vector_dim}, torch::kFloat32);
271271
}

torchtext/csrc/vectors.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ struct Vectors : torch::CustomClassHolder {
1919
torch::Tensor vectors_;
2020
torch::Tensor unk_tensor_;
2121

22-
explicit Vectors(const IndexMap &stoi, const torch::Tensor vectors,
23-
const torch::Tensor &unk_tensor);
22+
explicit Vectors(const IndexMap &stoi, torch::Tensor vectors,
23+
torch::Tensor unk_tensor);
2424
explicit Vectors(const std::vector<std::string> &tokens,
2525
const std::vector<std::int64_t> &indices,
26-
const torch::Tensor &vectors,
27-
const torch::Tensor &unk_tensor);
26+
torch::Tensor vectors,
27+
torch::Tensor unk_tensor);
2828
std::unordered_map<std::string, int64_t> get_stoi();
2929
torch::Tensor __getitem__(const std::string &token);
3030
torch::Tensor lookup_vectors(const std::vector<std::string> &tokens);
@@ -36,7 +36,7 @@ VectorsStates _serialize_vectors(const c10::intrusive_ptr<Vectors> &self);
3636
c10::intrusive_ptr<Vectors> _deserialize_vectors(VectorsStates states);
3737

3838
std::tuple<Vectors, std::vector<std::string>> _load_token_and_vectors_from_file(
39-
const std::string &file_path, const std::string delimiter_str,
39+
const std::string &file_path, const std::string &delimiter_str,
4040
const int64_t num_cpus, c10::optional<torch::Tensor> opt_unk_tensor);
4141

4242
} // namespace torchtext

torchtext/csrc/vocab.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
#include <vocab.h> // @manual
88
namespace torchtext {
99

10-
Vocab::Vocab(const StringList &tokens,
10+
Vocab::Vocab(StringList tokens,
1111
const c10::optional<int64_t> &default_index)
1212
: stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} {
13-
for (size_t i = 0; i < tokens.size(); i++) {
13+
for (auto &token : tokens) {
1414
// throw error if duplicate token is found
15-
auto id = _find(c10::string_view{tokens[i].data(), tokens[i].size()});
15+
auto id = _find(c10::string_view{token});
1616
TORCH_CHECK(stoi_[id] == -1,
17-
"Duplicate token found in tokens list: " + tokens[i]);
17+
"Duplicate token found in tokens list: " + token);
1818

19-
_add(tokens[i]);
19+
_add(std::move(token));
2020
}
2121
}
2222

23-
Vocab::Vocab(const StringList &tokens) : Vocab(tokens, {}) {}
23+
Vocab::Vocab(StringList tokens) : Vocab(std::move(tokens), {}) {}
2424

2525
int64_t Vocab::__len__() const { return itos_.size(); }
2626

@@ -54,17 +54,17 @@ c10::optional<int64_t> Vocab::get_default_index() const {
5454
return default_index_;
5555
}
5656

57-
void Vocab::append_token(const std::string &token) {
57+
void Vocab::append_token(std::string token) {
5858
// throw error if token already exist in vocab
59-
auto id = _find(c10::string_view{token.data(), token.size()});
59+
auto id = _find(c10::string_view{token});
6060
TORCH_CHECK(stoi_[id] == -1, "Token " + token +
6161
" already exists in the Vocab with index: " +
6262
std::to_string(stoi_[id]));
6363

64-
_add(token);
64+
_add(std::move(token));
6565
}
6666

67-
void Vocab::insert_token(const std::string &token, const int64_t &index) {
67+
void Vocab::insert_token(std::string token, const int64_t &index) {
6868
// throw error if index is not valid
6969
TORCH_CHECK(index >= 0 && index <= __len__(),
7070
"Specified index " + std::to_string(index) +
@@ -76,11 +76,11 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) {
7676

7777
// need to offset all tokens greater than or equal index by 1
7878
for (size_t i = index; i < __len__(); i++) {
79-
stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1;
79+
stoi_[_find(c10::string_view{itos_[i]})] = i + 1;
8080
}
8181

82-
itos_.insert(itos_.begin() + index, token);
83-
stoi_[_find(c10::string_view{token.data(), token.size()})] = index;
82+
stoi_[_find(c10::string_view{token})] = index;
83+
itos_.insert(itos_.begin() + index, std::move(token));
8484
}
8585

8686
std::string Vocab::lookup_token(const int64_t &index) {
@@ -144,7 +144,7 @@ int64_t _infer_lines(const std::string &file_path) {
144144

145145
void parse_vocab_file_chunk(const std::string &file_path, size_t offset,
146146
const int64_t start_line, const int64_t end_line,
147-
std::shared_ptr<IndexDict> counter) {
147+
const std::shared_ptr<IndexDict> &counter) {
148148
std::ifstream fin(file_path, std::ios::in);
149149
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);
150150

@@ -165,7 +165,7 @@ void parse_vocab_file_chunk(const std::string &file_path, size_t offset,
165165

166166
void parse_raw_text_file_chunk(const std::string &file_path, size_t offset,
167167
const int64_t start_line, const int64_t end_line,
168-
std::shared_ptr<IndexDict> counter,
168+
const std::shared_ptr<IndexDict> &counter,
169169
torch::jit::script::Module &module) {
170170
std::ifstream fin(file_path, std::ios::in);
171171
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);
@@ -225,9 +225,11 @@ _concat_tokens(std::vector<std::shared_ptr<IndexDict>> chunk_counters,
225225
// create token freq pairs
226226
std::vector<std::pair<std::string, int64_t>> token_freq_pairs;
227227

228-
for (std::string token : unique_tokens) {
229-
token_freq_pairs.push_back(std::make_pair(token, tokens_freq[token]));
228+
for (std::string &token : unique_tokens) {
229+
auto token_freq = tokens_freq[token];
230+
token_freq_pairs.emplace_back(std::move(token), token_freq);
230231
}
232+
unique_tokens.clear();
231233

232234
// sort tokens by freq
233235
if (sort_tokens) {
@@ -236,9 +238,8 @@ _concat_tokens(std::vector<std::shared_ptr<IndexDict>> chunk_counters,
236238
}
237239

238240
// update unique tokens with correct order
239-
unique_tokens.clear();
240-
for (const auto &token_freq_pair : token_freq_pairs) {
241-
unique_tokens.push_back(token_freq_pair.first);
241+
for (auto &token_freq_pair : token_freq_pairs) {
242+
unique_tokens.emplace_back(std::move(token_freq_pair.first));
242243
}
243244

244245
return unique_tokens;

torchtext/csrc/vocab.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ int64_t _infer_lines(const std::string &file_path);
2727

2828
struct Vocab : torch::CustomClassHolder {
2929
static const int32_t MAX_VOCAB_SIZE = 30000000;
30-
int64_t unk_index_;
30+
int64_t unk_index_{};
3131
std::vector<int32_t> stoi_;
3232
const std::string version_str_ = "0.0.2";
3333
StringList itos_;
@@ -36,16 +36,16 @@ struct Vocab : torch::CustomClassHolder {
3636
// TODO: [can we remove this?] we need to keep this constructor, otherwise
3737
// torch binding gets compilation error: no matching constructor for
3838
// initialization of 'torchtext::Vocab'
39-
explicit Vocab(const StringList &tokens);
40-
explicit Vocab(const StringList &tokens,
39+
explicit Vocab(StringList tokens);
40+
explicit Vocab(StringList tokens,
4141
const c10::optional<int64_t> &default_index);
4242
int64_t __len__() const;
4343
int64_t __getitem__(const c10::string_view &token) const;
4444
bool __contains__(const c10::string_view &token) const;
4545
void set_default_index(c10::optional<int64_t> index);
4646
c10::optional<int64_t> get_default_index() const;
47-
void insert_token(const std::string &token, const int64_t &index);
48-
void append_token(const std::string &token);
47+
void insert_token(std::string token, const int64_t &index);
48+
void append_token(std::string token);
4949
std::string lookup_token(const int64_t &index);
5050
std::vector<std::string> lookup_tokens(const std::vector<int64_t> &indices);
5151
std::vector<int64_t>
@@ -72,10 +72,10 @@ struct Vocab : torch::CustomClassHolder {
7272
return id;
7373
}
7474

75-
void _add(const std::string &w) {
75+
void _add(std::string w) {
7676
uint32_t h = _find(c10::string_view{w.data(), w.size()});
7777
if (stoi_[h] == -1) {
78-
itos_.push_back(w);
78+
itos_.emplace_back(std::move(w));
7979
stoi_[h] = itos_.size() - 1;
8080
}
8181
}

0 commit comments

Comments
 (0)