Skip to content

Commit 14cf93b

Browse files
authored
fix YaRN ramp, make mscale conditional, add --yarn-orig-ctx (#2)
1 parent 9ae10b3 commit 14cf93b

File tree

7 files changed

+33
-24
lines changed

7 files changed

+33
-24
lines changed

common/common.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
220220
break;
221221
}
222222
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
223+
} else if (arg == "--yarn-orig-ctx") {
224+
if (++i >= argc) {
225+
invalid_param = true;
226+
break;
227+
}
228+
params.yarn_orig_ctx = std::stoi(argv[i]);
223229
} else if (arg == "--yarn-ext-factor") {
224230
if (++i >= argc) {
225231
invalid_param = true;
@@ -737,6 +743,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
737743
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
738744
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
739745
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
746+
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
740747
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
741748
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
742749
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
@@ -861,6 +868,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
861868
cparams.yarn_attn_factor = params.yarn_attn_factor;
862869
cparams.yarn_beta_fast = params.yarn_beta_fast;
863870
cparams.yarn_beta_slow = params.yarn_beta_slow;
871+
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
864872

865873
return cparams;
866874
}

common/common.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ struct gpt_params {
5757
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
5858
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
5959
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
60-
float yarn_beta_fast = 32.0f; // YaRN low correction dim
61-
float yarn_beta_slow = 1.0f; // YaRN high correction dim
60+
float yarn_beta_fast = 32.0f;// YaRN low correction dim
61+
float yarn_beta_slow = 1.0f; // YaRN high correction dim
62+
int32_t yarn_orig_ctx = 0; // YaRN original context length
6263
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
6364

6465
// // sampling parameters

ggml-cuda.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4406,7 +4406,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
44064406
}
44074407

44084408
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
4409-
const float y = (i0 / 2 - low) / min(0.001f, high - low);
4409+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
44104410
return 1.0f - min(1.0f, max(0.0f, y));
44114411
}
44124412

@@ -4426,11 +4426,10 @@ static __device__ void rope_yarn(
44264426
if (ext_factor != 0.0f) {
44274427
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
44284428
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
4429-
}
44304429

4431-
// Get n-d magnitude scaling corrected for interpolation
4432-
if (freq_scale < 1.0f)
4430+
// Get n-d magnitude scaling corrected for interpolation
44334431
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
4432+
}
44344433
*cos_theta = cosf(theta) * mscale;
44354434
*sin_theta = sinf(theta) * mscale;
44364435
}

ggml-metal.metal

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ kernel void kernel_alibi_f32(
880880
}
881881

882882
static float rope_yarn_ramp(const float low, const float high, const int i0) {
883-
const float y = (i0 / 2 - low) / min(0.001f, high - low);
883+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
884884
return 1.0f - min(1.0f, max(0.0f, y));
885885
}
886886

@@ -896,11 +896,10 @@ static void rope_yarn(
896896
if (ext_factor != 0.0f) {
897897
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
898898
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
899-
}
900899

901-
// Get n-d magnitude scaling corrected for interpolation
902-
if (freq_scale < 1.0f)
900+
// Get n-d magnitude scaling corrected for interpolation
903901
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
902+
}
904903
*cos_theta = cosf(theta) * mscale;
905904
*sin_theta = sinf(theta) * mscale;
906905
}

ggml.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13345,7 +13345,7 @@ static void ggml_compute_forward_clamp(
1334513345
// ggml_compute_forward_rope
1334613346

1334713347
static float rope_yarn_ramp(const float low, const float high, const int i0) {
13348-
const float y = (i0 / 2 - low) / MIN(0.001f, high - low);
13348+
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
1334913349
return 1 - MIN(1, MAX(0, y));
1335013350
}
1335113351

@@ -13361,11 +13361,10 @@ static void rope_yarn(
1336113361
if (ext_factor != 0.0f) {
1336213362
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1336313363
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
13364-
}
1336513364

13366-
// Get n-d magnitude scaling corrected for interpolation
13367-
if (freq_scale < 1.0f)
13365+
// Get n-d magnitude scaling corrected for interpolation
1336813366
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
13367+
}
1336913368
*cos_theta = cosf(theta) * mscale;
1337013369
*sin_theta = sinf(theta) * mscale;
1337113370
}

