Skip to content

Commit 8211b62

Browse files
authored
Allow passing session_options for ORT backend (#620)
1 parent ce31f83 commit 8211b62

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/diffusers/onnx_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __call__(self, **kwargs):
4646
return self.model.run(None, inputs)
4747

4848
@staticmethod
49-
def load_model(path: Union[str, Path], provider=None):
49+
def load_model(path: Union[str, Path], provider=None, sess_options=None):
5050
"""
5151
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
5252
@@ -60,7 +60,7 @@ def load_model(path: Union[str, Path], provider=None):
6060
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
6161
provider = "CPUExecutionProvider"
6262

63-
return ort.InferenceSession(path, providers=[provider])
63+
return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
6464

6565
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
6666
"""
@@ -114,6 +114,7 @@ def _from_pretrained(
114114
cache_dir: Optional[str] = None,
115115
file_name: Optional[str] = None,
116116
provider: Optional[str] = None,
117+
sess_options: Optional[ort.SessionOptions] = None,
117118
**kwargs,
118119
):
119120
"""
@@ -143,7 +144,9 @@ def _from_pretrained(
143144
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
144145
# load model from local directory
145146
if os.path.isdir(model_id):
146-
model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
147+
model = OnnxRuntimeModel.load_model(
148+
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
149+
)
147150
kwargs["model_save_dir"] = Path(model_id)
148151
# load model from hub
149152
else:
@@ -158,7 +161,7 @@ def _from_pretrained(
158161
)
159162
kwargs["model_save_dir"] = Path(model_cache_path).parent
160163
kwargs["latest_model_name"] = Path(model_cache_path).name
161-
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
164+
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
162165
return cls(model=model, **kwargs)
163166

164167
@classmethod

src/diffusers/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
282282
revision = kwargs.pop("revision", None)
283283
torch_dtype = kwargs.pop("torch_dtype", None)
284284
provider = kwargs.pop("provider", None)
285+
sess_options = kwargs.pop("sess_options", None)
285286

286287
# 1. Download the checkpoints and configs
287288
# use snapshot download here to get it working from from_pretrained
@@ -398,6 +399,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
398399
loading_kwargs["torch_dtype"] = torch_dtype
399400
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
400401
loading_kwargs["provider"] = provider
402+
loading_kwargs["sess_options"] = sess_options
401403

402404
# check if the module is in a subdirectory
403405
if os.path.isdir(os.path.join(cached_folder, name)):

0 commit comments

Comments
 (0)