88# llama model, i.e. activation checkpointing, etc.
99
1010from collections import defaultdict
11- from typing import Tuple
11+ from typing import List , Tuple
1212
1313import torch
1414
@@ -138,7 +138,112 @@ def get_tp_parallel_strategy(
138138 return RowwiseParallel , ColwiseParallel
139139
140140
141- def apply_pipeline_parallelism (model , world_mesh , parallel_dims , job_config : JobConfig ):
141+ class DummyTransformerLayer (torch .nn .Module ):
142+ def forward (self , input , freqs_cis ):
143+ return input
144+
145+
146+ class TransformerChunk (torch .nn .Module ):
147+ def __init__ (
148+ self ,
149+ orig_model , # : Transformer,
150+ this_stage_layer_names : List [str ],
151+ device ,
152+ input_seqlen : int ,
153+ ):
154+ super ().__init__ ()
155+ self .tok_embeddings = None
156+
157+ # inferring seqlen from forward(input) only works on stage0, bc on later stages
158+ # the hidden state input may have reduced seqlen due to TP. We need to use the
159+ # original (full) seqlen for freqs_cis to be correct.
160+ self .input_seqlen = input_seqlen
161+
162+ if "tok_embeddings" in this_stage_layer_names :
163+ self .tok_embeddings = orig_model .tok_embeddings
164+
165+ with torch .device (device ):
166+ self .freqs_cis = orig_model ._precompute_freqs_cis ()
167+
168+ # preserve FQNs of original model by preserving structure
169+ # (including preserving position in layers[] list)- use dummy module
170+ self .layers = orig_model .layers
171+ for i in range (len (self .layers )):
172+ if f"layers.{ i } " not in this_stage_layer_names :
173+ self .layers [i ] = DummyTransformerLayer ()
174+ self .norm = None
175+ if "norm" in this_stage_layer_names :
176+ self .norm = orig_model .norm
177+ self .output = None
178+ if "output" in this_stage_layer_names :
179+ self .output = orig_model .output
180+
181+ def forward (self , input ):
182+ """
183+ Copypaste of original Transformer.forward, with conditionals and unpacking added
184+ such that we handle the cases where this rank doesn't have the embedding, or doesn't have
185+ the output layers.
186+ """
187+ if self .tok_embeddings :
188+ h = self .tok_embeddings (input )
189+ else :
190+ h = input
191+
192+ freqs_cis = self .freqs_cis [0 : self .input_seqlen ]
193+
194+ for layer in self .layers :
195+ h = layer (h , freqs_cis )
196+ output = h
197+
198+ if self .norm :
199+ h = self .norm (h )
200+ output = h
201+
202+ if self .output :
203+ output = self .output (h ).float ()
204+ return output
205+
206+
207+ def apply_pipeline_parallelism_manual (
208+ model , world_mesh , parallel_dims , job_config : JobConfig , device
209+ ):
210+ """
211+ This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages).
212+
213+ The SPMD parallelisms should be applied to
214+ """
215+ pp_mesh = world_mesh ["pp" ]
216+ pp_rank = pp_mesh .get_local_rank ()
217+ pp_size = pp_mesh .size ()
218+ stage_idx = pp_rank # TODO support virtual stages
219+ layers_per_rank = len (model .layers ) // parallel_dims .pp
220+ layer_offset = layers_per_rank * pp_rank
221+ this_stage_layer_names = [
222+ f"layers.{ i + layer_offset } " for i in range (layers_per_rank )
223+ ]
224+ if pp_rank == 0 :
225+ this_stage_layer_names .insert (0 , "tok_embeddings" )
226+ assert "layers.0" in this_stage_layer_names
227+ elif pp_rank == pp_size - 1 :
228+ this_stage_layer_names .append ("norm" )
229+ this_stage_layer_names .append ("output" )
230+ assert "layers.1" in this_stage_layer_names
231+
232+ input_seqlen = 2048 # TODO hack
233+
234+ stage_model = TransformerChunk (model , this_stage_layer_names , device , input_seqlen )
235+ # Create a pipeline representation from the model
236+
237+ # note for PipPy API
238+ # it would be nice if we could get fx.graph out of PipeInfo and then make it possible to manually construct PipeInfo
239+ # and then use the same _PipelineStage ctor in either tracer or manual cases.
240+
241+ return (stage_model ,)
242+
243+
244+ def apply_pipeline_parallelism_tracer (
245+ model , world_mesh , parallel_dims , job_config : JobConfig
246+ ):
142247 assert (
143248 parallel_dims .pp_enabled
144249 ), "can't apply pipeline parallelism if it is not enabled"
@@ -212,6 +317,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
212317
213318 # Apply tensor + sequence parallelism to every transformer block
214319 for layer_id , transformer_block in enumerate (model .layers ):
320+ if isinstance (transformer_block , DummyTransformerLayer ):
321+ continue
215322 layer_plan = {
216323 "attention" : PrepareModuleInput (
217324 input_layouts = (Shard (1 ), None ),
@@ -259,6 +366,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
259366 ac_mode = job_config .activation_checkpoint .mode
260367 fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
261368 for layer_name , transformer_block in model .layers .named_children ():
369+ if isinstance (transformer_block , DummyTransformerLayer ):
370+ continue
262371 if job_config .activation_checkpoint .mode in ("full" , "selective" ):
263372 transformer_block = checkpoint_wrapper (
264373 transformer_block , job_config .activation_checkpoint
@@ -275,6 +384,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
275384 )
276385 model .layers .add_module (layer_name , transformer_block )
277386
387+ # TODO(whc) do we need reshard_after_forward setting here too?
278388 model = fully_shard (model , ** fsdp_config )
279389 if ac_mode in ("full" , "selective" ):
280390 logger .info (f"Applied { ac_mode } activation checkpointing to the model" )
0 commit comments