Skip to content

Commit 26f57c2

Browse files
committed
Update SMP v2 shared_scripts
1 parent 37a6874 commit 26f57c2

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
accelerate>=0.12.0
2-
datasets>=2.16.1
2+
datasets>=2.19.1
33
einops
44
evaluate
55
expecttest
6-
flash-attn>=2.3.6
6+
flash-attn>=2.3.6,<2.4
77
h5py
88
humanize
99
hypothesis
@@ -14,4 +14,4 @@ protobuf
1414
scikit-learn
1515
sentencepiece!=0.1.92
1616
tensorboard
17-
transformers>=4.37.1
17+
transformers>=4.40.1

training/distributed_training/pytorch/model_parallel_v2/shared-scripts/train_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def main(args):
397397
len(args.num_kept_checkpoints),
398398
)
399399
if len(set(ckpt_lens)) != 1:
400-
raise ValueError(f"Len mismtach for checkpoint dir, freq vs num to keep: {ckpt_lens}.")
400+
raise ValueError(f"Len mismatch for checkpoint dir, freq vs num to keep: {ckpt_lens}.")
401401

402402
if args.distributed_backend == "smddp":
403403
import smdistributed.dataparallel.torch.torch_smddp # pylint: disable=unused-import

training/distributed_training/pytorch/model_parallel_v2/shared-scripts/train_utils.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,22 @@ def compute_num_params(model):
3434

3535

3636
def compute_tflops(args, global_batch_size, step_time, world_size):
37-
# Based on
37+
# Based on
3838
# https://github.com/NVIDIA/Megatron-LM/blob/ba773259dbe5735fbd91ca41e7f4ded60b335c52/megatron/training/training.py#L65
39-
num_experts_routed_to = 1 if args.moe > 1 else args.num_experts_per_tok
40-
if args.num_key_value_heads is None:
39+
# Attention projection size.
40+
kv_channels = args.hidden_width // args.num_heads
41+
query_projection_size = kv_channels * args.num_heads
42+
query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_width
43+
44+
# Group Query Attention.
45+
if not args.num_key_value_heads:
4146
args.num_key_value_heads = args.num_heads
47+
48+
# MoE.
49+
num_experts_routed_to = 1 if args.moe == 0 else args.num_experts_per_tok
50+
gated_linear_multiplier = 3/2 if args.moe > 0 else 1
51+
52+
# Compute the number of floating point operations
4253
num_flops = (
4354
12
4455
* global_batch_size
@@ -47,13 +58,26 @@ def compute_tflops(args, global_batch_size, step_time, world_size):
4758
* args.hidden_width
4859
* args.hidden_width
4960
* (
50-
1
51-
+ ((args.intermediate_size / args.hidden_width) * num_experts_routed_to)
52-
+ (args.num_key_value_heads / args.num_heads)
53-
+ (args.max_context_width / args.hidden_width)
61+
# Attention.
62+
(
63+
(
64+
1
65+
+ (args.num_key_value_heads / args.num_heads)
66+
+ (args.max_context_width / args.hidden_width)
67+
) * query_projection_to_hidden_size_ratio
68+
)
69+
# MLP.
70+
+ (
71+
(args.intermediate_size / args.hidden_width)
72+
* num_experts_routed_to
73+
* gated_linear_multiplier
74+
)
75+
# Logit.
5476
+ (args.vocab_size / (2 * args.num_layers * args.hidden_width))
5577
)
5678
)
79+
80+
# Convert to TFLOPs per GPU
5781
tflops_per_gpu = num_flops / (
5882
step_time * 10**12 * world_size)
5983
return tflops_per_gpu

0 commit comments

Comments
 (0)