@@ -46,7 +46,7 @@ def __call__(self, **kwargs):
4646 return self .model .run (None , inputs )
4747
4848 @staticmethod
49- def load_model (path : Union [str , Path ], provider = None ):
49+ def load_model (path : Union [str , Path ], provider = None , sess_options = None ):
5050 """
5151 Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
5252
@@ -60,7 +60,7 @@ def load_model(path: Union[str, Path], provider=None):
6060 logger .info ("No onnxruntime provider specified, using CPUExecutionProvider" )
6161 provider = "CPUExecutionProvider"
6262
63- return ort .InferenceSession (path , providers = [provider ])
63+ return ort .InferenceSession (path , providers = [provider ], sess_options = sess_options )
6464
6565 def _save_pretrained (self , save_directory : Union [str , Path ], file_name : Optional [str ] = None , ** kwargs ):
6666 """
@@ -114,6 +114,7 @@ def _from_pretrained(
114114 cache_dir : Optional [str ] = None ,
115115 file_name : Optional [str ] = None ,
116116 provider : Optional [str ] = None ,
117+ sess_options : Optional [ort .SessionOptions ] = None ,
117118 ** kwargs ,
118119 ):
119120 """
@@ -143,7 +144,9 @@ def _from_pretrained(
143144 model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
144145 # load model from local directory
145146 if os .path .isdir (model_id ):
146- model = OnnxRuntimeModel .load_model (os .path .join (model_id , model_file_name ), provider = provider )
147+ model = OnnxRuntimeModel .load_model (
148+ os .path .join (model_id , model_file_name ), provider = provider , sess_options = sess_options
149+ )
147150 kwargs ["model_save_dir" ] = Path (model_id )
148151 # load model from hub
149152 else :
@@ -158,7 +161,7 @@ def _from_pretrained(
158161 )
159162 kwargs ["model_save_dir" ] = Path (model_cache_path ).parent
160163 kwargs ["latest_model_name" ] = Path (model_cache_path ).name
161- model = OnnxRuntimeModel .load_model (model_cache_path , provider = provider )
164+ model = OnnxRuntimeModel .load_model (model_cache_path , provider = provider , sess_options = sess_options )
162165 return cls (model = model , ** kwargs )
163166
164167 @classmethod
0 commit comments