@@ -96,16 +96,11 @@ def init_distributed(job_config):
9696 os .environ ["TORCH_NCCL_AVOID_RECORD_STREAMS" ] = "1"
9797
9898
99- def get_num_params (model : torch .nn .Module , only_trainable : bool = False ) -> int :
100- """
101- Get the total model params
102- Args : only_trainable: whether to only count trainable params
103- """
104- param_list = list (model .parameters ())
105- if only_trainable :
106- param_list = [p for p in param_list if p .requires_grad ]
107- # unique_params = {p.data_ptr(): p for p in param_list}.values()
108- return sum (p .numel () for p in param_list )
99+ def get_num_params (model : torch .nn .Module , exclude_embedding : bool = False ) -> int :
100+ num_params = sum (p .numel () for p in model .parameters ())
101+ if exclude_embedding :
102+ num_params -= model .tok_embeddings .weight .numel ()
103+ return num_params
109104
110105
111106def get_num_flop_per_token (num_params : int , model_config , seq_len ) -> int :
@@ -115,7 +110,14 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
115110 model_config .dim // model_config .n_heads ,
116111 seq_len ,
117112 )
113+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
114+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
115+ # 2. the flash attention does 1 more matmul recomputation in the backward
116+ # but recomputation should not be counted in calculating MFU (+0)
117+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
118+ # 4. we follow the convention and do not account for sparsity in causal attention
118119 flop_per_token = 6 * num_params + 12 * l * h * q * t
120+
119121 return flop_per_token
120122
121123
0 commit comments