Skip to content

Commit 2283df7

Browse files
authored
Update SMP v2 notebooks to use latest PyTorch 2.3.1, TSM 2.4.0 release (#4678)
* Update SMP v2 notebooks to use latest PT2.3.1-TSM2.4.0 release. * Update SMP v2 shared_scripts * Update minimum sagemaker pysdk version to 2.224
1 parent a75d5f2 commit 2283df7

File tree

8 files changed

+54
-30
lines changed

8 files changed

+54
-30
lines changed

training/distributed_training/pytorch/model_parallel_v2/gpt-neox/smp-finetuning-gpt-neox-fsdp-tp.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"%pip install --upgrade \"sagemaker>=2.212\"\n",
83+
"%pip install --upgrade \"sagemaker>=2.224\"\n",
8484
"%pip install sagemaker-experiments"
8585
]
8686
},
@@ -882,8 +882,8 @@
882882
" }\n",
883883
" },\n",
884884
" },\n",
885-
" py_version=\"py310\",\n",
886-
" framework_version=\"2.2.0\",\n",
885+
" py_version=\"py311\",\n",
886+
" framework_version=\"2.3.1\",\n",
887887
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
888888
" output_path=s3_output_bucket,\n",
889889
" max_run=86400,\n",

training/distributed_training/pytorch/model_parallel_v2/gpt-neox/smp-train-gpt-neox-fsdp-tp.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"metadata": {},
7575
"outputs": [],
7676
"source": [
77-
"%pip install --upgrade \"sagemaker>=2.212\"\n",
77+
"%pip install --upgrade \"sagemaker>=2.224\"\n",
7878
"%pip install sagemaker-experiments"
7979
]
8080
},
@@ -873,8 +873,8 @@
873873
" }\n",
874874
" },\n",
875875
" },\n",
876-
" py_version=\"py310\",\n",
877-
" framework_version=\"2.2.0\",\n",
876+
" py_version=\"py311\",\n",
877+
" framework_version=\"2.3.1\",\n",
878878
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
879879
" output_path=s3_output_bucket,\n",
880880
" max_run=86400,\n",
@@ -955,8 +955,8 @@
955955
" }\n",
956956
" },\n",
957957
" },\n",
958-
" py_version=\"py310\",\n",
959-
" framework_version=\"2.2.0\",\n",
958+
" py_version=\"py311\",\n",
959+
" framework_version=\"2.3.1\",\n",
960960
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
961961
" output_path=s3_output_bucket,\n",
962962
" max_run=86400,\n",

training/distributed_training/pytorch/model_parallel_v2/llama_v2/smp-finetuning-llama-fsdp-tp.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"%pip install --upgrade \"sagemaker>=2.212\"\n",
83+
"%pip install --upgrade \"sagemaker>=2.224\"\n",
8484
"%pip install sagemaker-experiments"
8585
]
8686
},
@@ -867,8 +867,8 @@
867867
" }\n",
868868
" },\n",
869869
" },\n",
870-
" py_version=\"py310\",\n",
871-
" framework_version=\"2.2.0\",\n",
870+
" py_version=\"py311\",\n",
871+
" framework_version=\"2.3.1\",\n",
872872
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
873873
" output_path=s3_output_bucket,\n",
874874
" max_run=86400,\n",

training/distributed_training/pytorch/model_parallel_v2/llama_v2/smp-train-llama-fsdp-tp-fp8.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"metadata": {},
7575
"outputs": [],
7676
"source": [
77-
"%pip install --upgrade \"sagemaker>=2.212\"\n",
77+
"%pip install --upgrade \"sagemaker>=2.224\"\n",
7878
"%pip install sagemaker-experiments"
7979
]
8080
},
@@ -831,8 +831,8 @@
831831
" }\n",
832832
" },\n",
833833
" },\n",
834-
" py_version=\"py310\",\n",
835-
" framework_version=\"2.2.0\",\n",
834+
" py_version=\"py311\",\n",
835+
" framework_version=\"2.3.1\",\n",
836836
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
837837
" output_path=s3_output_bucket,\n",
838838
" max_run=86400,\n",
@@ -913,8 +913,8 @@
913913
" }\n",
914914
" },\n",
915915
" },\n",
916-
" py_version=\"py310\",\n",
917-
" framework_version=\"2.2.0\",\n",
916+
" py_version=\"py311\",\n",
917+
" framework_version=\"2.3.1\",\n",
918918
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
919919
" output_path=s3_output_bucket,\n",
920920
" max_run=86400,\n",

