@@ -17,6 +17,7 @@ def __init__(
1717 optimizer_class : Type [Optimizer ] = torch .optim .AdamW ,
1818 * ,
1919 offload_gradients : bool = False ,
20+ minimal_size : int = 4096 ,
2021 ** kwargs ,
2122 ) -> None :
2223 """Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
@@ -26,6 +27,7 @@ def __init__(
2627 params: a list of parameters or parameter groups.
2728 optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`.
2829 offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation.
30+ minimal_size: tensors smaller than this are kept on the GPU, to avoid excessively many small transfers.
2931 kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
3032 """
3133 # default to fused CPU AdamW
@@ -42,6 +44,11 @@ def __init__(
4244 if not isinstance (param_groups [0 ], dict ):
4345 param_groups = [{"params" : param_groups }]
4446
47+ # any parameter smaller than minimal size will be handled by the on-device optimizer d_opt
48+ self .minimal_size = minimal_size
49+ self .d_opt = None
50+ self .d_param_groups = []
51+
4552 self .param_d2h_map = dict ()
4653 self .optim_dict = dict ()
4754 self .device = get_available_devices ()[- 1 ]
@@ -77,11 +84,16 @@ def backward_hook(p_device):
7784
7885 for param_group in param_groups :
7986 params = param_group .pop ("params" )
87+ retained_params = []
8088
8189 for p_device in params :
8290 if not p_device .requires_grad :
8391 continue
8492
93+ if p_device .numel () < self .minimal_size :
94+ retained_params .append (p_device )
95+ continue
96+
8597 # pre-allocate CPU params and grads
8698 p_host = torch .empty_like (p_device , device = "cpu" , pin_memory = True )
8799 p_host .grad = torch .empty_like (p_host , pin_memory = True )
@@ -94,12 +106,22 @@ def backward_hook(p_device):
94106 [{"params" : p_host , ** param_group }], ** kwargs
95107 )
96108
109+ if len (retained_params ) > 0 :
110+ self .d_param_groups .append ({"params" : retained_params , ** param_group })
111+
112+ if len (self .d_param_groups ) > 0 :
113+ self .d_opt = optimizer_class (self .d_param_groups , ** kwargs )
114+
97115 @torch .no_grad ()
98116 def step (self , closure = None ):
99117 loss = None
100118 if closure is not None :
101119 loss = closure ()
102120
121+ # handle small parameters on the GPU, in parallel with the CPU calls below
122+ if self .d_opt is not None :
123+ self .d_opt .step ()
124+
103125 for p_device , grad_d2h_event in self .queue .items ():
104126 grad_d2h_event .synchronize ()
105127 self .optim_dict [p_device ].step ()
@@ -123,15 +145,35 @@ def zero_grad(self, set_to_none=True):
123145 for p_device in self .param_d2h_map .keys ():
124146 p_device .grad = None
125147
148+ if self .d_opt is not None :
149+ self .d_opt .zero_grad (set_to_none = set_to_none )
150+
126151 @property
127152 def param_groups (self ):
128153 # each param group will only has 1 parameter
129154 # TODO: we might want to return the original param_groups instead.
130- return sum ((optim .param_groups for optim in self .optim_dict .values ()), start = [])
155+ return sum (
156+ (optim .param_groups for optim in self .optim_dict .values ()),
157+ start = self .d_param_groups ,
158+ )
131159
132160 def state_dict (self ):
133- return [optim .state_dict () for optim in self .optim_dict .values ()]
161+ state_dict = {
162+ "offloaded" : [optim .state_dict () for optim in self .optim_dict .values ()]
163+ }
164+ if self .d_opt :
165+ state_dict ["on-device" ] = self .d_opt .state_dict ()
166+ return state_dict
134167
135168 def load_state_dict (self , state_dict ):
136- for optim , optim_state_dict in zip (self .optim_dict .values (), state_dict ):
169+ for optim , optim_state_dict in zip (
170+ self .optim_dict .values (), state_dict ["offloaded" ]
171+ ):
137172 optim .load_state_dict (optim_state_dict )
173+
174+ if self .d_opt :
175+ self .d_opt .load_state_dict (state_dict ["on-device" ])
176+ elif "on-device" in state_dict :
177+ raise ValueError (
178+ "loaded state dict has a 'on-device' parameter group not present in the optimizer"
179+ )
0 commit comments