diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py index 6286d7b19b1..d3d0d659668 100644 --- a/torchvision/prototype/models/_utils.py +++ b/torchvision/prototype/models/_utils.py @@ -43,6 +43,7 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> M: if ( (weights_param not in kwargs and pretrained_param not in kwargs) or isinstance(weights_arg, WeightsEnum) + or (isinstance(weights_arg, str) and weights_arg != "legacy") or weights_arg is None ): continue