@@ -2423,20 +2423,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
24232423 scoped_spin_lock lock (g_cuda_pool_lock);
24242424 int id;
24252425 CUDA_CHECK (cudaGetDevice (&id));
2426-
2426+ #ifdef DEBUG_CUDA_MALLOC
2427+ int nnz = 0 ;
2428+ size_t max_size = 0 , tot_size = 0 ;
2429+ #endif
2430+ size_t best_diff = 1ull << 36 ;
2431+ int ibest = -1 ;
24272432 for (int i = 0 ; i < MAX_CUDA_BUFFERS; ++i) {
24282433 cuda_buffer& b = g_cuda_buffer_pool[id][i];
2429- if (b.size >= size && b.ptr != nullptr ) {
2430- void * ptr = b.ptr ;
2431- *actual_size = b.size ;
2432- b.ptr = nullptr ;
2433- b.size = 0 ;
2434- return ptr;
2434+ if (b.ptr != nullptr ) {
2435+ #ifdef DEBUG_CUDA_MALLOC
2436+ ++nnz;
2437+ tot_size += b.size ;
2438+ if (b.size > max_size) max_size = b.size ;
2439+ #endif
2440+ if (b.size >= size) {
2441+ size_t diff = b.size - size;
2442+ if (diff < best_diff) {
2443+ best_diff = diff;
2444+ ibest = i;
2445+ if (!best_diff) {
2446+ void * ptr = b.ptr ;
2447+ *actual_size = b.size ;
2448+ b.ptr = nullptr ;
2449+ b.size = 0 ;
2450+ return ptr;
2451+ }
2452+ }
2453+ }
24352454 }
24362455 }
2456+ if (ibest >= 0 ) {
2457+ cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
2458+ void * ptr = b.ptr ;
2459+ *actual_size = b.size ;
2460+ b.ptr = nullptr ;
2461+ b.size = 0 ;
2462+ return ptr;
2463+ }
2464+ #ifdef DEBUG_CUDA_MALLOC
2465+ fprintf (stderr, " %s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n " , __func__, nnz,
2466+ (uint32_t )(max_size/1024 /1024 ), (uint32_t )(tot_size/1024 /1024 ), (uint32_t )(size/1024 /1024 ));
2467+ #endif
24372468 void * ptr;
2438- CUDA_CHECK (cudaMalloc ((void **) &ptr, size));
2439- *actual_size = size;
2469+ size_t look_ahead_size = (size_t ) (1.05 * size);
2470+ look_ahead_size = 256 * ((look_ahead_size + 255 )/256 );
2471+ CUDA_CHECK (cudaMalloc ((void **) &ptr, look_ahead_size));
2472+ *actual_size = look_ahead_size;
24402473 return ptr;
24412474}
24422475
@@ -2955,8 +2988,13 @@ inline void ggml_cuda_op_rope(
29552988 const int mode = ((int32_t *) src1->data )[2 ];
29562989 const int n_ctx = ((int32_t *) src1->data )[3 ];
29572990
2958- const float theta_scale = powf (10000.0 , -2 .0f /n_dims);
2959- const float p = ((mode & 1 ) == 0 ? n_past + i02 : i02);
2991+ // RoPE alteration for extended context
2992+ float freq_base, freq_scale;
2993+ memcpy (&freq_base, (int32_t *) src1->data + 4 , sizeof (float ));
2994+ memcpy (&freq_scale, (int32_t *) src1->data + 5 , sizeof (float ));
2995+
2996+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
2997+ const float p = (((mode & 1 ) == 0 ? n_past + i02 : i02)) * freq_scale;
29602998
29612999 bool is_glm = mode & 4 ;
29623000
0 commit comments