99
1010import copy
1111from collections import defaultdict
12- from typing import Dict , Tuple
12+ from typing import Tuple , TYPE_CHECKING , Union
1313
1414import torch
15+ import torch .nn as nn
16+ from torch .distributed import DeviceMesh
1517
1618from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
1719from torch .distributed ._tensor import Replicate , Shard
2931
3032from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
3133from torchtitan .logging_utils import logger
34+ from torchtitan .models .llama .model import ModelArgs
3235from torchtitan .parallelisms .pipelining_utils import stage_ids_this_rank
3336
37+ if TYPE_CHECKING :
38+ from torchtitan .parallelisms import ParallelDims
39+
40+
41+ DeviceType = Union [int , str , torch .device ]
42+
3443# for selective AC
3544no_recompute_list = {
3645 torch .ops .aten .mm .default ,
@@ -125,23 +134,30 @@ def get_tp_parallel_strategy(
125134
126135
127136def pipeline_llama (
128- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
137+ model : nn .Module ,
138+ world_mesh : DeviceMesh ,
139+ parallel_dims : "ParallelDims" ,
140+ job_config : JobConfig ,
141+ device : DeviceType ,
142+ model_config : ModelArgs ,
129143):
130- if job_config .experimental .pipeline_parallel_split_mode == "manual" :
144+ split_mode = job_config .experimental .pipeline_parallel_split_mode
145+ valid_split_modes = ("manual" , "tracer" )
146+ if split_mode not in valid_split_modes :
147+ raise ValueError (
148+ f"Invalid split mode: { split_mode } . Valid split modes: { valid_split_modes } "
149+ )
150+ if split_mode == "manual" :
131151 return pipeline_llama_manual (
132152 model , world_mesh , parallel_dims , job_config , device , model_config
133153 )
134- elif job_config . experimental . pipeline_parallel_split_mode == "tracer" :
154+ elif split_mode == "tracer" :
135155 return pipeline_llama_tracer (
136156 model , world_mesh , parallel_dims , job_config , device , model_config
137157 )
138- else :
139- raise NotImplementedError (
140- f"{ job_config .experimental .pipeline_parallel_split_mode } is not a valid split mode"
141- )
142158
143159
144- def _llama_trace_input (job_config , model_config , device = "meta" ):
160+ def _llama_trace_input (job_config : JobConfig , model_config : ModelArgs , device = "meta" ):
145161 """Get meta tensors with the right input shapes used for tracing"""
146162 tokens_shape = (job_config .training .batch_size , job_config .training .seq_len )
147163 tokens = torch .randint (
@@ -153,18 +169,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
153169def _mixed_precision_dtype (
154170 job_config : JobConfig , parallel_dims , default : torch .dtype = torch .float32
155171) -> torch .dtype :
156- """Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
172+ """Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
157173 mp_arg = job_config .training .mixed_precision_param
158174 return TORCH_DTYPE_MAP [mp_arg ] if parallel_dims .dp_enabled else default
159175
160176
161177def pipeline_llama_manual (
162- whole_model ,
163- world_mesh ,
164- parallel_dims ,
178+ whole_model : nn . Module ,
179+ world_mesh : DeviceMesh ,
180+ parallel_dims : "ParallelDims" ,
165181 job_config : JobConfig ,
166- device ,
167- model_config : Dict ,
182+ device : DeviceType ,
183+ model_config : ModelArgs ,
168184):
169185 """
170186 This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -262,19 +278,24 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
262278
263279
264280def pipeline_llama_tracer (
265- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
281+ model : nn .Module ,
282+ world_mesh : DeviceMesh ,
283+ parallel_dims : "ParallelDims" ,
284+ job_config : JobConfig ,
285+ device : DeviceType ,
286+ model_config : ModelArgs ,
266287):
267288 if job_config .model .norm_type == "fused_rmsnorm" :
268- # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
269- # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
289+ # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
290+ # invocation stride in strict mode from `if dy.stride(-1) != 1:` in
291+ # fused_rmsnorm
270292 raise NotImplementedError (
271- "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error) . Please use layernorm or rmsnorm ."
293+ "fused_rmsnorm is not compatible with Pipeline Tracer yet . Please use rmsnorm or layernorm ."
272294 )
273-
274- if _mixed_precision_dtype (job_config , parallel_dims ) == torch .bfloat16 :
295+ if _mixed_precision_dtype (job_config , parallel_dims ) != torch .float32 :
275296 raise NotImplementedError (
276- "pipeline tracer doesn't work with fsdp mixed precision currently . "
277- "To work around, edit fsdp mixed precision config to use fp32 ."
297+ "Pipeline tracer does not work with FSDP mixed precision yet . "
298+ "To work around, set mixed_precision_param to float32 ."
278299 )
279300
280301 pp_mesh = world_mesh ["pp" ]
@@ -310,10 +331,13 @@ def pipeline_llama_tracer(
310331 return (stages , models )
311332
312333
313- def apply_tp (model , world_mesh , parallel_dims , job_config : JobConfig ):
314- """
315- Apply tensor parallelism.
316- """
334+ def apply_tp (
335+ model : nn .Module ,
336+ world_mesh : DeviceMesh ,
337+ parallel_dims : "ParallelDims" ,
338+ job_config : JobConfig ,
339+ ):
340+ """Apply tensor parallelism."""
317341
318342 tp_mesh = world_mesh ["tp" ]
319343 # Parallel styles used for transformer block linear weights and their
@@ -392,10 +416,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
392416 return model
393417
394418
395- def apply_ac (model , job_config : JobConfig ):
396- """
397- Apply activation checkpointing to the model.
398- """
419+ def apply_ac (model : nn .Module , job_config : JobConfig ):
420+ """Apply activation checkpointing to the model."""
399421
400422 ac_config = job_config .activation_checkpoint
401423
@@ -407,18 +429,15 @@ def apply_ac(model, job_config: JobConfig):
407429 return model
408430
409431
410- def apply_compile (model , job_config : JobConfig ):
411- """
412- Apply torch.compile to the model.
413- """
432+ def apply_compile (model : nn .Module , job_config : JobConfig ):
433+ """Apply torch.compile to each transformer block."""
414434
415435 if job_config .model .norm_type == "fused_rmsnorm" :
416436 raise NotImplementedError (
417- "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm ."
437+ "fused_rmsnorm is not compatible with torch.compile yet . Please use rmsnorm or layernorm ."
418438 )
419439
420440 for layer_id , transformer_block in model .layers .named_children ():
421- # turn on per-transformer block compile after AC wrapping and before FSDP
422441 # TODO: dynamic shape have some issues so we turn it off for now.
423442 # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
424443 # compile time.
@@ -430,10 +449,13 @@ def apply_compile(model, job_config: JobConfig):
430449 return model
431450
432451
433- def apply_dp (model , world_mesh , parallel_dims , job_config : JobConfig ):
434- """
435- Apply data parallelism (FSDP2) to the model.
436- """
452+ def apply_dp (
453+ model : nn .Module ,
454+ world_mesh : DeviceMesh ,
455+ parallel_dims : "ParallelDims" ,
456+ job_config : JobConfig ,
457+ ):
458+ """Apply data parallelism (FSDP2) to the model."""
437459
438460 dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
439461 assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
@@ -466,7 +488,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
466488 return model
467489
468490
469- def parallelize_llama (model , world_mesh , parallel_dims , job_config : JobConfig ):
491+ def parallelize_llama (
492+ model : nn .Module ,
493+ world_mesh : DeviceMesh ,
494+ parallel_dims : "ParallelDims" ,
495+ job_config : JobConfig ,
496+ ):
470497 """
471498 Apply tensor parallelism, activation checkpointing, torch.compile, and data
472499 parallelism to the model.
0 commit comments