@@ -1875,52 +1875,36 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
18751875 cpy_1 (cx + x_offset, cdst + dst_offset);
18761876}
18771877
1878- static __device__ void ntkv2_ramp (const float low, const float high, const int i0, float *out ) {
1878+ static __device__ float ntkv2_ramp (const float low, const float high, const int i0) {
18791879 const float y = (i0 / 2 - low) / min (0 .001f , high - low);
1880- *out = 1 .0f - min (1 .0f , max (0 .0f , y));
1880+ return 1 .0f - min (1 .0f , max (0 .0f , y));
18811881}
18821882
18831883// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
18841884// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1885- static __device__ void compute_ntkv2 (
1885+ static __device__ float compute_ntkv2 (
18861886 float theta_base,
1887+ float theta_linear,
18871888 float theta_ntk,
1888- float dims_over_base,
1889- float freq_scale,
1889+ const float corr_factors[4 ],
18901890 int64_t i0,
18911891 float ntk_factor,
1892- float extrapolation_factor,
1893- int n_dims,
1894- float *theta) {
1895- // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
1896- // Do not change unless there is a good reason for doing so!
1897- // These are precomputed because CUDA doesn't allow dynamic init of device constants
1898- static const float low_1p = 2 .6135630f ;
1899- static const float high_1p = 2 .7817991f ;
1900- static const float low_2p = 1 .5070765f ;
1901- static const float high_2p = 2 .5467973f ;
1902-
1903- // start and end correction factors
1904- const float low_1 = max (0 .0f , floorf (low_1p * dims_over_base));
1905- const float high_1 = min (n_dims - 1 .0f , ceilf (high_1p * dims_over_base));
1906- const float low_2 = max (0 .0f , floorf (low_2p * dims_over_base));
1907- const float high_2 = min (n_dims - 1 .0f , ceilf (high_2p * dims_over_base));
1908-
1892+ float extrapolation_factor) {
19091893 float ramp_mix;
1894+ float theta;
19101895
1911- const float theta_linear = freq_scale * theta_base;
1912- ntkv2_ramp (low_1, high_1, i0, &ramp_mix);
1913- ramp_mix *= ntk_factor;
1914- const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1915- ntkv2_ramp (low_2, high_2, i0, &ramp_mix);
1916- ramp_mix *= extrapolation_factor;
1917- *theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
1896+ ramp_mix = ntkv2_ramp (corr_factors[0 ], corr_factors[1 ], i0) * ntk_factor;
1897+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1898+
1899+ ramp_mix = ntkv2_ramp (corr_factors[2 ], corr_factors[3 ], i0) * extrapolation_factor;
1900+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
1901+ return theta;
19181902}
19191903
19201904// rope == RoPE == rotary positional embedding
1921- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
1905+ static __global__ void rope_f32 (const float * x, float * dst, const int ncols,
19221906 const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
1923- const float theta_ntk_scale, const float dims_over_base , const float p ) {
1907+ const float theta_ntk_scale, const float p , const float corr_factors[ 4 ] ) {
19241908
19251909 const int col = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
19261910
@@ -1931,11 +1915,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
19311915 const int row = blockDim .y *blockIdx .y + threadIdx .y ;
19321916 const int i = row*ncols + col;
19331917
1934- const float theta_base = p*powf (theta_scale, col/2 );
1935- const float theta_ntk = p* powf (theta_ntk_scale, col/ 2 ) ;
1936- float theta ;
1937- compute_ntkv2 (theta_base, theta_ntk, dims_over_base ,
1938- freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta );
1918+ const float theta_base = p*powf (theta_scale, col/2 );
1919+ const float theta_linear = freq_scale * theta_base ;
1920+ const float theta_ntk = p* powf (theta_ntk_scale, col/ 2 ) ;
1921+ const float theta = compute_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor ,
1922+ extrapolation_factor);
19391923 const float sin_theta = sinf (theta);
19401924 const float cos_theta = cosf (theta);
19411925
@@ -2415,16 +2399,16 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
24152399}
24162400
24172401static void rope_f32_cuda (
2418- const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
2402+ const float * x, float * dst, const int ncols, const int nrows,
24192403 const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
2420- const float theta_ntk_scale, const float dims_over_base, const float p , cudaStream_t stream) {
2404+ const float theta_ntk_scale, const float p, const float corr_factors[ 4 ] , cudaStream_t stream) {
24212405
24222406 GGML_ASSERT (nrows % 2 == 0 );
24232407 const dim3 block_dims (2 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
24242408 const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
24252409 const dim3 block_nums (num_blocks_x, nrows, 1 );
2426- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
2427- extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p );
2410+ rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, freq_scale, ntk_factor,
2411+ extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors );
24282412}
24292413
24302414static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -2990,6 +2974,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
29902974 (void ) i1;
29912975}
29922976
2977+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
2978+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
2979+ static float ntkv2_correction_factor (const int n_dims, const float n_rot, const float base) {
2980+ static const float max_pos_emb = 2048 ;
2981+ return n_dims * logf (max_pos_emb / (n_rot * 2 * (float )M_PI)) / (2 * logf (base));
2982+ }
2983+
29932984inline void ggml_cuda_op_rope (
29942985 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
29952986 float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -3016,8 +3007,6 @@ inline void ggml_cuda_op_rope(
30163007 memcpy (&extrapolation_factor, (int32_t *) src1->data + 7 , sizeof (float ));
30173008
30183009 const float theta_scale = powf (freq_base, -2 .0f /n_dims);
3019- const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
3020- const float dims_over_base = n_dims / logf (freq_base);
30213010 const float p = ((mode & 1 ) == 0 ? n_past + i02 : i02);
30223011
30233012 bool is_glm = mode & 4 ;
@@ -3028,8 +3017,25 @@ inline void ggml_cuda_op_rope(
30283017 const float block_p = max (p - (n_ctx - 2 .f ), 0 .f );
30293018 rope_glm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
30303019 } else {
3031- rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
3032- extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
3020+ const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
3021+
3022+ // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
3023+ // Do not change unless there is a good reason for doing so!
3024+ static const float BETA_0 = 1 .75f ;
3025+ static const float BETA_1 = 1 .25f ;
3026+ static const float GAMMA_0 = 16 .0f ;
3027+ static const float GAMMA_1 = 2 .0f ;
3028+
3029+ // start and end correction factors
3030+ const float corr_factors[4 ] = {
3031+ max (0 .0f , floorf (ntkv2_correction_factor (n_dims, BETA_0, freq_base))),
3032+ min (n_dims - 1 .0f , ceilf (ntkv2_correction_factor (n_dims, BETA_1, freq_base))),
3033+ max (0 .0f , floorf (ntkv2_correction_factor (n_dims, GAMMA_0, freq_base))),
3034+ min (n_dims - 1 .0f , ceilf (ntkv2_correction_factor (n_dims, GAMMA_1, freq_base))),
3035+ };
3036+
3037+ rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor,
3038+ extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main);
30333039 }
30343040
30353041 (void ) dst;
0 commit comments