Skip to content

Commit d934d3d

Browse files
Mishig Davaadorjpatrickvonplatenpcuenca
authored
FlaxDiffusionPipeline & FlaxStableDiffusionPipeline (#559)
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline * todo comment * Fix imports * Fix imports * add dummies * Fix empty init * make pipeline work * up * Use Flax schedulers (typing, docstring) * Wrap model imports inside availability checks. * more updates * make sure flax is not broken * make style * more fixes * up Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent c6629e6 commit d934d3d

File tree

12 files changed

+798
-56
lines changed

12 files changed

+798
-56
lines changed

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from .modeling_flax_utils import FlaxModelMixin
6767
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
6868
from .models.vae_flax import FlaxAutoencoderKL
69+
from .pipeline_flax_utils import FlaxDiffusionPipeline
6970
from .schedulers import (
7071
FlaxDDIMScheduler,
7172
FlaxDDPMScheduler,
@@ -76,3 +77,8 @@
7677
)
7778
else:
7879
from .utils.dummy_flax_objects import * # noqa F403
80+
81+
if is_flax_available() and is_transformers_available():
82+
from .pipelines import FlaxStableDiffusionPipeline
83+
else:
84+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403

src/diffusers/modeling_flax_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,16 @@ def from_pretrained(
306306

307307
# Load model
308308
if os.path.isdir(pretrained_model_name_or_path):
309-
if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
310-
# Load from a Flax checkpoint
311-
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
312-
# At this stage we don't have a weight file so we will raise an error.
313-
elif from_pt:
309+
if from_pt:
314310
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
315311
raise EnvironmentError(
316312
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
317313
)
318314
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
315+
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
316+
# Load from a Flax checkpoint
317+
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
318+
# At this stage we don't have a weight file so we will raise an error.
319319
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
320320
raise EnvironmentError(
321321
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"

src/diffusers/models/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .unet_2d import UNet2DModel
16-
from .unet_2d_condition import UNet2DConditionModel
17-
from .vae import AutoencoderKL, VQModel
15+
from ..utils import is_flax_available, is_torch_available
16+
17+
18+
if is_torch_available():
19+
from .unet_2d import UNet2DModel
20+
from .unet_2d_condition import UNet2DConditionModel
21+
from .vae import AutoencoderKL, VQModel
22+
23+
if is_flax_available():
24+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
25+
from .vae_flax import FlaxAutoencoderKL

src/diffusers/models/attention_flax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def setup(self):
144144

145145
def __call__(self, hidden_states, context, deterministic=True):
146146
batch, height, width, channels = hidden_states.shape
147-
# import ipdb; ipdb.set_trace()
148147
residual = hidden_states
149148
hidden_states = self.norm(hidden_states)
150149
hidden_states = self.proj_in(hidden_states)

0 commit comments

Comments
 (0)