@@ -34,11 +34,22 @@ def compute_num_params(model):
3434
3535
3636def compute_tflops (args , global_batch_size , step_time , world_size ):
37- # Based on
37+ # Based on
3838 # https://github.com/NVIDIA/Megatron-LM/blob/ba773259dbe5735fbd91ca41e7f4ded60b335c52/megatron/training/training.py#L65
39- num_experts_routed_to = 1 if args .moe > 1 else args .num_experts_per_tok
40- if args .num_key_value_heads is None :
39+ # Attention projection size.
40+ kv_channels = args .hidden_width // args .num_heads
41+ query_projection_size = kv_channels * args .num_heads
42+ query_projection_to_hidden_size_ratio = query_projection_size / args .hidden_width
43+
44+ # Group Query Attention.
45+ if not args .num_key_value_heads :
4146 args .num_key_value_heads = args .num_heads
47+
48+ # MoE.
49+ num_experts_routed_to = 1 if args .moe == 0 else args .num_experts_per_tok
50+ gated_linear_multiplier = 3 / 2 if args .moe > 0 else 1
51+
52+ # Compute the number of floating point operations
4253 num_flops = (
4354 12
4455 * global_batch_size
@@ -47,13 +58,26 @@ def compute_tflops(args, global_batch_size, step_time, world_size):
4758 * args .hidden_width
4859 * args .hidden_width
4960 * (
50- 1
51- + ((args .intermediate_size / args .hidden_width ) * num_experts_routed_to )
52- + (args .num_key_value_heads / args .num_heads )
53- + (args .max_context_width / args .hidden_width )
61+ # Attention.
62+ (
63+ (
64+ 1
65+ + (args .num_key_value_heads / args .num_heads )
66+ + (args .max_context_width / args .hidden_width )
67+ ) * query_projection_to_hidden_size_ratio
68+ )
69+ # MLP.
70+ + (
71+ (args .intermediate_size / args .hidden_width )
72+ * num_experts_routed_to
73+ * gated_linear_multiplier
74+ )
75+ # Logit.
5476 + (args .vocab_size / (2 * args .num_layers * args .hidden_width ))
5577 )
5678 )
79+
80+ # Convert to TFLOPs per GPU
5781 tflops_per_gpu = num_flops / (
5882 step_time * 10 ** 12 * world_size )
5983 return tflops_per_gpu
0 commit comments