|
16 | 16 | # and https://github.com/hojonathanho/diffusion |
17 | 17 |
|
18 | 18 | import math |
19 | | -import warnings |
20 | 19 | from dataclasses import dataclass |
21 | 20 | from typing import Optional, Tuple, Union |
22 | 21 |
|
23 | 22 | import numpy as np |
24 | 23 | import torch |
25 | 24 |
|
26 | 25 | from ..configuration_utils import ConfigMixin, register_to_config |
27 | | -from ..utils import BaseOutput |
| 26 | +from ..utils import BaseOutput, deprecate |
28 | 27 | from .scheduling_utils import SchedulerMixin |
29 | 28 |
|
30 | 29 |
|
@@ -122,12 +121,12 @@ def __init__( |
122 | 121 | steps_offset: int = 0, |
123 | 122 | **kwargs, |
124 | 123 | ): |
125 | | - if "tensor_format" in kwargs: |
126 | | - warnings.warn( |
127 | | - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." |
128 | | - "If you're running your code in PyTorch, you can safely remove this argument.", |
129 | | - DeprecationWarning, |
130 | | - ) |
| 124 | + deprecate( |
| 125 | + "tensor_format", |
| 126 | + "0.5.0", |
| 127 | + "If you're running your code in PyTorch, you can safely remove this argument.", |
| 128 | + take_from=kwargs, |
| 129 | + ) |
131 | 130 |
|
132 | 131 | if trained_betas is not None: |
133 | 132 | self.betas = torch.from_numpy(trained_betas) |
@@ -175,17 +174,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): |
175 | 174 | num_inference_steps (`int`): |
176 | 175 | the number of diffusion steps used when generating samples with a pre-trained model. |
177 | 176 | """ |
178 | | - |
179 | | - offset = self.config.steps_offset |
180 | | - |
181 | | - if "offset" in kwargs: |
182 | | - warnings.warn( |
183 | | - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." |
184 | | - " Please pass `steps_offset` to `__init__` instead.", |
185 | | - DeprecationWarning, |
186 | | - ) |
187 | | - |
188 | | - offset = kwargs["offset"] |
| 177 | + deprecated_offset = deprecate( |
| 178 | + "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs |
| 179 | + ) |
| 180 | + offset = deprecated_offset or self.config.steps_offset |
189 | 181 |
|
190 | 182 | self.num_inference_steps = num_inference_steps |
191 | 183 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps |
|
0 commit comments