Skip to content

Commit 3adcfae

Browse files
ngc92msaroufim
authored andcommitted
CPUOffload: only offload parameters above a certain size (#1720)
* CPUOffload: only offload parameters above a certain size * lint * ruff --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent 8b70150 commit 3adcfae

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,11 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
273273
model1 = nn.Sequential(
274274
nn.Linear(32, 131072),
275275
nn.ReLU(),
276-
nn.Linear(131072, 64),
276+
nn.Linear(131072, 64, bias=True),
277277
nn.ReLU(),
278-
nn.Linear(64, 64),
278+
nn.Linear(64, 64, bias=True),
279279
nn.ReLU(),
280-
nn.Linear(64, 128),
280+
nn.Linear(64, 128, bias=True),
281281
)
282282
model1.to(device)
283283

@@ -329,7 +329,11 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
329329
)
330330
def test_optim_cpu_offload_save_load(self):
331331
device = _DEVICES[-1]
332-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
332+
# enable bias parameters so we have some small tensors that
333+
# are not offloaded.
334+
model1 = nn.Sequential(
335+
nn.Linear(32, 1024, bias=True), nn.ReLU(), nn.Linear(1024, 128, bias=True)
336+
)
333337
model1.to(device)
334338
optim1 = low_bit_optim.CPUOffloadOptimizer(
335339
model1.parameters(), torch.optim.AdamW

torchao/prototype/low_bit_optim/cpu_offload.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
optimizer_class: Type[Optimizer] = torch.optim.AdamW,
1818
*,
1919
offload_gradients: bool = False,
20+
minimal_size: int = 4096,
2021
**kwargs,
2122
) -> None:
2223
"""Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
@@ -26,6 +27,7 @@ def __init__(
2627
params: a list of parameters or parameter groups.
2728
optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`.
2829
offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation.
30+
minimal_size: tensors smaller than this are kept on the GPU, to avoid excessively many small transfers.
2931
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
3032
"""
3133
# default to fused CPU AdamW
@@ -42,6 +44,11 @@ def __init__(
4244
if not isinstance(param_groups[0], dict):
4345
param_groups = [{"params": param_groups}]
4446

47+
# any parameter smaller than minimal size will be handled by the on-device optimizer d_opt
48+
self.minimal_size = minimal_size
49+
self.d_opt = None
50+
self.d_param_groups = []
51+
4552
self.param_d2h_map = dict()
4653
self.optim_dict = dict()
4754
self.device = get_available_devices()[-1]
@@ -77,11 +84,16 @@ def backward_hook(p_device):
7784

7885
for param_group in param_groups:
7986
params = param_group.pop("params")
87+
retained_params = []
8088

8189
for p_device in params:
8290
if not p_device.requires_grad:
8391
continue
8492

93+
if p_device.numel() < self.minimal_size:
94+
retained_params.append(p_device)
95+
continue
96+
8597
# pre-allocate CPU params and grads
8698
p_host = torch.empty_like(p_device, device="cpu", pin_memory=True)
8799
p_host.grad = torch.empty_like(p_host, pin_memory=True)
@@ -94,12 +106,22 @@ def backward_hook(p_device):
94106
[{"params": p_host, **param_group}], **kwargs
95107
)
96108

109+
if len(retained_params) > 0:
110+
self.d_param_groups.append({"params": retained_params, **param_group})
111+
112+
if len(self.d_param_groups) > 0:
113+
self.d_opt = optimizer_class(self.d_param_groups, **kwargs)
114+
97115
@torch.no_grad()
98116
def step(self, closure=None):
99117
loss = None
100118
if closure is not None:
101119
loss = closure()
102120

121+
# handle small parameters on the GPU, in parallel with the CPU calls below
122+
if self.d_opt is not None:
123+
self.d_opt.step()
124+
103125
for p_device, grad_d2h_event in self.queue.items():
104126
grad_d2h_event.synchronize()
105127
self.optim_dict[p_device].step()
@@ -123,15 +145,35 @@ def zero_grad(self, set_to_none=True):
123145
for p_device in self.param_d2h_map.keys():
124146
p_device.grad = None
125147

148+
if self.d_opt is not None:
149+
self.d_opt.zero_grad(set_to_none=set_to_none)
150+
126151
@property
127152
def param_groups(self):
128153
# each param group will only has 1 parameter
129154
# TODO: we might want to return the original param_groups instead.
130-
return sum((optim.param_groups for optim in self.optim_dict.values()), start=[])
155+
return sum(
156+
(optim.param_groups for optim in self.optim_dict.values()),
157+
start=self.d_param_groups,
158+
)
131159

132160
def state_dict(self):
133-
return [optim.state_dict() for optim in self.optim_dict.values()]
161+
state_dict = {
162+
"offloaded": [optim.state_dict() for optim in self.optim_dict.values()]
163+
}
164+
if self.d_opt:
165+
state_dict["on-device"] = self.d_opt.state_dict()
166+
return state_dict
134167

135168
def load_state_dict(self, state_dict):
136-
for optim, optim_state_dict in zip(self.optim_dict.values(), state_dict):
169+
for optim, optim_state_dict in zip(
170+
self.optim_dict.values(), state_dict["offloaded"]
171+
):
137172
optim.load_state_dict(optim_state_dict)
173+
174+
if self.d_opt:
175+
self.d_opt.load_state_dict(state_dict["on-device"])
176+
elif "on-device" in state_dict:
177+
raise ValueError(
178+
"loaded state dict has a 'on-device' parameter group not present in the optimizer"
179+
)

0 commit comments

Comments
 (0)