|
25 | 25 |
|
26 | 26 | import diffusers |
27 | 27 | import PIL |
| 28 | +from accelerate.utils.versions import is_torch_version |
28 | 29 | from huggingface_hub import snapshot_download |
29 | 30 | from packaging import version |
30 | 31 | from PIL import Image |
|
33 | 34 | from .configuration_utils import ConfigMixin |
34 | 35 | from .dynamic_modules_utils import get_class_from_dynamic_module |
35 | 36 | from .hub_utils import http_user_agent |
| 37 | +from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT |
36 | 38 | from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
37 | 39 | from .utils import ( |
38 | 40 | CONFIG_NAME, |
@@ -328,6 +330,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
328 | 330 | Mirror source to accelerate downloads in China. If you are from China and have an accessibility |
329 | 331 | problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. |
330 | 332 | Please refer to the mirror site for more information. specify the folder name here. |
| 333 | + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): |
| 334 | + A map that specifies where each submodule should go. It doesn't need to be refined to each |
| 335 | + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the |
| 336 | + same device. |
| 337 | +
|
| 338 | + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For |
| 339 | + more information about each option see [designing a device |
| 340 | + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). |
| 341 | + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): |
| 342 | + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This |
| 343 | + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the |
| 344 | + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, |
| 345 | + setting this argument to `True` will raise an error. |
331 | 346 |
|
332 | 347 | kwargs (remaining dictionary of keyword arguments, *optional*): |
333 | 348 | Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the |
@@ -380,7 +395,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
380 | 395 | provider = kwargs.pop("provider", None) |
381 | 396 | sess_options = kwargs.pop("sess_options", None) |
382 | 397 | device_map = kwargs.pop("device_map", None) |
383 | | - fast_load = kwargs.pop("fast_load", True) |
| 398 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) |
| 399 | + |
| 400 | + if device_map is not None and not is_torch_version(">=", "1.9.0"): |
| 401 | + raise NotImplementedError( |
| 402 | + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" |
| 403 | + " `device_map=None`." |
| 404 | + ) |
| 405 | + |
| 406 | + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): |
| 407 | + raise NotImplementedError( |
| 408 | + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" |
| 409 | + " `low_cpu_mem_usage=False`." |
| 410 | + ) |
| 411 | + |
| 412 | + if low_cpu_mem_usage is False and device_map is not None: |
| 413 | + raise ValueError( |
| 414 | + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" |
| 415 | + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." |
| 416 | + ) |
384 | 417 |
|
385 | 418 | # 1. Download the checkpoints and configs |
386 | 419 | # use snapshot download here to get it working from from_pretrained |
@@ -573,17 +606,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
573 | 606 | and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") |
574 | 607 | ) |
575 | 608 |
|
576 | | - if is_diffusers_model: |
577 | | - loading_kwargs["fast_load"] = fast_load |
578 | | - |
579 | 609 | # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. |
580 | | - # To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default. |
| 610 | + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. |
581 | 611 | # This makes sure that the weights won't be initialized which significantly speeds up loading. |
582 | | - if is_transformers_model and device_map is None: |
583 | | - loading_kwargs["low_cpu_mem_usage"] = fast_load |
584 | | - |
585 | 612 | if is_diffusers_model or is_transformers_model: |
586 | 613 | loading_kwargs["device_map"] = device_map |
| 614 | + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
587 | 615 |
|
588 | 616 | # check if the module is in a subdirectory |
589 | 617 | if os.path.isdir(os.path.join(cached_folder, name)): |
|
0 commit comments