Skip to content

Commit ff649b5

Browse files
authored
Add TF EfficientNet Model (huggingface#502)
1 parent e9e138c commit ff649b5

File tree

4 files changed

+57
-15
lines changed

4 files changed

+57
-15
lines changed

tank/all_models.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ resnet18,linalg,torch,1e-2,1e-3,default
3232
resnet50,linalg,torch,1e-2,1e-3,default
3333
squeezenet1_0,linalg,torch,1e-2,1e-3,default
3434
wide_resnet50_2,linalg,torch,1e-2,1e-3,default
35+
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default

tank/model_metadata.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ microsoft/mpnet-base,False,False,-,-,-
2727
roberta-base,False,False,-,-,-
2828
xlm-roberta-base,False,False,-,-,-
2929
facebook/convnext-tiny-224,False,False,-,-,-
30+
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
31+

tank/model_utils_tf.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111

1212
################################## MHLO/TF models #########################################
1313
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
14-
keras_models = [
15-
"resnet50",
16-
]
14+
keras_models = ["resnet50", "efficientnet-v2-s"]
1715
maskedlm_models = [
1816
"albert-base-v2",
1917
"bert-base-uncased",
@@ -168,45 +166,85 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
168166
##################### TensorFlow Keras Resnet Models #########################################################
169167
# Static shape, including batch size (1).
170168
# Can be dynamic once dynamic shape support is ready.
171-
INPUT_SHAPE = [1, 224, 224, 3]
169+
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
170+
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
171+
172+
tf_resnet_model = tf.keras.applications.resnet50.ResNet50(
173+
weights="imagenet",
174+
include_top=True,
175+
input_shape=tuple(RESNET_INPUT_SHAPE[1:]),
176+
)
172177

173-
tf_model = tf.keras.applications.resnet50.ResNet50(
174-
weights="imagenet", include_top=True, input_shape=tuple(INPUT_SHAPE[1:])
178+
tf_efficientnet_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
179+
weights="imagenet",
180+
include_top=True,
181+
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
175182
)
176183

177184

178185
class ResNetModule(tf.Module):
179186
def __init__(self):
180187
super(ResNetModule, self).__init__()
181-
self.m = tf_model
188+
self.m = tf_resnet_model
182189
self.m.predict = lambda x: self.m.call(x, training=False)
183190

184191
@tf.function(
185-
input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)],
192+
input_signature=[tf.TensorSpec(RESNET_INPUT_SHAPE, tf.float32)],
186193
jit_compile=True,
187194
)
188195
def forward(self, inputs):
189196
return self.m.predict(inputs)
190197

198+
def input_shape(self):
199+
return RESNET_INPUT_SHAPE
200+
201+
def preprocess_input(self, image):
202+
return tf.keras.applications.resnet50.preprocess_input(image)
203+
191204

192-
def load_image(path_to_image):
205+
class EfficientNetModule(tf.Module):
206+
def __init__(self):
207+
super(EfficientNetModule, self).__init__()
208+
self.m = tf_efficientnet_model
209+
self.m.predict = lambda x: self.m.call(x, training=False)
210+
211+
@tf.function(
212+
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
213+
jit_compile=True,
214+
)
215+
def forward(self, inputs):
216+
return self.m.predict(inputs)
217+
218+
def input_shape(self):
219+
return EFFICIENTNET_INPUT_SHAPE
220+
221+
def preprocess_input(self, image):
222+
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
223+
224+
225+
def load_image(path_to_image, width, height, channels):
193226
image = tf.io.read_file(path_to_image)
194-
image = tf.image.decode_image(image, channels=3)
195-
image = tf.image.resize(image, (224, 224))
227+
image = tf.image.decode_image(image, channels=channels)
228+
image = tf.image.resize(image, (width, height))
196229
image = image[tf.newaxis, :]
197230
return image
198231

199232

200233
def get_keras_model(modelname):
201-
model = ResNetModule()
234+
if modelname == "efficientnet-v2-s":
235+
model = EfficientNetModule()
236+
else:
237+
model = ResNetModule()
238+
202239
content_path = tf.keras.utils.get_file(
203240
"YellowLabradorLooking_new.jpg",
204241
"https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg",
205242
)
206-
content_image = load_image(content_path)
207-
input_tensor = tf.keras.applications.resnet50.preprocess_input(
208-
content_image
243+
input_shape = model.input_shape()
244+
content_image = load_image(
245+
content_path, input_shape[1], input_shape[2], input_shape[3]
209246
)
247+
input_tensor = model.preprocess_input(content_image)
210248
input_data = tf.expand_dims(input_tensor, 0)
211249
actual_out = model.forward(*input_data)
212250
return model, input_data, actual_out

tank/tf_model_list.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ funnel-transformer/small,hf
1717
microsoft/mpnet-base,hf
1818
facebook/convnext-tiny-224,img
1919
google/vit-base-patch16-224,img
20+
efficientnet-v2-s,keras

0 commit comments

Comments
 (0)