Skip to content

Conversation

@mishig25
Copy link
Contributor

@mishig25 mishig25 commented Sep 14, 2022

Implementation of init_weights method is required for any class that is inheriting from FlaxModelMixin.

Here is an example:

class UNet2D(nn.Module, FlaxModelMixin, ConfigMixin):
    def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
        # init input tensors
        sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels)
        sample = jnp.zeros(sample_shape, dtype=jnp.float32)
        timestpes = jnp.ones((1,), dtype=jnp.int32)
        encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32)

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.init(rngs, sample, timestpes, encoder_hidden_states)["params"]

Unlike transformers.FlaxPretrainedModel.init_weights, diffusers.FlaxModelMixin.init_weights signature does not have parameter input_shape. Read more here & here on why the decision was made

Users will have 2 option to init weights:

class FlaxModel(nn.Module, FlaxModelMixin, ConfigMixin):
   ...

my_model = FlaxModel()
# option1 (FlaxModelMixin)
params = my_model.init_weights(key)
# option2 (linen.Module)
x = random.normal(key1, (...)) # Dummy input
params = FlaxModel.init(key, x)

# option1 is more convenient since the random input is automatically handled inside `init_weights`, but they will still have an option2 if they want to use it as a normal linen.Module

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@mishig25 mishig25 marked this pull request as draft September 14, 2022 17:05
```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that we don't allow the input_shape to be passed for now since it's much more restricted than Transformers, i.e. we should for now always be able to infer the correct shape from the config. This looks good to me!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments regarding naming, and think we should remove the "allow mismatched keys" functionality for now. But apart from this this is top!

@pcuenca
Copy link
Member

pcuenca commented Sep 15, 2022

I agree with @patrickvonplaten's comments. This is very cool, I'll test it later!

@mishig25 mishig25 marked this pull request as ready for review September 15, 2022 08:22
@mishig25
Copy link
Contributor Author

should be ready for review, also updated the description with more details for clairfication

pcuenca added a commit that referenced this pull request Sep 15, 2022
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@mishig25
Copy link
Contributor Author

@patrickvonplaten @patil-suraj should I merge?

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Just left couple of nits

```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 428 to 434
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring should be updated for diffusers

Copy link
Contributor Author

@mishig25 mishig25 Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt 869014c since we don't have LLM-like heads in diffusion models ?

@patrickvonplaten
Copy link
Contributor

Merging as tests are taking too long at the moment

@patrickvonplaten patrickvonplaten merged commit fb5468a into main Sep 15, 2022
@patrickvonplaten patrickvonplaten deleted the flax_init_weights branch September 15, 2022 15:01
pcuenca added a commit that referenced this pull request Sep 15, 2022
* First UNet Flax modeling blocks.

Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.

* Remove FlaxUNet2DConfig class.

* ignore_for_config non-config args.

* Implement `FlaxModelMixin`

* Use new mixins for Flax UNet.

For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.

* Import `FlaxUNet2DConditionModel` if flax is available.

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* Indicate types in flax.struct.dataclass as pointed out by @mishig25

Co-authored-by: Mishig Davaadorj <[email protected]>

* Fix typo in transformer block.

* make style

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Style: remove empty line.

* Apply style (after upgrading black from pinned version)

* Remove some commented code and unused imports.

* Add init_weights (not yet in use until #513).

* Trickle down deterministic to blocks.

* Rename q, k, v according to the latest PyTorch version.

Note that weights were exported with the old names, so we need to be
careful.

* Flax UNet docstrings, default props as in PyTorch.

* Fix minor typos in PyTorch docstrings.

* Use FlaxUNet2DConditionOutput as output from UNet.

* make style

Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Add `init_weights` method to `FlaxMixin`

* Rn `random_state` -> `shape_state`

* `PRNGKey(0)` for `jax.eval_shape`

* No allow mismatched sizes

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* docstring diffusers

Co-authored-by: Suraj Patil <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* First UNet Flax modeling blocks.

Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.

* Remove FlaxUNet2DConfig class.

* ignore_for_config non-config args.

* Implement `FlaxModelMixin`

* Use new mixins for Flax UNet.

For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.

* Import `FlaxUNet2DConditionModel` if flax is available.

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* Indicate types in flax.struct.dataclass as pointed out by @mishig25

Co-authored-by: Mishig Davaadorj <[email protected]>

* Fix typo in transformer block.

* make style

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Style: remove empty line.

* Apply style (after upgrading black from pinned version)

* Remove some commented code and unused imports.

* Add init_weights (not yet in use until huggingface#513).

* Trickle down deterministic to blocks.

* Rename q, k, v according to the latest PyTorch version.

Note that weights were exported with the old names, so we need to be
careful.

* Flax UNet docstrings, default props as in PyTorch.

* Fix minor typos in PyTorch docstrings.

* Use FlaxUNet2DConditionOutput as output from UNet.

* make style

Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants