Skip to content

Commit 77ea513

Browse files
amyerobertssayakpaulRocketknight1NielsRogge
authored
Add TF ResNet model (#17427)
* Rought TF conversion outline * Tidy up * Fix padding differences between layers * Add back embedder - whoops * Match test file to main * Match upstream test file * Correctly pass and assign image_size parameter Co-authored-by: Sayak Paul <[email protected]> * Add in MainLayer * Correctly name layer * Tidy up AdaptivePooler * Small tidy-up More accurate type hints and remove whitespaces * Change AdaptiveAvgPool Use the AdaptiveAvgPool implementation by @Rocketknight1, which correctly pools if the output shape does not evenly divide by input shape c.f. https://github.com/huggingface/transformers/pull/17554/files/9e26607e22aa8d069c86b50196656012ff0ce62a#r900109509 Co-authored-by: From: matt <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * Use updated AdaptiveAvgPool Co-authored-by: matt <[email protected]> * Make AdaptiveAvgPool compatible with CPU * Remove image_size from configuration * Fixup * Tensorflow -> TensorFlow * Fix pt references in tests * Apply suggestions from code review - grammar and wording Co-authored-by: NielsRogge <[email protected]> Co-authored-by: NielsRogge <[email protected]> * Add TFResNet to doc tests * PR comments - GlobalAveragePooling and clearer comments * Remove unused import * Add in keepdims argument * Add num_channels check * grammar fix: by -> of Co-authored-by: matt <[email protected]> Co-authored-by: Matt <[email protected]> * Remove transposes - keep NHWC throughout forward pass * Fixup look sharp * Add missing layer names * Final tidy up - remove from_pt now weights on hub Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: matt <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Matt <[email protected]>
1 parent 7b18702 commit 77ea513

File tree

10 files changed

+818
-8
lines changed

10 files changed

+818
-8
lines changed

docs/source/en/index.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ Flax), PyTorch, and/or TensorFlow.
273273
| Reformer | | | | | |
274274
| RegNet | | | | | |
275275
| RemBERT | | | | | |
276-
| ResNet | | | | | |
276+
| ResNet | | | | | |
277277
| RetriBERT | | | | | |
278278
| RoBERTa | | | | | |
279279
| RoFormer | | | | | |

docs/source/en/model_doc/resnet.mdx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The figure below illustrates the architecture of ResNet. Taken from the [origina
3131

3232
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/resnet_architecture.png"/>
3333

34-
This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).
34+
This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of this model was added by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).
3535

3636
## ResNetConfig
3737

@@ -47,4 +47,16 @@ This model was contributed by [Francesco](https://huggingface.co/Francesco). The
4747
## ResNetForImageClassification
4848

4949
[[autodoc]] ResNetForImageClassification
50-
- forward
50+
- forward
51+
52+
53+
## TFResNetModel
54+
55+
[[autodoc]] TFResNetModel
56+
- call
57+
58+
59+
## TFResNetForImageClassification
60+
61+
[[autodoc]] TFResNetForImageClassification
62+
- call

src/transformers/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,6 +2380,14 @@
23802380
"TFRemBertPreTrainedModel",
23812381
]
23822382
)
2383+
_import_structure["models.resnet"].extend(
2384+
[
2385+
"TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
2386+
"TFResNetForImageClassification",
2387+
"TFResNetModel",
2388+
"TFResNetPreTrainedModel",
2389+
]
2390+
)
23832391
_import_structure["models.roberta"].extend(
23842392
[
23852393
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -4721,6 +4729,12 @@
47214729
TFRemBertModel,
47224730
TFRemBertPreTrainedModel,
47234731
)
4732+
from .models.resnet import (
4733+
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
4734+
TFResNetForImageClassification,
4735+
TFResNetModel,
4736+
TFResNetPreTrainedModel,
4737+
)
47244738
from .models.roberta import (
47254739
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
47264740
TFRobertaForCausalLM,

src/transformers/modeling_tf_outputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
6262
"""
6363

6464
last_hidden_state: tf.Tensor = None
65-
hidden_states: Optional[Tuple[tf.Tensor]] = None
65+
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
6666

6767

6868
@dataclass
@@ -118,7 +118,7 @@ class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
118118

119119
last_hidden_state: tf.Tensor = None
120120
pooler_output: tf.Tensor = None
121-
hidden_states: Optional[Tuple[tf.Tensor]] = None
121+
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
122122

123123

124124
@dataclass
@@ -886,4 +886,4 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
886886

887887
loss: Optional[tf.Tensor] = None
888888
logits: tf.Tensor = None
889-
hidden_states: Optional[Tuple[tf.Tensor]] = None
889+
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
("pegasus", "TFPegasusModel"),
6565
("regnet", "TFRegNetModel"),
6666
("rembert", "TFRemBertModel"),
67+
("resnet", "TFResNetModel"),
6768
("roberta", "TFRobertaModel"),
6869
("roformer", "TFRoFormerModel"),
6970
("speech_to_text", "TFSpeech2TextModel"),
@@ -175,6 +176,7 @@
175176
("convnext", "TFConvNextForImageClassification"),
176177
("data2vec-vision", "TFData2VecVisionForImageClassification"),
177178
("regnet", "TFRegNetForImageClassification"),
179+
("resnet", "TFResNetForImageClassification"),
178180
("swin", "TFSwinForImageClassification"),
179181
("vit", "TFViTForImageClassification"),
180182
]

src/transformers/models/resnet/__init__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import TYPE_CHECKING
1919

2020
# rely on isort to merge the imports
21-
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
21+
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
2222

2323

2424
_import_structure = {
@@ -38,6 +38,19 @@
3838
"ResNetPreTrainedModel",
3939
]
4040

41+
try:
42+
if not is_tf_available():
43+
raise OptionalDependencyNotAvailable()
44+
except OptionalDependencyNotAvailable:
45+
pass
46+
else:
47+
_import_structure["modeling_tf_resnet"] = [
48+
"TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
49+
"TFResNetForImageClassification",
50+
"TFResNetModel",
51+
"TFResNetPreTrainedModel",
52+
]
53+
4154

4255
if TYPE_CHECKING:
4356
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
@@ -55,6 +68,19 @@
5568
ResNetPreTrainedModel,
5669
)
5770

71+
try:
72+
if not is_tf_available():
73+
raise OptionalDependencyNotAvailable()
74+
except OptionalDependencyNotAvailable:
75+
pass
76+
else:
77+
from .modeling_tf_resnet import (
78+
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
79+
TFResNetForImageClassification,
80+
TFResNetModel,
81+
TFResNetPreTrainedModel,
82+
)
83+
5884

5985
else:
6086
import sys

0 commit comments

Comments
 (0)