-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add init_weights method to FlaxMixin
#513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| ```""" | ||
| return self._cast_floating_to(params, jnp.float16, mask) | ||
|
|
||
| def init_weights(self, rng: jax.random.PRNGKey) -> Dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that we don't allow the input_shape to be passed for now since it's much more restricted than Transformers, i.e. we should for now always be able to infer the correct shape from the config. This looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments regarding naming, and think we should remove the "allow mismatched keys" functionality for now. But apart from this this is top!
|
I agree with @patrickvonplaten's comments. This is very cool, I'll test it later! |
|
should be ready for review, also updated the description with more details for clairfication |
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
|
@patrickvonplaten @patil-suraj should I merge? |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Just left couple of nits
| ```""" | ||
| return self._cast_floating_to(params, jnp.float16, mask) | ||
|
|
||
| def init_weights(self, rng: jax.random.PRNGKey) -> Dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
src/diffusers/modeling_flax_utils.py
Outdated
| f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | ||
| f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" | ||
| f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" | ||
| " with another architecture (e.g. initializing a BertForSequenceClassification model from a" | ||
| " BertForPreTraining model).\n- This IS NOT expected if you are initializing" | ||
| f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" | ||
| " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring should be updated for diffusers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdyt 869014c since we don't have LLM-like heads in diffusion models ?
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
|
Merging as tests are taking too long at the moment |
* First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. * Remove FlaxUNet2DConfig class. * ignore_for_config non-config args. * Implement `FlaxModelMixin` * Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. * Import `FlaxUNet2DConditionModel` if flax is available. * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj <[email protected]> * Fix typo in transformer block. * make style * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Style: remove empty line. * Apply style (after upgrading black from pinned version) * Remove some commented code and unused imports. * Add init_weights (not yet in use until #513). * Trickle down deterministic to blocks. * Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. * Flax UNet docstrings, default props as in PyTorch. * Fix minor typos in PyTorch docstrings. * Use FlaxUNet2DConditionOutput as output from UNet. * make style Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
* Add `init_weights` method to `FlaxMixin` * Rn `random_state` -> `shape_state` * `PRNGKey(0)` for `jax.eval_shape` * No allow mismatched sizes * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * docstring diffusers Co-authored-by: Suraj Patil <[email protected]>
* First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. * Remove FlaxUNet2DConfig class. * ignore_for_config non-config args. * Implement `FlaxModelMixin` * Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. * Import `FlaxUNet2DConditionModel` if flax is available. * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj <[email protected]> * Fix typo in transformer block. * make style * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Style: remove empty line. * Apply style (after upgrading black from pinned version) * Remove some commented code and unused imports. * Add init_weights (not yet in use until huggingface#513). * Trickle down deterministic to blocks. * Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. * Flax UNet docstrings, default props as in PyTorch. * Fix minor typos in PyTorch docstrings. * Use FlaxUNet2DConditionOutput as output from UNet. * make style Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
Implementation of
init_weightsmethod is required for any class that is inheriting from FlaxModelMixin.Here is an example:
Unlike transformers.FlaxPretrainedModel.init_weights, diffusers.FlaxModelMixin.init_weights signature does not have parameter
input_shape. Read more here & here on why the decision was madeUsers will have 2 option to init weights: