Skip to content

Commit af4257d

Browse files
authored
Add tf image classification auto model (huggingface#213)
1 parent dc1a283 commit af4257d

File tree

4 files changed

+108
-3
lines changed

4 files changed

+108
-3
lines changed

generate_sharktank.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def save_torch_model(torch_model_list):
8080

8181
def save_tf_model(tf_model_list):
8282
from tank.masked_lm_tf import get_causal_lm_model
83+
from tank.tf.automodelimageclassification import get_causal_image_model
8384

8485
with open(tf_model_list) as csvfile:
8586
tf_reader = csv.reader(csvfile, delimiter=",")
@@ -93,6 +94,8 @@ def save_tf_model(tf_model_list):
9394
print(model_type)
9495
if model_type == "hf":
9596
model, input, _ = get_causal_lm_model(tf_model_name)
97+
if model_type == "img":
98+
model, input, _ = get_causal_image_model(tf_model_name)
9699

97100
tf_model_name = tf_model_name.replace("/", "_")
98101
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
@@ -184,8 +187,8 @@ def is_valid_file(arg):
184187
if args.tf_model_csv:
185188
save_tf_model(args.tf_model_csv)
186189

187-
if args.tflite_model_csv:
188-
save_tflite_model(args.tflite_model_csv)
190+
# if args.tflite_model_csv:
191+
# save_tflite_model(args.tflite_model_csv)
189192

190193
if args.upload:
191194
print("uploading files to gs://shark_tank/")

shark/shark_importer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,11 @@ def import_debug(
197197
golden_out = tuple(
198198
golden_out.numpy(),
199199
)
200-
else:
200+
elif golden_out is tuple:
201201
golden_out = self.convert_to_numpy(golden_out)
202+
else:
203+
# from transformers import TFSequenceClassifierOutput
204+
golden_out = golden_out.logits
202205
# Save the artifacts in the directory dir.
203206
self.save_data(
204207
dir,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from transformers import TFAutoModelForImageClassification
2+
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
3+
from transformers import BeitFeatureExtractor, AutoFeatureExtractor
4+
import tensorflow as tf
5+
from PIL import Image
6+
import requests
7+
from shark.shark_inference import SharkInference
8+
from shark.shark_downloader import download_tf_model
9+
10+
# Create a set of input signature.
11+
inputs_signature = [
12+
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
13+
]
14+
15+
16+
class AutoModelImageClassfication(tf.Module):
17+
def __init__(self, model_name):
18+
super(AutoModelImageClassfication, self).__init__()
19+
self.m = TFAutoModelForImageClassification.from_pretrained(
20+
model_name, output_attentions=False
21+
)
22+
self.m.predict = lambda x: self.m(x)
23+
24+
@tf.function(input_signature=inputs_signature)
25+
def forward(self, inputs):
26+
return self.m.predict(inputs)
27+
28+
29+
fail_models = [
30+
"facebook/data2vec-vision-base-ft1k",
31+
"microsoft/swin-tiny-patch4-window7-224",
32+
]
33+
34+
supported_models = [
35+
# "facebook/convnext-tiny-224",
36+
"google/vit-base-patch16-224",
37+
]
38+
39+
img_models_fe_dict = {
40+
"facebook/convnext-tiny-224": ConvNextFeatureExtractor,
41+
"facebook/data2vec-vision-base-ft1k": BeitFeatureExtractor,
42+
"microsoft/swin-tiny-patch4-window7-224": AutoFeatureExtractor,
43+
"google/vit-base-patch16-224": ViTFeatureExtractor,
44+
}
45+
46+
47+
def preprocess_input_image(model_name):
48+
# from datasets import load_dataset
49+
# dataset = load_dataset("huggingface/cats-image")
50+
# image1 = dataset["test"]["image"][0]
51+
# # print("image1: ", image1) # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7FA0B86BB6D0>
52+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
53+
# <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7FA0B86BB6D0>
54+
image = Image.open(requests.get(url, stream=True).raw)
55+
feature_extractor = img_models_fe_dict[model_name].from_pretrained(
56+
model_name
57+
)
58+
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
59+
inputs = feature_extractor(images=image, return_tensors="tf")
60+
61+
return [inputs[str(*inputs)]]
62+
63+
64+
def get_causal_image_model(hf_name):
65+
model = AutoModelImageClassfication(hf_name)
66+
test_input = preprocess_input_image(hf_name)
67+
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=
68+
# array([[]], dtype=float32)>, hidden_states=None, attentions=None)
69+
actual_out = model.forward(*test_input)
70+
return model, test_input, actual_out
71+
72+
73+
if __name__ == "__main__":
74+
for model_name in supported_models:
75+
print(f"Running model: {model_name}")
76+
inputs = preprocess_input_image(model_name)
77+
model = AutoModelImageClassfication(model_name)
78+
79+
# 1. USE SharkImporter to get the mlir
80+
# from shark.shark_importer import SharkImporter
81+
# mlir_importer = SharkImporter(
82+
# model,
83+
# inputs,
84+
# frontend="tf",
85+
# )
86+
# imported_mlir, func_name = mlir_importer.import_mlir()
87+
88+
# 2. USE SharkDownloader to get the mlir
89+
imported_mlir, func_name, inputs, golden_out = download_tf_model(
90+
model_name
91+
)
92+
93+
shark_module = SharkInference(
94+
imported_mlir, func_name, device="cpu", mlir_dialect="mhlo"
95+
)
96+
shark_module.compile()
97+
shark_module.forward(inputs)

tank/tf/tf_model_list.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ roberta-base,hf
1414
xlm-roberta-base,hf
1515
microsoft/MiniLM-L12-H384-uncased,hf
1616
funnel-transformer/small,hf
17+
facebook/convnext-tiny-224,img
18+
google/vit-base-patch16-224,img

0 commit comments

Comments
 (0)