Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 96 additions & 104 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,103 @@ def build_test_list():
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[],
],
"default",
"default",
),
OverrideDefinitions(
[
[
"--training.compile",
],
],
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2",
],
],
"2D eager",
"2d_eager",
),
OverrideDefinitions(
[
[
"--training.compile",
"--training.tensor_parallel_degree 2",
],
],
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2",
"--model.norm_type=fused_rmsnorm",
],
],
"2D eager with fused_rmsnorm",
"2d_eager_fused_rmsnorm",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
],
[
"--checkpoint.enable_checkpoint",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
"model_weights_only_bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped flexible 1f1b test",
Expand All @@ -69,7 +158,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test 1f1b",
Expand All @@ -85,7 +173,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test gpipe",
Expand All @@ -101,7 +188,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP+DP 1f1b 2D test",
Expand All @@ -116,7 +202,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP+DP gpipe 2D test",
Expand All @@ -130,7 +215,6 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+TP 2D test",
Expand All @@ -144,102 +228,13 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_split_mode tracer",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[],
],
"default",
"default",
),
OverrideDefinitions(
[
[
"--training.compile --model.norm_type=rmsnorm",
],
],
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"Eager mode 2DParallel with rmsnorm",
"eager_2d_rmsnorm",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
],
],
"Eager mode 2DParallel with fused_rmsnorm",
"eager_2d_fused_rmsnorm",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
],
[
"--checkpoint.enable_checkpoint",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
"model_weights_only_bf16",
),
OverrideDefinitions(
[
[
Expand All @@ -248,7 +243,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
[
"--training.steps 20",
Expand All @@ -257,7 +251,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+DP+TP 3D test with save/load resume ckpt",
Expand All @@ -272,7 +265,6 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped 1f1b test",
Expand All @@ -292,21 +284,21 @@ def build_test_list():
OverrideDefinitions(
[
[
"--memory_estimation.enabled --model.norm_type rmsnorm",
"--training.data_parallel_type ddp",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
"DDP",
"ddp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
"--memory_estimation.enabled",
]
],
"DDP",
"ddp",
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=4,
),
]
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down
28 changes: 6 additions & 22 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):

Args:
norm_type (str): The type of normalization layer to build.
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.

Expand All @@ -42,13 +42,6 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
import warnings

warnings.warn(
"compiled_rmsnorm is currently experimental and not ready to use yet."
)
return RMSNorm(dim, eps=eps, compile=True)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps)
else:
Expand Down Expand Up @@ -94,26 +87,17 @@ class RMSNorm(nn.Module):

"""

def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm_fn = (
torch.compile(self.compute_rmsnorm, fullgraph=True)
if compile
else self.compute_rmsnorm
)

@staticmethod
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
def _norm(x, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

output = _norm(x.float(), eps).type_as(x)
return output * weight
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor):
return self.rmsnorm_fn(x, self.weight, self.eps)
output = self._norm(x.float()).type_as(x)
return output * self.weight

def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "debugmodel"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

Expand Down