Skip to content

Commit 66e4e02

Browse files
committed
remove compiled_rmsnorm
[ghstack-poisoned]
1 parent 81c74c7 commit 66e4e02

File tree

4 files changed

+104
-128
lines changed

4 files changed

+104
-128
lines changed

test_runner.py

Lines changed: 96 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,103 @@ def build_test_list():
4646
"""
4747
integration_tests_flavors = defaultdict(list)
4848
integration_tests_flavors["debug_model.toml"] = [
49+
OverrideDefinitions(
50+
[
51+
[],
52+
],
53+
"default",
54+
"default",
55+
),
56+
OverrideDefinitions(
57+
[
58+
[
59+
"--training.compile",
60+
],
61+
],
62+
"1D compile",
63+
"1d_compile",
64+
),
65+
OverrideDefinitions(
66+
[
67+
[
68+
"--training.compile",
69+
"--activation_checkpoint.mode selective",
70+
"--activation_checkpoint.selective_ac_option op",
71+
],
72+
],
73+
"1D compile with selective op AC",
74+
"1d_compile_sac_op",
75+
),
76+
OverrideDefinitions(
77+
[
78+
[
79+
"--training.tensor_parallel_degree 2",
80+
],
81+
],
82+
"2D eager",
83+
"2d_eager",
84+
),
85+
OverrideDefinitions(
86+
[
87+
[
88+
"--training.compile",
89+
"--training.tensor_parallel_degree 2",
90+
],
91+
],
92+
"2D compile",
93+
"2d_compile",
94+
),
95+
OverrideDefinitions(
96+
[
97+
[
98+
"--training.tensor_parallel_degree 2",
99+
"--model.norm_type=fused_rmsnorm",
100+
],
101+
],
102+
"2D eager with fused_rmsnorm",
103+
"2d_eager_fused_rmsnorm",
104+
),
105+
OverrideDefinitions(
106+
[
107+
[
108+
"--checkpoint.enable_checkpoint",
109+
],
110+
[
111+
"--checkpoint.enable_checkpoint",
112+
"--training.steps 20",
113+
],
114+
],
115+
"Checkpoint Integration Test - Save Load Full Checkpoint",
116+
"full_checkpoint",
117+
),
118+
OverrideDefinitions(
119+
[
120+
[
121+
"--checkpoint.enable_checkpoint",
122+
"--checkpoint.model_weights_only",
123+
],
124+
],
125+
"Checkpoint Integration Test - Save Model Weights Only fp32",
126+
"model_weights_only_fp32",
127+
),
128+
OverrideDefinitions(
129+
[
130+
[
131+
"--checkpoint.enable_checkpoint",
132+
"--checkpoint.model_weights_only",
133+
"--checkpoint.export_dtype bfloat16",
134+
],
135+
],
136+
"Checkpoint Integration Test - Save Model Weights Only bf16",
137+
"model_weights_only_bf16",
138+
),
49139
OverrideDefinitions(
50140
[
51141
[
52142
"--checkpoint.enable_checkpoint",
53143
"--experimental.pipeline_parallel_degree 4",
54144
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
55145
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
56-
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
57146
],
58147
],
59148
"PP looped flexible 1f1b test",
@@ -69,7 +158,6 @@ def build_test_list():
69158
"--experimental.pipeline_parallel_split_points layers.4",
70159
"--experimental.pipeline_parallel_schedule 1f1b",
71160
"--training.data_parallel_degree 1",
72-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
73161
],
74162
],
75163
"PP 1D test 1f1b",
@@ -85,7 +173,6 @@ def build_test_list():
85173
"--experimental.pipeline_parallel_split_points layers.4",
86174
"--experimental.pipeline_parallel_schedule gpipe",
87175
"--training.data_parallel_degree 1",
88-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
89176
],
90177
],
91178
"PP 1D test gpipe",
@@ -101,7 +188,6 @@ def build_test_list():
101188
"--experimental.pipeline_parallel_split_points layers.4",
102189
"--experimental.pipeline_parallel_schedule 1f1b",
103190
"--training.data_parallel_degree 2",
104-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
105191
],
106192
],
107193
"PP+DP 1f1b 2D test",
@@ -116,7 +202,6 @@ def build_test_list():
116202
"--experimental.pipeline_parallel_split_points layers.4",
117203
"--experimental.pipeline_parallel_schedule gpipe",
118204
"--training.data_parallel_degree 2",
119-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
120205
],
121206
],
122207
"PP+DP gpipe 2D test",
@@ -130,7 +215,6 @@ def build_test_list():
130215
"--experimental.pipeline_parallel_degree 2",
131216
"--experimental.pipeline_parallel_split_points layers.4",
132217
"--training.tensor_parallel_degree 2",
133-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
134218
],
135219
],
136220
"PP+TP 2D test",
@@ -144,102 +228,13 @@ def build_test_list():
144228
"--experimental.pipeline_parallel_degree 2",
145229
"--experimental.pipeline_parallel_split_points layers.4",
146230
"--experimental.pipeline_parallel_split_mode tracer",
147-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
148231
],
149232
],
150233
"PP tracer frontend test",
151234
"pp_tracer",
152235
requires_seed_checkpoint=True,
153236
ngpu=2,
154237
),
155-
OverrideDefinitions(
156-
[
157-
[],
158-
],
159-
"default",
160-
"default",
161-
),
162-
OverrideDefinitions(
163-
[
164-
[
165-
"--training.compile --model.norm_type=rmsnorm",
166-
],
167-
],
168-
"1D compile",
169-
"1d_compile",
170-
),
171-
OverrideDefinitions(
172-
[
173-
[
174-
"--training.compile",
175-
"--activation_checkpoint.mode selective",
176-
"--activation_checkpoint.selective_ac_option op",
177-
],
178-
],
179-
"1D compile with selective op AC",
180-
"1d_compile_sac_op",
181-
),
182-
OverrideDefinitions(
183-
[
184-
[
185-
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
186-
],
187-
],
188-
"2D compile",
189-
"2d_compile",
190-
),
191-
OverrideDefinitions(
192-
[
193-
[
194-
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
195-
],
196-
],
197-
"Eager mode 2DParallel with rmsnorm",
198-
"eager_2d_rmsnorm",
199-
),
200-
OverrideDefinitions(
201-
[
202-
[
203-
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
204-
],
205-
],
206-
"Eager mode 2DParallel with fused_rmsnorm",
207-
"eager_2d_fused_rmsnorm",
208-
),
209-
OverrideDefinitions(
210-
[
211-
[
212-
"--checkpoint.enable_checkpoint",
213-
],
214-
[
215-
"--checkpoint.enable_checkpoint",
216-
"--training.steps 20",
217-
],
218-
],
219-
"Checkpoint Integration Test - Save Load Full Checkpoint",
220-
"full_checkpoint",
221-
),
222-
OverrideDefinitions(
223-
[
224-
[
225-
"--checkpoint.enable_checkpoint",
226-
"--checkpoint.model_weights_only",
227-
],
228-
],
229-
"Checkpoint Integration Test - Save Model Weights Only fp32",
230-
"model_weights_only_fp32",
231-
),
232-
OverrideDefinitions(
233-
[
234-
[
235-
"--checkpoint.enable_checkpoint",
236-
"--checkpoint.model_weights_only",
237-
"--checkpoint.export_dtype bfloat16",
238-
],
239-
],
240-
"Checkpoint Integration Test - Save Model Weights Only bf16",
241-
"model_weights_only_bf16",
242-
),
243238
OverrideDefinitions(
244239
[
245240
[
@@ -248,7 +243,6 @@ def build_test_list():
248243
"--experimental.pipeline_parallel_split_points layers.4",
249244
"--training.data_parallel_degree 2",
250245
"--training.tensor_parallel_degree 2",
251-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
252246
],
253247
[
254248
"--training.steps 20",
@@ -257,7 +251,6 @@ def build_test_list():
257251
"--experimental.pipeline_parallel_split_points layers.4",
258252
"--training.data_parallel_degree 2",
259253
"--training.tensor_parallel_degree 2",
260-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
261254
],
262255
],
263256
"PP+DP+TP 3D test with save/load resume ckpt",
@@ -272,7 +265,6 @@ def build_test_list():
272265
"--experimental.pipeline_parallel_degree 4",
273266
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
274267
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
275-
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
276268
],
277269
],
278270
"PP looped 1f1b test",
@@ -292,21 +284,21 @@ def build_test_list():
292284
OverrideDefinitions(
293285
[
294286
[
295-
"--memory_estimation.enabled --model.norm_type rmsnorm",
287+
"--training.data_parallel_type ddp",
296288
]
297289
],
298-
"FSDP2 Memory Tracking and Estimation",
299-
"fsdp2_mem_tracker",
290+
"DDP",
291+
"ddp",
300292
ngpu=4,
301293
),
302294
OverrideDefinitions(
303295
[
304296
[
305-
"--training.data_parallel_type ddp",
297+
"--memory_estimation.enabled",
306298
]
307299
],
308-
"DDP",
309-
"ddp",
300+
"FSDP2 Memory Tracking and Estimation",
301+
"fsdp2_mem_tracker",
310302
ngpu=4,
311303
),
312304
]

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(self):
165165
"--model.norm_type",
166166
type=str,
167167
default="rmsnorm",
168-
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]",
168+
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
169169
)
170170
self.parser.add_argument(
171171
"--model.tokenizer_path",

torchtitan/models/norms.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
2424
2525
Args:
2626
norm_type (str): The type of normalization layer to build.
27-
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
27+
Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
2828
dim (int): The dimension of the normalization layer.
2929
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
3030
@@ -42,13 +42,6 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
4242
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
4343
elif norm_type == "rmsnorm":
4444
return RMSNorm(dim, eps=eps)
45-
elif norm_type == "compiled_rmsnorm":
46-
import warnings
47-
48-
warnings.warn(
49-
"compiled_rmsnorm is currently experimental and not ready to use yet."
50-
)
51-
return RMSNorm(dim, eps=eps, compile=True)
5245
elif norm_type == "fused_rmsnorm":
5346
return FusedRMSNorm(dim, eps=eps)
5447
else:
@@ -94,26 +87,17 @@ class RMSNorm(nn.Module):
9487
9588
"""
9689

97-
def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
90+
def __init__(self, dim: int, eps: float = 1e-6):
9891
super().__init__()
9992
self.eps = eps
10093
self.weight = nn.Parameter(torch.ones(dim))
101-
self.rmsnorm_fn = (
102-
torch.compile(self.compute_rmsnorm, fullgraph=True)
103-
if compile
104-
else self.compute_rmsnorm
105-
)
106-
107-
@staticmethod
108-
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
109-
def _norm(x, eps):
110-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
11194

112-
output = _norm(x.float(), eps).type_as(x)
113-
return output * weight
95+
def _norm(self, x: torch.Tensor):
96+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
11497

11598
def forward(self, x: torch.Tensor):
116-
return self.rmsnorm_fn(x, self.weight, self.eps)
99+
output = self._norm(x.float()).type_as(x)
100+
return output * self.weight
117101

118102
def reset_parameters(self):
119103
torch.nn.init.ones_(self.weight) # type: ignore

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ save_tb_folder = "tb"
2121
[model]
2222
name = "llama3"
2323
flavor = "debugmodel"
24-
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
24+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
2525
# test tokenizer.model, for debug purpose only
2626
tokenizer_path = "./test/assets/test_tiktoken.model"
2727

0 commit comments

Comments
 (0)