@@ -1875,8 +1875,53 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
18751875 cpy_1 (cx + x_offset, cdst + dst_offset);
18761876}
18771877
1878+ static __device__ void ntkv2_ramp (const float low, const float high, const int i0, float *out) {
1879+ const float y = (i0 / 2 - low) / min (0 .001f , high - low);
1880+ *out = 1 .0f - min (1 .0f , max (0 .0f , y));
1881+ }
1882+
1883+ // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
1884+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1885+ static __device__ void compute_ntkv2 (
1886+ float theta_base,
1887+ float theta_ntk,
1888+ float dims_over_base,
1889+ float freq_scale,
1890+ int64_t i0,
1891+ float ntk_factor,
1892+ float extrapolation_factor,
1893+ int n_dims,
1894+ float *theta) {
1895+ // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
1896+ // Do not change unless there is a good reason for doing so!
1897+ // These are precomputed because CUDA doesn't allow dynamic init of device constants
1898+ static const float low_1p = 2 .6135630f ;
1899+ static const float high_1p = 2 .7817991f ;
1900+ static const float low_2p = 1 .5070765f ;
1901+ static const float high_2p = 2 .5467973f ;
1902+
1903+ // start and end correction factors
1904+ const float low_1 = max (0 .0f , floorf (low_1p * dims_over_base));
1905+ const float high_1 = min (n_dims - 1 .0f , ceilf (high_1p * dims_over_base));
1906+ const float low_2 = max (0 .0f , floorf (low_2p * dims_over_base));
1907+ const float high_2 = min (n_dims - 1 .0f , ceilf (high_2p * dims_over_base));
1908+
1909+ float ramp_mix;
1910+
1911+ const float theta_linear = freq_scale * theta_base;
1912+ ntkv2_ramp (low_1, high_1, i0, &ramp_mix);
1913+ ramp_mix *= ntk_factor;
1914+ const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1915+ ntkv2_ramp (low_2, high_2, i0, &ramp_mix);
1916+ ramp_mix *= extrapolation_factor;
1917+ *theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
1918+ }
1919+
18781920// rope == RoPE == rotary positional embedding
1879- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
1921+ static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
1922+ const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
1923+ const float theta_ntk_scale, const float dims_over_base, const float p) {
1924+
18801925 const int col = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
18811926
18821927 if (col >= ncols) {
@@ -1886,7 +1931,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
18861931 const int row = blockDim .y *blockIdx .y + threadIdx .y ;
18871932 const int i = row*ncols + col;
18881933
1889- const float theta = p*powf (theta_scale, col/2 );
1934+ const float theta_base = p*powf (theta_scale, col/2 );
1935+ const float theta_ntk = p*powf (theta_ntk_scale, col/2 );
1936+ float theta;
1937+ compute_ntkv2 (theta_base, theta_ntk, dims_over_base,
1938+ freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta);
18901939 const float sin_theta = sinf (theta);
18911940 const float cos_theta = cosf (theta);
18921941
@@ -2365,12 +2414,17 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
23652414 scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
23662415}
23672416
2368- static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
2417+ static void rope_f32_cuda (
2418+ const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
2419+ const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
2420+ const float theta_ntk_scale, const float dims_over_base, const float p, cudaStream_t stream) {
2421+
23692422 GGML_ASSERT (nrows % 2 == 0 );
23702423 const dim3 block_dims (2 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
23712424 const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
23722425 const dim3 block_nums (num_blocks_x, nrows, 1 );
2373- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p, theta_scale);
2426+ rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
2427+ extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p);
23742428}
23752429
23762430static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -2947,12 +3001,23 @@ inline void ggml_cuda_op_rope(
29473001 const int64_t ne00 = src0->ne [0 ];
29483002 const int64_t i01_diff = i01_high - i01_low;
29493003
3004+ float freq_base;
3005+ float freq_scale;
3006+ float ntk_factor;
3007+ float extrapolation_factor;
3008+
29503009 const int n_past = ((int32_t *) src1->data )[0 ];
29513010 const int n_dims = ((int32_t *) src1->data )[1 ];
29523011 const int mode = ((int32_t *) src1->data )[2 ];
29533012 const int n_ctx = ((int32_t *) src1->data )[3 ];
2954-
2955- const float theta_scale = powf (10000.0 , -2 .0f /n_dims);
3013+ memcpy (&freq_base, (int32_t *) src1->data + 4 , sizeof (float ));
3014+ memcpy (&freq_scale, (int32_t *) src1->data + 5 , sizeof (float ));
3015+ memcpy (&ntk_factor, (int32_t *) src1->data + 6 , sizeof (float ));
3016+ memcpy (&extrapolation_factor, (int32_t *) src1->data + 7 , sizeof (float ));
3017+
3018+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
3019+ const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
3020+ const float dims_over_base = n_dims / logf (freq_base);
29563021 const float p = ((mode & 1 ) == 0 ? n_past + i02 : i02);
29573022
29583023 bool is_glm = mode & 4 ;
@@ -2963,7 +3028,8 @@ inline void ggml_cuda_op_rope(
29633028 const float block_p = max (p - (n_ctx - 2 .f ), 0 .f );
29643029 rope_glm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
29653030 } else {
2966- rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
3031+ rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
3032+ extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
29673033 }
29683034
29693035 (void ) dst;
0 commit comments