Skip to content

Commit 1be7afd

Browse files
authored
rename TestWeights to appease pytest (#5054)
1 parent cca452f commit 1be7afd

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

test/test_prototype_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,27 +203,27 @@ def test_smoke():
203203
# With this filter, every unexpected warning will be turned into an error
204204
@pytest.mark.filterwarnings("error")
205205
class TestHandleLegacyInterface:
206-
class TestWeights(WeightsEnum):
206+
class ModelWeights(WeightsEnum):
207207
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
208208

209209
@pytest.mark.parametrize(
210210
"kwargs",
211211
[
212212
pytest.param(dict(), id="empty"),
213213
pytest.param(dict(weights=None), id="None"),
214-
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
214+
pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
215215
],
216216
)
217217
def test_no_warn(self, kwargs):
218-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
218+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
219219
def builder(*, weights=None):
220220
pass
221221

222222
builder(**kwargs)
223223

224224
@pytest.mark.parametrize("pretrained", (True, False))
225225
def test_pretrained_pos(self, pretrained):
226-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
226+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
227227
def builder(*, weights=None):
228228
pass
229229

@@ -232,7 +232,7 @@ def builder(*, weights=None):
232232

233233
@pytest.mark.parametrize("pretrained", (True, False))
234234
def test_pretrained_kw(self, pretrained):
235-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
235+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
236236
def builder(*, weights=None):
237237
pass
238238

@@ -242,12 +242,12 @@ def builder(*, weights=None):
242242
@pytest.mark.parametrize("pretrained", (True, False))
243243
@pytest.mark.parametrize("positional", (True, False))
244244
def test_equivalent_behavior_weights(self, pretrained, positional):
245-
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
245+
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
246246
def builder(*, weights=None):
247247
pass
248248

249249
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
250-
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
250+
with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
251251
builder(*args, **kwargs)
252252

253253
def test_multi_params(self):
@@ -256,7 +256,7 @@ def test_multi_params(self):
256256

257257
@handle_legacy_interface(
258258
**{
259-
weights_param: (pretrained_param, self.TestWeights.Sentinel)
259+
weights_param: (pretrained_param, self.ModelWeights.Sentinel)
260260
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
261261
}
262262
)
@@ -271,7 +271,7 @@ def test_default_callable(self):
271271
@handle_legacy_interface(
272272
weights=(
273273
"pretrained",
274-
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
274+
lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
275275
)
276276
)
277277
def builder(*, weights=None, flag):

0 commit comments

Comments
 (0)