99
1010import copy
1111from collections import defaultdict
12- from typing import Dict , Tuple
12+ from typing import Tuple , TYPE_CHECKING , Union
1313
1414import torch
15+ import torch .nn as nn
16+ from torch .distributed import DeviceMesh
1517
1618from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
1719from torch .distributed ._tensor import Replicate , Shard
2931
3032from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
3133from torchtitan .logging_utils import logger
34+ from torchtitan .models .llama .model import ModelArgs
3235from torchtitan .parallelisms .pipelining_utils import stage_ids_this_rank
3336
37+ if TYPE_CHECKING :
38+ from torchtitan .parallelisms import ParallelDims
39+
40+
41+ DeviceType = Union [int , str , torch .device ]
42+
3443# for selective AC
3544no_recompute_list = {
3645 torch .ops .aten .mm .default ,
@@ -107,23 +116,27 @@ def get_tp_parallel_strategy(
107116
108117
109118def pipeline_llama (
110- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
119+ model : nn .Module ,
120+ world_mesh : DeviceMesh ,
121+ parallel_dims : "ParallelDims" ,
122+ job_config : JobConfig ,
123+ device : DeviceType ,
124+ model_config : ModelArgs ,
111125):
112- if job_config .experimental .pipeline_parallel_split_mode == "manual" :
126+ split_mode = job_config .experimental .pipeline_parallel_split_mode
127+ if split_mode == "manual" :
113128 return pipeline_llama_manual (
114129 model , world_mesh , parallel_dims , job_config , device , model_config
115130 )
116- elif job_config . experimental . pipeline_parallel_split_mode == "tracer" :
131+ elif split_mode == "tracer" :
117132 return pipeline_llama_tracer (
118133 model , world_mesh , parallel_dims , job_config , device , model_config
119134 )
120135 else :
121- raise NotImplementedError (
122- f"{ job_config .experimental .pipeline_parallel_split_mode } is not a valid split mode"
123- )
136+ raise NotImplementedError (f"{ split_mode } is not a valid split mode" )
124137
125138
126- def _llama_trace_input (job_config , model_config , device = "meta" ):
139+ def _llama_trace_input (job_config : JobConfig , model_config : ModelArgs , device = "meta" ):
127140 """Get meta tensors with the right input shapes used for tracing"""
128141 tokens_shape = (job_config .training .batch_size , job_config .training .seq_len )
129142 tokens = torch .randint (
@@ -135,18 +148,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
135148def _mixed_precision_dtype (
136149 job_config : JobConfig , parallel_dims , default : torch .dtype = torch .float32
137150) -> torch .dtype :
138- """Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
151+ """Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
139152 mp_arg = job_config .training .mixed_precision_param
140153 return TORCH_DTYPE_MAP [mp_arg ] if parallel_dims .dp_enabled else default
141154
142155
143156def pipeline_llama_manual (
144- whole_model ,
145- world_mesh ,
146- parallel_dims ,
157+ whole_model : nn . Module ,
158+ world_mesh : DeviceMesh ,
159+ parallel_dims : "ParallelDims" ,
147160 job_config : JobConfig ,
148- device ,
149- model_config : Dict ,
161+ device : DeviceType ,
162+ model_config : ModelArgs ,
150163):
151164 """
152165 This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -244,19 +257,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
244257
245258
246259def pipeline_llama_tracer (
247- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
260+ model : nn .Module ,
261+ world_mesh : DeviceMesh ,
262+ parallel_dims : "ParallelDims" ,
263+ job_config : JobConfig ,
264+ device : DeviceType ,
265+ model_config : ModelArgs ,
248266):
249- if job_config .model .norm_type == "fused_rmsnorm" :
250- # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
251- # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
267+ if _mixed_precision_dtype (job_config , parallel_dims ) != torch .float32 :
252268 raise NotImplementedError (
253- "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
254- )
255-
256- if _mixed_precision_dtype (job_config , parallel_dims ) == torch .bfloat16 :
257- raise NotImplementedError (
258- "pipeline tracer doesn't work with fsdp mixed precision currently. "
259- "To work around, edit fsdp mixed precision config to use fp32."
269+ "Pipeline tracer does not work with FSDP mixed precision yet. "
270+ "To work around, set mixed_precision_param to float32."
260271 )
261272
262273 pp_mesh = world_mesh ["pp" ]
@@ -292,10 +303,13 @@ def pipeline_llama_tracer(
292303 return (stages , models )
293304
294305
295- def apply_tp (model , world_mesh , parallel_dims , job_config : JobConfig ):
296- """
297- Apply tensor parallelism.
298- """
306+ def apply_tp (
307+ model : nn .Module ,
308+ world_mesh : DeviceMesh ,
309+ parallel_dims : "ParallelDims" ,
310+ job_config : JobConfig ,
311+ ):
312+ """Apply tensor parallelism."""
299313
300314 tp_mesh = world_mesh ["tp" ]
301315 # Parallel styles for transformer block linear weights may be different for
@@ -374,10 +388,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
374388 return model
375389
376390
377- def apply_ac (model , job_config : JobConfig ):
378- """
379- Apply activation checkpointing to the model.
380- """
391+ def apply_ac (model : nn .Module , job_config : JobConfig ):
392+ """Apply activation checkpointing to the model."""
381393
382394 ac_config = job_config .activation_checkpoint
383395
@@ -389,18 +401,10 @@ def apply_ac(model, job_config: JobConfig):
389401 return model
390402
391403
392- def apply_compile (model , job_config : JobConfig ):
393- """
394- Apply torch.compile to the model.
395- """
396-
397- if job_config .model .norm_type == "fused_rmsnorm" :
398- raise NotImplementedError (
399- "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
400- )
404+ def apply_compile (model : nn .Module , job_config : JobConfig ):
405+ """Apply torch.compile to each transformer block."""
401406
402407 for layer_id , transformer_block in model .layers .named_children ():
403- # turn on per-transformer block compile after AC wrapping and before FSDP
404408 # TODO: dynamic shape have some issues so we turn it off for now.
405409 # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
406410 # compile time.
@@ -412,10 +416,13 @@ def apply_compile(model, job_config: JobConfig):
412416 return model
413417
414418
415- def apply_dp (model , world_mesh , parallel_dims , job_config : JobConfig ):
416- """
417- Apply data parallelism (FSDP2) to the model.
418- """
419+ def apply_dp (
420+ model : nn .Module ,
421+ world_mesh : DeviceMesh ,
422+ parallel_dims : "ParallelDims" ,
423+ job_config : JobConfig ,
424+ ):
425+ """Apply data parallelism (FSDP2) to the model."""
419426
420427 dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
421428 assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
@@ -448,7 +455,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
448455 return model
449456
450457
451- def parallelize_llama (model , world_mesh , parallel_dims , job_config : JobConfig ):
458+ def parallelize_llama (
459+ model : nn .Module ,
460+ world_mesh : DeviceMesh ,
461+ parallel_dims : "ParallelDims" ,
462+ job_config : JobConfig ,
463+ ):
452464 """
453465 Apply tensor parallelism, activation checkpointing, torch.compile, and data
454466 parallelism to the model.
0 commit comments