66"""
77
88import glob
9+ import logging
910import operator
1011import os
11- import logging
12+ import shutil
1213
1314import torch
1415
@@ -32,7 +33,8 @@ def __init__(
3233 recovery_dir = '' ,
3334 decreasing = False ,
3435 max_history = 10 ,
35- unwrap_fn = unwrap_model ):
36+ unwrap_fn = unwrap_model
37+ ):
3638
3739 # objects to save state_dicts of
3840 self .model = model
@@ -46,7 +48,8 @@ def __init__(
4648 self .best_epoch = None
4749 self .best_metric = None
4850 self .curr_recovery_file = ''
49- self .last_recovery_file = ''
51+ self .prev_recovery_file = ''
52+ self .can_hardlink = True
5053
5154 # config
5255 self .checkpoint_dir = checkpoint_dir
@@ -60,41 +63,26 @@ def __init__(
6063 self .unwrap_fn = unwrap_fn
6164 assert self .max_history >= 1
6265
63- def save_checkpoint (self , epoch , metric = None ):
64- assert epoch >= 0
65- tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
66- last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
67- self ._save (tmp_save_path , epoch , metric )
68- if os .path .exists (last_save_path ):
69- os .unlink (last_save_path ) # required for Windows support.
70- os .rename (tmp_save_path , last_save_path )
71- worst_file = self .checkpoint_files [- 1 ] if self .checkpoint_files else None
72- if (len (self .checkpoint_files ) < self .max_history
73- or metric is None or self .cmp (metric , worst_file [1 ])):
74- if len (self .checkpoint_files ) >= self .max_history :
75- self ._cleanup_checkpoints (1 )
76- filename = '-' .join ([self .save_prefix , str (epoch )]) + self .extension
77- save_path = os .path .join (self .checkpoint_dir , filename )
78- os .link (last_save_path , save_path )
79- self .checkpoint_files .append ((save_path , metric ))
80- self .checkpoint_files = sorted (
81- self .checkpoint_files , key = lambda x : x [1 ],
82- reverse = not self .decreasing ) # sort in descending order if a lower metric is not better
83-
84- checkpoints_str = "Current checkpoints:\n "
85- for c in self .checkpoint_files :
86- checkpoints_str += ' {}\n ' .format (c )
87- _logger .info (checkpoints_str )
88-
89- if metric is not None and (self .best_metric is None or self .cmp (metric , self .best_metric )):
90- self .best_epoch = epoch
91- self .best_metric = metric
92- best_save_path = os .path .join (self .checkpoint_dir , 'model_best' + self .extension )
93- if os .path .exists (best_save_path ):
94- os .unlink (best_save_path )
95- os .link (last_save_path , best_save_path )
96-
97- return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
66+ def _replace (self , src , dst ):
67+ if self .can_hardlink :
68+ try :
69+ if os .path .exists (dst ):
70+ os .unlink (dst ) # required for Windows support.
71+ except (OSError , NotImplementedError ) as e :
72+ self .can_hardlink = False
73+ os .replace (src , dst )
74+
75+ def _duplicate (self , src , dst ):
76+ if self .can_hardlink :
77+ try :
78+ if os .path .exists (dst ):
79+ # for Windows
80+ os .unlink (dst )
81+ os .link (src , dst )
82+ return
83+ except (OSError , NotImplementedError ) as e :
84+ self .can_hardlink = False
85+ shutil .copy2 (src , dst )
9886
9987 def _save (self , save_path , epoch , metric = None ):
10088 save_state = {
@@ -129,18 +117,61 @@ def _cleanup_checkpoints(self, trim=0):
129117 _logger .error ("Exception '{}' while deleting checkpoint" .format (e ))
130118 self .checkpoint_files = self .checkpoint_files [:delete_index ]
131119
120+ def save_checkpoint (self , epoch , metric = None ):
121+ assert epoch >= 0
122+ tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
123+ last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
124+ self ._save (tmp_save_path , epoch , metric )
125+ self ._replace (tmp_save_path , last_save_path )
126+
127+ worst_file = self .checkpoint_files [- 1 ] if self .checkpoint_files else None
128+ if (
129+ len (self .checkpoint_files ) < self .max_history
130+ or metric is None
131+ or self .cmp (metric , worst_file [1 ])
132+ ):
133+ if len (self .checkpoint_files ) >= self .max_history :
134+ self ._cleanup_checkpoints (1 )
135+ filename = '-' .join ([self .save_prefix , str (epoch )]) + self .extension
136+ save_path = os .path .join (self .checkpoint_dir , filename )
137+ self ._duplicate (last_save_path , save_path )
138+
139+ self .checkpoint_files .append ((save_path , metric ))
140+ self .checkpoint_files = sorted (
141+ self .checkpoint_files ,
142+ key = lambda x : x [1 ],
143+ reverse = not self .decreasing # sort in descending order if a lower metric is not better
144+ )
145+
146+ checkpoints_str = "Current checkpoints:\n "
147+ for c in self .checkpoint_files :
148+ checkpoints_str += ' {}\n ' .format (c )
149+ _logger .info (checkpoints_str )
150+
151+ if metric is not None and (self .best_metric is None or self .cmp (metric , self .best_metric )):
152+ self .best_epoch = epoch
153+ self .best_metric = metric
154+ best_save_path = os .path .join (self .checkpoint_dir , 'model_best' + self .extension )
155+ self ._duplicate (last_save_path , best_save_path )
156+
157+ return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
158+
132159 def save_recovery (self , epoch , batch_idx = 0 ):
133160 assert epoch >= 0
161+ tmp_save_path = os .path .join (self .recovery_dir , 'recovery_tmp' + self .extension )
162+ self ._save (tmp_save_path , epoch )
163+
134164 filename = '-' .join ([self .recovery_prefix , str (epoch ), str (batch_idx )]) + self .extension
135165 save_path = os .path .join (self .recovery_dir , filename )
136- self ._save (save_path , epoch )
137- if os .path .exists (self .last_recovery_file ):
166+ self ._replace (tmp_save_path , save_path )
167+
168+ if os .path .exists (self .prev_recovery_file ):
138169 try :
139- _logger .debug ("Cleaning recovery: {}" .format (self .last_recovery_file ))
140- os .remove (self .last_recovery_file )
170+ _logger .debug ("Cleaning recovery: {}" .format (self .prev_recovery_file ))
171+ os .remove (self .prev_recovery_file )
141172 except Exception as e :
142- _logger .error ("Exception '{}' while removing {}" .format (e , self .last_recovery_file ))
143- self .last_recovery_file = self .curr_recovery_file
173+ _logger .error ("Exception '{}' while removing {}" .format (e , self .prev_recovery_file ))
174+ self .prev_recovery_file = self .curr_recovery_file
144175 self .curr_recovery_file = save_path
145176
146177 def find_recovery (self ):
0 commit comments