Skip to content

Commit b5c71f1

Browse files
piEspositopatrickvonplaten
authored andcommitted
add accelerate to load models with smaller memory footprint (huggingface#361)
* add accelerate to load models with smaller memory footprint * remove low_cpu_mem_usage as it is reduntant * move accelerate init weights context to modelling utils * add test to ensure results are the same when loading with accelerate * add tests to ensure ram usage gets lower when using accelerate * move accelerate logic to single snippet under modelling utils and remove it from configuration utils * format code using to pass quality check * fix imports with isor * add accelerate to test extra deps * only import accelerate if device_map is set to auto * move accelerate availability check to diffusers import utils * format code Co-authored-by: Patrick von Platen <[email protected]>
1 parent a3a13fd commit b5c71f1

File tree

5 files changed

+152
-32
lines changed

5 files changed

+152
-32
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
"torch>=1.4",
105105
"torchvision",
106106
"transformers>=4.21.0",
107+
"accelerate>=0.12.0"
107108
]
108109

109110
# this is a lookup table with items like:

src/diffusers/modeling_utils.py

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from torch import Tensor, device
2323

24+
from diffusers.utils import is_accelerate_available
2425
from huggingface_hub import hf_hub_download
2526
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2627
from requests import HTTPError
@@ -293,33 +294,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
293294
from_auto_class = kwargs.pop("_from_auto", False)
294295
torch_dtype = kwargs.pop("torch_dtype", None)
295296
subfolder = kwargs.pop("subfolder", None)
297+
device_map = kwargs.pop("device_map", None)
296298

297299
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
298300

299301
# Load config if we don't provide a configuration
300302
config_path = pretrained_model_name_or_path
301-
model, unused_kwargs = cls.from_config(
302-
config_path,
303-
cache_dir=cache_dir,
304-
return_unused_kwargs=True,
305-
force_download=force_download,
306-
resume_download=resume_download,
307-
proxies=proxies,
308-
local_files_only=local_files_only,
309-
use_auth_token=use_auth_token,
310-
revision=revision,
311-
subfolder=subfolder,
312-
**kwargs,
313-
)
314303

315-
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
316-
raise ValueError(
317-
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
318-
)
319-
elif torch_dtype is not None:
320-
model = model.to(torch_dtype)
321-
322-
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
323304
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
324305
# Load model
325306
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@@ -391,25 +372,81 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
391372
)
392373

393374
# restore default dtype
394-
state_dict = load_state_dict(model_file)
395-
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
396-
model,
397-
state_dict,
398-
model_file,
399-
pretrained_model_name_or_path,
400-
ignore_mismatched_sizes=ignore_mismatched_sizes,
401-
)
402375

403-
# Set model in evaluation mode to deactivate DropOut modules by default
404-
model.eval()
376+
if device_map == "auto":
377+
if is_accelerate_available():
378+
import accelerate
379+
else:
380+
raise ImportError("Please install accelerate via `pip install accelerate`")
381+
382+
with accelerate.init_empty_weights():
383+
model, unused_kwargs = cls.from_config(
384+
config_path,
385+
cache_dir=cache_dir,
386+
return_unused_kwargs=True,
387+
force_download=force_download,
388+
resume_download=resume_download,
389+
proxies=proxies,
390+
local_files_only=local_files_only,
391+
use_auth_token=use_auth_token,
392+
revision=revision,
393+
subfolder=subfolder,
394+
device_map=device_map,
395+
**kwargs,
396+
)
397+
398+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
399+
400+
loading_info = {
401+
"missing_keys": [],
402+
"unexpected_keys": [],
403+
"mismatched_keys": [],
404+
"error_msgs": [],
405+
}
406+
else:
407+
model, unused_kwargs = cls.from_config(
408+
config_path,
409+
cache_dir=cache_dir,
410+
return_unused_kwargs=True,
411+
force_download=force_download,
412+
resume_download=resume_download,
413+
proxies=proxies,
414+
local_files_only=local_files_only,
415+
use_auth_token=use_auth_token,
416+
revision=revision,
417+
subfolder=subfolder,
418+
device_map=device_map,
419+
**kwargs,
420+
)
421+
422+
state_dict = load_state_dict(model_file)
423+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
424+
model,
425+
state_dict,
426+
model_file,
427+
pretrained_model_name_or_path,
428+
ignore_mismatched_sizes=ignore_mismatched_sizes,
429+
)
405430

