Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
# e.g.
# LOG_RANK=0,1 NGPU=4 SP=2 ./run_llama_train.sh

MODEL=${MODEL:-"debugmodel"}
MODEL=${MODEL:-"llama"}
MODEL_CONF=${MODEL_CONF:-"debugmodel"}
NGPU=${NGPU:-"8"}
PP=${PP:-"1"}
SP=${SP:-"1"}
Expand All @@ -24,6 +25,8 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

torchrun --nproc_per_node=${NGPU} \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 --compile \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP}
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
4 changes: 3 additions & 1 deletion torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
__all__ = ["Transformer"]

llama_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16),
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"70B": ModelArgs(
dim=8192,
n_layers=80,
Expand Down
66 changes: 62 additions & 4 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn.functional as F
from torch import nn

from torchtrain.logging_utils import rank0_log


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -183,7 +185,6 @@ class Attention(nn.Module):
"""

def __init__(self, model_args: ModelArgs):

super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = (
Expand All @@ -203,6 +204,20 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
for item in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -277,7 +292,6 @@ def __init__(
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):

super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
Expand All @@ -292,6 +306,20 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,
)


class RotaryEmbedding(nn.Module):
"""
Expand Down Expand Up @@ -350,7 +378,6 @@ class TransformerBlock(nn.Module):
"""

def __init__(self, layer_id: int, model_args: ModelArgs):

super().__init__()
self.n_heads = model_args.n_heads
self.dim = model_args.dim
Expand All @@ -362,8 +389,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

def forward(
self,
Expand All @@ -385,6 +414,14 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)


class Transformer(nn.Module):
"""
Expand All @@ -406,11 +443,11 @@ class Transformer(nn.Module):
"""

def __init__(self, model_args: ModelArgs):

super().__init__()
self.model_args = model_args
self.vocab_size = model_args.vocab_size
self.n_layers = model_args.n_layers
self.model_dim = model_args.dim

self.embeddings = RotaryEmbedding(model_args)

Expand All @@ -421,6 +458,27 @@ def __init__(self, model_args: ModelArgs):
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

# init model weights
self.reset_parameters()
rank0_log(f"Model built with: {self.model_args}")

def reset_parameters(
self,
):
for layer in self.layers:
layer.reset_parameters()
self.norm.reset_parameters()
final_out_std = self.model_dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
rank0_log("Model fully initialized via reset_params")

def forward(self, tokens: torch.Tensor):
"""
Perform a forward pass through the Transformer model.
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def main(args):
world_mesh = parallel_dims.build_mesh(device_type="cuda")

model_name = args.model
rank0_log(f"Building {model_name}")
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, args.tokenizer_path)
Expand Down Expand Up @@ -222,7 +223,7 @@ def main(args):
parser.add_argument(
"--warmup_pct",
type=float,
default=0.10,
default=0.20,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
Expand Down