@@ -40,8 +40,6 @@ class SharkInference:
4040 ----------
4141 mlir_module : str
4242 mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
43- function_name : str
44- function to execute in the given mlir_module.
4543 device : str
4644 device to execute the mlir_module on.
4745 currently supports cpu, cuda, vulkan, and metal backends.
@@ -53,10 +51,10 @@ class SharkInference:
5351
5452 Methods
5553 -------
56- run( inputs=None):
57- Runs the mlir_module with the given inputs, if the inputs are not
58- given it autogenerates the inputs. Also, the inputs should be a
59- numpy array.
54+ __call__(function_name, inputs=None):
55+ Runs the function with `function_name` within the mlir_module along
56+ with the given inputs, if the inputs are not given it autogenerates the
57+ inputs. Also, the inputs should be a numpy array.
6058 input_info():
6159 Gives the information about the inputs required by the `function_name`.
6260 This can be expensive as it does string matching to do so.
@@ -66,15 +64,13 @@ class SharkInference:
6664 def __init__ (
6765 self ,
6866 mlir_module : bytes ,
69- function_name : str = "forward" ,
7067 device : str = "none" ,
7168 mlir_dialect : str = "linalg" ,
7269 is_benchmark : bool = False ,
7370 dispatch_benchmark : str = None ,
7471 dispatch_benchmark_dir : str = "temp_dispatch_benchmarks" ,
7572 ):
7673 self .mlir_module = mlir_module
77- self .function_name = function_name
7874 self .device = shark_args .device if device == "none" else device
7975 self .mlir_dialect = mlir_dialect
8076 self .is_benchmark = is_benchmark
@@ -113,7 +109,6 @@ def compile(self, extra_args=[]):
113109
114110 self .shark_runner = SharkBenchmarkRunner (
115111 self .mlir_module ,
116- self .function_name ,
117112 self .device ,
118113 self .mlir_dialect ,
119114 extra_args = extra_args ,
@@ -122,7 +117,6 @@ def compile(self, extra_args=[]):
122117 else :
123118 self .shark_runner = SharkRunner (
124119 self .mlir_module ,
125- self .function_name ,
126120 self .device ,
127121 self .mlir_dialect ,
128122 extra_args = extra_args ,
@@ -138,21 +132,25 @@ def compile(self, extra_args=[]):
138132 os .system (f"rm -rf { self .temp_dispatch_benchmarks_dir } " )
139133
140134 # inputs are considered to be tuple of np.array.
141- def forward (self , inputs : tuple , send_to_host = True ):
142- return self .shark_runner .run (inputs , send_to_host )
135+ def __call__ (self , function_name : str , inputs : tuple , send_to_host = True ):
136+ return self .shark_runner .run (function_name , inputs , send_to_host )
137+
138+ # Get all function names defined within the compiled module.
139+ def get_functions_in_module (self ):
140+ return self .shark_runner .get_functions_in_module ()
143141
144142 # Captures the static input information from the mlir_module.
145143 # TODO(pashu123): Generate the input information for dynamic shapes.
146- def _input_info (self ):
144+ def _input_info (self , function_name ):
147145 # func_key to get the line which contains the function.
148- func_key = "func.func @" + self . function_name
146+ func_key = "func.func @" + function_name
149147 func_header = None
150148 for line in str (self .mlir_module ).splitlines ():
151149 if func_key in line :
152150 func_header = line
153151 break
154152 if func_header is None :
155- print (f"Function: { self . function_name } not found" )
153+ print (f"Function: { function_name } not found" )
156154
157155 import re
158156
@@ -190,15 +188,13 @@ def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
190188 self .device ,
191189 dir ,
192190 self .mlir_dialect ,
193- self .function_name ,
194191 module_name = module_name ,
195192 extra_args = extra_args ,
196193 )
197194
198195 # load and return the module.
199196 def load_module (self , path , extra_args = []):
200197 self .shark_runner = SharkRunner (
201- function_name = self .function_name ,
202198 device = self .device ,
203199 compile_vmfb = False ,
204200 extra_args = extra_args ,
@@ -209,6 +205,5 @@ def load_module(self, path, extra_args=[]):
209205 ) = load_flatbuffer (
210206 path ,
211207 self .device ,
212- self .function_name ,
213208 )
214209 return
0 commit comments