|
11 | 11 |
|
12 | 12 | ################################## MHLO/TF models ######################################### |
13 | 13 | # TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv |
14 | | -keras_models = [ |
15 | | - "resnet50", |
16 | | -] |
| 14 | +keras_models = ["resnet50", "efficientnet-v2-s"] |
17 | 15 | maskedlm_models = [ |
18 | 16 | "albert-base-v2", |
19 | 17 | "bert-base-uncased", |
@@ -168,45 +166,85 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."): |
168 | 166 | ##################### TensorFlow Keras Resnet Models ######################################################### |
169 | 167 | # Static shape, including batch size (1). |
170 | 168 | # Can be dynamic once dynamic shape support is ready. |
171 | | -INPUT_SHAPE = [1, 224, 224, 3] |
| 169 | +RESNET_INPUT_SHAPE = [1, 224, 224, 3] |
| 170 | +EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3] |
| 171 | + |
| 172 | +tf_resnet_model = tf.keras.applications.resnet50.ResNet50( |
| 173 | + weights="imagenet", |
| 174 | + include_top=True, |
| 175 | + input_shape=tuple(RESNET_INPUT_SHAPE[1:]), |
| 176 | +) |
172 | 177 |
|
173 | | -tf_model = tf.keras.applications.resnet50.ResNet50( |
174 | | - weights="imagenet", include_top=True, input_shape=tuple(INPUT_SHAPE[1:]) |
| 178 | +tf_efficientnet_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S( |
| 179 | + weights="imagenet", |
| 180 | + include_top=True, |
| 181 | + input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]), |
175 | 182 | ) |
176 | 183 |
|
177 | 184 |
|
178 | 185 | class ResNetModule(tf.Module): |
179 | 186 | def __init__(self): |
180 | 187 | super(ResNetModule, self).__init__() |
181 | | - self.m = tf_model |
| 188 | + self.m = tf_resnet_model |
182 | 189 | self.m.predict = lambda x: self.m.call(x, training=False) |
183 | 190 |
|
184 | 191 | @tf.function( |
185 | | - input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)], |
| 192 | + input_signature=[tf.TensorSpec(RESNET_INPUT_SHAPE, tf.float32)], |
186 | 193 | jit_compile=True, |
187 | 194 | ) |
188 | 195 | def forward(self, inputs): |
189 | 196 | return self.m.predict(inputs) |
190 | 197 |
|
| 198 | + def input_shape(self): |
| 199 | + return RESNET_INPUT_SHAPE |
| 200 | + |
| 201 | + def preprocess_input(self, image): |
| 202 | + return tf.keras.applications.resnet50.preprocess_input(image) |
| 203 | + |
191 | 204 |
|
192 | | -def load_image(path_to_image): |
| 205 | +class EfficientNetModule(tf.Module): |
| 206 | + def __init__(self): |
| 207 | + super(EfficientNetModule, self).__init__() |
| 208 | + self.m = tf_efficientnet_model |
| 209 | + self.m.predict = lambda x: self.m.call(x, training=False) |
| 210 | + |
| 211 | + @tf.function( |
| 212 | + input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)], |
| 213 | + jit_compile=True, |
| 214 | + ) |
| 215 | + def forward(self, inputs): |
| 216 | + return self.m.predict(inputs) |
| 217 | + |
| 218 | + def input_shape(self): |
| 219 | + return EFFICIENTNET_INPUT_SHAPE |
| 220 | + |
| 221 | + def preprocess_input(self, image): |
| 222 | + return tf.keras.applications.efficientnet_v2.preprocess_input(image) |
| 223 | + |
| 224 | + |
| 225 | +def load_image(path_to_image, width, height, channels): |
193 | 226 | image = tf.io.read_file(path_to_image) |
194 | | - image = tf.image.decode_image(image, channels=3) |
195 | | - image = tf.image.resize(image, (224, 224)) |
| 227 | + image = tf.image.decode_image(image, channels=channels) |
| 228 | + image = tf.image.resize(image, (width, height)) |
196 | 229 | image = image[tf.newaxis, :] |
197 | 230 | return image |
198 | 231 |
|
199 | 232 |
|
200 | 233 | def get_keras_model(modelname): |
201 | | - model = ResNetModule() |
| 234 | + if modelname == "efficientnet-v2-s": |
| 235 | + model = EfficientNetModule() |
| 236 | + else: |
| 237 | + model = ResNetModule() |
| 238 | + |
202 | 239 | content_path = tf.keras.utils.get_file( |
203 | 240 | "YellowLabradorLooking_new.jpg", |
204 | 241 | "https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg", |
205 | 242 | ) |
206 | | - content_image = load_image(content_path) |
207 | | - input_tensor = tf.keras.applications.resnet50.preprocess_input( |
208 | | - content_image |
| 243 | + input_shape = model.input_shape() |
| 244 | + content_image = load_image( |
| 245 | + content_path, input_shape[1], input_shape[2], input_shape[3] |
209 | 246 | ) |
| 247 | + input_tensor = model.preprocess_input(content_image) |
210 | 248 | input_data = tf.expand_dims(input_tensor, 0) |
211 | 249 | actual_out = model.forward(*input_data) |
212 | 250 | return model, input_data, actual_out |
|
0 commit comments