From 1581154c18171d50cf8cb6ff4a3f9c62f3979078 Mon Sep 17 00:00:00 2001 From: Kevin Yin Date: Thu, 25 Apr 2024 17:01:39 -0700 Subject: [PATCH] Change MFU calculation Uses Flash's FLOP counter of 7*: https://github.com/Dao-AILab/flash-attention/blob/23e8fa5a263d1c7122bc46a86ef32030ee7130f9/benchmarks/benchmark_flash_attention.py#L27 Excludes vocab embedding from FLOPS. Did not test. --- torchtitan/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 689335c4ef..70df32e2ce 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -109,13 +109,15 @@ def get_num_params(model: torch.nn.Module, only_trainable: bool = False) -> int: def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: - l, h, q, t = ( + l, h, q, v, d, t = ( model_config.n_layers, model_config.n_heads, model_config.dim // model_config.n_heads, + model_config.vocab_size, + model_config.dim, seq_len, ) - flop_per_token = 6 * num_params + 12 * l * h * q * t + flop_per_token = 6 * (num_params - v * d) + 7 * l * h * q * t return flop_per_token