File tree Expand file tree Collapse file tree 2 files changed +2
-7
lines changed Expand file tree Collapse file tree 2 files changed +2
-7
lines changed Original file line number Diff line number Diff line change @@ -27,6 +27,7 @@ def on_train_start(self):
2727
2828 def on_train_end (self ):
2929 # clean up memory
30+ self .model .cpu ()
3031 with torch .cuda .device (self .root_device ):
3132 torch .cuda .empty_cache ()
3233
Original file line number Diff line number Diff line change @@ -148,13 +148,7 @@ def on_train_end(self):
148148 self .trainer .profiler .describe ()
149149
150150 # give accelerators a chance to finish
151- self .trainer .accelerator_backend .on_train_end ()
152-
153- # clear mem
154- if self .trainer ._device_type == DeviceType .GPU :
155- model = self .trainer .get_model ()
156- model .cpu ()
157- torch .cuda .empty_cache ()
151+ self .trainer .accelerator .on_train_end ()
158152
159153 def check_checkpoint_callback (self , should_update , is_last = False ):
160154 # TODO bake this logic into the ModelCheckpoint callback
You can’t perform that action at this time.
0 commit comments