Skip to content

Commit 932bdd9

Browse files
authored
Adding Rope scaling. (#741)
# What does this PR do? - Adds Rope NTK scaling. Done because #529 was closed Took some code from huggingface/transformers#24653 - `--rope-scaling` and `--rope-factor` are added separately. I considered having a single one and parsing something line ("linear:4.0" , or "dynamic") but decided against it because it would push more parsing+validation a bit everywhere (both in the launcher and the server). Fixes #512 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
1 parent b9633c4 commit 932bdd9

File tree

5 files changed

+141
-14
lines changed

5 files changed

+141
-14
lines changed

launcher/src/main.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
6060
}
6161
}
6262

63+
#[derive(Clone, Copy, Debug, ValueEnum)]
64+
enum RopeScaling {
65+
Linear,
66+
Dynamic,
67+
}
68+
69+
impl std::fmt::Display for RopeScaling {
70+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71+
// To keep in track with `server`.
72+
match self {
73+
RopeScaling::Linear => {
74+
write!(f, "linear")
75+
}
76+
RopeScaling::Dynamic => {
77+
write!(f, "dynamic")
78+
}
79+
}
80+
}
81+
}
82+
6383
/// App Configuration
6484
#[derive(Parser, Debug)]
6585
#[clap(author, version, about, long_about = None)]
@@ -250,6 +270,26 @@ struct Args {
250270
#[clap(default_value = "1.0", long, env)]
251271
cuda_memory_fraction: f32,
252272

273+
/// Rope scaling will only be used for RoPE models
274+
/// and allow rescaling the position rotary to accomodate for
275+
/// larger prompts.
276+
///
277+
/// Goes together with `rope_factor`.
278+
///
279+
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
280+
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
281+
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
282+
/// basically)
283+
///
284+
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
285+
#[clap(long, env)]
286+
rope_scaling: Option<RopeScaling>,
287+
288+
/// Rope scaling will only be used for RoPE models
289+
/// See `rope_scaling`
290+
#[clap(long, env)]
291+
rope_factor: Option<f32>,
292+
253293
/// Outputs the logs in JSON format (useful for telemetry)
254294
#[clap(long, env)]
255295
json_output: bool,
@@ -305,6 +345,8 @@ fn shard_manager(
305345
watermark_gamma: Option<f32>,
306346
watermark_delta: Option<f32>,
307347
cuda_memory_fraction: f32,
348+
rope_scaling: Option<RopeScaling>,
349+
rope_factor: Option<f32>,
308350
otlp_endpoint: Option<String>,
309351
status_sender: mpsc::Sender<ShardStatus>,
310352
shutdown: Arc<AtomicBool>,
@@ -358,6 +400,12 @@ fn shard_manager(
358400
shard_args.push(revision)
359401
}
360402

403+
let rope = match (rope_scaling, rope_factor) {
404+
(None, None) => None,
405+
(Some(scaling), None) => Some((scaling, 1.0)),
406+
(Some(scaling), Some(factor)) => Some((scaling, factor)),
407+
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
408+
};
361409
// OpenTelemetry
362410
if let Some(otlp_endpoint) = otlp_endpoint {
363411
shard_args.push("--otlp-endpoint".to_string());
@@ -395,6 +443,15 @@ fn shard_manager(
395443
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
396444
};
397445

446+
// Detect rope scaling
447+
// Sending as env instead of CLI args to not bloat everything
448+
// those only can be used by RoPE models, so passing information around
449+
// for all models will complexify code unnecessarily
450+
if let Some((scaling, factor)) = rope {
451+
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
452+
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
453+
}
454+
398455
// If huggingface_hub_cache is some, pass it to the shard
399456
// Useful when running inside a docker container
400457
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@@ -784,6 +841,8 @@ fn spawn_shards(
784841
let watermark_gamma = args.watermark_gamma;
785842
let watermark_delta = args.watermark_delta;
786843
let cuda_memory_fraction = args.cuda_memory_fraction;
844+
let rope_scaling = args.rope_scaling;
845+
let rope_factor = args.rope_factor;
787846
thread::spawn(move || {
788847
shard_manager(
789848
model_id,
@@ -802,6 +861,8 @@ fn spawn_shards(
802861
watermark_gamma,
803862
watermark_delta,
804863
cuda_memory_fraction,
864+
rope_scaling,
865+
rope_factor,
805866
otlp_endpoint,
806867
status_sender,
807868
shutdown,

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(
186186
self.head_size = self.hidden_size // self.num_heads
187187

188188
self.rotary_emb = PositionRotaryEmbedding.load(
189-
prefix=f"{prefix}.rotary_emb", weights=weights
189+
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
190190
)
191191

192192
self.softmax_scale = self.head_size**-0.5

server/text_generation_server/models/custom_modeling/flash_neox_modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(self, config, prefix, weights):
102102
self.num_heads = self.num_heads // weights.process_group.size()
103103

104104
self.rotary_emb = PositionRotaryEmbedding.load(
105-
prefix=f"{prefix}.rotary_emb", weights=weights
105+
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
106106
)
107107

108108
self.softmax_scale = self.head_size ** (-0.5)

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133
self.head_size = self.hidden_size // self.num_heads
134134

135135
self.rotary_emb = PositionRotaryEmbedding.static(
136-
dim=self.head_size, base=10000.0, device=weights.device
136+
config=config, dim=self.head_size, base=10000.0, device=weights.device
137137
)
138138
self.softmax_scale = self.head_size ** (-0.5)
139139

@@ -247,7 +247,7 @@ def __init__(
247247
self.head_size = hidden_size // num_heads
248248

249249
self.rotary_emb = PositionRotaryEmbedding.static(
250-
self.head_size, base=10000.0, device=weights.device
250+
config=config, dim=self.head_size, base=10000.0, device=weights.device
251251
)
252252
self.softmax_scale = self.head_size ** (-0.5)
253253

server/text_generation_server/utils/layers.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -381,33 +381,65 @@ def forward(self, hidden_states, residual=None):
381381
from flash_attn.layers.rotary import RotaryEmbedding
382382
import rotary_emb
383383

384+
def _create_inv_freq(dim, base, device):
385+
inv_freq = 1.0 / (
386+
base
387+
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
388+
)
389+
return inv_freq
390+
391+
def _get_rope_config(config):
392+
if os.getenv("ROPE_SCALING", None) is not None:
393+
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
394+
return rope_scaling
395+
return getattr(config, "rope_scaling", None)
396+
384397
class PositionRotaryEmbedding(nn.Module):
385-
def __init__(self, inv_freq):
398+
def __init__(self, inv_freq, scaling_factor):
386399
super().__init__()
387-
388400
self.inv_freq = inv_freq
389401
self._seq_len_cached = 0
390402
self._cos_cached = None
391403
self._sin_cached = None
392404
self._cos_k_cached = None
393405
self._sin_k_cached = None
406+
self.scaling_factor = scaling_factor
407+
self.dynamic_args = None
394408

395409
@classmethod
396-
def static(cls, dim, base, device):
397-
inv_freq = 1.0 / (
398-
base
399-
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
400-
)
401-
return cls(inv_freq)
410+
def static(cls, config, dim, base, device):
411+
inv_freq = _create_inv_freq(dim, base, device)
412+
scaling_factor = None
413+
rope_scaling = _get_rope_config(config)
414+
if rope_scaling is not None:
415+
scaling_factor = rope_scaling["factor"]
416+
if rope_scaling["type"] == "linear":
417+
pass
418+
elif rope_scaling["type"] == "dynamic":
419+
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
420+
else:
421+
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
422+
return cls(inv_freq, scaling_factor)
402423

403424
@classmethod
404-
def load(cls, prefix, weights):
425+
def load(cls, config, prefix, weights):
405426
# XXX: Always load this in float32 !
406427
dtype = weights.dtype
407428
weights.dtype = torch.float32
408429
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
409430
weights.dtype = dtype
410-
return cls(inv_freq)
431+
432+
scaling_factor = None
433+
rope_scaling = _get_rope_config(config)
434+
if rope_scaling is not None:
435+
scaling_factor = rope_scaling["factor"]
436+
if rope_scaling["type"] == "linear":
437+
pass
438+
elif rope_scaling["type"] == "dynamic":
439+
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
440+
else:
441+
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
442+
return cls(inv_freq, scaling_factor)
411443

412444
def _update_cos_sin_cache(self, dtype, device, seqlen):
413445
# Reset the tables if the sequence length has changed,
@@ -419,8 +451,11 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
419451
):
420452
self._seq_len_cached = seqlen
421453
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
454+
if self.scaling_factor is not None:
455+
t /= self.scaling_factor
422456
# Don't do einsum, it converts fp32 to fp16
423457
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
458+
424459
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
425460
self._cos_cached = torch.cos(freqs).to(dtype)
426461
self._sin_cached = torch.sin(freqs).to(dtype)
@@ -446,5 +481,36 @@ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
446481
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
447482
return x
448483

484+
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
485+
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
486+
inv_freq = create_inv_freq(dim, base, device)
487+
super().__init__(inv_freq, scaling_factor)
488+
self.dim = dim
489+
self.max_position_embeddings = max_position_embeddings
490+
self.base = base
491+
492+
def _update_cos_sin_cache(self, dtype, device, seqlen):
493+
# Reset the tables if the sequence length has changed,
494+
# or if we're on a new device (possibly due to tracing for instance)
495+
if (
496+
seqlen > self._seq_len_cached
497+
or self._cos_cached.device != device
498+
or self._cos_cached.dtype != dtype
499+
):
500+
if seqlen > self.max_position_embeddings:
501+
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
502+
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
503+
self._seq_len_cached = seqlen
504+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
505+
if self.scaling_factor is not None:
506+
t /= self.scaling_factor
507+
# Don't do einsum, it converts fp32 to fp16
508+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
509+
510+
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
511+
self._cos_cached = torch.cos(freqs).to(dtype)
512+
self._sin_cached = torch.sin(freqs).to(dtype)
513+
514+
449515
except ImportError:
450516
pass

0 commit comments

Comments
 (0)