You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add truncated llama style model init via reset parameters() (#54)
This PR adds the following:
1 - via reset parameters, a full layerwise init for the llama models
under /llama. This uses the total model depth as part of the init via:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
2 - The final output ffn (head) is init with sqrt of the dim of the
model itself and a slightly wider cutoff factor of 3.
3 - tangential change - updates run_llama_train.sh with updated MODEL
and MODEL_CONF params to allow for direct model control via the sh
script. (there was a MODEL already but it was incorrectly using that in
place of MODEL_CONF...though we should update this as it's not
intuitive).
4 - made the debugmodel default to 2 layers as an improved debug check.
5 - added a 1B and 40B for additional testing configs. I can't currently
run 70B on my H100 due to OOM, but can run 40B.
Testing:
Verified proper init and training with 7B, 13B and ~40B:
<img width="1085" alt="Screenshot 2024-02-11 at 10 39 12 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/049037ed-63a4-4ab0-bebc-f297857aab72">
[ghstack-poisoned]
0 commit comments