training/distributed_training/pytorch/model_parallel_v2/mixtral/smp-train-mixtral-fsdp-ep.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"metadata": {},
7575
"outputs": [],
7676
"source": [
77-
"%pip install --upgrade \"sagemaker>=2.215\"\n",
77+
"%pip install --upgrade \"sagemaker>=2.224\"\n",
7878
"%pip install sagemaker-experiments"
7979
]
8080
},
@@ -916,8 +916,8 @@
916916
" }\n",
917917
" },\n",
918918
" },\n",
919-
" py_version=\"py310\",\n",
920-
" framework_version=\"2.2.0\",\n",
919+
" py_version=\"py311\",\n",
920+
" framework_version=\"2.3.1\",\n",
921921
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
922922
" output_path=s3_output_bucket,\n",
923923
" max_run=86400,\n",
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
accelerate>=0.12.0
2-
datasets>=2.16.1
2+
datasets>=2.19.1
33
einops
44
evaluate
55
expecttest
6-
flash-attn>=2.3.6
6+
flash-attn>=2.3.6,<2.4
77
h5py
88
humanize
99
hypothesis
@@ -14,4 +14,4 @@ protobuf
1414
scikit-learn
1515
sentencepiece!=0.1.92
1616
tensorboard
17-
transformers>=4.37.1
17+
transformers>=4.40.1

training/distributed_training/pytorch/model_parallel_v2/shared-scripts/train_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def main(args):
397397
len(args.num_kept_checkpoints),
398398
)
399399
if len(set(ckpt_lens)) != 1:
400-
raise ValueError(f"Len mismtach for checkpoint dir, freq vs num to keep: {ckpt_lens}.")
400+
raise ValueError(f"Len mismatch for checkpoint dir, freq vs num to keep: {ckpt_lens}.")
401401

402402
if args.distributed_backend == "smddp":
403403
import smdistributed.dataparallel.torch.torch_smddp # pylint: disable=unused-import

training/distributed_training/pytorch/model_parallel_v2/shared-scripts/train_utils.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,22 @@ def compute_num_params(model):
3434

3535

3636
def compute_tflops(args, global_batch_size, step_time, world_size):
37-
# Based on
37+
# Based on
3838
# https://github.com/NVIDIA/Megatron-LM/blob/ba773259dbe5735fbd91ca41e7f4ded60b335c52/megatron/training/training.py#L65
39-
num_experts_routed_to = 1 if args.moe > 1 else args.num_experts_per_tok
40-
if args.num_key_value_heads is None:
39+
# Attention projection size.
40+
kv_channels = args.hidden_width // args.num_heads
41+
query_projection_size = kv_channels * args.num_heads
42+
query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_width
43+
44+
# Group Query Attention.
45+
if not args.num_key_value_heads:
4146
args.num_key_value_heads = args.num_heads
47+
48+
# MoE.
49+
num_experts_routed_to = 1 if args.moe == 0 else args.num_experts_per_tok
50+
gated_linear_multiplier = 3/2 if args.moe > 0 else 1
51+
52+
# Compute the number of floating point operations
4253
num_flops = (
4354
12
4455
* global_batch_size
@@ -47,13 +58,26 @@ def compute_tflops(args, global_batch_size, step_time, world_size):
4758
* args.hidden_width
4859
* args.hidden_width
4960
* (
50-
1
51-
+ ((args.intermediate_size / args.hidden_width) * num_experts_routed_to)
52-
+ (args.num_key_value_heads / args.num_heads)
53-
+ (args.max_context_width / args.hidden_width)
61+
# Attention.
62+
(
63+
(
64+
1
65+
+ (args.num_key_value_heads / args.num_heads)
66+
+ (args.max_context_width / args.hidden_width)
67+
) * query_projection_to_hidden_size_ratio
68+
)
69+
# MLP.
70+
+ (
71+
(args.intermediate_size / args.hidden_width)
72+
* num_experts_routed_to
73+
* gated_linear_multiplier
74+
)
75+
# Logit.
5476
+ (args.vocab_size / (2 * args.num_layers * args.hidden_width))
5577
)
5678
)
79+
80+
# Convert to TFLOPs per GPU
5781
tflops_per_gpu = num_flops / (
5882
step_time * 10**12 * world_size)
5983
return tflops_per_gpu

0 commit comments

Comments
 (0)