Skip to content

[rfc] getting rid of seed-checkpoint for Pipeline Parallelism #514

@wconstab

Description

@wconstab

Currently PP uses a 'seed checkpoint' for initialization becuase (1) its nice to initialize the model the same way (same RNG) as non-PP for loss-comparison purposes, (2) the whole model may not be able to fit on GPU memory (or even CPU memory when accounting for 8 copies, one per GPU).

The downside is the seed checkpoint creation process takes an extra step that is slower as the model grows larger, which is not a good user experience.

Step 1: Make model.init_weights 'pipeline friendly'

Currently, if we call init_weights on a model-chunk after meta-init and pipeline splitting, we'd crash. init_weights expects all the paramers of the model, but in pipeline splitting we delete some.

A pretty simple fix is to modify Transformer.init_weights to respect the possibility that self.tok_embeddings is None or self.output is none (skip initializing them if so). Layers should already be OK since the loop will only hit layers that PP did not delete. Now our initializer runs without crashing.

So far, we only unblocked basic functionality (running for CI, checking WPS, peak memory), but every PP stage will use the same RNG state so convergence should be affected.

Note: this approach does not take finer-grained splitting into account. If users wanted to put "half a transformer layer" on a pipeline stage, additional work would be needed to make initialization work.

Step 2: Fix the RNG problem

Option 1: A quick thing to try is to add one function to torch.pipelining that draws PP_Ranks-1 random integers and broadcasts them to the nonzero PP ranks to use to set their own seeds. Now every PP rank starts out with a different seed and none of the layers get the same initial value. At this point we probably should converge OK, even though we'd still not exactly match a non-PP initialization.

Option 2: A more advanced version would be a function in torch.pipelining.Schedule class that accepts the model init_weights function as argument, already has pointers to all the local stages and their model chunks and knows the pipeline order (looped, interleaved, V-shaped, etc). It would sequentially initialize one layer at a time, starting from rank0 chunk0. Then, it would extract the current RNG seeds locally and send them to the rank that holds chunk 1, which would set its RNG state to match the states taken from rank0, initialize chunk 1, and extract its updated RNG states to send on to the next rank, so on until complete.

cc @H-Huang @wanchaol @tianyu-l @lessw2020

Metadata

Metadata

Labels

enhancementNew feature or requestrelease blockingIssues that are blocking the milestone / release completion

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions