@@ -377,7 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
377377 also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
378378 model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
379379 setting this argument to `True` will raise an error.
380-
380+ return_cached_folder (`bool`, *optional*, defaults to `False`):
381+ If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
381382 kwargs (remaining dictionary of keyword arguments, *optional*):
382383 Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
383384 specific pipeline class. The overwritten components are then directly passed to the pipelines
@@ -430,33 +431,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
430431 sess_options = kwargs .pop ("sess_options" , None )
431432 device_map = kwargs .pop ("device_map" , None )
432433 low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
433-
434- if low_cpu_mem_usage and not is_accelerate_available ():
435- low_cpu_mem_usage = False
436- logger .warning (
437- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
438- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
439- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
440- " install accelerate\n ```\n ."
441- )
442-
443- if device_map is not None and not is_torch_version (">=" , "1.9.0" ):
444- raise NotImplementedError (
445- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
446- " `device_map=None`."
447- )
448-
449- if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
450- raise NotImplementedError (
451- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
452- " `low_cpu_mem_usage=False`."
453- )
454-
455- if low_cpu_mem_usage is False and device_map is not None :
456- raise ValueError (
457- f"You cannot set `low_cpu_mem_usage` to False while using device_map={ device_map } for loading and"
458- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
459- )
434+ return_cached_folder = kwargs .pop ("return_cached_folder" , False )
460435
461436 # 1. Download the checkpoints and configs
462437 # use snapshot download here to get it working from from_pretrained
@@ -585,6 +560,33 @@ def load_module(name, value):
585560 f"Keyword arguments { unused_kwargs } are not expected by { pipeline_class .__name__ } and will be ignored."
586561 )
587562
563+ if low_cpu_mem_usage and not is_accelerate_available ():
564+ low_cpu_mem_usage = False
565+ logger .warning (
566+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
567+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
568+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
569+ " install accelerate\n ```\n ."
570+ )
571+
572+ if device_map is not None and not is_torch_version (">=" , "1.9.0" ):
573+ raise NotImplementedError (
574+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
575+ " `device_map=None`."
576+ )
577+
578+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
579+ raise NotImplementedError (
580+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
581+ " `low_cpu_mem_usage=False`."
582+ )
583+
584+ if low_cpu_mem_usage is False and device_map is not None :
585+ raise ValueError (
586+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={ device_map } for loading and"
587+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
588+ )
589+
588590 # import it here to avoid circular import
589591 from diffusers import pipelines
590592
@@ -704,6 +706,9 @@ def load_module(name, value):
704706
705707 # 5. Instantiate the pipeline
706708 model = pipeline_class (** init_kwargs )
709+
710+ if return_cached_folder :
711+ return model , cached_folder
707712 return model
708713
709714 @staticmethod
0 commit comments