Skip to content

Commit e60b456

Browse files
[SharkInference] Make SharkInference compile the entire module (huggingface#708)
* [SharkInference] Make SharkInference compile the entire module -- Previously SharkInference was compiling and providing run APIs for a harcoded function with function name "forward". -- This commit makes the compiling functionality generic and now any function being defined within the module can be run. -- It also creates an API to fetch all the function names defined within the compiled module. -- This commit updates both web and command-line execution of Stable Diffusion to use new API of SharkInference. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 4ee3d95 commit e60b456

File tree

12 files changed

+65
-72
lines changed

12 files changed

+65
-72
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def end_profiling(device):
151151
vae_warmup_input = torch.clone(latents).detach().numpy()
152152
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
153153
for i in range(args.warmup_count):
154-
vae.forward((vae_warmup_input,))
155-
clip.forward((clip_warmup_input,))
154+
vae("forward", (vae_warmup_input,))
155+
clip("forward", (clip_warmup_input,))
156156

157157
start = time.time()
158158

@@ -174,7 +174,7 @@ def end_profiling(device):
174174
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
175175

176176
clip_inf_start = time.time()
177-
text_embeddings = clip.forward((text_input,))
177+
text_embeddings = clip("forward", (text_input,))
178178
clip_inf_end = time.time()
179179
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
180180
text_embeddings_numpy = text_embeddings.detach().numpy()
@@ -196,7 +196,8 @@ def end_profiling(device):
196196

197197
profile_device = start_profiling(file_path="unet.rdc")
198198

199-
noise_pred = unet.forward(
199+
noise_pred = unet(
200+
"forward",
200201
(
201202
latent_model_input,
202203
timestep,
@@ -227,7 +228,7 @@ def end_profiling(device):
227228
latents_numpy = latents.detach().numpy()
228229
profile_device = start_profiling(file_path="vae.rdc")
229230
vae_start = time.time()
230-
images = vae.forward((latents_numpy,))
231+
images = vae("forward", (latents_numpy,))
231232
vae_end = time.time()
232233
end_profiling(profile_device)
233234
if args.use_base_vae:

shark/examples/shark_inference/stable_diffusion/schedulers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def forward(self, noise_pred, sigma, latent, dt):
108108
def scale_model_input(self, sample, timestep):
109109
step_index = (self.timesteps == timestep).nonzero().item()
110110
sigma = self.sigmas[step_index]
111-
return self.scaling_model.forward(
111+
return self.scaling_model(
112+
"forward",
112113
(
113114
sample,
114115
sigma,
@@ -120,7 +121,8 @@ def step(self, noise_pred, timestep, latent):
120121
step_index = (self.timesteps == timestep).nonzero().item()
121122
sigma = self.sigmas[step_index]
122123
dt = self.sigmas[step_index + 1] - sigma
123-
return self.step_model.forward(
124+
return self.step_model(
125+
"forward",
124126
(
125127
noise_pred,
126128
sigma,

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
5353
frontend="torch",
5454
)
5555
shark_module = SharkInference(
56-
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
56+
mlir_model, device=args.device, mlir_dialect="linalg"
5757
)
5858
return _compile_module(shark_module, model_name, extra_args)
5959

@@ -65,7 +65,6 @@ def compile_through_fx(model, inputs, model_name, extra_args=[]):
6565

6666
shark_module = SharkInference(
6767
mlir_module,
68-
func_name,
6968
device=args.device,
7069
mlir_dialect="linalg",
7170
)

shark/iree_eager_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from iree.runtime import DeviceArray
2222
from torch_mlir._mlir_libs._mlir.ir import Module
2323
from torch_mlir.compiler_utils import (
24-
get_module_name_for_debug_dump,
2524
run_pipeline_with_repro_report,
2625
)
2726
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
@@ -64,14 +63,13 @@ def get_torch_metadata(
6463
)
6564

6665
def compile(self, imported_module: Module):
67-
fn_name = get_module_name_for_debug_dump(imported_module)
6866
run_pipeline_with_repro_report(
6967
imported_module,
7068
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
7169
"EagerMode",
7270
)
7371
callable, _ = get_iree_compiled_module(
74-
imported_module, self.raw_device_str, func_name=fn_name
72+
imported_module, self.raw_device_str
7573
)
7674
return callable
7775

shark/iree_utils/compile_utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def compile_module_to_flatbuffer(
234234
module,
235235
device,
236236
frontend,
237-
func_name,
238237
model_config_path,
239238
extra_args,
240239
model_name="None",
@@ -277,62 +276,58 @@ def compile_module_to_flatbuffer(
277276
return flatbuffer_blob
278277

279278

280-
def get_iree_module(flatbuffer_blob, device, func_name):
279+
def get_iree_module(flatbuffer_blob, device):
281280
# Returns the compiled module and the configs.
282281
config = get_iree_runtime_config(device)
283282
vm_module = ireert.VmModule.from_flatbuffer(
284283
config.vm_instance, flatbuffer_blob
285284
)
286285
ctx = ireert.SystemContext(config=config)
287286
ctx.add_vm_module(vm_module)
288-
ModuleCompiled = ctx.modules.module[func_name]
287+
ModuleCompiled = ctx.modules.module
289288
return ModuleCompiled, config
290289

291290

292291
def get_iree_compiled_module(
293292
module,
294293
device: str,
295294
frontend: str = "torch",
296-
func_name: str = "forward",
297295
model_config_path: str = None,
298296
extra_args: list = [],
299297
):
300298
"""Given a module returns the compiled .vmfb and configs"""
301299
flatbuffer_blob = compile_module_to_flatbuffer(
302-
module, device, frontend, func_name, model_config_path, extra_args
300+
module, device, frontend, model_config_path, extra_args
303301
)
304-
return get_iree_module(flatbuffer_blob, device, func_name)
302+
return get_iree_module(flatbuffer_blob, device)
305303

306304

307-
def load_flatbuffer(
308-
flatbuffer_path: str, device: str, func_name: str = "forward"
309-
):
305+
def load_flatbuffer(flatbuffer_path: str, device: str):
310306

311307
with open(os.path.join(flatbuffer_path), "rb") as f:
312308
flatbuffer_blob = f.read()
313309

314-
return get_iree_module(flatbuffer_blob, device, func_name)
310+
return get_iree_module(flatbuffer_blob, device)
315311

316312

317313
def export_iree_module_to_vmfb(
318314
module,
319315
device: str,
320316
directory: str,
321317
mlir_dialect: str = "linalg",
322-
func_name: str = "forward",
323318
model_config_path: str = None,
324319
module_name: str = None,
325320
extra_args: list = [],
326321
):
327322
# Compiles the module given specs and saves it as .vmfb file.
328323
flatbuffer_blob = compile_module_to_flatbuffer(
329-
module, device, mlir_dialect, func_name, model_config_path, extra_args
324+
module, device, mlir_dialect, model_config_path, extra_args
330325
)
331326
if module_name is None:
332327
device_name = (
333328
device if "://" not in device else "-".join(device.split("://"))
334329
)
335-
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
330+
module_name = f"{mlir_dialect}_{device_name}"
336331
filename = os.path.join(directory, module_name + ".vmfb")
337332
print(f"Saved vmfb in {filename}.")
338333
with open(filename, "wb") as f:
@@ -355,11 +350,16 @@ def export_module_to_mlir_file(module, frontend, directory: str):
355350

356351

357352
def get_results(
358-
compiled_vm, input, config, frontend="torch", send_to_host=True
353+
compiled_vm,
354+
function_name,
355+
input,
356+
config,
357+
frontend="torch",
358+
send_to_host=True,
359359
):
360360
"""Runs a .vmfb file given inputs and config and returns output."""
361361
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
362-
result = compiled_vm(*device_inputs)
362+
result = compiled_vm[function_name](*device_inputs)
363363
result_tensors = []
364364
if isinstance(result, tuple):
365365
if send_to_host:
@@ -376,7 +376,7 @@ def get_results(
376376
return np.copy(res)
377377
return data
378378
else:
379-
if send_to_host:
379+
if send_to_host and result is not None:
380380
return result.to_host()
381381
return result
382382

shark/shark_benchmark_runner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class SharkBenchmarkRunner(SharkRunner):
6060
def __init__(
6161
self,
6262
mlir_module: bytes,
63-
function_name: str = "forward",
6463
device: str = "none",
6564
mlir_dialect: str = "linalg",
6665
extra_args: list = [],
@@ -73,7 +72,6 @@ def __init__(
7372
SharkRunner.__init__(
7473
self,
7574
mlir_module,
76-
function_name,
7775
device,
7876
self.mlir_dialect,
7977
self.extra_args,
@@ -85,7 +83,6 @@ def __init__(
8583
device,
8684
shark_args.repro_dir,
8785
self.mlir_dialect,
88-
function_name,
8986
extra_args=self.extra_args,
9087
)
9188

@@ -185,11 +182,11 @@ def benchmark_c(self):
185182
def benchmark_python(self, inputs):
186183
input_list = [x for x in inputs]
187184
for i in range(shark_args.num_warmup_iterations):
188-
self.run(input_list)
185+
self.run("forward", input_list)
189186

190187
begin = time.time()
191188
for i in range(shark_args.num_iterations):
192-
out = self.run(input_list)
189+
out = self.run("forward", input_list)
193190
if i == shark_args.num_iterations - 1:
194191
end = time.time()
195192
print(

shark/shark_inference.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ class SharkInference:
4040
----------
4141
mlir_module : str
4242
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
43-
function_name : str
44-
function to execute in the given mlir_module.
4543
device : str
4644
device to execute the mlir_module on.
4745
currently supports cpu, cuda, vulkan, and metal backends.
@@ -53,10 +51,10 @@ class SharkInference:
5351
5452
Methods
5553
-------
56-
run(inputs=None):
57-
Runs the mlir_module with the given inputs, if the inputs are not
58-
given it autogenerates the inputs. Also, the inputs should be a
59-
numpy array.
54+
__call__(function_name, inputs=None):
55+
Runs the function with `function_name` within the mlir_module along
56+
with the given inputs, if the inputs are not given it autogenerates the
57+
inputs. Also, the inputs should be a numpy array.
6058
input_info():
6159
Gives the information about the inputs required by the `function_name`.
6260
This can be expensive as it does string matching to do so.
@@ -66,15 +64,13 @@ class SharkInference:
6664
def __init__(
6765
self,
6866
mlir_module: bytes,
69-
function_name: str = "forward",
7067
device: str = "none",
7168
mlir_dialect: str = "linalg",
7269
is_benchmark: bool = False,
7370
dispatch_benchmark: str = None,
7471
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
7572
):
7673
self.mlir_module = mlir_module
77-
self.function_name = function_name
7874
self.device = shark_args.device if device == "none" else device
7975
self.mlir_dialect = mlir_dialect
8076
self.is_benchmark = is_benchmark
@@ -113,7 +109,6 @@ def compile(self, extra_args=[]):
113109

114110
self.shark_runner = SharkBenchmarkRunner(
115111
self.mlir_module,
116-
self.function_name,
117112
self.device,
118113
self.mlir_dialect,
119114
extra_args=extra_args,
@@ -122,7 +117,6 @@ def compile(self, extra_args=[]):
122117
else:
123118
self.shark_runner = SharkRunner(
124119
self.mlir_module,
125-
self.function_name,
126120
self.device,
127121
self.mlir_dialect,
128122
extra_args=extra_args,
@@ -138,21 +132,25 @@ def compile(self, extra_args=[]):
138132
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
139133

140134
# inputs are considered to be tuple of np.array.
141-
def forward(self, inputs: tuple, send_to_host=True):
142-
return self.shark_runner.run(inputs, send_to_host)
135+
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
136+
return self.shark_runner.run(function_name, inputs, send_to_host)
137+
138+
# Get all function names defined within the compiled module.
139+
def get_functions_in_module(self):
140+
return self.shark_runner.get_functions_in_module()
143141

144142
# Captures the static input information from the mlir_module.
145143
# TODO(pashu123): Generate the input information for dynamic shapes.
146-
def _input_info(self):
144+
def _input_info(self, function_name):
147145
# func_key to get the line which contains the function.
148-
func_key = "func.func @" + self.function_name
146+
func_key = "func.func @" + function_name
149147
func_header = None
150148
for line in str(self.mlir_module).splitlines():
151149
if func_key in line:
152150
func_header = line
153151
break
154152
if func_header is None:
155-
print(f"Function: {self.function_name} not found")
153+
print(f"Function: {function_name} not found")
156154

157155
import re
158156

@@ -190,15 +188,13 @@ def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
190188
self.device,
191189
dir,
192190
self.mlir_dialect,
193-
self.function_name,
194191
module_name=module_name,
195192
extra_args=extra_args,
196193
)
197194

198195
# load and return the module.
199196
def load_module(self, path, extra_args=[]):
200197
self.shark_runner = SharkRunner(
201-
function_name=self.function_name,
202198
device=self.device,
203199
compile_vmfb=False,
204200
extra_args=extra_args,
@@ -209,6 +205,5 @@ def load_module(self, path, extra_args=[]):
209205
) = load_flatbuffer(
210206
path,
211207
self.device,
212-
self.function_name,
213208
)
214209
return

0 commit comments

Comments
 (0)