Skip to content

Commit 78878f5

Browse files
authored
add configurable unique layer init, clean up lr and loss display (#64)
Small PR: 1 - add configurable init style in model_args - 'use_unique_init' will use the layer_id in the init stddev denom, otherwise uses the original init style of total layer count. (verified both work on 7B llama...not clear yet if one is better vs other). 2 - clean up lr and loss display formatting - lr display was spanning out to 12+ digits which isn't that informative, and was wrapped in list format. This PR rounds it to max of 8 digits precision and removes the []'s that were around the lr rate display. (note this is purely UI...the full float precision is still used in actual lr calcs). 3 - clean up loss display - rounds the loss display to 4 digits precision to make it more readable and informative. previously: <img width="1198" alt="Screenshot 2024-02-16 at 2 33 34 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/77733af0-42db-4fab-a047-fccc7d404278"> Now: <img width="1063" alt="Screenshot 2024-02-16 at 2 51 53 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/4eb75b98-67f4-41ec-83d8-dd84a0e8b29e">
1 parent 70be86e commit 78878f5

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

torchtrain/models/llama/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class ModelArgs:
2424

2525
max_batch_size: int = 32
2626
max_seq_len: int = 32768
27+
depth_init: bool = (
28+
True # initialization uses each unique layer_id or total model layer count
29+
)
2730

2831

2932
class RMSNorm(torch.nn.Module):
@@ -392,7 +395,11 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
392395
self.num_layers = model_args.n_layers
393396
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
394397
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
395-
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
398+
399+
if model_args.depth_init:
400+
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
401+
else:
402+
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
396403

397404
def forward(
398405
self,

train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ def main(args):
207207

208208
# log metrics
209209
if (train_state.step - 1) % args.log_freq == 0:
210-
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
211-
losses_since_last_log
210+
avg_loss, max_loss = (
211+
np.mean(losses_since_last_log),
212+
np.max(losses_since_last_log),
213+
)
214+
global_avg_loss, global_max_loss = (
215+
dist_mean(avg_loss, world_mesh),
216+
dist_max(max_loss, world_mesh),
212217
)
213-
global_avg_loss, global_max_loss = dist_mean(
214-
avg_loss, world_mesh
215-
), dist_max(max_loss, world_mesh)
216218

217219
time_delta = timer() - time_last_log
218220
wps = nwords_since_last_log / (
@@ -239,7 +241,8 @@ def main(args):
239241
time_last_log = timer()
240242

241243
rank0_log(
242-
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
244+
f"step: {train_state.step}, current loss: {round(train_state.current_loss,4)},"
245+
f" lr: {round(float(scheduler.get_last_lr()[0]), 8)}"
243246
)
244247
scheduler.step()
245248

0 commit comments

Comments
 (0)