@@ -203,27 +203,27 @@ def test_smoke():
203203# With this filter, every unexpected warning will be turned into an error
204204@pytest .mark .filterwarnings ("error" )
205205class TestHandleLegacyInterface :
206- class TestWeights (WeightsEnum ):
206+ class ModelWeights (WeightsEnum ):
207207 Sentinel = Weights (url = "https://pytorch.org" , transforms = lambda x : x , meta = dict ())
208208
209209 @pytest .mark .parametrize (
210210 "kwargs" ,
211211 [
212212 pytest .param (dict (), id = "empty" ),
213213 pytest .param (dict (weights = None ), id = "None" ),
214- pytest .param (dict (weights = TestWeights .Sentinel ), id = "Weights" ),
214+ pytest .param (dict (weights = ModelWeights .Sentinel ), id = "Weights" ),
215215 ],
216216 )
217217 def test_no_warn (self , kwargs ):
218- @handle_legacy_interface (weights = ("pretrained" , self .TestWeights .Sentinel ))
218+ @handle_legacy_interface (weights = ("pretrained" , self .ModelWeights .Sentinel ))
219219 def builder (* , weights = None ):
220220 pass
221221
222222 builder (** kwargs )
223223
224224 @pytest .mark .parametrize ("pretrained" , (True , False ))
225225 def test_pretrained_pos (self , pretrained ):
226- @handle_legacy_interface (weights = ("pretrained" , self .TestWeights .Sentinel ))
226+ @handle_legacy_interface (weights = ("pretrained" , self .ModelWeights .Sentinel ))
227227 def builder (* , weights = None ):
228228 pass
229229
@@ -232,7 +232,7 @@ def builder(*, weights=None):
232232
233233 @pytest .mark .parametrize ("pretrained" , (True , False ))
234234 def test_pretrained_kw (self , pretrained ):
235- @handle_legacy_interface (weights = ("pretrained" , self .TestWeights .Sentinel ))
235+ @handle_legacy_interface (weights = ("pretrained" , self .ModelWeights .Sentinel ))
236236 def builder (* , weights = None ):
237237 pass
238238
@@ -242,12 +242,12 @@ def builder(*, weights=None):
242242 @pytest .mark .parametrize ("pretrained" , (True , False ))
243243 @pytest .mark .parametrize ("positional" , (True , False ))
244244 def test_equivalent_behavior_weights (self , pretrained , positional ):
245- @handle_legacy_interface (weights = ("pretrained" , self .TestWeights .Sentinel ))
245+ @handle_legacy_interface (weights = ("pretrained" , self .ModelWeights .Sentinel ))
246246 def builder (* , weights = None ):
247247 pass
248248
249249 args , kwargs = ((pretrained ,), dict ()) if positional else ((), dict (pretrained = pretrained ))
250- with pytest .warns (UserWarning , match = f"weights={ self .TestWeights .Sentinel if pretrained else None } " ):
250+ with pytest .warns (UserWarning , match = f"weights={ self .ModelWeights .Sentinel if pretrained else None } " ):
251251 builder (* args , ** kwargs )
252252
253253 def test_multi_params (self ):
@@ -256,7 +256,7 @@ def test_multi_params(self):
256256
257257 @handle_legacy_interface (
258258 ** {
259- weights_param : (pretrained_param , self .TestWeights .Sentinel )
259+ weights_param : (pretrained_param , self .ModelWeights .Sentinel )
260260 for weights_param , pretrained_param in zip (weights_params , pretrained_params )
261261 }
262262 )
@@ -271,7 +271,7 @@ def test_default_callable(self):
271271 @handle_legacy_interface (
272272 weights = (
273273 "pretrained" ,
274- lambda kwargs : self .TestWeights .Sentinel if kwargs ["flag" ] else None ,
274+ lambda kwargs : self .ModelWeights .Sentinel if kwargs ["flag" ] else None ,
275275 )
276276 )
277277 def builder (* , weights = None , flag ):
0 commit comments