@@ -49,10 +49,10 @@ typedef struct {
4949 // float* freq_cis_real; // (seq_len, dim/2)
5050 // float* freq_cis_imag; // (seq_len, dim/2)
5151 // (optional) classifier weights for the logits, on the last layer
52- // float* wcls;
52+ float * wcls;
5353} TransformerWeights;
5454
55- void malloc_weights (TransformerWeights* w, Config* p) {
55+ void malloc_weights (TransformerWeights* w, Config* p, bool shared_weights ) {
5656 // we calloc instead of malloc to keep valgrind happy
5757 w->token_embedding_table = new float [p->vocab_size * p->dim ]();
5858 printf (" [%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n " ,__func__,p->vocab_size , p->dim , p->vocab_size * p->dim );
@@ -86,9 +86,16 @@ void malloc_weights(TransformerWeights* w, Config* p) {
8686
8787 w->rms_final_weight = new float [p->dim ]();
8888 printf (" [%s:AK] Allocating [%d] float space for w->rms_final_weight\n " ,__func__,p->dim );
89+
90+ if (shared_weights) {
91+ w->wcls = NULL ;
92+ } else {
93+ w->wcls = new float [p->vocab_size * p->dim ]();
94+ printf (" [%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n " ,__func__,p->vocab_size , p->dim , p->vocab_size * p->dim );
95+ }
8996}
9097
91- int checkpoint_init_weights (TransformerWeights *w, Config* p, FILE* f) {
98+ int checkpoint_init_weights (TransformerWeights *w, Config* p, FILE* f, bool shared_weights ) {
9299 if (fread (w->token_embedding_table , sizeof (float ), p->vocab_size * p->dim , f) != static_cast <size_t >(p->vocab_size * p->dim )) return 1 ;
93100 if (fread (w->rms_att_weight , sizeof (float ), p->n_layers * p->dim , f) != static_cast <size_t >(p->n_layers * p->dim )) return 1 ;
94101 if (fread (w->wq , sizeof (float ), p->n_layers * p->dim * p->dim , f) != static_cast <size_t >(p->n_layers * p->dim * p->dim )) return 1 ;
@@ -100,6 +107,22 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
100107 if (fread (w->w2 , sizeof (float ), p->n_layers * p->hidden_dim * p->dim , f) != static_cast <size_t >(p->n_layers * p->hidden_dim * p->dim )) return 1 ;
101108 if (fread (w->w3 , sizeof (float ), p->n_layers * p->dim * p->hidden_dim , f) != static_cast <size_t >(p->n_layers * p->dim * p->hidden_dim )) return 1 ;
102109 if (fread (w->rms_final_weight , sizeof (float ), p->dim , f) != static_cast <size_t >(p->dim )) return 1 ;
110+
111+ // Skip freq_cis_real & freq_cis_imag
112+ int head_size = p->dim / p->n_heads ;
113+ fseek (f, p->seq_len * head_size * sizeof (float ), SEEK_CUR);
114+
115+ if (!shared_weights && fread (w->wcls , sizeof (float ), p->vocab_size * p->dim , f) != static_cast <size_t >(p->vocab_size * p->dim )) return 1 ;
116+
117+ // Check we didn't forget to read anything
118+ auto curr = ftell (f);
119+ fseek (f, 0 , SEEK_END);
120+ auto end = ftell (f);
121+ if (curr != end) {
122+ printf (" Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n " , curr, end);
123+ return 1 ;
124+ }
125+
103126 return 0 ;
104127}
105128
@@ -115,6 +138,7 @@ void free_weights(TransformerWeights* w) {
115138 delete w->w2 ;
116139 delete w->w3 ;
117140 delete w->rms_final_weight ;
141+ if (w->wcls ) delete w->wcls ;
118142}
119143
120144void print_sample_weights (TransformerWeights *w){
@@ -131,6 +155,7 @@ void print_sample_weights(TransformerWeights *w){
131155 printf (" %f\n " , w->w2 [0 ]);
132156 printf (" %f\n " , w->w3 [0 ]);
133157 printf (" %f\n " , w->rms_att_weight [0 ]);
158+ if (w->wcls ) printf (" %f\n " , w->wcls [0 ]);
134159}
135160// //////////////////////////////////////////////////////////////////////////////////////////////////////////
136161
@@ -617,7 +642,7 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
617642// // w->token_embedding_table -> model->tok_embeddings
618643// // float* -> struct ggml_tensor
619644// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
620- // stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table);
645+ // stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w-> token_embedding_table);
621646//
622647// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
623648// //print_row(model->norm, 0);
@@ -791,9 +816,12 @@ int main(int argc, char ** argv) {
791816 if (!file) { printf (" Unable to open the checkpoint file %s!\n " , params.fn_llama2c_model ); return 1 ; }
792817 // read in the config header
793818 if (fread (&config, sizeof (Config), 1 , file) != 1 ) { return 1 ; }
819+ auto shared_weights = config.vocab_size > 0 ;
820+ config.vocab_size = abs (config.vocab_size );
821+
794822 // read in the Transformer weights
795- malloc_weights (&weights, &config);
796- if (checkpoint_init_weights (&weights, &config, file)) { return 1 ; }
823+ malloc_weights (&weights, &config, shared_weights );
824+ if (checkpoint_init_weights (&weights, &config, file, shared_weights )) { return 1 ; }
797825 fclose (file);
798826 }
799827
0 commit comments