@@ -580,6 +580,32 @@ kernel void kernel_alibi_f32(
580580 }
581581}
582582
583+ static float rope_ntkv2_ramp (const float low, const float high, const int i0) {
584+ const float y = (i0 / 2 - low) / min (0 .001f , high - low);
585+ return 1 .0f - min (1 .0f , max (0 .0f , y));
586+ }
587+
588+ // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
589+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
590+ static float rope_ntkv2 (
591+ const float theta_base,
592+ const float theta_linear,
593+ const float theta_ntk,
594+ device const float corr_factors[4 ],
595+ const int64_t i0,
596+ const float ntk_factor,
597+ const float extrapolation_factor) {
598+ float ramp_mix;
599+ float theta;
600+
601+ ramp_mix = rope_ntkv2_ramp (corr_factors[0 ], corr_factors[1 ], i0) * ntk_factor;
602+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
603+
604+ ramp_mix = rope_ntkv2_ramp (corr_factors[2 ], corr_factors[3 ], i0) * extrapolation_factor;
605+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
606+ return theta;
607+ }
608+
583609kernel void kernel_rope (
584610 device const void * src0,
585611 device float * dst,
@@ -604,24 +630,33 @@ kernel void kernel_rope(
604630 constant int & mode,
605631 constant float & freq_base,
606632 constant float & freq_scale,
633+ constant float & ntk_factor,
634+ constant float & extrapolation_factor,
607635 uint3 tpig[[thread_position_in_grid]]) {
608636 const int64_t i3 = tpig[2 ];
609637 const int64_t i2 = tpig[1 ];
610638 const int64_t i1 = tpig[0 ];
611639
612- const bool is_neox = mode & 2 ;
613- const float theta_scale = pow (freq_base, -2 .0f /n_dims);
640+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
641+ const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
642+ device float corr_factors[4 ];
643+ ggml_rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors);
614644
615- const int64_t p = ((mode & 1 ) == 0 ? n_past + i2 : i2);
645+ float theta_base = (mode & 1 ) == 0 ? n_past + i2 : i2;
646+ float theta_ntk = theta_base;
616647
617- float theta = freq_scale * ( float )p ;
648+ const bool is_neox = mode & 2 ;
618649
619650 if (!is_neox) {
620651 for (int64_t i0 = 0 ; i0 < ne0; i0 += 2 ) {
621- const float cos_theta = cos (theta);
622- const float sin_theta = sin (theta);
652+ const float theta_linear = freq_scale * theta_base;
653+ const float theta = rope_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors,
654+ i0, ntk_factor, extrapolation_factor);
655+ const float cos_theta = cosf (theta);
656+ const float sin_theta = sinf (theta);
623657
624- theta *= theta_scale;
658+ theta_base *= theta_scale;
659+ theta_ntk *= theta_ntk_scale;
625660
626661 device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
627662 device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -633,6 +668,7 @@ kernel void kernel_rope(
633668 dst_data[1 ] = x0*sin_theta + x1*cos_theta;
634669 }
635670 } else {
671+ theta_base *= freq_scale;
636672 // TODO: implement
637673 }
638674}
0 commit comments