diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 8a21e53ddc..7eb16a1f3f 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -11,6 +11,7 @@ from torchtitan.components.optimizer import build_optimizers from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer + from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_deepseekv3 diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index f8090683c1..99338663f6 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import torch.nn as nn - from torch.distributed.device_mesh import DeviceMesh -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp +from torchtitan.tools.logging import logger def parallelize_deepseekv3( @@ -19,5 +19,32 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - # TODO: Add support for parallelizing the model, this is a placeholder function for now + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + dp_mesh: DeviceMesh | None = None + if ( + parallel_dims.dp_shard_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + else: + dp_mesh_dim_names = ("dp_shard",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + return model diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index c0134bf548..09e882764f 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -111,7 +111,6 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in nparams_dense = 0 for name, p in model.named_parameters(): - print(name) if "embedding" in name: nparams_embedding += p.numel() nparams_dense += p.numel() diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index c5ee02327a..3eb0f2fbc6 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -217,6 +217,10 @@ def forward( [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 ) # (bsz, seqlen, n_heads, qk_head_dim) + q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + # TODO: Need to pass softmax_scale to sdpa() interface. # For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17 @@ -310,11 +314,10 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): "freqs_cis", precompute_freqs_cis(model_args), persistent=False ) - self.layers = torch.nn.ModuleList() + self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): - self.layers.append( - TransformerBlock(layer_id=layer_id, model_args=model_args) - ) + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim) self.output = nn.Linear( model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype() @@ -333,7 +336,7 @@ def forward(self, tokens: torch.Tensor): """ h = self.tok_embeddings(tokens) - for layer in self.layers: + for layer in self.layers.values(): h = layer(h, self.freqs_cis) h = self.norm(h) output = self.output(h) # (batch_size, seq_len, dim) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 3e17968e11..c9217c8be8 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -307,7 +307,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_local_tokens_per_expert) - routed_output = routed_output * top_scores.unsqueeze(-1) + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) # shared expert if self.shared_expert is not None: diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml new file mode 100644 index 0000000000..4f08fb0982 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -0,0 +1,67 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 16B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "16B" +# test tokenizer.model, for debug purpose only +tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +local_batch_size = 32 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"]