diff --git a/test_runner.py b/test_runner.py index a7c95ce1e6..4773058657 100755 --- a/test_runner.py +++ b/test_runner.py @@ -46,6 +46,96 @@ def build_test_list(): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [], + ], + "default", + "default", + ), + OverrideDefinitions( + [ + [ + "--training.compile", + ], + ], + "1D compile", + "1d_compile", + ), + OverrideDefinitions( + [ + [ + "--training.compile", + "--activation_checkpoint.mode selective", + "--activation_checkpoint.selective_ac_option op", + ], + ], + "1D compile with selective op AC", + "1d_compile_sac_op", + ), + OverrideDefinitions( + [ + [ + "--training.tensor_parallel_degree 2", + ], + ], + "2D eager", + "2d_eager", + ), + OverrideDefinitions( + [ + [ + "--training.compile", + "--training.tensor_parallel_degree 2", + ], + ], + "2D compile", + "2d_compile", + ), + OverrideDefinitions( + [ + [ + "--training.tensor_parallel_degree 2", + "--model.norm_type=fused_rmsnorm", + ], + ], + "2D eager with fused_rmsnorm", + "2d_eager_fused_rmsnorm", + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + ], + [ + "--checkpoint.enable_checkpoint", + "--training.steps 20", + ], + ], + "Checkpoint Integration Test - Save Load Full Checkpoint", + "full_checkpoint", + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.model_weights_only", + ], + ], + "Checkpoint Integration Test - Save Model Weights Only fp32", + "model_weights_only_fp32", + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.model_weights_only", + "--checkpoint.export_dtype bfloat16", + ], + ], + "Checkpoint Integration Test - Save Model Weights Only bf16", + "model_weights_only_bf16", + ), OverrideDefinitions( [ [ @@ -53,7 +143,6 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 4", "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", "--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b", - "--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp ], ], "PP looped flexible 1f1b test", @@ -69,7 +158,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", "--training.data_parallel_degree 1", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP ], ], "PP 1D test 1f1b", @@ -85,7 +173,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", "--training.data_parallel_degree 1", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP ], ], "PP 1D test gpipe", @@ -101,7 +188,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", "--training.data_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP ], ], "PP+DP 1f1b 2D test", @@ -116,7 +202,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", "--training.data_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP ], ], "PP+DP gpipe 2D test", @@ -130,7 +215,6 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--training.tensor_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP ], ], "PP+TP 2D test", @@ -144,7 +228,6 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_split_mode tracer", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer ], ], "PP tracer frontend test", @@ -152,94 +235,6 @@ def build_test_list(): requires_seed_checkpoint=True, ngpu=2, ), - OverrideDefinitions( - [ - [], - ], - "default", - "default", - ), - OverrideDefinitions( - [ - [ - "--training.compile --model.norm_type=rmsnorm", - ], - ], - "1D compile", - "1d_compile", - ), - OverrideDefinitions( - [ - [ - "--training.compile", - "--activation_checkpoint.mode selective", - "--activation_checkpoint.selective_ac_option op", - ], - ], - "1D compile with selective op AC", - "1d_compile_sac_op", - ), - OverrideDefinitions( - [ - [ - "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - ], - ], - "2D compile", - "2d_compile", - ), - OverrideDefinitions( - [ - [ - "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - ], - ], - "Eager mode 2DParallel with rmsnorm", - "eager_2d_rmsnorm", - ), - OverrideDefinitions( - [ - [ - "--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm", - ], - ], - "Eager mode 2DParallel with fused_rmsnorm", - "eager_2d_fused_rmsnorm", - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - ], - [ - "--checkpoint.enable_checkpoint", - "--training.steps 20", - ], - ], - "Checkpoint Integration Test - Save Load Full Checkpoint", - "full_checkpoint", - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--checkpoint.model_weights_only", - ], - ], - "Checkpoint Integration Test - Save Model Weights Only fp32", - "model_weights_only_fp32", - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--checkpoint.model_weights_only", - "--checkpoint.export_dtype bfloat16", - ], - ], - "Checkpoint Integration Test - Save Model Weights Only bf16", - "model_weights_only_bf16", - ), OverrideDefinitions( [ [ @@ -248,7 +243,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_degree 2", "--training.tensor_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP ], [ "--training.steps 20", @@ -257,7 +251,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_degree 2", "--training.tensor_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP ], ], "PP+DP+TP 3D test with save/load resume ckpt", @@ -272,7 +265,6 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 4", "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", "--experimental.pipeline_parallel_schedule interleaved_1f1b", - "--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp ], ], "PP looped 1f1b test", @@ -292,21 +284,21 @@ def build_test_list(): OverrideDefinitions( [ [ - "--memory_estimation.enabled --model.norm_type rmsnorm", + "--training.data_parallel_type ddp", ] ], - "FSDP2 Memory Tracking and Estimation", - "fsdp2_mem_tracker", + "DDP", + "ddp", ngpu=4, ), OverrideDefinitions( [ [ - "--training.data_parallel_type ddp", + "--memory_estimation.enabled", ] ], - "DDP", - "ddp", + "FSDP2 Memory Tracking and Estimation", + "fsdp2_mem_tracker", ngpu=4, ), ] diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2bc37bfbf0..56080e0d51 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -165,7 +165,7 @@ def __init__(self): "--model.norm_type", type=str, default="rmsnorm", - help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]", + help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]", ) self.parser.add_argument( "--model.tokenizer_path", diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 798c7c4dbb..266453301f 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -24,7 +24,7 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6): Args: norm_type (str): The type of normalization layer to build. - Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm + Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm dim (int): The dimension of the normalization layer. eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. @@ -42,13 +42,6 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6): return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) elif norm_type == "rmsnorm": return RMSNorm(dim, eps=eps) - elif norm_type == "compiled_rmsnorm": - import warnings - - warnings.warn( - "compiled_rmsnorm is currently experimental and not ready to use yet." - ) - return RMSNorm(dim, eps=eps, compile=True) elif norm_type == "fused_rmsnorm": return FusedRMSNorm(dim, eps=eps) else: @@ -94,26 +87,17 @@ class RMSNorm(nn.Module): """ - def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - self.rmsnorm_fn = ( - torch.compile(self.compute_rmsnorm, fullgraph=True) - if compile - else self.compute_rmsnorm - ) - - @staticmethod - def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): - def _norm(x, eps): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - output = _norm(x.float(), eps).type_as(x) - return output * weight + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor): - return self.rmsnorm_fn(x, self.weight, self.eps) + output = self._norm(x.float()).type_as(x) + return output * self.weight def reset_parameters(self): torch.nn.init.ones_(self.weight) # type: ignore diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index e428060c80..af54721481 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -21,7 +21,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "debugmodel" -norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./test/assets/test_tiktoken.model"