@@ -66,6 +66,7 @@ enum e_model {
6666 MODEL_65B,
6767};
6868
69+ static const size_t kB = 1024 ;
6970static const size_t MB = 1024 *1024 ;
7071
7172// computed for n_ctx == 2048
@@ -129,6 +130,34 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
129130 return k_sizes;
130131}
131132
133+ // amount of VRAM needed per batch size to hold temporary results
134+ // the values for 3b and 65b are not derived from testing but instead chosen conservatively
135+ static const std::map<e_model, size_t > & VRAM_REQ_SCRATCH_BASE ()
136+ {
137+ static std::map<e_model, size_t > k_sizes = {
138+ { MODEL_3B, 512ull * kB },
139+ { MODEL_7B, 512ull * kB },
140+ { MODEL_13B, 640ull * kB },
141+ { MODEL_30B, 768ull * kB },
142+ { MODEL_65B, 1536ull * kB },
143+ };
144+ return k_sizes;
145+ }
146+
147+ // amount of VRAM needed per batch size and context to hold temporary results
148+ // the values for 3b and 65b are not derived from testing but instead chosen conservatively
149+ static const std::map<e_model, size_t > & VRAM_REQ_SCRATCH_PER_CONTEXT ()
150+ {
151+ static std::map<e_model, size_t > k_sizes = {
152+ { MODEL_3B, 128ull },
153+ { MODEL_7B, 128ull },
154+ { MODEL_13B, 160ull },
155+ { MODEL_30B, 208ull },
156+ { MODEL_65B, 416ull },
157+ };
158+ return k_sizes;
159+ }
160+
132161// default hparams (LLaMA 7B)
133162struct llama_hparams {
134163 uint32_t n_vocab = 32000 ;
@@ -1118,11 +1147,14 @@ static void llama_model_load_internal(
11181147 fprintf (stderr, " %s: not allocating a VRAM scratch buffer due to low VRAM option\n " , __func__);
11191148 ggml_cuda_set_scratch_size (0 ); // disable scratch
11201149 } else {
1121- vram_scratch = n_batch * MB;
1150+ const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE ().at (model.type );
1151+ const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT ().at (model.type );
1152+ vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context);
11221153 ggml_cuda_set_scratch_size (vram_scratch);
11231154 if (n_gpu_layers > 0 ) {
1124- fprintf (stderr, " %s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n " ,
1125- __func__, vram_scratch / MB);
1155+ fprintf (stderr, " %s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n " ,
1156+ __func__, vram_scratch_base / kB , vram_scratch_per_context,
1157+ (vram_scratch + MB - 1 ) / MB); // round up
11261158 }
11271159 }
11281160#endif // GGML_USE_CUBLAS
0 commit comments