llama.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ struct llama_cparams {
11131113
float rope_freq_base;
11141114
float rope_freq_scale;
11151115

1116+
uint32_t n_yarn_orig_ctx;
11161117
// These hyperparameters are not exposed in GGUF, because all
11171118
// existing YaRN models use the same values for them.
11181119
float yarn_ext_factor;
@@ -3028,7 +3029,7 @@ static struct ggml_cgraph * llm_build_llama(
30283029
const int32_t n_embd = hparams.n_embd;
30293030
const int32_t n_layer = hparams.n_layer;
30303031
const int32_t n_ctx = cparams.n_ctx;
3031-
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
3032+
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
30323033
const int32_t n_head = hparams.n_head;
30333034
const int32_t n_head_kv = hparams.n_head_kv;
30343035
const int32_t n_embd_head = hparams.n_embd_head();
@@ -3430,7 +3431,7 @@ static struct ggml_cgraph * llm_build_baichaun(
34303431
const int32_t n_embd = hparams.n_embd;
34313432
const int32_t n_layer = hparams.n_layer;
34323433
const int32_t n_ctx = cparams.n_ctx;
3433-
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
3434+
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
34343435
const int32_t n_head = hparams.n_head;
34353436
const int32_t n_head_kv = hparams.n_head_kv;
34363437
const int32_t n_embd_head = hparams.n_embd_head();
@@ -4194,7 +4195,7 @@ static struct ggml_cgraph * llm_build_falcon(
41944195
const int32_t n_embd = hparams.n_embd;
41954196
const int32_t n_layer = hparams.n_layer;
41964197
const int32_t n_ctx = cparams.n_ctx;
4197-
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
4198+
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
41984199
const int32_t n_head = hparams.n_head;
41994200
const int32_t n_head_kv = hparams.n_head_kv;
42004201
const int32_t n_embd_head = hparams.n_embd_head();
@@ -4818,7 +4819,7 @@ static struct ggml_cgraph * llm_build_persimmon(
48184819
const int64_t n_embd = hparams.n_embd;
48194820
const int64_t n_layer = hparams.n_layer;
48204821
const int64_t n_ctx = cparams.n_ctx;
4821-
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
4822+
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
48224823
const int64_t n_head_kv = hparams.n_head_kv;
48234824
const int64_t n_head = hparams.n_head;
48244825
const int64_t n_embd_head = hparams.n_embd_head();
@@ -8676,6 +8677,7 @@ struct llama_context * llama_new_context_with_model(
86768677
cparams.mul_mat_q = params.mul_mat_q;
86778678

86788679
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
8680+
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx == 0 ? hparams.n_ctx_train : params.yarn_orig_ctx;
86798681
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
86808682
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
86818683

llama.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,13 @@ extern "C" {
182182
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
183183

184184
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
185-
float rope_freq_base; // RoPE base frequency, 0 = from model
186-
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
187-
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
188-
float yarn_attn_factor; // YaRN magnitude scaling factor
189-
float yarn_beta_fast; // YaRN low correction dim
190-
float yarn_beta_slow; // YaRN high correction dim
185+
float rope_freq_base; // RoPE base frequency, 0 = from model
186+
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
187+
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
188+
float yarn_attn_factor; // YaRN magnitude scaling factor
189+
float yarn_beta_fast; // YaRN low correction dim
190+
float yarn_beta_slow; // YaRN high correction dim
191+
uint32_t yarn_orig_ctx; // YaRN original context size
191192

192193
// Keep the booleans together to avoid misalignment during copy-by-value.
193194
bool mul_mat_q; // if true, use experimental mul_mat_q kernels

0 commit comments

Comments
 (0)