1818import hashlib
1919import cv2
2020import skimage
21+ from diffusers import DiffusionPipeline , DDIMScheduler , LMSDiscreteScheduler , EulerDiscreteScheduler , \
22+ EulerAncestralDiscreteScheduler
2123
2224from omegaconf import OmegaConf
2325from ldm .invoke .generator .base import downsampling
@@ -386,7 +388,10 @@ def process_image(image,seed):
386388 width = width or self .width
387389 height = height or self .height
388390
389- configure_model_padding (model , seamless , seamless_axes )
391+ if isinstance (model , DiffusionPipeline ):
392+ configure_model_padding (model .unet , seamless , seamless_axes )
393+ else :
394+ configure_model_padding (model , seamless , seamless_axes )
390395
391396 assert cfg_scale > 1.0 , 'CFG_Scale (-C) must be >1.0'
392397 assert threshold >= 0.0 , '--threshold must be >=0.0'
@@ -930,9 +935,15 @@ def sample_to_image(self, samples):
930935 def sample_to_lowres_estimated_image (self , samples ):
931936 return self ._make_base ().sample_to_lowres_estimated_image (samples )
932937
938+ def _set_sampler (self ):
939+ if isinstance (self .model , DiffusionPipeline ):
940+ return self ._set_scheduler ()
941+ else :
942+ return self ._set_sampler_legacy ()
943+
933944 # very repetitive code - can this be simplified? The KSampler names are
934945 # consistent, at least
935- def _set_sampler (self ):
946+ def _set_sampler_legacy (self ):
936947 msg = f'>> Setting Sampler to { self .sampler_name } '
937948 if self .sampler_name == 'plms' :
938949 self .sampler = PLMSSampler (self .model , device = self .device )
@@ -956,6 +967,44 @@ def _set_sampler(self):
956967
957968 print (msg )
958969
970+ def _set_scheduler (self ):
971+ msg = f'>> Setting Sampler to { self .sampler_name } '
972+ default = self .model .scheduler
973+ # TODO: Test me! Not all schedulers take the same args.
974+ scheduler_args = dict (
975+ num_train_timesteps = default .num_train_timesteps ,
976+ beta_start = default .beta_start ,
977+ beta_end = default .beta_end ,
978+ beta_schedule = default .beta_schedule ,
979+ )
980+ trained_betas = getattr (self .model .scheduler , 'trained_betas' )
981+ if trained_betas is not None :
982+ scheduler_args .update (trained_betas = trained_betas )
983+ if self .sampler_name == 'plms' :
984+ raise NotImplementedError ("What's the diffusers implementation of PLMS?" )
985+ elif self .sampler_name == 'ddim' :
986+ self .sampler = DDIMScheduler (** scheduler_args )
987+ elif self .sampler_name == 'k_dpm_2_a' :
988+ raise NotImplementedError ("no diffusers implementation of dpm_2 samplers" )
989+ elif self .sampler_name == 'k_dpm_2' :
990+ raise NotImplementedError ("no diffusers implementation of dpm_2 samplers" )
991+ elif self .sampler_name == 'k_euler_a' :
992+ self .sampler = EulerAncestralDiscreteScheduler (** scheduler_args )
993+ elif self .sampler_name == 'k_euler' :
994+ self .sampler = EulerDiscreteScheduler (** scheduler_args )
995+ elif self .sampler_name == 'k_heun' :
996+ raise NotImplementedError ("no diffusers implementation of Heun's sampler" )
997+ elif self .sampler_name == 'k_lms' :
998+ self .sampler = LMSDiscreteScheduler (** scheduler_args )
999+ else :
1000+ msg = f'>> Unsupported Sampler: { self .sampler_name } , Defaulting to { default } '
1001+
1002+ print (msg )
1003+
1004+ if not hasattr (self .sampler , 'uses_inpainting_model' ):
1005+ # FIXME: terrible kludge!
1006+ self .sampler .uses_inpainting_model = lambda : False
1007+
9591008 def _load_img (self , img )-> Image :
9601009 if isinstance (img , Image .Image ):
9611010 image = img
0 commit comments