diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 689335c4ef..70df32e2ce 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -109,13 +109,15 @@ def get_num_params(model: torch.nn.Module, only_trainable: bool = False) -> int: def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: - l, h, q, t = ( + l, h, q, v, d, t = ( model_config.n_layers, model_config.n_heads, model_config.dim // model_config.n_heads, + model_config.vocab_size, + model_config.dim, seq_len, ) - flop_per_token = 6 * num_params + 12 * l * h * q * t + flop_per_token = 6 * (num_params - v * d) + 7 * l * h * q * t return flop_per_token