|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | | -# |
16 | | -# Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py |
17 | | -# |
18 | | -# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets |
19 | | -# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in |
20 | | -# the future. Once extracted, the weights don't require DeepSpeed and can be used in any |
21 | | -# application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict |
22 | | -# for being able to run Model.load_from_checkpoint('...'). |
23 | | -# |
24 | | -# example usage within the lightning checkpoint directory where 'latest' is found: |
25 | | -# |
26 | | -# from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict |
27 | | -# |
28 | | -# lightning deepspeed has saved a directory instead of a file |
29 | | -# save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/" |
30 | | -# output_path = "lightning_model.pt" |
31 | | -# convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) |
| 15 | + |
| 16 | +""" |
| 17 | +Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py |
| 18 | +
|
| 19 | +This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets |
| 20 | +copied into the top level checkpoint dir, so the user can easily do the conversion at any point in |
| 21 | +the future. Once extracted, the weights don't require DeepSpeed and can be used in any |
| 22 | +application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict |
| 23 | +for being able to run Model.load_from_checkpoint('...'). |
| 24 | +
|
| 25 | +Example usage within the Lightning checkpoint directory where 'latest' is found: |
| 26 | +
|
| 27 | +>>> from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict # doctest: +SKIP |
| 28 | +
|
| 29 | +# Lightning deepspeed has saved a directory instead of a file |
| 30 | +
|
| 31 | +>>> save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/" # doctest: +SKIP |
| 32 | +>>> output_path = "lightning_model.pt" # doctest: +SKIP |
| 33 | +>>> convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) # doctest: +SKIP |
| 34 | +Saving fp32 state dict to lightning_model.pt |
| 35 | +""" |
32 | 36 |
|
33 | 37 | import os |
34 | 38 |
|
|
0 commit comments