@@ -354,7 +354,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
354354 # TODO(Patrick, Suraj) - delete later
355355 if class_name == "DummyChecker" :
356356 library_name = "stable_diffusion"
357- class_name = "StableDiffusionSafetyChecker "
357+ class_name = "FlaxStableDiffusionSafetyChecker "
358358
359359 is_pipeline_module = hasattr (pipelines , library_name )
360360 loaded_sub_model = None
@@ -421,16 +421,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
421421 loaded_sub_model = cached_folder
422422
423423 if issubclass (class_obj , FlaxModelMixin ):
424- # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
424+ loaded_sub_model , loaded_params = load_method (loadable_folder , from_pt = from_pt , dtype = dtype )
425+ params [name ] = loaded_params
426+ elif is_transformers_available () and issubclass (class_obj , FlaxPreTrainedModel ):
427+ # make sure we don't initialize the weights to save time
425428 if name == "safety_checker" :
426429 loaded_sub_model = DummyChecker ()
427430 loaded_params = DummyChecker ()
428- else :
429- loaded_sub_model , loaded_params = load_method (loadable_folder , from_pt = from_pt , dtype = dtype )
430- params [name ] = loaded_params
431- elif is_transformers_available () and issubclass (class_obj , FlaxPreTrainedModel ):
432- # make sure we don't initialize the weights to save time
433- if from_pt :
431+ elif from_pt :
434432 # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
435433 loaded_sub_model = load_method (loadable_folder , from_pt = from_pt )
436434 loaded_params = loaded_sub_model .params
0 commit comments