diff --git a/train.py b/train.py index ea6cdc3fc0..cdee551051 100644 --- a/train.py +++ b/train.py @@ -230,7 +230,9 @@ def loss_fn(pred, labels): True ) logger.info("Compiling model with torch.compile") - model = torch.compile(model) + # Dynamic shape have issues with distributed, turn dynamic off as Transformer + # training is static_shape TODO: resolve dynamic shape issue and restore defaults + model = torch.compile(model, dynamic=False) train_state = TrainState() diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3b338d1135..9d3f5e3e42 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -norm_type = "fused_rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] +norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer]