Skip to content

Commit dc1a283

Browse files
authored
Remove tf package dep for SharkDownloader tflite tests (huggingface#212)
1 parent cc4fa96 commit dc1a283

File tree

44 files changed

+681
-1486
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+681
-1486
lines changed

generate_sharktank.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def save_tflite_model(tflite_model_list):
115115
with open(tflite_model_list) as csvfile:
116116
tflite_reader = csv.reader(csvfile, delimiter=",")
117117
for row in tflite_reader:
118+
print("\n")
118119
tflite_model_name = row[0]
119120
tflite_model_link = row[1]
120121
print("tflite_model_name", tflite_model_name)
@@ -125,13 +126,6 @@ def save_tflite_model(tflite_model_list):
125126
os.makedirs(tflite_model_name_dir, exist_ok=True)
126127
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
127128

128-
tflite_tosa_file = "/".join(
129-
[
130-
tflite_model_name_dir,
131-
str(tflite_model_name) + "_tflite.mlir",
132-
]
133-
)
134-
135129
# Preprocess to get SharkImporter input args
136130
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
137131
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
@@ -145,15 +139,11 @@ def save_tflite_model(tflite_model_list):
145139
frontend="tflite",
146140
raw_model_file=raw_model_file_path,
147141
)
148-
mlir_model, func_name = my_shark_importer.import_mlir()
149-
150-
if os.path.exists(tflite_tosa_file):
151-
print("Exists", tflite_tosa_file)
152-
else:
153-
mlir_str = mlir_model.decode("utf-8")
154-
with open(tflite_tosa_file, "w") as f:
155-
f.write(mlir_str)
156-
print(f"Saved mlir in {tflite_tosa_file}")
142+
my_shark_importer.import_debug(
143+
dir=tflite_model_name_dir,
144+
model_name=tflite_model_name,
145+
func_name="main",
146+
)
157147

158148

