@@ -136,231 +136,6 @@ StringList Vocab::get_itos() const {
136136 return itos_;
137137}
138138
139- int64_t _infer_lines (const std::string& file_path) {
140- int64_t num_lines = 0 ;
141- std::ifstream fin;
142- fin.open (file_path, std::ios::in);
143- TORCH_CHECK (fin.is_open (), " Cannot open input file " + file_path);
144-
145- while (fin.ignore (std::numeric_limits<std::streamsize>::max (), ' \n ' )) {
146- num_lines++;
147- }
148- return num_lines;
149- }
150-
151- void parse_vocab_file_chunk (
152- const std::string& file_path,
153- size_t offset,
154- const int64_t start_line,
155- const int64_t end_line,
156- const std::shared_ptr<IndexDict>& counter) {
157- std::ifstream fin (file_path, std::ios::in);
158- TORCH_CHECK (fin.is_open (), " Cannot open input file " + file_path);
159-
160- fin.seekg (offset);
161-
162- for (int64_t i = start_line; i < end_line; i++) {
163- std::string token;
164- fin >> token;
165- fin >> std::ws;
166-
167- if ((*counter).find (token) == (*counter).end ()) {
168- (*counter)[token] = 1 ;
169- } else {
170- (*counter)[token] += 1 ;
171- }
172- }
173- }
174-
175- void parse_raw_text_file_chunk (
176- const std::string& file_path,
177- size_t offset,
178- const int64_t start_line,
179- const int64_t end_line,
180- const std::shared_ptr<IndexDict>& counter,
181- torch::jit::script::Module& module ) {
182- std::ifstream fin (file_path, std::ios::in);
183- TORCH_CHECK (fin.is_open (), " Cannot open input file " + file_path);
184-
185- fin.seekg (offset);
186-
187- std::string line;
188- for (int64_t i = start_line; i < end_line; i++) {
189- std::getline (fin, line);
190- auto token_list =
191- module .forward (std::vector<c10::IValue>({c10::IValue (line)})).toList ();
192-
193- for (size_t i = 0 ; i < token_list.size (); i++) {
194- c10::IValue token_ref = token_list.get (i);
195- std::string token = token_ref.toStringRef ();
196-
197- if ((*counter).find (token) == (*counter).end ()) {
198- (*counter)[token] = 1 ;
199- } else {
200- (*counter)[token] += 1 ;
201- }
202- }
203- }
204- }
205-
206- StringList _concat_tokens (
207- std::vector<std::shared_ptr<IndexDict>> chunk_counters,
208- const int64_t min_freq,
209- const int64_t num_lines,
210- const bool sort_tokens) {
211- TORCH_CHECK (
212- chunk_counters.size () > 0 ,
213- " There must be at least 1 chunk to concatenate!" );
214-
215- IndexDict tokens_freq;
216- StringList unique_tokens;
217- unique_tokens.reserve (num_lines);
218-
219- // concatenate all counters
220- for (size_t i = 0 ; i < chunk_counters.size (); i++) {
221- auto & cur_counter = *chunk_counters[i];
222- for (const auto & item : cur_counter) {
223- int64_t cur_token_freq = item.second ;
224- if (tokens_freq.find (item.first ) != tokens_freq.end ()) {
225- tokens_freq[item.first ] += cur_token_freq;
226- } else {
227- tokens_freq[item.first ] = cur_token_freq;
228- }
229-
230- // add to tokens list only if all of the conditions are met:
231- // 1. token is not empty
232- // 2. we exceed min_freq for the first time
233- if (item.first .length () &&
234- tokens_freq[item.first ] - cur_token_freq < min_freq &&
235- tokens_freq[item.first ] >= min_freq) {
236- unique_tokens.push_back (item.first );
237- }
238- }
239- }
240-
241- // create token freq pairs
242- std::vector<std::pair<std::string, int64_t >> token_freq_pairs;
243-
244- for (std::string& token : unique_tokens) {
245- auto token_freq = tokens_freq[token];
246- token_freq_pairs.emplace_back (std::move (token), token_freq);
247- }
248- unique_tokens.clear ();
249-
250- // sort tokens by freq
251- if (sort_tokens) {
252- CompareTokens compare_tokens;
253- std::sort (token_freq_pairs.begin (), token_freq_pairs.end (), compare_tokens);
254- }
255-
256- // update unique tokens with correct order
257- for (auto & token_freq_pair : token_freq_pairs) {
258- unique_tokens.emplace_back (std::move (token_freq_pair.first ));
259- }
260-
261- return unique_tokens;
262- }
263-
264- constexpr int64_t GRAIN_SIZE = 13107 ;
265- Vocab _load_vocab_from_file (
266- const std::string& file_path,
267- const int64_t min_freq,
268- const int64_t num_cpus) {
269- int64_t num_lines = _infer_lines (file_path);
270- int64_t chunk_size = impl::divup (num_lines, num_cpus);
271- // Launching a thread on less lines than this likely has too much overhead.
272- // TODO: Add explicit test beyond grain size to cover multithreading
273- chunk_size = std::max (chunk_size, GRAIN_SIZE);
274-
275- std::vector<size_t > offsets;
276- impl::infer_offsets (file_path, num_lines, chunk_size, offsets);
277-
278- std::vector<std::shared_ptr<IndexDict>> chunk_counters;
279-
280- std::mutex m;
281- std::condition_variable cv;
282- std::atomic<int > thread_count (0 );
283-
284- // create threads
285- int64_t j = 0 ;
286- for (int64_t i = 0 ; i < num_lines; i += chunk_size) {
287- auto counter_ptr = std::make_shared<IndexDict>();
288-
289- thread_count++;
290- at::launch ([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() {
291- parse_vocab_file_chunk (
292- file_path,
293- offsets[j],
294- i,
295- std::min (num_lines, i + chunk_size),
296- counter_ptr);
297- std::lock_guard<std::mutex> lk (m);
298- thread_count--;
299- cv.notify_all ();
300- });
301- chunk_counters.push_back (counter_ptr);
302- j++;
303- }
304-
305- // block until all threads finish execution
306- std::unique_lock<std::mutex> lock (m);
307- cv.wait (lock, [&thread_count] { return thread_count == 0 ; });
308-
309- StringList tokens =
310- _concat_tokens (chunk_counters, min_freq, num_lines, false );
311-
312- return Vocab (std::move (tokens));
313- }
314-
315- Vocab _build_vocab_from_text_file (
316- const std::string& file_path,
317- const int64_t min_freq,
318- const int64_t num_cpus,
319- torch::jit::script::Module tokenizer) {
320- int64_t num_lines = _infer_lines (file_path);
321- int64_t chunk_size = impl::divup (num_lines, num_cpus);
322- // Launching a thread on less lines than this likely has too much overhead.
323- chunk_size = std::max (chunk_size, GRAIN_SIZE);
324-
325- std::vector<size_t > offsets;
326- impl::infer_offsets (file_path, num_lines, chunk_size, offsets);
327-
328- std::vector<std::shared_ptr<IndexDict>> chunk_counters;
329-
330- std::mutex m;
331- std::condition_variable cv;
332- std::atomic<int > thread_count (0 );
333-
334- // create threads
335- int64_t j = 0 ;
336- for (int64_t i = 0 ; i < num_lines; i += chunk_size) {
337- auto counter_ptr = std::make_shared<IndexDict>();
338- thread_count++;
339- at::launch ([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() {
340- parse_raw_text_file_chunk (
341- file_path,
342- offsets[j],
343- i,
344- std::min (num_lines, i + chunk_size),
345- counter_ptr,
346- tokenizer);
347- std::lock_guard<std::mutex> lk (m);
348- thread_count--;
349- cv.notify_all ();
350- });
351- chunk_counters.push_back (counter_ptr);
352- j++;
353- }
354-
355- // block until all threads finish execution
356- std::unique_lock<std::mutex> lock (m);
357- cv.wait (lock, [&thread_count] { return thread_count == 0 ; });
358-
359- StringList tokens = _concat_tokens (chunk_counters, min_freq, num_lines, true );
360-
361- return Vocab (std::move (tokens));
362- }
363-
364139VocabStates _serialize_vocab (const c10::intrusive_ptr<Vocab>& self) {
365140 std::vector<int64_t > integers;
366141 StringList strings = self->itos_ ;
0 commit comments