-
Notifications
You must be signed in to change notification settings - Fork 295
Add Mix transformer #1780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
divyashreepathihalli
merged 6 commits into
keras-team:keras-hub
from
sachinprasadhs:mix_transformer
Aug 20, 2024
Merged
Add Mix transformer #1780
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4b5e4ab
Add MixTransformer
sachinprasadhs d7b993a
fix testcase
sachinprasadhs df9f65e
test changes and comments
sachinprasadhs c228eaa
lint fix
sachinprasadhs 3888b54
update config list
sachinprasadhs 85bda08
modify testcase for 2 layers
sachinprasadhs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
181 changes: 181 additions & 0 deletions
181
keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import keras | ||
import numpy as np | ||
from keras import ops | ||
|
||
from keras_nlp.src.api_export import keras_nlp_export | ||
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone | ||
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( | ||
HierarchicalTransformerEncoder, | ||
) | ||
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import ( | ||
OverlappingPatchingAndEmbedding, | ||
) | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.MiTBackbone") | ||
class MiTBackbone(FeaturePyramidBackbone): | ||
def __init__( | ||
self, | ||
depths, | ||
num_layers, | ||
blockwise_num_heads, | ||
blockwise_sr_ratios, | ||
end_value, | ||
patch_sizes, | ||
strides, | ||
include_rescaling=True, | ||
image_shape=(224, 224, 3), | ||
hidden_dims=None, | ||
**kwargs, | ||
): | ||
"""A Backbone implementing the MixTransformer. | ||
|
||
This architecture to be used as a backbone for the SegFormer | ||
architecture [SegFormer: Simple and Efficient Design for Semantic | ||
Segmentation with Transformers](https://arxiv.org/abs/2105.15203) | ||
[Based on the TensorFlow implementation from DeepVision]( | ||
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) | ||
|
||
Args: | ||
depths: The number of transformer encoders to be used per layer in the | ||
network. | ||
num_layers: int. The number of Transformer layers. | ||
blockwise_num_heads: list of integers, the number of heads to use | ||
in the attention computation for each layer. | ||
blockwise_sr_ratios: list of integers, the sequence reduction | ||
ratio to perform for each layer on the sequence before key and | ||
value projections. If set to > 1, a `Conv2D` layer is used to | ||
reduce the length of the sequence. | ||
end_value: The end value of the sequence. | ||
include_rescaling: bool, whether to rescale the inputs. If set | ||
to `True`, inputs will be passed through a `Rescaling(1/255.0)` | ||
layer. Defaults to `True`. | ||
image_shape: optional shape tuple, defaults to (224, 224, 3). | ||
hidden_dims: the embedding dims per hierarchical layer, used as | ||
the levels of the feature pyramid. | ||
patch_sizes: list of integers, the patch_size to apply for each layer. | ||
strides: list of integers, stride to apply for each layer. | ||
|
||
Examples: | ||
|
||
Using the class with a `backbone`: | ||
|
||
```python | ||
images = np.ones(shape=(1, 96, 96, 3)) | ||
labels = np.zeros(shape=(1, 96, 96, 1)) | ||
backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet") | ||
|
||
# Evaluate model | ||
model(images) | ||
|
||
# Train model | ||
model.compile( | ||
optimizer="adam", | ||
loss=keras.losses.BinaryCrossentropy(from_logits=False), | ||
metrics=["accuracy"], | ||
) | ||
model.fit(images, labels, epochs=3) | ||
``` | ||
""" | ||
dpr = [x for x in np.linspace(0.0, end_value, sum(depths))] | ||
|
||
# === Layers === | ||
cur = 0 | ||
patch_embedding_layers = [] | ||
transformer_blocks = [] | ||
layer_norms = [] | ||
|
||
for i in range(num_layers): | ||
patch_embed_layer = OverlappingPatchingAndEmbedding( | ||
project_dim=hidden_dims[i], | ||
patch_size=patch_sizes[i], | ||
stride=strides[i], | ||
name=f"patch_and_embed_{i}", | ||
) | ||
patch_embedding_layers.append(patch_embed_layer) | ||
|
||
transformer_block = [ | ||
HierarchicalTransformerEncoder( | ||
project_dim=hidden_dims[i], | ||
num_heads=blockwise_num_heads[i], | ||
sr_ratio=blockwise_sr_ratios[i], | ||
drop_prob=dpr[cur + k], | ||
name=f"hierarchical_encoder_{i}_{k}", | ||
) | ||
for k in range(depths[i]) | ||
] | ||
transformer_blocks.append(transformer_block) | ||
cur += depths[i] | ||
layer_norms.append(keras.layers.LayerNormalization()) | ||
|
||
# === Functional Model === | ||
image_input = keras.layers.Input(shape=image_shape) | ||
x = image_input | ||
|
||
if include_rescaling: | ||
x = keras.layers.Rescaling(scale=1 / 255)(x) | ||
|
||
pyramid_outputs = {} | ||
for i in range(num_layers): | ||
# Compute new height/width after the `proj` | ||
# call in `OverlappingPatchingAndEmbedding` | ||
stride = strides[i] | ||
new_height, new_width = ( | ||
int(ops.shape(x)[1] / stride), | ||
int(ops.shape(x)[2] / stride), | ||
) | ||
|
||
x = patch_embedding_layers[i](x) | ||
for blk in transformer_blocks[i]: | ||
x = blk(x) | ||
x = layer_norms[i](x) | ||
x = keras.layers.Reshape( | ||
(new_height, new_width, -1), name=f"output_level_{i}" | ||
)(x) | ||
pyramid_outputs[f"P{i + 1}"] = x | ||
|
||
super().__init__(inputs=image_input, outputs=x, **kwargs) | ||
|
||
# === Config === | ||
self.depths = depths | ||
self.include_rescaling = include_rescaling | ||
self.image_shape = image_shape | ||
self.hidden_dims = hidden_dims | ||
self.pyramid_outputs = pyramid_outputs | ||
self.num_layers = num_layers | ||
self.blockwise_num_heads = blockwise_num_heads | ||
self.blockwise_sr_ratios = blockwise_sr_ratios | ||
self.end_value = end_value | ||
self.patch_sizes = patch_sizes | ||
self.strides = strides | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"depths": self.depths, | ||
"include_rescaling": self.include_rescaling, | ||
"hidden_dims": self.hidden_dims, | ||
"image_shape": self.image_shape, | ||
"num_layers": self.num_layers, | ||
"blockwise_num_heads": self.blockwise_num_heads, | ||
"blockwise_sr_ratios": self.blockwise_sr_ratios, | ||
"end_value": self.end_value, | ||
"patch_sizes": self.patch_sizes, | ||
"strides": self.strides, | ||
} | ||
) | ||
return config |
75 changes: 75 additions & 0 deletions
75
keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import pytest | ||
from keras import models | ||
|
||
from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( | ||
MiTBackbone, | ||
) | ||
from keras_nlp.src.tests.test_case import TestCase | ||
|
||
|
||
class MiTBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
sachinprasadhs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"depths": [2, 2], | ||
"include_rescaling": True, | ||
"image_shape": (16, 16, 3), | ||
"hidden_dims": [4, 8], | ||
"num_layers": 2, | ||
"blockwise_num_heads": [1, 2], | ||
"blockwise_sr_ratios": [8, 4], | ||
"end_value": 0.1, | ||
"patch_sizes": [7, 3], | ||
"strides": [4, 2], | ||
} | ||
self.input_size = 16 | ||
self.input_data = np.ones( | ||
(2, self.input_size, self.input_size, 3), dtype="float32" | ||
) | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=MiTBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 2, 2, 8), | ||
run_quantization_check=False, | ||
run_mixed_precision_check=False, | ||
) | ||
|
||
def test_pyramid_output_format(self): | ||
init_kwargs = self.init_kwargs | ||
backbone = MiTBackbone(**init_kwargs) | ||
model = models.Model(backbone.inputs, backbone.pyramid_outputs) | ||
output_data = model(self.input_data) | ||
|
||
self.assertIsInstance(output_data, dict) | ||
self.assertEqual( | ||
list(output_data.keys()), list(backbone.pyramid_outputs.keys()) | ||
) | ||
self.assertEqual(list(output_data.keys()), ["P1", "P2"]) | ||
for k, v in output_data.items(): | ||
size = self.input_size // (2 ** (int(k[1:]) + 1)) | ||
self.assertEqual(tuple(v.shape[:3]), (2, size, size)) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=MiTBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.