Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions pytorch_lightning/utilities/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py
#
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict
# for being able to run Model.load_from_checkpoint('...').
#
# example usage within the lightning checkpoint directory where 'latest' is found:
#
# from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
#
# lightning deepspeed has saved a directory instead of a file
# save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/"
# output_path = "lightning_model.pt"
# convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)

"""
Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py

This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
the future. Once extracted, the weights don't require DeepSpeed and can be used in any
application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict
for being able to run Model.load_from_checkpoint('...').

Example usage within the Lightning checkpoint directory where 'latest' is found:

>>> from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict # doctest: +SKIP

# Lightning deepspeed has saved a directory instead of a file

>>> save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/" # doctest: +SKIP
>>> output_path = "lightning_model.pt" # doctest: +SKIP
>>> convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) # doctest: +SKIP
Saving fp32 state dict to lightning_model.pt
"""

import os

Expand Down