Skip to content

Commit c9fd0b5

Browse files
author
Andrew Gu
committed
Add truncated llama style model init via reset parameters() (#54)
This PR adds the following: 1 - via reset parameters, a full layerwise init for the llama models under /llama. This uses the total model depth as part of the init via: self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 2 - The final output ffn (head) is init with sqrt of the dim of the model itself and a slightly wider cutoff factor of 3. 3 - tangential change - updates run_llama_train.sh with updated MODEL and MODEL_CONF params to allow for direct model control via the sh script. (there was a MODEL already but it was incorrectly using that in place of MODEL_CONF...though we should update this as it's not intuitive). 4 - made the debugmodel default to 2 layers as an improved debug check. 5 - added a 1B and 40B for additional testing configs. I can't currently run 70B on my H100 due to OOM, but can run 40B. Testing: Verified proper init and training with 7B, 13B and ~40B: <img width="1085" alt="Screenshot 2024-02-11 at 10 39 12 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/049037ed-63a4-4ab0-bebc-f297857aab72"> [ghstack-poisoned]
1 parent 06958ce commit c9fd0b5

File tree

4 files changed

+73
-9
lines changed

4 files changed

+73
-9
lines changed

run_llama_train.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
88
# e.g.
99
# LOG_RANK=0,1 NGPU=4 SP=2 ./run_llama_train.sh
1010

11-
MODEL=${MODEL:-"debugmodel"}
11+
MODEL=${MODEL:-"llama"}
12+
MODEL_CONF=${MODEL_CONF:-"debugmodel"}
1213
NGPU=${NGPU:-"8"}
1314
PP=${PP:-"1"}
1415
SP=${SP:-"1"}
@@ -24,6 +25,8 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}
2425

2526
torchrun --nproc_per_node=${NGPU} \
2627
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
27-
train.py --steps 10 --compile \
28-
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP}
28+
train.py --steps 10 \
29+
--model ${MODEL} --model_conf ${MODEL_CONF} \
30+
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
31+
--compile
2932
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}

torchtrain/models/llama/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
__all__ = ["Transformer"]
77

88
llama_configs = {
9-
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16),
9+
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
10+
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
1011
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
1112
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
13+
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
1214
"70B": ModelArgs(
1315
dim=8192,
1416
n_layers=80,

torchtrain/models/llama/model.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch.nn.functional as F
99
from torch import nn
1010

11+
from torchtrain.logging_utils import rank0_log
12+
1113

1214
@dataclass
1315
class ModelArgs:
@@ -183,7 +185,6 @@ class Attention(nn.Module):
183185
"""
184186

185187
def __init__(self, model_args: ModelArgs):
186-
187188
super().__init__()
188189
self.n_heads = model_args.n_heads
189190
self.n_kv_heads = (
@@ -203,6 +204,20 @@ def __init__(self, model_args: ModelArgs):
203204
model_args.n_heads * self.head_dim, model_args.dim, bias=False
204205
)
205206

207+
def reset_parameters(self, init_std):
208+
for item in (self.wq, self.wk, self.wv):
209+
nn.init.trunc_normal_(
210+
item.weight,
211+
mean=0.0,
212+
std=0.02,
213+
)
214+
215+
nn.init.trunc_normal_(
216+
self.wo.weight,
217+
mean=0.0,
218+
std=init_std,
219+
)
220+
206221
def forward(
207222
self,
208223
x: torch.Tensor,
@@ -277,7 +292,6 @@ def __init__(
277292
multiple_of: int,
278293
ffn_dim_multiplier: Optional[float],
279294
):
280-
281295
super().__init__()
282296
hidden_dim = int(2 * hidden_dim / 3)
283297
# custom dim factor multiplier
@@ -292,6 +306,20 @@ def __init__(
292306
def forward(self, x):
293307
return self.w2(F.silu(self.w1(x)) * self.w3(x))
294308

309+
def reset_parameters(self, init_std):
310+
nn.init.trunc_normal_(
311+
self.w1.weight,
312+
mean=0.0,
313+
std=0.02,
314+
)
315+
316+
for item in (self.w2, self.w3):
317+
nn.init.trunc_normal_(
318+
item.weight,
319+
mean=0.0,
320+
std=init_std,
321+
)
322+
295323

296324
class RotaryEmbedding(nn.Module):
297325
"""
@@ -350,7 +378,6 @@ class TransformerBlock(nn.Module):
350378
"""
351379

352380
def __init__(self, layer_id: int, model_args: ModelArgs):
353-
354381
super().__init__()
355382
self.n_heads = model_args.n_heads
356383
self.dim = model_args.dim
@@ -362,8 +389,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
362389
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
363390
)
364391
self.layer_id = layer_id
392+
self.num_layers = model_args.n_layers
365393
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
366394
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
395+
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
367396

368397
def forward(
369398
self,
@@ -385,6 +414,14 @@ def forward(
385414
out = h + self.feed_forward(self.ffn_norm(h))
386415
return out
387416

417+
def reset_parameters(self):
418+
"""reset params and norms for entire block"""
419+
self.attention_norm.reset_parameters()
420+
self.ffn_norm.reset_parameters()
421+
422+
self.attention.reset_parameters(self.weight_init_std)
423+
self.feed_forward.reset_parameters(self.weight_init_std)
424+
388425

389426
class Transformer(nn.Module):
390427
"""
@@ -406,11 +443,11 @@ class Transformer(nn.Module):
406443
"""
407444

408445
def __init__(self, model_args: ModelArgs):
409-
410446
super().__init__()
411447
self.model_args = model_args
412448
self.vocab_size = model_args.vocab_size
413449
self.n_layers = model_args.n_layers
450+
self.model_dim = model_args.dim
414451

415452
self.embeddings = RotaryEmbedding(model_args)
416453

@@ -421,6 +458,27 @@ def __init__(self, model_args: ModelArgs):
421458
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
422459
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
423460

461+
# init model weights
462+
self.reset_parameters()
463+
rank0_log(f"Model built with: {self.model_args}")
464+
465+
def reset_parameters(
466+
self,
467+
):
468+
for layer in self.layers:
469+
layer.reset_parameters()
470+
self.norm.reset_parameters()
471+
final_out_std = self.model_dim**-0.5
472+
cutoff_factor = 3
473+
nn.init.trunc_normal_(
474+
self.output.weight,
475+
mean=0.0,
476+
std=final_out_std,
477+
a=-cutoff_factor * final_out_std,
478+
b=cutoff_factor * final_out_std,
479+
)
480+
rank0_log("Model fully initialized via reset_params")
481+
424482
def forward(self, tokens: torch.Tensor):
425483
"""
426484
Perform a forward pass through the Transformer model.

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def main(args):
7878
world_mesh = parallel_dims.build_mesh(device_type="cuda")
7979

8080
model_name = args.model
81+
rank0_log(f"Building {model_name}")
8182
# build tokenizer
8283
tokenizer_type = model_name_to_tokenizer[model_name]
8384
tokenizer = create_tokenizer(tokenizer_type, args.tokenizer_path)
@@ -222,7 +223,7 @@ def main(args):
222223
parser.add_argument(
223224
"--warmup_pct",
224225
type=float,
225-
default=0.10,
226+
default=0.20,
226227
help="percentage of total training steps to use for warmup",
227228
)
228229
parser.add_argument(

0 commit comments

Comments
 (0)