Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6a43bd5

Browse files
joecummingsAlexander MazukabzovNayef211
authored
Remove dependency on the torch::jit::script::Module for mobile builds (#1885)
* Remove dependency on the torch::jit::script::Module for mobile builds Summary: In order to resolve linkage errors. Specifically when vocab getting build for "mobile" version it can't resolve symbols for torch::jit::script::Module Reviewed By: Nayef211 Differential Revision: D38771271 fbshipit-source-id: 693b656f2a17af9fa5a7a1904742557f902edb55 * Add vocab factory to CMakeLists * Fix type conversion from py::object to STL container (#1887) * Export symbols in `common.h` file (#1888) * Fix type conversion from py::object to STL container * Adding TORCHTEXT_API to expose symbols in common.h * Add common.h import in corresponding cpp file Co-authored-by: Alexander Mazukabzov <[email protected]> Co-authored-by: Nayef Ahmed <[email protected]>
1 parent 0225abe commit 6a43bd5

File tree

7 files changed

+310
-284
lines changed

7 files changed

+310
-284
lines changed

torchtext/csrc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ if (BUILD_TORCHTEXT_PYTHON_EXTENSION)
113113
set(
114114
EXTENSION_SOURCES
115115
register_pybindings.cpp
116+
vocab_factory.cpp
116117
)
117118

118119
set(

torchtext/csrc/common.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1+
#include <torchtext/csrc/common.h>
2+
13
#include <fstream>
24
#include <ios>
3-
#include <iostream>
45
#include <limits>
5-
#include <string>
6-
#include <vector>
76

87
namespace torchtext {
98
namespace impl {

torchtext/csrc/common.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
#include <torchtext/csrc/export.h>
2+
3+
#include <cstdint>
4+
#include <string>
5+
#include <vector>
6+
17
namespace torchtext {
28

39
namespace impl {
4-
int64_t divup(int64_t x, int64_t y);
5-
void infer_offsets(
10+
TORCHTEXT_API int64_t divup(int64_t x, int64_t y);
11+
TORCHTEXT_API void infer_offsets(
612
const std::string& file_path,
713
int64_t num_lines,
814
int64_t chunk_size,

torchtext/csrc/vocab.cpp

Lines changed: 0 additions & 225 deletions
Original file line numberDiff line numberDiff line change
@@ -136,231 +136,6 @@ StringList Vocab::get_itos() const {
136136
return itos_;
137137
}
138138

139-
int64_t _infer_lines(const std::string& file_path) {
140-
int64_t num_lines = 0;
141-
std::ifstream fin;
142-
fin.open(file_path, std::ios::in);
143-
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);
144-
145-
while (fin.ignore(std::numeric_limits<std::streamsize>::max(), '\n')) {
146-
num_lines++;
147-
}
148-
return num_lines;
149-
}
150-
151-
void parse_vocab_file_chunk(
152-
const std::string& file_path,
153-
size_t offset,
154-
const int64_t start_line,
155-
const int64_t end_line,
156-
const std::shared_ptr<IndexDict>& counter) {
157-
std::ifstream fin(file_path, std::ios::in);
158-
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);
159-
160-
fin.seekg(offset);
161-
162-
for (int64_t i = start_line; i < end_line; i++) {
163-
std::string token;
164-
fin >> token;
165-
fin >> std::ws;
166-
167-
if ((*counter).find(token) == (*counter).end()) {
168-
(*counter)[token] = 1;
169-
} else {
170-
(*counter)[token] += 1;
171-
}
172-
}
173-
}
174-
175-
void parse_raw_text_file_chunk(
176-
const std::string& file_path,
177-
size_t offset,
178-
const int64_t start_line,
179-
const int64_t end_line,
180-
const std::shared_ptr<IndexDict>& counter,
181-
torch::jit::script::Module& module) {
182-
std::ifstream fin(file_path, std::ios::in);
183-
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);
184-
185-
fin.seekg(offset);
186-
187-
std::string line;
188-
for (int64_t i = start_line; i < end_line; i++) {
189-
std::getline(fin, line);
190-
auto token_list =
191-
module.forward(std::vector<c10::IValue>({c10::IValue(line)})).toList();
192-
193-
for (size_t i = 0; i < token_list.size(); i++) {
194-
c10::IValue token_ref = token_list.get(i);
195-
std::string token = token_ref.toStringRef();
196-
197-
if ((*counter).find(token) == (*counter).end()) {
198-
(*counter)[token] = 1;
199-
} else {
200-
(*counter)[token] += 1;
201-
}
202-
}
203-
}
204-
}
205-
206-
StringList _concat_tokens(
207-
std::vector<std::shared_ptr<IndexDict>> chunk_counters,
208-
const int64_t min_freq,
209-
const int64_t num_lines,
210-
const bool sort_tokens) {
211-
TORCH_CHECK(
212-
chunk_counters.size() > 0,
213-
"There must be at least 1 chunk to concatenate!");
214-
215-
IndexDict tokens_freq;
216-
StringList unique_tokens;
217-
unique_tokens.reserve(num_lines);
218-
219-
// concatenate all counters
220-
for (size_t i = 0; i < chunk_counters.size(); i++) {
221-
auto& cur_counter = *chunk_counters[i];
222-
for (const auto& item : cur_counter) {
223-
int64_t cur_token_freq = item.second;
224-
if (tokens_freq.find(item.first) != tokens_freq.end()) {
225-
tokens_freq[item.first] += cur_token_freq;
226-
} else {
227-
tokens_freq[item.first] = cur_token_freq;
228-
}
229-
230-
// add to tokens list only if all of the conditions are met:
231-
// 1. token is not empty
232-
// 2. we exceed min_freq for the first time
233-
if (item.first.length() &&
234-
tokens_freq[item.first] - cur_token_freq < min_freq &&
235-
tokens_freq[item.first] >= min_freq) {
236-
unique_tokens.push_back(item.first);
237-
}
238-
}
239-
}
240-
241-
// create token freq pairs
242-
std::vector<std::pair<std::string, int64_t>> token_freq_pairs;
243-
244-
for (std::string& token : unique_tokens) {
245-
auto token_freq = tokens_freq[token];
246-
token_freq_pairs.emplace_back(std::move(token), token_freq);
247-
}
248-
unique_tokens.clear();
249-
250-
// sort tokens by freq
251-
if (sort_tokens) {
252-
CompareTokens compare_tokens;
253-
std::sort(token_freq_pairs.begin(), token_freq_pairs.end(), compare_tokens);
254-
}
255-
256-
// update unique tokens with correct order
257-
for (auto& token_freq_pair : token_freq_pairs) {
258-
unique_tokens.emplace_back(std::move(token_freq_pair.first));
259-
}
260-
261-
return unique_tokens;
262-
}
263-
264-
constexpr int64_t GRAIN_SIZE = 13107;
265-
Vocab _load_vocab_from_file(
266-
const std::string& file_path,
267-
const int64_t min_freq,
268-
const int64_t num_cpus) {
269-
int64_t num_lines = _infer_lines(file_path);
270-
int64_t chunk_size = impl::divup(num_lines, num_cpus);
271-
// Launching a thread on less lines than this likely has too much overhead.
272-
// TODO: Add explicit test beyond grain size to cover multithreading
273-
chunk_size = std::max(chunk_size, GRAIN_SIZE);
274-
275-
std::vector<size_t> offsets;
276-
impl::infer_offsets(file_path, num_lines, chunk_size, offsets);
277-
278-
std::vector<std::shared_ptr<IndexDict>> chunk_counters;
279-
280-
std::mutex m;
281-
std::condition_variable cv;
282-
std::atomic<int> thread_count(0);
283-
284-
// create threads
285-
int64_t j = 0;
286-
for (int64_t i = 0; i < num_lines; i += chunk_size) {
287-
auto counter_ptr = std::make_shared<IndexDict>();
288-
289-
thread_count++;
290-
at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() {
291-
parse_vocab_file_chunk(
292-
file_path,
293-
offsets[j],
294-
i,
295-
std::min(num_lines, i + chunk_size),
296-
counter_ptr);
297-
std::lock_guard<std::mutex> lk(m);
298-
thread_count--;
299-
cv.notify_all();
300-
});
301-
chunk_counters.push_back(counter_ptr);
302-
j++;
303-
}
304-
305-
// block until all threads finish execution
306-
std::unique_lock<std::mutex> lock(m);
307-
cv.wait(lock, [&thread_count] { return thread_count == 0; });
308-
309-
StringList tokens =
310-
_concat_tokens(chunk_counters, min_freq, num_lines, false);
311-
312-
return Vocab(std::move(tokens));
313-
}
314-
315-
Vocab _build_vocab_from_text_file(
316-
const std::string& file_path,
317-
const int64_t min_freq,
318-
const int64_t num_cpus,
319-
torch::jit::script::Module tokenizer) {
320-
int64_t num_lines = _infer_lines(file_path);
321-
int64_t chunk_size = impl::divup(num_lines, num_cpus);
322-
// Launching a thread on less lines than this likely has too much overhead.
323-
chunk_size = std::max(chunk_size, GRAIN_SIZE);
324-
325-
std::vector<size_t> offsets;
326-
impl::infer_offsets(file_path, num_lines, chunk_size, offsets);
327-
328-
std::vector<std::shared_ptr<IndexDict>> chunk_counters;
329-
330-
std::mutex m;
331-
std::condition_variable cv;
332-
std::atomic<int> thread_count(0);
333-
334-
// create threads
335-
int64_t j = 0;
336-
for (int64_t i = 0; i < num_lines; i += chunk_size) {
337-
auto counter_ptr = std::make_shared<IndexDict>();
338-
thread_count++;
339-
at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() {
340-
parse_raw_text_file_chunk(
341-
file_path,
342-
offsets[j],
343-
i,
344-
std::min(num_lines, i + chunk_size),
345-
counter_ptr,
346-
tokenizer);
347-
std::lock_guard<std::mutex> lk(m);
348-
thread_count--;
349-
cv.notify_all();
350-
});
351-
chunk_counters.push_back(counter_ptr);
352-
j++;
353-
}
354-
355-
// block until all threads finish execution
356-
std::unique_lock<std::mutex> lock(m);
357-
cv.wait(lock, [&thread_count] { return thread_count == 0; });
358-
359-
StringList tokens = _concat_tokens(chunk_counters, min_freq, num_lines, true);
360-
361-
return Vocab(std::move(tokens));
362-
}
363-
364139
VocabStates _serialize_vocab(const c10::intrusive_ptr<Vocab>& self) {
365140
std::vector<int64_t> integers;
366141
StringList strings = self->itos_;

torchtext/csrc/vocab.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,4 @@ struct Vocab : torch::CustomClassHolder {
9191
TORCHTEXT_API VocabStates
9292
_serialize_vocab(const c10::intrusive_ptr<Vocab>& self);
9393
TORCHTEXT_API c10::intrusive_ptr<Vocab> _deserialize_vocab(VocabStates states);
94-
95-
TORCHTEXT_API Vocab _load_vocab_from_file(
96-
const std::string& file_path,
97-
const int64_t min_freq,
98-
const int64_t num_cpus);
99-
TORCHTEXT_API Vocab _build_vocab_from_text_file(
100-
const std::string& file_path,
101-
const int64_t min_freq,
102-
const int64_t num_cpus,
103-
torch::jit::script::Module tokenizer);
10494
} // namespace torchtext

0 commit comments

Comments
 (0)