|
21 | 21 | import torch |
22 | 22 | from torch import Tensor, device |
23 | 23 |
|
| 24 | +from diffusers.utils import is_accelerate_available |
24 | 25 | from huggingface_hub import hf_hub_download |
25 | 26 | from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError |
26 | 27 | from requests import HTTPError |
@@ -293,33 +294,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
293 | 294 | from_auto_class = kwargs.pop("_from_auto", False) |
294 | 295 | torch_dtype = kwargs.pop("torch_dtype", None) |
295 | 296 | subfolder = kwargs.pop("subfolder", None) |
| 297 | + device_map = kwargs.pop("device_map", None) |
296 | 298 |
|
297 | 299 | user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} |
298 | 300 |
|
299 | 301 | # Load config if we don't provide a configuration |
300 | 302 | config_path = pretrained_model_name_or_path |
301 | | - model, unused_kwargs = cls.from_config( |
302 | | - config_path, |
303 | | - cache_dir=cache_dir, |
304 | | - return_unused_kwargs=True, |
305 | | - force_download=force_download, |
306 | | - resume_download=resume_download, |
307 | | - proxies=proxies, |
308 | | - local_files_only=local_files_only, |
309 | | - use_auth_token=use_auth_token, |
310 | | - revision=revision, |
311 | | - subfolder=subfolder, |
312 | | - **kwargs, |
313 | | - ) |
314 | 303 |
|
315 | | - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): |
316 | | - raise ValueError( |
317 | | - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." |
318 | | - ) |
319 | | - elif torch_dtype is not None: |
320 | | - model = model.to(torch_dtype) |
321 | | - |
322 | | - model.register_to_config(_name_or_path=pretrained_model_name_or_path) |
323 | 304 | # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the |
324 | 305 | # Load model |
325 | 306 | pretrained_model_name_or_path = str(pretrained_model_name_or_path) |
@@ -391,25 +372,81 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
391 | 372 | ) |
392 | 373 |
|
393 | 374 | # restore default dtype |
394 | | - state_dict = load_state_dict(model_file) |
395 | | - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( |
396 | | - model, |
397 | | - state_dict, |
398 | | - model_file, |
399 | | - pretrained_model_name_or_path, |
400 | | - ignore_mismatched_sizes=ignore_mismatched_sizes, |
401 | | - ) |
402 | 375 |
|
403 | | - # Set model in evaluation mode to deactivate DropOut modules by default |
404 | | - model.eval() |
| 376 | + if device_map == "auto": |
| 377 | + if is_accelerate_available(): |
| 378 | + import accelerate |
| 379 | + else: |
| 380 | + raise ImportError("Please install accelerate via `pip install accelerate`") |
| 381 | + |
| 382 | + with accelerate.init_empty_weights(): |
| 383 | + model, unused_kwargs = cls.from_config( |
| 384 | + config_path, |
| 385 | + cache_dir=cache_dir, |
| 386 | + return_unused_kwargs=True, |
| 387 | + force_download=force_download, |
| 388 | + resume_download=resume_download, |
| 389 | + proxies=proxies, |
| 390 | + local_files_only=local_files_only, |
| 391 | + use_auth_token=use_auth_token, |
| 392 | + revision=revision, |
| 393 | + subfolder=subfolder, |
| 394 | + device_map=device_map, |
| 395 | + **kwargs, |
| 396 | + ) |
| 397 | + |
| 398 | + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) |
| 399 | + |
| 400 | + loading_info = { |
| 401 | + "missing_keys": [], |
| 402 | + "unexpected_keys": [], |
| 403 | + "mismatched_keys": [], |
| 404 | + "error_msgs": [], |
| 405 | + } |
| 406 | + else: |
| 407 | + model, unused_kwargs = cls.from_config( |
| 408 | + config_path, |
| 409 | + cache_dir=cache_dir, |
| 410 | + return_unused_kwargs=True, |
| 411 | + force_download=force_download, |
| 412 | + resume_download=resume_download, |
| 413 | + proxies=proxies, |
| 414 | + local_files_only=local_files_only, |
| 415 | + use_auth_token=use_auth_token, |
| 416 | + revision=revision, |
| 417 | + subfolder=subfolder, |
| 418 | + device_map=device_map, |
| 419 | + **kwargs, |
| 420 | + ) |
| 421 | + |
| 422 | + state_dict = load_state_dict(model_file) |
| 423 | + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( |
| 424 | + model, |
| 425 | + state_dict, |
| 426 | + model_file, |
| 427 | + pretrained_model_name_or_path, |
| 428 | + ignore_mismatched_sizes=ignore_mismatched_sizes, |
| 429 | + ) |
405 | 430 |
|
406 | | - if output_loading_info: |
407 | 431 | loading_info = { |
408 | 432 | "missing_keys": missing_keys, |
409 | 433 | "unexpected_keys": unexpected_keys, |
410 | 434 | "mismatched_keys": mismatched_keys, |
411 | 435 | "error_msgs": error_msgs, |
412 | 436 | } |
| 437 | + |
| 438 | + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): |
| 439 | + raise ValueError( |
| 440 | + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." |
| 441 | + ) |
| 442 | + elif torch_dtype is not None: |
| 443 | + model = model.to(torch_dtype) |
| 444 | + |
| 445 | + model.register_to_config(_name_or_path=pretrained_model_name_or_path) |
| 446 | + |
| 447 | + # Set model in evaluation mode to deactivate DropOut modules by default |
| 448 | + model.eval() |
| 449 | + if output_loading_info: |
413 | 450 | return model, loading_info |
414 | 451 |
|
415 | 452 | return model |
|
0 commit comments