6868#include <cstdio>
6969#include <cstring>
7070#include <ctime>
71+ #include <cwctype>
7172#include <forward_list>
7273#include <fstream>
7374#include <functional>
7475#include <initializer_list>
76+ #include <locale>
7577#include <map>
7678#include <memory>
7779#include <mutex>
@@ -8941,37 +8943,46 @@ struct llm_tokenizer_wpm {
89418943    }
89428944
89438945    std::vector<std::string> preprocess(const std::string & text) {
8944-         std::string ori_str = normalize(text);
8945-         uint64_t ori_size = ori_str.size();
8946+         // normalalization form D
8947+         std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
8948+         std::vector<uint32_t> nfd_codepoints;
8949+         for (uint32_t code : codepoints) {
8950+             auto it = nfd_map.find(code);
8951+             if (it != nfd_map.end()) {
8952+                 for (uint32_t c : it->second) {
8953+                     nfd_codepoints.push_back(c);
8954+                 }
8955+             } else {
8956+                 nfd_codepoints.push_back(code);
8957+             }
8958+         }
89468959
8947-         // single punct / single symbol / single digit
8948-         // baseline: add whitespace on the left and right of punct and chinese characters
8949-         std::vector<std::string> words;
8960+         // strip accents, strip control, uniformize whitespace,
8961+         // to lowercase, pad chinese characters, pad punctuation
89508962        std::string new_str = "";
8951-         uint64_t i = 0;
8952-         while (i < ori_size) {
8953-             int utf_char_len = utf8_len(ori_str[i]);
8954-             if ((utf_char_len == 1) && ispunct(ori_str[i])) {
8955-                 new_str += " ";
8956-                 new_str += ori_str[i];
8957-                 new_str += " ";
8958-                 i += 1;
8963+         for (uint32_t code : nfd_codepoints) {
8964+             int type = codepoint_type(code);
8965+             if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
8966+                 continue;
89598967            }
8960-             else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
8968+             code = to_lower(code);
8969+             if (type == CODEPOINT_TYPE_WHITESPACE) {
8970+                 code = ' ';
8971+             }
8972+             std::string s = codepoint_to_utf8(code);
8973+             if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
89618974                new_str += " ";
8962-                 new_str += ori_str.substr(i, 3) ;
8975+                 new_str += s ;
89638976                new_str += " ";
8964-                 i += 3;
8965-             }
8966-             else {
8967-                 new_str += ori_str[i];
8968-                 i += 1;
8977+             } else {
8978+                 new_str += s;
89698979            }
89708980        }
89718981
89728982        // split by whitespace
89738983        uint64_t l = 0;
89748984        uint64_t r = 0;
8985+         std::vector<std::string> words;
89758986        while (r < new_str.size()) {
89768987            // if is whitespace
89778988            if (isspace(new_str[r])) {
@@ -8989,47 +9000,20 @@ struct llm_tokenizer_wpm {
89899000        return words;
89909001    }
89919002
8992-     std::string normalize(const std::string & text) {
8993-         // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
8994-         std::string text2 = strip_accents(text);
8995-         for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
8996-             char c = text2[i];
8997-             if (c >= 'A' && c <= 'Z') {
8998-                 text2[i] = c - 'A' + 'a';
8999-             }
9003+     uint32_t to_lower(uint32_t code) {
9004+ #if defined(_WIN32)
9005+         if (code > 0xFFFF) {
9006+             return code;
90009007        }
9001-         return text2;
9008+ #endif
9009+         return std::tolower(wchar_t(code), std::locale("en_US.UTF-8"));
90029010    }
90039011
9004-     bool is_chinese_char(const std::string & str) {
9005-         int len = str.length();
9006-         unsigned int codepoint = 0;
9007-         int num_bytes = 0;
9008-         int i = 0;
9009-         unsigned char ch = static_cast<unsigned char>(str[i]);
9010-         if (ch <= 0x7f) {
9011-             codepoint = ch;
9012-             num_bytes = 1;
9013-         } else if ((ch >> 5) == 0x06) {
9014-             codepoint = ch & 0x1f;
9015-             num_bytes = 2;
9016-         } else if ((ch >> 4) == 0x0e) {
9017-             codepoint = ch & 0x0f;
9018-             num_bytes = 3;
9019-         } else if ((ch >> 3) == 0x1e) {
9020-             codepoint = ch & 0x07;
9021-             num_bytes = 4;
9022-         }
9023-         for (int j = 1; j < num_bytes; ++j) {
9024-             if (i + j >= len) {
9025-                 return false; // incomplete UTF-8 character
9026-             }
9027-             unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
9028-             if ((next_ch >> 6) != 0x02) {
9029-                 return false; // invalid trailing byte
9030-             }
9031-             codepoint = (codepoint << 6) | (next_ch & 0x3f);
9032-         }
9012+     bool is_ascii_punct(uint32_t code) {
9013+         return code < 256 && ispunct(code);
9014+     }
9015+ 
9016+     bool is_chinese_char(uint32_t codepoint) {
90339017        if ((codepoint >= 0x4E00  && codepoint <= 0x9FFF)  ||
90349018            (codepoint >= 0x3400  && codepoint <= 0x4DBF)  ||
90359019            (codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
@@ -9045,41 +9029,6 @@ struct llm_tokenizer_wpm {
90459029        return false;
90469030    }
90479031
9048-     std::string strip_accents(const std::string & input_string) {
9049-         std::string resultString;
9050-         std::map<std::string, char> accent_map = {
9051-             {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
9052-             {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
9053-             {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
9054-             {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
9055-             {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
9056-             {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
9057-             {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
9058-             {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
9059-             {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
9060-         };
9061- 
9062-         for (size_t i = 0; i <  input_string.length();) {
9063-             int len = utf8_len(input_string[i]);
9064-             std::string curChar = input_string.substr(i, len);
9065-             auto iter = accent_map.find(curChar);
9066-             if (iter != accent_map.end()) {
9067-                 resultString += iter->second;
9068-             } else {
9069-                 resultString += curChar;
9070-             }
9071-             i += len;
9072-         }
9073- 
9074-         return resultString;
9075-     }
9076- 
9077-     static size_t utf8_len(char src) {
9078-         const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
9079-         uint8_t highbits = static_cast<uint8_t>(src) >> 4;
9080-         return lookup[highbits];
9081-     }
9082- 
90839032    const llama_vocab & vocab;
90849033};
90859034
0 commit comments