Skip to content

Commit 647a89d

Browse files
[Low cpu memory] Correct naming and improve default usage (huggingface#1122)
* correct naming * finish * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent ed16571 commit 647a89d

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

modeling_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
logger = logging.get_logger(__name__)
3636

3737

38+
if is_torch_version(">=", "1.9.0"):
39+
_LOW_CPU_MEM_USAGE_DEFAULT = True
40+
else:
41+
_LOW_CPU_MEM_USAGE_DEFAULT = False
42+
43+
3844
def get_parameter_device(parameter: torch.nn.Module):
3945
try:
4046
return next(parameter.parameters()).device
@@ -278,11 +284,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
278284
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
279285
more information about each option see [designing a device
280286
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
281-
fast_load (`bool`, *optional*, defaults to `True`):
287+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
282288
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
283289
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
284290
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
285-
this argument will be ignored and the model will be loaded normally.
291+
setting this argument to `True` will raise an error.
286292
287293
<Tip>
288294
@@ -311,16 +317,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
311317
torch_dtype = kwargs.pop("torch_dtype", None)
312318
subfolder = kwargs.pop("subfolder", None)
313319
device_map = kwargs.pop("device_map", None)
314-
fast_load = kwargs.pop("fast_load", True)
320+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
315321

316322
# Check if we can handle device_map and dispatching the weights
317323
if device_map is not None and not is_torch_version(">=", "1.9.0"):
318-
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
324+
raise NotImplementedError(
325+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
326+
" `device_map=None`."
327+
)
319328

320-
# Fast init is only possible if torch version is >= 1.9.0
321-
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
322-
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
323-
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
329+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
330+
raise NotImplementedError(
331+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
332+
" `low_cpu_mem_usage=False`."
333+
)
334+
335+
if low_cpu_mem_usage is False and device_map is not None:
336+
raise ValueError(
337+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
338+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
339+
)
324340

325341
user_agent = {
326342
"diffusers": __version__,
@@ -403,7 +419,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
403419

404420
# restore default dtype
405421

406-
if _INIT_EMPTY_WEIGHTS:
422+
if low_cpu_mem_usage:
407423
# Instantiate model with empty weights
408424
with accelerate.init_empty_weights():
409425
model, unused_kwargs = cls.from_config(

pipeline_utils.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import diffusers
2727
import PIL
28+
from accelerate.utils.versions import is_torch_version
2829
from huggingface_hub import snapshot_download
2930
from packaging import version
3031
from PIL import Image
@@ -33,6 +34,7 @@
3334
from .configuration_utils import ConfigMixin
3435
from .dynamic_modules_utils import get_class_from_dynamic_module
3536
from .hub_utils import http_user_agent
37+
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
3638
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
3739
from .utils import (
3840
CONFIG_NAME,
@@ -328,6 +330,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
328330
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
329331
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
330332
Please refer to the mirror site for more information. specify the folder name here.
333+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
334+
A map that specifies where each submodule should go. It doesn't need to be refined to each
335+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
336+
same device.
337+
338+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
339+
more information about each option see [designing a device
340+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
341+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
342+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
343+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
344+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
345+
setting this argument to `True` will raise an error.
331346
332347
kwargs (remaining dictionary of keyword arguments, *optional*):
333348
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
@@ -380,7 +395,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
380395
provider = kwargs.pop("provider", None)
381396
sess_options = kwargs.pop("sess_options", None)
382397
device_map = kwargs.pop("device_map", None)
383-
fast_load = kwargs.pop("fast_load", True)
398+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
399+
400+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
401+
raise NotImplementedError(
402+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
403+
" `device_map=None`."
404+
)
405+
406+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
407+
raise NotImplementedError(
408+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
409+
" `low_cpu_mem_usage=False`."
410+
)
411+
412+
if low_cpu_mem_usage is False and device_map is not None:
413+
raise ValueError(
414+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
415+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
416+
)
384417

385418
# 1. Download the checkpoints and configs
386419
# use snapshot download here to get it working from from_pretrained
@@ -573,17 +606,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
573606
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
574607
)
575608

576-
if is_diffusers_model:
577-
loading_kwargs["fast_load"] = fast_load
578-
579609
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
580-
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
610+
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
581611
# This makes sure that the weights won't be initialized which significantly speeds up loading.
582-
if is_transformers_model and device_map is None:
583-
loading_kwargs["low_cpu_mem_usage"] = fast_load
584-
585612
if is_diffusers_model or is_transformers_model:
586613
loading_kwargs["device_map"] = device_map
614+
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
587615

588616
# check if the module is in a subdirectory
589617
if os.path.isdir(os.path.join(cached_folder, name)):

0 commit comments

Comments
 (0)