Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 69 additions & 42 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

import copy
from collections import defaultdict
from typing import Dict, Tuple
from typing import Tuple, TYPE_CHECKING, Union

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
Expand All @@ -29,8 +31,15 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank

if TYPE_CHECKING:
from torchtitan.parallelisms import ParallelDims


DeviceType = Union[int, str, torch.device]

# for selective AC
no_recompute_list = {
torch.ops.aten.mm.default,
Expand Down Expand Up @@ -125,23 +134,30 @@ def get_tp_parallel_strategy(


def pipeline_llama(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.experimental.pipeline_parallel_split_mode == "manual":
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
if split_mode not in valid_split_modes:
raise ValueError(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
return pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config, device, model_config
)
elif job_config.experimental.pipeline_parallel_split_mode == "tracer":
elif split_mode == "tracer":
return pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config, device, model_config
)
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode"
)


def _llama_trace_input(job_config, model_config, device="meta"):
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
tokens = torch.randint(
Expand All @@ -153,18 +169,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
def _mixed_precision_dtype(
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default


def pipeline_llama_manual(
whole_model,
world_mesh,
parallel_dims,
whole_model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device,
model_config: Dict,
device: DeviceType,
model_config: ModelArgs,
):
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
Expand Down Expand Up @@ -262,19 +278,24 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal


def pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
# fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
"fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm."
)

if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
"Pipeline tracer does not work with FSDP mixed precision yet. "
"To work around, set mixed_precision_param to float32."
)

pp_mesh = world_mesh["pp"]
Expand Down Expand Up @@ -310,10 +331,13 @@ def pipeline_llama_tracer(
return (stages, models)


def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply tensor parallelism.
"""
def apply_tp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""Apply tensor parallelism."""

tp_mesh = world_mesh["tp"]
# Parallel styles used for transformer block linear weights and their
Expand Down Expand Up @@ -392,10 +416,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def apply_ac(model, job_config: JobConfig):
"""
Apply activation checkpointing to the model.
"""
def apply_ac(model: nn.Module, job_config: JobConfig):
"""Apply activation checkpointing to the model."""

ac_config = job_config.activation_checkpoint

Expand All @@ -407,18 +429,15 @@ def apply_ac(model, job_config: JobConfig):
return model


def apply_compile(model, job_config: JobConfig):
"""
Apply torch.compile to the model.
"""
def apply_compile(model: nn.Module, job_config: JobConfig):
"""Apply torch.compile to each transformer block."""

if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
)

for layer_id, transformer_block in model.layers.named_children():
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
Expand All @@ -430,10 +449,13 @@ def apply_compile(model, job_config: JobConfig):
return model


def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism (FSDP2) to the model.
"""
def apply_dp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""Apply data parallelism (FSDP2) to the model."""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -466,7 +488,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.
Expand Down