@@ -128,13 +128,12 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
128128// default hparams (LLaMA 7B)
129129struct llama_hparams {
130130 uint32_t n_vocab = 32000 ;
131- uint32_t n_vocab_sp = 0 ;
131+ uint32_t n_vocab_base = 32000 ;
132132 uint32_t n_ctx = 512 ; // this is provided as user input?
133133 uint32_t n_embd = 4096 ;
134134 uint32_t n_mult = 256 ;
135135 uint32_t n_head = 32 ;
136136 uint32_t n_layer = 32 ;
137- uint32_t n_rot = 64 ;
138137 enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
139138
140139 bool operator !=(const llama_hparams & other) const {
@@ -460,7 +459,6 @@ enum llama_file_version {
460459 LLAMA_FILE_VERSION_GGJT_V1, // added padding
461460 LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
462461 LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format
463- LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens
464462};
465463
466464struct llama_file_loader {
@@ -476,6 +474,7 @@ struct llama_file_loader {
476474 read_hparams ();
477475 read_vocab ();
478476 read_tensor_metadata (file_idx, tensors_map);
477+ set_vocab_sp ();
479478 }
480479 void read_magic () {
481480 uint32_t magic = file.read_u32 ();
@@ -498,7 +497,6 @@ struct llama_file_loader {
498497 case 1 : file_version = LLAMA_FILE_VERSION_GGJT_V1; return ;
499498 case 2 : file_version = LLAMA_FILE_VERSION_GGJT_V2; return ;
500499 case 3 : file_version = LLAMA_FILE_VERSION_GGJT_V3; return ;
501- case 4 : file_version = LLAMA_FILE_VERSION_GGJT_V4; return ;
502500 }
503501 }
504502
@@ -507,12 +505,12 @@ struct llama_file_loader {
507505 }
508506 void read_hparams () {
509507 hparams.n_vocab = file.read_u32 ();
510- hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32 () : 0 ;
511508 hparams.n_embd = file.read_u32 ();
512509 hparams.n_mult = file.read_u32 ();
513510 hparams.n_head = file.read_u32 ();
514511 hparams.n_layer = file.read_u32 ();
515- hparams.n_rot = file.read_u32 ();
512+ hparams.n_vocab_base = file.read_u32 ();
513+ hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000 ) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
516514 hparams.ftype = (enum llama_ftype) file.read_u32 ();
517515 }
518516 void read_vocab () {
@@ -533,20 +531,6 @@ struct llama_file_loader {
533531 tok_score.tok = std::move (word);
534532 tok_score.score = score;
535533 }
536-
537- vocab.special_token_to_id .reserve (hparams.n_vocab_sp );
538-
539- for (uint32_t i = 0 ; i < hparams.n_vocab_sp ; i++) {
540- llama_vocab::id token_id = file.read_u32 ();
541- const auto & word = vocab.id_to_token [token_id].tok ;
542-
543- vocab.special_token_trie .add (word);
544- vocab.special_token_to_id [word] = token_id;
545-
546- if (vocab.max_special_token_length < word.size ()) {
547- vocab.max_special_token_length = word.size ();
548- }
549- }
550534 }
551535 void read_tensor_metadata (size_t file_idx, llama_load_tensors_map & tensors_map) {
552536 while (file.tell () < file.size ) {
@@ -601,6 +585,24 @@ struct llama_file_loader {
601585 tensors_map.tensors .at (idx).shards .push_back (shard);
602586 }
603587 }
588+ void set_vocab_sp () {
589+ uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base ;
590+ vocab.special_token_to_id .reserve (vocab_sp);
591+ for (uint32_t i = 0 ; i < vocab_sp; i++) {
592+ llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
593+ const auto & word = vocab.id_to_token [token_id].tok ;
594+ if (word.empty ()) {
595+ continue ;
596+ }
597+
598+ vocab.special_token_trie .add (word);
599+ vocab.special_token_to_id [word] = token_id;
600+
601+ if (vocab.max_special_token_length < word.size ()) {
602+ vocab.max_special_token_length = word.size ();
603+ }
604+ }
605+ }
604606};
605607
606608struct llama_file_saver {
@@ -620,12 +622,11 @@ struct llama_file_saver {
620622 void write_hparams (enum llama_ftype new_ftype) {
621623 const llama_hparams & hparams = any_file_loader->hparams ;
622624 file.write_u32 (hparams.n_vocab );
623- file.write_u32 (hparams.n_vocab_sp );
624625 file.write_u32 (hparams.n_embd );
625626 file.write_u32 (hparams.n_mult );
626627 file.write_u32 (hparams.n_head );
627628 file.write_u32 (hparams.n_layer );
628- file.write_u32 (hparams.n_rot );
629+ file.write_u32 (hparams.n_vocab_base | 0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
629630 file.write_u32 (new_ftype);
630631 }
631632 void write_vocab () {
@@ -639,9 +640,6 @@ struct llama_file_saver {
639640 file.write_raw (token_score.tok .data (), token_score.tok .size ());
640641 file.write_raw (&token_score.score , sizeof (token_score.score ));
641642 }
642- for (const auto & pair : any_file_loader->vocab .special_token_to_id ) {
643- file.write_u32 (pair.second );
644- }
645643 }
646644 void write_tensor (llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
647645 switch (new_type) {
@@ -1015,8 +1013,7 @@ static const char *llama_file_version_name(llama_file_version version) {
10151013 case LLAMA_FILE_VERSION_GGMF_V1: return " ggmf v1 (old version with no mmap support)" ;
10161014 case LLAMA_FILE_VERSION_GGJT_V1: return " ggjt v1 (pre #1405)" ;
10171015 case LLAMA_FILE_VERSION_GGJT_V2: return " ggjt v2 (pre #1508)" ;
1018- case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (pre #1931)" ;
1019- case LLAMA_FILE_VERSION_GGJT_V4: return " ggjt v4 (latest)" ;
1016+ case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (latest)" ;
10201017 }
10211018
10221019 return " unknown" ;
@@ -1113,7 +1110,7 @@ static void llama_model_load_internal(
11131110 fprintf (stderr, " %s: n_mult = %u\n " , __func__, hparams.n_mult );
11141111 fprintf (stderr, " %s: n_head = %u\n " , __func__, hparams.n_head );
11151112 fprintf (stderr, " %s: n_layer = %u\n " , __func__, hparams.n_layer );
1116- fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_rot );
1113+ fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_embd /hparams. n_head );
11171114 fprintf (stderr, " %s: ftype = %u (%s)\n " , __func__, hparams.ftype , llama_ftype_name (hparams.ftype ));
11181115 fprintf (stderr, " %s: n_ff = %u\n " , __func__, n_ff);
11191116 fprintf (stderr, " %s: n_parts = %zu\n " , __func__, ml->file_loaders .size ());
0 commit comments