From 1773aee1810c5c47f45ec6867b53c610d43a00af Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 5 Apr 2024 09:39:42 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchtrain/models/llama/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 77c5e4006f..410875b046 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -313,9 +313,7 @@ def __init__(self, model_args: ModelArgs): super().__init__() self.model_args = model_args self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - self.register_buffer( - "freqs_cis", self._precompute_freqs_cis(), persistent=False - ) + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) def _precompute_freqs_cis(self): return precompute_freqs_cis( From 5cea973b823db60cdb07a31245608d774f0d2d51 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 5 Apr 2024 14:21:43 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchtrain/models/llama/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 410875b046..0ab1e97733 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -313,6 +313,14 @@ def __init__(self, model_args: ModelArgs): super().__init__() self.model_args = model_args self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) def _precompute_freqs_cis(self):