You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* first commit:
- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file
* small nit
- fix a small nit - to not enter in the second if condition
* major changes
- modify FlaxUnet modules
- first conversion script
- more keys to be matched
* keys match
- now all keys match
- change module names for correct matching
- upsample module name changed
* working v1
- test pass with atol and rtol= `4e-02`
* replace unsued arg
* make quality
* add small docstring
* add more comments
- add TODO for embedding layers
* small change
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
* add more conditions on conversion
- add better test to check for keys conversion
* make shapes consistent
- output `img_w x img_h x n_channels` from the VAE
* Revert "make shapes consistent"
This reverts commit 4cad1ae.
* fix unet shape
- channels first!
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
41
+
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
0 commit comments