159149
# Validates whether the file is present or not.
@@ -170,7 +160,7 @@ def is_valid_file(arg):
170160
"--torch_model_csv",
171161
type=lambda x: is_valid_file(x),
172162
default="./tank/pytorch/torch_model_list.csv",
173-
help="""Contains the file with torch_model name and args.
163+
help="""Contains the file with torch_model name and args.
174164
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/pytorch/torch_model_list.csv""",
175165
)
176166
parser.add_argument(

shark/shark_downloader.py

Lines changed: 31 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"int8": np.int8,
2828
}
2929

30-
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
30+
WORKDIR = os.path.join(os.path.dirname(__file__), "./../gen_shark_tank")
3131

3232
# Checks whether the directory and files exists.
3333
def check_dir_exists(model_name, frontend="torch", dynamic=""):
@@ -83,6 +83,36 @@ def download_torch_model(model_name, dynamic=False):
8383
return mlir_file, function_name, inputs_tuple, golden_out_tuple
8484

8585

86+
# Downloads the tflite model from gs://shark_tank dir.
87+
def download_tflite_model(model_name, dynamic=False):
88+
dyn_str = "_dynamic" if dynamic else ""
89+
os.makedirs(WORKDIR, exist_ok=True)
90+
if not check_dir_exists(model_name, dyn_str):
91+
gs_command = (
92+
'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
93+
+ "/"
94+
+ model_name
95+
+ " "
96+
+ WORKDIR
97+
)
98+
if os.system(gs_command) != 0:
99+
raise Exception("model not present in the tank. Contact Nod Admin")
100+
101+
model_dir = os.path.join(WORKDIR, model_name)
102+
with open(
103+
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir")
104+
) as f:
105+
mlir_file = f.read()
106+
107+
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
108+
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
109+
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
110+
111+
inputs_tuple = tuple([inputs[key] for key in inputs])
112+
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
113+
return mlir_file, function_name, inputs_tuple, golden_out_tuple
114+
115+
86116
def download_tf_model(model_name):
87117
model_name = model_name.replace("/", "_")
88118
os.makedirs(WORKDIR, exist_ok=True)
@@ -109,149 +139,3 @@ def download_tf_model(model_name):
109139
inputs_tuple = tuple([inputs[key] for key in inputs])
110140
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
111141
return mlir_file, function_name, inputs_tuple, golden_out_tuple
112-
113-
114-
class SharkDownloader:
115-
def __init__(
116-
self,
117-
model_name: str,
118-
tank_url: str = "https://storage.googleapis.com/shark_tank",
119-
local_tank_dir: str = "./../gen_shark_tank/tflite",
120-
model_type: str = "tflite",
121-
input_json: str = "input.json",
122-
input_type: str = "int32",
123-
):
124-
self.model_name = model_name
125-
self.local_tank_dir = local_tank_dir
126-
self.tank_url = tank_url
127-
self.model_type = model_type
128-
self.input_json = input_json # optional if you don't have input
129-
self.input_type = input_type_to_np_dtype[
130-
input_type
131-
] # optional if you don't have input
132-
self.mlir_file = None # .mlir file local address.
133-
self.mlir_url = None
134-
self.inputs = None # Input has to be (list of np.array) for sharkInference.forward use
135-
self.mlir_model = []
136-
137-
# create tmp model file directory
138-
if self.tank_url is None and self.model_name is None:
139-
print("Error. No tank_url, No model name,Please input either one.")
140-
return
141-
142-
self.workdir = os.path.join(
143-
os.path.dirname(__file__), self.local_tank_dir
144-
)
145-
os.makedirs(self.workdir, exist_ok=True)
146-
print(f"TMP_MODEL_DIR = {self.workdir}")
147-
# use model name get dir.
148-
self.model_name_dir = os.path.join(self.workdir, str(self.model_name))
149-
if not os.path.exists(self.model_name_dir):
150-
print(
151-
"Model has not been download."
152-
"shark_downloader will automatically download by "
153-
"tank_url if provided. You can also manually to "
154-
"download the model from shark_tank by yourself."
155-
)
156-
os.makedirs(self.model_name_dir, exist_ok=True)
157-
print(f"TMP_MODELNAME_DIR = {self.model_name_dir}")
158-
159-
# read inputs from json file
160-
self.load_json_input()
161-
# get milr model file
162-
self.load_mlir_model()
163-
164-
def get_mlir_file(self):
165-
return self.mlir_model
166-
167-
def get_inputs(self):
168-
return self.inputs
169-
170-
def load_json_input(self):
171-
print("load json inputs")
172-
if self.model_type in ["tflite"]:
173-
input_url = (
174-
self.tank_url + "/" + str(self.model_name) + "/" + "input.json"
175-
)
176-
input_file = "/".join([self.model_name_dir, str(self.input_json)])
177-
if os.path.exists(input_file):
178-
print("Input has been downloaded before.", input_file)
179-
else:
180-
print("Download input", input_url)
181-
urllib.request.urlretrieve(input_url, input_file)
182-
183-
args = []
184-
with open(input_file, "r") as f:
185-
args = json.load(f)
186-
self.inputs = [
187-
np.asarray(arg, dtype=self.input_type) for arg in args
188-
]
189-
else:
190-
print(
191-
"No json input required for current model type. "
192-
"You could call setup_inputs(YOU_INPUTS)."
193-
)
194-
return self.inputs
195-
196-
def load_mlir_model(self):
197-
if self.model_type in ["tflite"]:
198-
self.mlir_url = (
199-
self.tank_url
200-
+ "/"
201-
+ str(self.model_name)
202-
+ "/"
203-
+ str(self.model_name)
204-
+ "_tflite.mlir"
205-
)
206-
self.mlir_file = "/".join(
207-
[self.model_name_dir, str(self.model_name) + "_tfite.mlir"]
208-
)
209-
elif self.model_type in ["tensorflow"]:
210-
self.mlir_url = (
211-
self.tank_url
212-
+ "/"
213-
+ str(self.model_name)
214-
+ "/"
215-
+ str(self.model_name)
216-
+ "_tf.mlir"
217-
)
218-
self.mlir_file = "/".join(
219-
[self.model_name_dir, str(self.model_name) + "_tf.mlir"]
220-
)
221-
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
222-
self.mlir_url = (
223-
self.tank_url
224-
+ "/"
225-
+ str(self.model_name)
226-
+ "/"
227-
+ str(self.model_name)
228-
+ "_"
229-
+ str(self.model_type)
230-
+ ".mlir"
231-
)
232-
self.mlir_file = "/".join(
233-
[
234-
self.model_name_dir,
235-
str(self.model_name)
236-
+ "_"
237-
+ str(self.model_type)
238-
+ ".mlir",
239-
]
240-
)
241-
else:
242-
print("Unsupported mlir model")
243-
244-
if os.path.exists(self.mlir_file):
245-
print("Model has been downloaded before.", self.mlir_file)
246-
else:
247-
print("Download mlir model", self.mlir_url)
248-
urllib.request.urlretrieve(self.mlir_url, self.mlir_file)
249-
250-
print("Get .mlir model return")
251-
with open(self.mlir_file) as f:
252-
self.mlir_model = f.read()
253-
return self.mlir_model
254-
255-
def setup_inputs(self, inputs):
256-
print("Setting up inputs. Input has to be (list of np.array)")
257-
self.inputs = inputs

shark/shark_importer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def save_data(
129129
inputs_name = "inputs.npz"
130130
outputs_name = "golden_out.npz"
131131
func_file_name = "function_name"
132-
model_name_mlir = model_name + ".mlir"
132+
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
133133
np.savez(os.path.join(dir, inputs_name), *inputs)
134134
np.savez(os.path.join(dir, outputs_name), *outputs)
135135
np.save(os.path.join(dir, func_file_name), np.array(func_name))
@@ -139,6 +139,8 @@ def save_data(
139139
mlir_str = mlir_data.operation.get_asm()
140140
elif self.frontend == "tf":
141141
mlir_str = mlir_data.decode("utf-8")
142+
elif self.frontend == "tflite":
143+
mlir_str = mlir_data.decode("utf-8")
142144
with open(os.path.join(dir, model_name_mlir), "w") as mlir_file:
143145
mlir_file.write(mlir_str)
144146

@@ -214,6 +216,14 @@ def import_debug(
214216
if self.frontend in ["tflite", "tf-lite"]:
215217
# TODO(Chi): Validate it for tflite models.
216218
golden_out = self.module.invoke_tflite(self.inputs)
219+
self.save_data(
220+
dir,
221+
model_name,
222+
imported_mlir[0],
223+
imported_mlir[1],
224+
self.inputs,
225+
golden_out,
226+
)
217227
return (
218228
imported_mlir,
219229
self.inputs,

shark/tests/test_shark_importer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def create_and_check_module(self):
9898
for i in range(len(output_details)):
9999
dtype = output_details[i]["dtype"]
100100
mlir_results[i] = mlir_results[i].astype(dtype)
101-
tflite_results = tflite_preprocessor.get_raw_model_output()
101+
tflite_results = tflite_preprocessor.get_golden_output()
102102
compare_results(mlir_results, tflite_results, output_details)
103103

104104
# Case2: Use manually set inputs
@@ -114,7 +114,7 @@ def create_and_check_module(self):
114114
shark_module.compile()
115115
mlir_results = shark_module.forward(inputs)
116116
## post process results for compare
117-
tflite_results = tflite_preprocessor.get_raw_model_output()
117+
tflite_results = tflite_preprocessor.get_golden_output()
118118
compare_results(mlir_results, tflite_results, output_details)
119119
# print(mlir_results)
120120

shark/tflite_utils.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import csv
55
import urllib.request
6-
import json
76

87

98
class TFLiteModelUtil:
@@ -48,6 +47,8 @@ def invoke_tflite(self, inputs):
4847
)
4948

5049
for i in range(len(self.output_details)):
50+
# print("output_details ", i, "shape", self.output_details[i]["shape"].__name__,
51+
# ", dtype: ", self.output_details[i]["dtype"].__name__)
5152
out_dtype = self.output_details[i]["dtype"]
5253
tflite_results[i] = tflite_results[i].astype(out_dtype)
5354
return tflite_results
@@ -84,6 +85,7 @@ def __init__(
8485
None # could be tflite/tf/torch_interpreter in utils
8586
)
8687
self.input_file = None
88+
self.output_file = None
8789

8890
# create tmp model file directory
8991
if self.model_path is None and self.model_name is None:
@@ -127,7 +129,9 @@ def load_tflite_model(self):
127129
self.mlir_file = "/".join(
128130
[tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"]
129131
)
130-
self.input_file = "/".join([tflite_model_name_dir, "input.json"])
132+
self.input_file = "/".join([tflite_model_name_dir, "inputs"])
133+
self.output_file = "/".join([tflite_model_name_dir, "golden_out"])
134+
# np.save("/".join([tflite_model_name_dir, "function_name"]), np.array("main"))
131135

132136
if os.path.exists(self.raw_model_file):
133137
print(
@@ -165,21 +169,15 @@ def setup_interpreter(self):
165169
def generate_inputs(self, input_details):
166170
self.inputs = []
167171
for tmp_input in input_details:
168-
# print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
172+
print(
173+
"input_details shape:",
174+
str(tmp_input["shape"]),
175+
" type:",
176+
tmp_input["dtype"].__name__,
177+
)
169178
self.inputs.append(
170179
np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"])
171180
)
172-
# save inputs into json file
173-
tmp_json = []
174-
for tmp_input in input_details:
175-
# print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
176-
tmp_json.append(
177-
np.ones(
178-
shape=tmp_input["shape"], dtype=tmp_input["dtype"]
179-
).tolist()
180-
)
181-
with open(self.input_file, "w") as f:
182-
json.dump(tmp_json, f)
183181
return self.inputs
184182

185183
def setup_inputs(self, inputs):
@@ -195,8 +193,9 @@ def get_mlir_file(self):
195193
def get_inputs(self):
196194
return self.inputs
197195

198-
def get_raw_model_output(self):
196+
def get_golden_output(self):
199197
self.output_tensor = self.interpreter.invoke_tflite(self.inputs)
198+
np.savez(self.output_file, *self.output_tensor)
200199
return self.output_tensor
201200

202201
def get_model_details(self):

0 commit comments

Comments
 (0)