406-
if output_loading_info:
407431
loading_info = {
408432
"missing_keys": missing_keys,
409433
"unexpected_keys": unexpected_keys,
410434
"mismatched_keys": mismatched_keys,
411435
"error_msgs": error_msgs,
412436
}
437+
438+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
439+
raise ValueError(
440+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
441+
)
442+
elif torch_dtype is not None:
443+
model = model.to(torch_dtype)
444+
445+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
446+
447+
# Set model in evaluation mode to deactivate DropOut modules by default
448+
model.eval()
449+
if output_loading_info:
413450
return model, loading_info
414451

415452
return model

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
USE_TF,
2424
USE_TORCH,
2525
DummyObject,
26+
is_accelerate_available,
2627
is_flax_available,
2728
is_inflect_available,
2829
is_modelcards_available,

src/diffusers/utils/import_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@
159159
except importlib_metadata.PackageNotFoundError:
160160
_scipy_available = False
161161

162+
_accelerate_available = importlib.util.find_spec("accelerate") is not None
163+
try:
164+
_accelerate_version = importlib_metadata.version("accelerate")
165+
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
166+
except importlib_metadata.PackageNotFoundError:
167+
_accelerate_available = False
168+
162169

163170
def is_torch_available():
164171
return _torch_available
@@ -196,6 +203,10 @@ def is_scipy_available():
196203
return _scipy_available
197204

198205

206+
def is_accelerate_available():
207+
return _accelerate_available
208+
209+
199210
# docstyle-ignore
200211
FLAX_IMPORT_ERROR = """
201212
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

tests/test_models_unet.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import math
18+
import tracemalloc
1719
import unittest
1820

1921
import torch
@@ -133,6 +135,74 @@ def test_from_pretrained_hub(self):
133135

134136
assert image is not None, "Make sure output is not None"
135137

138+
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
139+
def test_from_pretrained_accelerate(self):
140+
model, _ = UNet2DModel.from_pretrained(
141+
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
142+
)
143+
model.to(torch_device)
144+
image = model(**self.dummy_input).sample
145+
146+
assert image is not None, "Make sure output is not None"
147+
148+
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
149+
def test_from_pretrained_accelerate_wont_change_results(self):
150+
model_accelerate, _ = UNet2DModel.from_pretrained(
151+
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
152+
)
153+
model_accelerate.to(torch_device)
154+
model_accelerate.eval()
155+
156+
noise = torch.randn(
157+
1,
158+
model_accelerate.config.in_channels,
159+
model_accelerate.config.sample_size,
160+
model_accelerate.config.sample_size,
161+
generator=torch.manual_seed(0),
162+
)
163+
noise = noise.to(torch_device)
164+
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
165+
166+
arr_accelerate = model_accelerate(noise, time_step)["sample"]
167+
168+
# two models don't need to stay in the device at the same time
169+
del model_accelerate
170+
torch.cuda.empty_cache()
171+
gc.collect()
172+
173+
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
174+
model_normal_load.to(torch_device)
175+
model_normal_load.eval()
176+
arr_normal_load = model_normal_load(noise, time_step)["sample"]
177+
178+
assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3)
179+
180+
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
181+
def test_memory_footprint_gets_reduced(self):
182+
torch.cuda.empty_cache()
183+
gc.collect()
184+
185+
tracemalloc.start()
186+
model_accelerate, _ = UNet2DModel.from_pretrained(
187+
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
188+
)
189+
model_accelerate.to(torch_device)
190+
model_accelerate.eval()
191+
_, peak_accelerate = tracemalloc.get_traced_memory()
192+
193+
del model_accelerate
194+
torch.cuda.empty_cache()
195+
gc.collect()
196+
197+
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
198+
model_normal_load.to(torch_device)
199+
model_normal_load.eval()
200+
_, peak_normal = tracemalloc.get_traced_memory()
201+
202+
tracemalloc.stop()
203+
204+
assert peak_accelerate < peak_normal
205+
136206
def test_output_pretrained(self):
137207
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
138208
model.eval()

0 commit comments

Comments
 (0)