@@ -381,8 +381,8 @@ std::vector<std::string> GPT2BPEEncoder::PreTokenize_(std::string input) {
381381std::vector<int64_t > GPT2BPEEncoder::Encode (const std::string& text) {
382382 std::vector<int64_t > bpe_token_ids;
383383 for (const auto & token : PreTokenize_ (text)) {
384- if (added_tokens_encoder .contains (token)) {
385- bpe_token_ids.push_back (added_tokens_encoder .at (token));
384+ if (added_tokens_encoder_ .contains (token)) {
385+ bpe_token_ids.push_back (added_tokens_encoder_ .at (token));
386386 continue ;
387387 }
388388 bool is_never_split_token =
@@ -397,18 +397,45 @@ std::vector<int64_t> GPT2BPEEncoder::Encode(const std::string& text) {
397397
398398std::string GPT2BPEEncoder::Decode (const std::vector<int64_t >& tokens) {
399399 std::string text;
400+ std::vector<bool > special_token_flags (tokens.size ());
400401 // setup converter for converting wide chars to/from chars
401402 using convert_type = std::codecvt_utf8<wchar_t >;
402403 std::wstring_convert<convert_type, wchar_t > converter;
403404
404- for (const auto token : tokens) {
405- // get unicode string for given integer key
406- const std::string str = bpe_decoder_.at (token);
407- const std::wstring ws = converter.from_bytes (str);
408- for (wchar_t wchr : ws) {
409- // get output character from byte decoder for each wide character
410- unsigned char uchr = byte_decoder_.at (converter.to_bytes (wchr));
411- text.push_back (uchr);
405+ for (int tok_idx = 0 ; tok_idx < tokens.size (); tok_idx++) {
406+ const auto token = tokens[tok_idx];
407+ std::string decoded_token;
408+
409+ if (added_tokens_decoder_.contains (token)) {
410+ // string is a special token from extended vocab
411+ decoded_token = added_tokens_decoder_.at (token);
412+ special_token_flags[tok_idx] = true ;
413+ } else {
414+ const std::string str = bpe_decoder_.at (token);
415+ if (bpe_never_split_set_.find (str) != bpe_never_split_set_.end ()) {
416+ // string is a special token from known vocab
417+ decoded_token = str;
418+ special_token_flags[tok_idx] = true ;
419+ } else {
420+ // string is a regular token from known vocab
421+ const std::wstring ws = converter.from_bytes (str);
422+ for (wchar_t wchr : ws) {
423+ // get output character from byte decoder for each wide character
424+ unsigned char uchr = byte_decoder_.at (converter.to_bytes (wchr));
425+ decoded_token.push_back (uchr);
426+ }
427+ }
428+ }
429+
430+ // fix left space(s) for special tokens
431+ if (special_token_flags[tok_idx] == true &&
432+ (tok_idx > 0 && special_token_flags[tok_idx - 1 ] == false )) {
433+ text.push_back (' ' );
434+ }
435+ text.append (decoded_token);
436+ // fix right space(s) for special tokens
437+ if (special_token_flags[tok_idx] == true && tok_idx != tokens.size () - 1 ) {
438+ text.push_back (' ' );
412439 }
413440 }
414441 return text;
@@ -433,30 +460,34 @@ int64_t GPT2BPEEncoder::AddSpecialTokens(
433460 int64_t newly_added = 0 ;
434461
435462 /* All special tokens get added to `bpe_never_split_set_` set to avoid being
436- * split during tokenization. Tokens are added to `added_tokens_encoder ` only
437- * if they are not already known (i.e. present in `bpe_encoder_`).
463+ * split during tokenization. Tokens are added to `added_tokens_encoder_ ` only
464+ * if they are not already known (i.e. not already present in `bpe_encoder_`).
438465 */
439466
440467 // Loop for standard tokens such as "bos_token", "eos_token", etc.
441468 for (auto const & token : standard_special_tokens_dict) {
442- if (added_tokens_encoder .contains (token.value ()))
469+ if (added_tokens_encoder_ .contains (token.value ()))
443470 continue ;
444471 bpe_never_split_set_.insert (token.value ());
445472 if (!bpe_encoder_.contains (token.value ())) {
446- added_tokens_encoder.insert (
447- token.value (), bpe_encoder_.size () + added_tokens_encoder.size ());
473+ added_tokens_encoder_.insert (
474+ token.value (), bpe_encoder_.size () + added_tokens_encoder_.size ());
475+ added_tokens_decoder_.insert (
476+ bpe_decoder_.size () + added_tokens_decoder_.size (), token.value ());
448477 newly_added++;
449478 }
450479 }
451480
452481 // Loop for any additional tokens
453482 for (auto const & token : additional_special_tokens) {
454- if (added_tokens_encoder .contains (token))
483+ if (added_tokens_encoder_ .contains (token))
455484 continue ;
456485 bpe_never_split_set_.insert (token);
457486 if (!bpe_encoder_.contains (token)) {
458- added_tokens_encoder.insert (
459- token, bpe_encoder_.size () + added_tokens_encoder.size ());
487+ added_tokens_encoder_.insert (
488+ token, bpe_encoder_.size () + added_tokens_encoder_.size ());
489+ added_tokens_decoder_.insert (
490+ bpe_decoder_.size () + added_tokens_decoder_.size (), token);
460491 newly_added++;
461492 }
462493 }
0 commit comments