From 0d62aaf8baea3a96901f9f04e0509315c16c875c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 28 Jul 2022 18:40:40 +0100 Subject: [PATCH 01/31] Model registration mechanism. --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 51 +++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7bca0276c34..698977552ad 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight +from ._api import get_weight, list_models, load diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 901bb0015e4..e8f88a8a930 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,7 +3,9 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, cast, Dict, Mapping +from torch import nn +from types import ModuleType +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar from torchvision._utils import StrEnum @@ -140,3 +142,50 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: ) return cast(WeightsEnum, weights_enum) + + +M = TypeVar("M", bound=Type[nn.Module]) + +BUILTIN_MODELS = {} + + +def register_model(name: str, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: + def wrapper(fn: Callable[..., M]) -> Callable[..., M]: + if name in BUILTIN_MODELS and not overwrite: + raise ValueError(f"A model is already registered under tha name '{name}'.") + BUILTIN_MODELS[name] = fn + return fn + return wrapper + + +def list_models(module: Optional[ModuleType] = None) -> List[str]: + """ + Returns a list with the names of registred models. + + Args: + module (ModuleType, optional): The module from which we want to extract the available models. + + Returns: + models (list): A list with the names of available models. + """ + models = [k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__ == module] + return sorted(models) + + +def load(name: str, **config: Any) -> M: + """ + Gets the model name and configuration and returns an instantiated model. + + Args: + name (str): The name under which the model is registered. + **config (Any): parameters passed to the model builder method. + + Returns: + model (nn.Module): The initialized model. + """ + name = name.lower() + try: + fn = BUILTIN_MODELS[name] + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn(**config) From 0e7eb8a0dc1e013d3ffeeccc4520d5c8090fbbfa Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 28 Jul 2022 18:42:58 +0100 Subject: [PATCH 02/31] Add overwrite options to the dataset prototype registration mechanism. --- torchvision/models/_api.py | 2 +- torchvision/prototype/datasets/_api.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index e8f88a8a930..a5422d08443 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -152,7 +152,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: def register_model(name: str, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: def wrapper(fn: Callable[..., M]) -> Callable[..., M]: if name in BUILTIN_MODELS and not overwrite: - raise ValueError(f"A model is already registered under tha name '{name}'.") + raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_MODELS[name] = fn return fn return wrapper diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index f6f06c60a21..c3f40c448ca 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -12,8 +12,10 @@ BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: +def register_info(name: str, overwrite: bool = False) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: + if name in BUILTIN_INFOS and not overwrite: + raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_INFOS[name] = fn() return fn @@ -23,8 +25,10 @@ def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: BUILTIN_DATASETS = {} -def register_dataset(name: str) -> Callable[[D], D]: +def register_dataset(name: str, overwrite: bool = False) -> Callable[[D], D]: def wrapper(dataset_cls: D) -> D: + if name in BUILTIN_DATASETS and not overwrite: + raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_DATASETS[name] = dataset_cls return dataset_cls From 1520566e7a73001e3610a97291437d53d1520c05 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 28 Jul 2022 18:43:26 +0100 Subject: [PATCH 03/31] Adding example models. --- torchvision/models/alexnet.py | 3 ++- torchvision/models/mobilenetv3.py | 4 +++- torchvision/models/quantization/mobilenetv3.py | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 5d1401dcb36..0b1bd8b1310 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,7 +6,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import Weights, WeightsEnum, register_model from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model("alexnet") @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks `__. diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 10d2a1c91ac..9ac8c99050e 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,7 +8,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import Weights, WeightsEnum, register_model from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface @@ -371,6 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model("mobilenet_v3_large") @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) def mobilenet_v3_large( *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any @@ -401,6 +402,7 @@ def mobilenet_v3_large( return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) +@register_model("mobilenet_v3_small") @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) def mobilenet_v3_small( *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 56341bb280e..835fd239453 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -7,7 +7,7 @@ from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import Weights, WeightsEnum, register_model from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( @@ -184,6 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_QNNPACK_V1 +@register_model("quantized_mobilenet_v3_large") @handle_legacy_interface( weights=( "pretrained", From 2e16077422a89e91067a9c89f5caa053cdb82e88 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 08:31:36 +0100 Subject: [PATCH 04/31] Fix module filtering --- torchvision/models/_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index a5422d08443..74940a9c0e7 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -168,7 +168,9 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: Returns: models (list): A list with the names of available models. """ - models = [k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__ == module] + models = [ + k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ + ] return sorted(models) From a02c124fb3b845fa387518081be16d365e87e37c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 08:41:25 +0100 Subject: [PATCH 05/31] Fix linter --- torchvision/models/_api.py | 4 +++- torchvision/models/alexnet.py | 2 +- torchvision/models/mobilenetv3.py | 2 +- torchvision/models/quantization/mobilenetv3.py | 2 +- torchvision/prototype/datasets/_api.py | 4 +++- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 74940a9c0e7..713dc9ece1f 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,10 +3,11 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from torch import nn from types import ModuleType from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar +from torch import nn + from torchvision._utils import StrEnum from .._internally_replaced_utils import load_state_dict_from_url @@ -155,6 +156,7 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]: raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_MODELS[name] = fn return fn + return wrapper diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 0b1bd8b1310..de12bb415e6 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,7 +6,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum, register_model +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 9ac8c99050e..e2b571d820a 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,7 +8,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum, register_model +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 835fd239453..bc916ad6d02 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -7,7 +7,7 @@ from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum, register_model +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index c3f40c448ca..8e237aa3897 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -12,7 +12,9 @@ BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register_info(name: str, overwrite: bool = False) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: +def register_info( + name: str, overwrite: bool = False +) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: if name in BUILTIN_INFOS and not overwrite: raise ValueError(f"An entry is already registered under the name '{name}'.") From eedf8dfcb4c69adcfbcaa736f9e6e910567784ed Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 08:42:38 +0100 Subject: [PATCH 06/31] Fix docs --- torchvision/models/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 713dc9ece1f..111a0afde06 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -162,7 +162,7 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]: def list_models(module: Optional[ModuleType] = None) -> List[str]: """ - Returns a list with the names of registred models. + Returns a list with the names of registered models. Args: module (ModuleType, optional): The module from which we want to extract the available models. From a91a5b4aabca86a5911ffc94ea03eb554514859e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 09:29:34 +0100 Subject: [PATCH 07/31] Make name optional if same as model builder --- torchvision/models/_api.py | 9 +++++---- torchvision/models/alexnet.py | 2 +- torchvision/models/mobilenetv3.py | 4 ++-- torchvision/models/quantization/mobilenetv3.py | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 111a0afde06..5fa6b95d99c 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -150,11 +150,12 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: BUILTIN_MODELS = {} -def register_model(name: str, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: +def register_model(name: Optional[str] = None, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: def wrapper(fn: Callable[..., M]) -> Callable[..., M]: - if name in BUILTIN_MODELS and not overwrite: - raise ValueError(f"An entry is already registered under the name '{name}'.") - BUILTIN_MODELS[name] = fn + key = name if name is not None else fn.__name__ + if key in BUILTIN_MODELS and not overwrite: + raise ValueError(f"An entry is already registered under the name '{key}'.") + BUILTIN_MODELS[key] = fn return fn return wrapper diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index de12bb415e6..328f978ba11 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -75,7 +75,7 @@ class AlexNet_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -@register_model("alexnet") +@register_model() @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks `__. diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index e2b571d820a..81fc3c5d4c0 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -371,7 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -@register_model("mobilenet_v3_large") +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) def mobilenet_v3_large( *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any @@ -402,7 +402,7 @@ def mobilenet_v3_large( return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) -@register_model("mobilenet_v3_small") +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) def mobilenet_v3_small( *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index bc916ad6d02..986f67c6080 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -184,7 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_QNNPACK_V1 -@register_model("quantized_mobilenet_v3_large") +@register_model(name="quantized_mobilenet_v3_large") @handle_legacy_interface( weights=( "pretrained", From abbe23e83351f05b1f6cf3d982b266e4cbbaf251 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:03:22 +0100 Subject: [PATCH 08/31] Apply updates from code-review. --- torchvision/models/_api.py | 6 +++--- torchvision/prototype/datasets/_api.py | 10 ++-------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 5fa6b95d99c..c7d99c87833 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -150,10 +150,10 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: BUILTIN_MODELS = {} -def register_model(name: Optional[str] = None, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: +def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: def wrapper(fn: Callable[..., M]) -> Callable[..., M]: key = name if name is not None else fn.__name__ - if key in BUILTIN_MODELS and not overwrite: + if key in BUILTIN_MODELS: raise ValueError(f"An entry is already registered under the name '{key}'.") BUILTIN_MODELS[key] = fn return fn @@ -177,7 +177,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def load(name: str, **config: Any) -> M: +def load_model(name: str, **config: Any) -> M: """ Gets the model name and configuration and returns an instantiated model. diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8e237aa3897..f6f06c60a21 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -12,12 +12,8 @@ BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register_info( - name: str, overwrite: bool = False -) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: +def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: - if name in BUILTIN_INFOS and not overwrite: - raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_INFOS[name] = fn() return fn @@ -27,10 +23,8 @@ def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: BUILTIN_DATASETS = {} -def register_dataset(name: str, overwrite: bool = False) -> Callable[[D], D]: +def register_dataset(name: str) -> Callable[[D], D]: def wrapper(dataset_cls: D) -> D: - if name in BUILTIN_DATASETS and not overwrite: - raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_DATASETS[name] = dataset_cls return dataset_cls From 1eb8159b13a737c1800f8b0143fb1c9b83b14d8c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:05:59 +0100 Subject: [PATCH 09/31] fix minor bug --- torchvision/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 698977552ad..5c28db4660c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight, list_models, load +from ._api import get_weight, list_models, load_model From 924388efe8b3916c828c1e98bfa23cd1d8331d5e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:26:51 +0100 Subject: [PATCH 10/31] Adding getter for model weight enum --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 41 +++++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 5c28db4660c..7dee4c389f8 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight, list_models, load_model +from ._api import get_weight, list_models, get_model, get_model_weight diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index c7d99c87833..f85fe0599b3 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_weight"] +__all__ = ["WeightsEnum", "Weights", "get_weight", "list_models", "get_model", "get_model_weight"] @dataclass @@ -110,10 +110,26 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) -def _get_enum_from_fn(fn: Callable) -> WeightsEnum: +W = TypeVar("W", bound=Type[WeightsEnum]) + + +def get_model_weight(name: str) -> W: + """ + Retuns the Weights Enum from the model name. + + Args: + name (str): The name under which the model is registered. + + Returns: + weights_enum (W): The weights enum class associated with the model. + """ + fn = _find_model(name) + return _get_enum_from_fn(fn) + + +def _get_enum_from_fn(fn: Callable) -> W: """ Internal method that gets the weight enum of a specific model builder method. - Might be removed after the handle_legacy_interface is removed. Args: fn (Callable): The builder method used to create the model. @@ -142,7 +158,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return cast(WeightsEnum, weights_enum) + return cast(W, weights_enum) M = TypeVar("M", bound=Type[nn.Module]) @@ -177,7 +193,16 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def load_model(name: str, **config: Any) -> M: +def _find_model(name: str) -> Callable[..., M]: + name = name.lower() + try: + fn = BUILTIN_MODELS[name] + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn + + +def get_model(name: str, **config: Any) -> M: """ Gets the model name and configuration and returns an instantiated model. @@ -188,9 +213,5 @@ def load_model(name: str, **config: Any) -> M: Returns: model (nn.Module): The initialized model. """ - name = name.lower() - try: - fn = BUILTIN_MODELS[name] - except KeyError: - raise ValueError(f"Unknown model {name}") + fn = _find_model(name) return fn(**config) From bd2327a5666f723753f70d411e319bdb503a2f8e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:51:58 +0100 Subject: [PATCH 11/31] Support both strings and callables on get_model_weight. --- torchvision/models/_api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index f85fe0599b3..b09b5a5e7e4 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, fields from inspect import signature from types import ModuleType -from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar, Union from torch import nn @@ -113,18 +113,19 @@ def get_weight(name: str) -> WeightsEnum: W = TypeVar("W", bound=Type[WeightsEnum]) -def get_model_weight(name: str) -> W: +def get_model_weight(model: Union[Callable, str]) -> W: """ - Retuns the Weights Enum from the model name. + Retuns the Weights Enum of a model. Args: - name (str): The name under which the model is registered. + name (callable or str): The model builder function or the name under which it is registered. Returns: weights_enum (W): The weights enum class associated with the model. """ - fn = _find_model(name) - return _get_enum_from_fn(fn) + if isinstance(model, str): + model = _find_model(model) + return _get_enum_from_fn(model) def _get_enum_from_fn(fn: Callable) -> W: From a815a633394f02fba7c04f2f374d17fd98c8eff3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:57:33 +0100 Subject: [PATCH 12/31] linter fixes --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7dee4c389f8..95f30b44872 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight, list_models, get_model, get_model_weight +from ._api import get_model, get_model_weight, get_weight, list_models diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index b09b5a5e7e4..66fdb652475 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_weight", "list_models", "get_model", "get_model_weight"] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weight", "get_weight", "list_models"] @dataclass @@ -124,7 +124,7 @@ def get_model_weight(model: Union[Callable, str]) -> W: weights_enum (W): The weights enum class associated with the model. """ if isinstance(model, str): - model = _find_model(model) + model = find_model(model) return _get_enum_from_fn(model) @@ -194,7 +194,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def _find_model(name: str) -> Callable[..., M]: +def find_model(name: str) -> Callable[..., M]: name = name.lower() try: fn = BUILTIN_MODELS[name] @@ -214,5 +214,5 @@ def get_model(name: str, **config: Any) -> M: Returns: model (nn.Module): The initialized model. """ - fn = _find_model(name) + fn = find_model(name) return fn(**config) From 9e4e62cf37cec8eaaf59b59d42727bffd9abe4fd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 11:17:54 +0100 Subject: [PATCH 13/31] Fixing mypy. --- torchvision/models/_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 66fdb652475..89436bced9a 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -125,10 +125,10 @@ def get_model_weight(model: Union[Callable, str]) -> W: """ if isinstance(model, str): model = find_model(model) - return _get_enum_from_fn(model) + return cast(W, _get_enum_from_fn(model)) -def _get_enum_from_fn(fn: Callable) -> W: +def _get_enum_from_fn(fn: Callable) -> WeightsEnum: """ Internal method that gets the weight enum of a specific model builder method. @@ -159,7 +159,7 @@ def _get_enum_from_fn(fn: Callable) -> W: "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return cast(W, weights_enum) + return cast(WeightsEnum, weights_enum) M = TypeVar("M", bound=Type[nn.Module]) From 020932708a31eb666b487d6a9eb1ff2ac9229170 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 14:54:02 +0100 Subject: [PATCH 14/31] Renaming `get_model_weight` to `get_model_weights` --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 95f30b44872..eb949fb3d5c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_model, get_model_weight, get_weight, list_models +from ._api import get_model, get_model_weights, get_weight, list_models diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 89436bced9a..3f4e9e914af 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weight", "get_weight", "list_models"] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"] @dataclass @@ -78,7 +78,7 @@ def __getattr__(self, name): def get_weight(name: str) -> WeightsEnum: """ - Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" + Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" Args: name (str): The name of the weight enum entry. @@ -113,9 +113,9 @@ def get_weight(name: str) -> WeightsEnum: W = TypeVar("W", bound=Type[WeightsEnum]) -def get_model_weight(model: Union[Callable, str]) -> W: +def get_model_weights(model: Union[Callable, str]) -> W: """ - Retuns the Weights Enum of a model. + Retuns the weights enum class associated to the given model. Args: name (callable or str): The model builder function or the name under which it is registered. From 2a63dce7a65525d0cb717ced08e98c39f523bc46 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 15:26:18 +0100 Subject: [PATCH 15/31] Registering all classification models. --- test/test_models.py | 3 ++- torchvision/models/convnext.py | 6 +++++- torchvision/models/densenet.py | 6 +++++- torchvision/models/efficientnet.py | 13 ++++++++++++- torchvision/models/googlenet.py | 3 ++- torchvision/models/inception.py | 3 ++- torchvision/models/mnasnet.py | 6 +++++- torchvision/models/mobilenetv2.py | 3 ++- torchvision/models/regnet.py | 17 ++++++++++++++++- torchvision/models/resnet.py | 12 +++++++++++- torchvision/models/shufflenetv2.py | 6 +++++- torchvision/models/squeezenet.py | 4 +++- torchvision/models/swin_transformer.py | 5 ++++- torchvision/models/vgg.py | 10 +++++++++- torchvision/models/vision_transformer.py | 7 ++++++- 15 files changed, 89 insertions(+), 15 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 05bab11e479..3f70758491b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -23,10 +23,11 @@ def get_models_from_module(module): # TODO add a registration mechanism to torchvision.models + non_model_fn = {"get_model", "get_model_weights", "get_weight", "list_models"} return [ v for k, v in module.__dict__.items() - if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k not in non_model_fn ] diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 5b79e5934f4..025baa3d148 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -278,6 +278,7 @@ class ConvNeXt_Large_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: """ConvNeXt Tiny model architecture from the @@ -308,6 +309,7 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) def convnext_small( *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any @@ -340,6 +342,7 @@ def convnext_small( return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: """ConvNeXt Base model architecture from the @@ -370,6 +373,7 @@ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) def convnext_large( *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 8eaac615c86..9aa5ed176a0 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -11,7 +11,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -337,6 +337,7 @@ class DenseNet201_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-121 model from @@ -362,6 +363,7 @@ def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-161 model from @@ -387,6 +389,7 @@ def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-169 model from @@ -412,6 +415,7 @@ def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-201 model from diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 417ebabcbe5..c98eb37f935 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -12,7 +12,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface @@ -729,6 +729,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) def efficientnet_b0( *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any @@ -757,6 +758,7 @@ def efficientnet_b0( return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) def efficientnet_b1( *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any @@ -785,6 +787,7 @@ def efficientnet_b1( return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) def efficientnet_b2( *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any @@ -813,6 +816,7 @@ def efficientnet_b2( return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) def efficientnet_b3( *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any @@ -841,6 +845,7 @@ def efficientnet_b3( return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) def efficientnet_b4( *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any @@ -869,6 +874,7 @@ def efficientnet_b4( return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) def efficientnet_b5( *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any @@ -905,6 +911,7 @@ def efficientnet_b5( ) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) def efficientnet_b6( *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any @@ -941,6 +948,7 @@ def efficientnet_b6( ) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) def efficientnet_b7( *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any @@ -977,6 +985,7 @@ def efficientnet_b7( ) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) def efficientnet_v2_s( *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any @@ -1014,6 +1023,7 @@ def efficientnet_v2_s( ) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) def efficientnet_v2_m( *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any @@ -1051,6 +1061,7 @@ def efficientnet_v2_m( ) +@register_model() @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) def efficientnet_v2_l( *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 895fcd1e4e6..0ea3dd5d0b9 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -10,7 +10,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -296,6 +296,7 @@ class GoogLeNet_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: """GoogLeNet (Inception v1) model architecture from diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index c1a87954f7c..928c07ac843 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -9,7 +9,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -428,6 +428,7 @@ class Inception_V3_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: """ diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 27117ae3a83..48103f11585 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -8,7 +8,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -314,6 +314,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa return model +@register_model() @handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 0.5 from @@ -341,6 +342,7 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = return _mnasnet(0.5, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1)) def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 0.75 from @@ -368,6 +370,7 @@ def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool return _mnasnet(0.75, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 1.0 from @@ -395,6 +398,7 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = return _mnasnet(1.0, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1)) def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 1.3 from diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 06fbff2802a..6d8796b7a16 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -8,7 +8,7 @@ from ..ops.misc import Conv2dNormActivation from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface @@ -238,6 +238,7 @@ class MobileNet_V2_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V2 +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) def mobilenet_v2( *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index d4b4147404c..67665d2ffd4 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -9,7 +9,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface @@ -1101,6 +1101,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V2 +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1126,6 +1127,7 @@ def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1151,6 +1153,7 @@ def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1178,6 +1181,7 @@ def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1205,6 +1209,7 @@ def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1232,6 +1237,7 @@ def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bo return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1259,6 +1265,7 @@ def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1286,6 +1293,7 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", None)) def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1313,6 +1321,7 @@ def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1338,6 +1347,7 @@ def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1363,6 +1373,7 @@ def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1392,6 +1403,7 @@ def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1421,6 +1433,7 @@ def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1450,6 +1463,7 @@ def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bo return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ @@ -1479,6 +1493,7 @@ def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: return _regnet(params, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 667bece5730..e743474b331 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -7,7 +7,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -645,6 +645,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V2 +@register_model() @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-18 from `Deep Residual Learning for Image Recognition `__. @@ -670,6 +671,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-34 from `Deep Residual Learning for Image Recognition `__. @@ -695,6 +697,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-50 from `Deep Residual Learning for Image Recognition `__. @@ -726,6 +729,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-101 from `Deep Residual Learning for Image Recognition `__. @@ -757,6 +761,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-152 from `Deep Residual Learning for Image Recognition `__. @@ -788,6 +793,7 @@ def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = T return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) def resnext50_32x4d( *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any @@ -817,6 +823,7 @@ def resnext50_32x4d( return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) def resnext101_32x8d( *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any @@ -846,6 +853,7 @@ def resnext101_32x8d( return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) +@register_model() def resnext101_64x4d( *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any ) -> ResNet: @@ -874,6 +882,7 @@ def resnext101_64x4d( return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) def wide_resnet50_2( *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any @@ -907,6 +916,7 @@ def wide_resnet50_2( return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) def wide_resnet101_2( *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index cc4291c9a86..159e1be3bc8 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -7,7 +7,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -276,6 +276,7 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) def shufflenet_v2_x0_5( *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any @@ -306,6 +307,7 @@ def shufflenet_v2_x0_5( return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) def shufflenet_v2_x1_0( *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any @@ -336,6 +338,7 @@ def shufflenet_v2_x1_0( return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1)) def shufflenet_v2_x1_5( *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any @@ -366,6 +369,7 @@ def shufflenet_v2_x1_5( return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1)) def shufflenet_v2_x2_0( *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 8d43d3a0330..9fe6521e1a1 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -7,7 +7,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -159,6 +159,7 @@ class SqueezeNet1_1_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) def squeezenet1_0( *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any @@ -187,6 +188,7 @@ def squeezenet1_0( return _squeezenet("1_0", weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) def squeezenet1_1( *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index db5604fb377..c5bc43a14fd 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param @@ -515,6 +515,7 @@ class Swin_B_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_tiny architecture from @@ -551,6 +552,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * ) +@register_model() def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_small architecture from @@ -587,6 +589,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * ) +@register_model() def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: """ Constructs a swin_base architecture from diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 7c141381ee8..dea783c2fb1 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -6,7 +6,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -285,6 +285,7 @@ class VGG19_BN_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -310,6 +311,7 @@ def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **k return _vgg("A", False, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -335,6 +337,7 @@ def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = Tru return _vgg("A", True, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -360,6 +363,7 @@ def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **k return _vgg("B", False, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -385,6 +389,7 @@ def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = Tru return _vgg("B", True, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -410,6 +415,7 @@ def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **k return _vgg("D", False, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -435,6 +441,7 @@ def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = Tru return _vgg("D", True, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. @@ -460,6 +467,7 @@ def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **k return _vgg("E", False, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index e9a8c94cc67..a0a42ab07b7 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -9,7 +9,7 @@ from ..ops.misc import Conv2dNormActivation, MLP from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -596,6 +596,7 @@ class ViT_H_14_Weights(WeightsEnum): DEFAULT = IMAGENET1K_SWAG_E2E_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ @@ -629,6 +630,7 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru ) +@register_model() @handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ @@ -662,6 +664,7 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru ) +@register_model() @handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ @@ -695,6 +698,7 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru ) +@register_model() @handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ @@ -728,6 +732,7 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru ) +@register_model() def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_h_14 architecture from From 976a93eded51ccb7b4b136584ec6a7d8c5bf9237 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 15:34:13 +0100 Subject: [PATCH 16/31] Registering all video models. --- torchvision/models/video/mvit.py | 3 ++- torchvision/models/video/resnet.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 702116f047c..cfa82a4b851 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -10,7 +10,7 @@ from ...ops import MLP, StochasticDepth from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _KINETICS400_CATEGORIES from .._utils import _ovewrite_named_param @@ -461,6 +461,7 @@ class MViT_V1_B_Weights(WeightsEnum): DEFAULT = KINETICS400_V1 +@register_model() def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: """ Constructs a base MViTV1 architecture from diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index ab369c55553..352ae92d194 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -6,7 +6,7 @@ from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _KINETICS400_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface @@ -373,6 +373,7 @@ class R2Plus1D_18_Weights(WeightsEnum): DEFAULT = KINETICS400_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer Resnet3D model. @@ -409,6 +410,7 @@ def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, * ) +@register_model() @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer Mixed Convolution network as in @@ -445,6 +447,7 @@ def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, * ) +@register_model() @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer deep R(2+1)D network as in From 2b8dc89f04c472e1bef05a070cb4b5e44e66d879 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 15:42:39 +0100 Subject: [PATCH 17/31] Registering all detection models. --- torchvision/models/detection/faster_rcnn.py | 6 +++++- torchvision/models/detection/fcos.py | 3 ++- torchvision/models/detection/keypoint_rcnn.py | 3 ++- torchvision/models/detection/mask_rcnn.py | 4 +++- torchvision/models/detection/retinanet.py | 4 +++- torchvision/models/detection/ssd.py | 3 ++- torchvision/models/detection/ssdlite.py | 3 ++- 7 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index de46aadfe4f..3160e8e89b3 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -7,7 +7,7 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights @@ -451,6 +451,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): DEFAULT = COCO_V1 +@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), @@ -569,6 +570,7 @@ def fasterrcnn_resnet50_fpn( return model +@register_model() def fasterrcnn_resnet50_fpn_v2( *, weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None, @@ -685,6 +687,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( return model +@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), @@ -758,6 +761,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ) +@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 872fae8a359..f95ee5b763f 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -11,7 +11,7 @@ from ...ops.feature_pyramid_network import LastLevelP6P7 from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..resnet import resnet50, ResNet50_Weights @@ -666,6 +666,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): DEFAULT = COCO_V1 +@register_model() @handle_legacy_interface( weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index f4044a2c1a2..21fb53c2a49 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -6,7 +6,7 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..resnet import resnet50, ResNet50_Weights @@ -353,6 +353,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): DEFAULT = COCO_V1 +@register_model() @handle_legacy_interface( weights=( "pretrained", diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 422bacd135b..e2d105b5e41 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -6,7 +6,7 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..resnet import resnet50, ResNet50_Weights @@ -396,6 +396,7 @@ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): DEFAULT = COCO_V1 +@register_model() @handle_legacy_interface( weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), @@ -503,6 +504,7 @@ def maskrcnn_resnet50_fpn( return model +@register_model() def maskrcnn_resnet50_fpn_v2( *, weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None, diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 57c75354389..792c2c36ce4 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -11,7 +11,7 @@ from ...ops.feature_pyramid_network import LastLevelP6P7 from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..resnet import resnet50, ResNet50_Weights @@ -715,6 +715,7 @@ class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum): DEFAULT = COCO_V1 +@register_model() @handle_legacy_interface( weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), @@ -817,6 +818,7 @@ def retinanet_resnet50_fpn( return model +@register_model() def retinanet_resnet50_fpn_v2( *, weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None, diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 1a926116450..c30e508f488 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -9,7 +9,7 @@ from ...ops import boxes as box_ops from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..vgg import VGG, vgg16, VGG16_Weights @@ -568,6 +568,7 @@ def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): return SSDFeatureExtractorVGG(backbone, highres) +@register_model() @handle_legacy_interface( weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 7d695823b39..63ac0d2bc73 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -10,7 +10,7 @@ from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once from .. import mobilenet -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights @@ -204,6 +204,7 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): DEFAULT = COCO_V1 +@register_model() @handle_legacy_interface( weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), From 040ddfc30113cf76c7439051f41dc6404af6fdf8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 15:47:24 +0100 Subject: [PATCH 18/31] Registering all optical flow models. --- torchvision/models/optical_flow/raft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index c46e52a01a0..04076c96032 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -10,7 +10,7 @@ from ...transforms._presets import OpticalFlow from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._utils import handle_legacy_interface from ._utils import grid_sample, make_coords_grid, upsample_flow @@ -800,6 +800,7 @@ def _raft( return model +@register_model() @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT model from @@ -855,6 +856,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * ) +@register_model() @handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT "small" model from From 2031bf75fcb714cd4cd6b3f37c4a732ce0a53aac Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 15:55:31 +0100 Subject: [PATCH 19/31] Fixing mypy. --- torchvision/models/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 3f4e9e914af..80f00f08498 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -162,7 +162,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: return cast(WeightsEnum, weights_enum) -M = TypeVar("M", bound=Type[nn.Module]) +M = TypeVar("M", bound=nn.Module) BUILTIN_MODELS = {} From 1f27788125e330293c1f600a1a3daa5c2e5abf55 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 15:57:47 +0100 Subject: [PATCH 20/31] Registering all segmentation models. --- torchvision/models/_api.py | 2 +- torchvision/models/segmentation/deeplabv3.py | 5 ++++- torchvision/models/segmentation/fcn.py | 4 +++- torchvision/models/segmentation/lraspp.py | 3 ++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 80f00f08498..276af80cbc2 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -110,7 +110,7 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) -W = TypeVar("W", bound=Type[WeightsEnum]) +W = TypeVar("W", bound=WeightsEnum) def get_model_weights(model: Union[Callable, str]) -> W: diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 0937369a1e7..3e451a21aaf 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -6,7 +6,7 @@ from torch.nn import functional as F from ...transforms._presets import SemanticSegmentation -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _VOC_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 @@ -218,6 +218,7 @@ def _deeplabv3_mobilenetv3( return DeepLabV3(backbone, classifier, aux_classifier) +@register_model() @handle_legacy_interface( weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), @@ -273,6 +274,7 @@ def deeplabv3_resnet50( return model +@register_model() @handle_legacy_interface( weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), @@ -328,6 +330,7 @@ def deeplabv3_resnet101( return model +@register_model() @handle_legacy_interface( weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 2782d675ffe..5ec0747b710 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -4,7 +4,7 @@ from torch import nn from ...transforms._presets import SemanticSegmentation -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _VOC_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights @@ -110,6 +110,7 @@ def _fcn_resnet( return FCN(backbone, classifier, aux_classifier) +@register_model() @handle_legacy_interface( weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), @@ -168,6 +169,7 @@ def fcn_resnet50( return model +@register_model() @handle_legacy_interface( weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 339d5feffe6..4bf71e77ae2 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -7,7 +7,7 @@ from ...transforms._presets import SemanticSegmentation from ...utils import _log_api_usage_once -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _VOC_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 @@ -117,6 +117,7 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): DEFAULT = COCO_WITH_VOC_LABELS_V1 +@register_model() @handle_legacy_interface( weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), From 4d98f6c1fdafbf33864f55485beee2185d71d66a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 16:02:20 +0100 Subject: [PATCH 21/31] Registering all quantization models. --- torchvision/models/quantization/googlenet.py | 3 ++- torchvision/models/quantization/inception.py | 3 ++- torchvision/models/quantization/mobilenetv2.py | 3 ++- torchvision/models/quantization/resnet.py | 6 +++++- torchvision/models/quantization/shufflenetv2.py | 6 +++++- 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index f0205ef608c..a75beb131b7 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux @@ -132,6 +132,7 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_FBGEMM_V1 +@register_model(name="quantized_googlenet") @handle_legacy_interface( weights=( "pretrained", diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 1698cec7557..5af73c80fa0 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -10,7 +10,7 @@ from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model @@ -198,6 +198,7 @@ class Inception_V3_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_FBGEMM_V1 +@register_model(name="quantized_inception_v3") @handle_legacy_interface( weights=( "pretrained", diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 61a3cb7eeba..1f91967f146 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -7,7 +7,7 @@ from ...ops.misc import Conv2dNormActivation from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model @@ -89,6 +89,7 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_QNNPACK_V1 +@register_model(name="quantized_mobilenet_v2") @handle_legacy_interface( weights=( "pretrained", diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index bf3c733887e..39bea3f48f1 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -15,7 +15,7 @@ ) from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model @@ -268,6 +268,7 @@ class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_FBGEMM_V1 +@register_model(name="quantized_resnet18") @handle_legacy_interface( weights=( "pretrained", @@ -317,6 +318,7 @@ def resnet18( return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) +@register_model(name="quantized_resnet50") @handle_legacy_interface( weights=( "pretrained", @@ -366,6 +368,7 @@ def resnet50( return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) +@register_model(name="quantized_resnext101_32x8d") @handle_legacy_interface( weights=( "pretrained", @@ -417,6 +420,7 @@ def resnext101_32x8d( return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) +@register_model(name="quantized_resnext101_64x4d") def resnext101_64x4d( *, weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None, diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 028df8be982..1d3622b6403 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -7,7 +7,7 @@ from torchvision.models import shufflenetv2 from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..shufflenetv2 import ( @@ -203,6 +203,7 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_FBGEMM_V1 +@register_model(name="quantized_shufflenet_v2_x0_5") @handle_legacy_interface( weights=( "pretrained", @@ -256,6 +257,7 @@ def shufflenet_v2_x0_5( ) +@register_model(name="quantized_shufflenet_v2_x1_0") @handle_legacy_interface( weights=( "pretrained", @@ -309,6 +311,7 @@ def shufflenet_v2_x1_0( ) +@register_model(name="quantized_shufflenet_v2_x1_5") def shufflenet_v2_x1_5( *, weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None, @@ -354,6 +357,7 @@ def shufflenet_v2_x1_5( ) +@register_model(name="quantized_shufflenet_v2_x2_0") def shufflenet_v2_x2_0( *, weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None, From ba0bb82b59dce2cd61e9bcc6f4038337dbc99190 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 16:05:48 +0100 Subject: [PATCH 22/31] Fixing linter --- torchvision/models/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 276af80cbc2..4eecc389e4d 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, fields from inspect import signature from types import ModuleType -from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union from torch import nn From 0e2d120e74caa68e1aa8b400244818c677118647 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 16:15:19 +0100 Subject: [PATCH 23/31] Registering all prototype depth perception models. --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index fa636f8ef00..522ad24c3a2 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -5,7 +5,7 @@ import torch.nn.functional as F import torchvision.models.optical_flow.raft as raft from torch import Tensor -from torchvision.models._api import WeightsEnum +from torchvision.models._api import register_model, WeightsEnum from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock from torchvision.ops import Conv2dNormActivation @@ -617,6 +617,7 @@ class Raft_Stereo_Base_Weights(WeightsEnum): pass +@register_model() def raft_stereo_realtime( *, weights: Optional[Raft_Stereo_Realtime_Weights] = None, progress=True, **kwargs ) -> RaftStereo: @@ -676,6 +677,7 @@ def raft_stereo_realtime( ) +@register_model() def raft_stereo_base(*, weights: Optional[Raft_Stereo_Base_Weights] = None, progress=True, **kwargs) -> RaftStereo: """RAFT-Stereo model from `RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching `_. From 2499c75f0249997c50ba317812a163a335af559c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 17:45:23 +0100 Subject: [PATCH 24/31] Adding tests and updating existing tests. --- test/test_extended_models.py | 102 ++++++++++++++++++++++------------ test/test_models.py | 34 +++++------- test/test_prototype_models.py | 14 ++--- 3 files changed, 88 insertions(+), 62 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 677d19d18f7..cf9ca54420d 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -5,7 +5,7 @@ import test_models as TM import torch from torchvision import models -from torchvision.models._api import Weights, WeightsEnum +from torchvision.models._api import get_model_weights, Weights, WeightsEnum from torchvision.models._utils import handle_legacy_interface @@ -15,23 +15,53 @@ ) -def _get_parent_module(model_fn): - parent_module_name = ".".join(model_fn.__module__.split(".")[:-1]) - module = importlib.import_module(parent_module_name) - return module +@pytest.mark.parametrize( + "name, model_class", + [ + ("resnet50", models.ResNet), + ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet), + ("raft_large", models.optical_flow.RAFT), + ("quantized_resnet50", models.quantization.QuantizableResNet), + ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP), + ("mvit_v1_b", models.video.MViT), + ], +) +def test_get_model(name, model_class): + assert isinstance(models.get_model(name), model_class) + + +@pytest.mark.parametrize( + "name, weight", + [ + ("resnet50", models.ResNet50_Weights), + ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights), + ("raft_large", models.optical_flow.Raft_Large_Weights), + ("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights), + ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights), + ("mvit_v1_b", models.video.MViT_V1_B_Weights), + ], +) +def test_get_model_weights(name, weight): + assert models.get_model_weights(name) == weight -def _get_model_weights(model_fn): - module = _get_parent_module(model_fn) - weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" - try: - return next( - v +@pytest.mark.parametrize( + "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow] +) +def test_list_models(module): + def get_models_from_module(module): + return [ + v.__name__ for k, v in module.__dict__.items() - if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__ - ) - except StopIteration: - return None + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k not in models._api.__all__ + ] + + a = set(get_models_from_module(module)) + b = set(x.replace("quantized_", "") for x in models.list_models(module)) + + assert len(b) > 0 + assert len(a - b) == 0 + assert len(b - a) == 0 @pytest.mark.parametrize( @@ -55,27 +85,27 @@ def test_get_weight(name, weight): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) - + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + TM.list_model_fns(models) + + TM.list_model_fns(models.detection) + + TM.list_model_fns(models.quantization) + + TM.list_model_fns(models.segmentation) + + TM.list_model_fns(models.video) + + TM.list_model_fns(models.optical_flow), ) def test_naming_conventions(model_fn): - weights_enum = _get_model_weights(model_fn) + weights_enum = get_model_weights(model_fn) assert weights_enum is not None assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) - + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + TM.list_model_fns(models) + + TM.list_model_fns(models.detection) + + TM.list_model_fns(models.quantization) + + TM.list_model_fns(models.segmentation) + + TM.list_model_fns(models.video) + + TM.list_model_fns(models.optical_flow), ) @run_if_test_with_extended def test_schema_meta_validation(model_fn): @@ -112,7 +142,7 @@ def test_schema_meta_validation(model_fn): module_name = model_fn.__module__.split(".")[-2] expected_fields = defaults["all"] | defaults[module_name] - weights_enum = _get_model_weights(model_fn) + weights_enum = get_model_weights(model_fn) if len(weights_enum) == 0: pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") @@ -153,17 +183,17 @@ def test_schema_meta_validation(model_fn): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) - + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + TM.list_model_fns(models) + + TM.list_model_fns(models.detection) + + TM.list_model_fns(models.quantization) + + TM.list_model_fns(models.segmentation) + + TM.list_model_fns(models.video) + + TM.list_model_fns(models.optical_flow), ) @run_if_test_with_extended def test_transforms_jit(model_fn): model_name = model_fn.__name__ - weights_enum = _get_model_weights(model_fn) + weights_enum = get_model_weights(model_fn) if len(weights_enum) == 0: pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") diff --git a/test/test_models.py b/test/test_models.py index 3f70758491b..5ab0640a70e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -16,19 +16,15 @@ from _utils_internal import get_relative_path from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from torchvision import models +from torchvision.models._api import find_model, list_models + ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" -def get_models_from_module(module): - # TODO add a registration mechanism to torchvision.models - non_model_fn = {"get_model", "get_model_weights", "get_weight", "list_models"} - return [ - v - for k, v in module.__dict__.items() - if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k not in non_model_fn - ] +def list_model_fns(module): + return [find_model(name) for name in list_models(module)] @pytest.fixture @@ -598,7 +594,7 @@ def test_vitc_models(model_fn, dev): test_classification_model(model_fn, dev) -@pytest.mark.parametrize("model_fn", get_models_from_module(models)) +@pytest.mark.parametrize("model_fn", list_model_fns(models)) @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_classification_model(model_fn, dev): set_rng_seed(0) @@ -634,7 +630,7 @@ def test_classification_model(model_fn, dev): _check_input_backprop(model, x) -@pytest.mark.parametrize("model_fn", get_models_from_module(models.segmentation)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation)) @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_segmentation_model(model_fn, dev): set_rng_seed(0) @@ -696,7 +692,7 @@ def check_out(out): _check_input_backprop(model, x) -@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_detection_model(model_fn, dev): set_rng_seed(0) @@ -794,7 +790,7 @@ def compute_mean_std(tensor): _check_input_backprop(model, model_input) -@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) def test_detection_model_validation(model_fn): set_rng_seed(0) model = model_fn(num_classes=50, weights=None, weights_backbone=None) @@ -823,7 +819,7 @@ def test_detection_model_validation(model_fn): model(x, targets=targets) -@pytest.mark.parametrize("model_fn", get_models_from_module(models.video)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.video)) @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_video_model(model_fn, dev): set_rng_seed(0) @@ -869,7 +865,7 @@ def test_video_model(model_fn, dev): ), reason="This Pytorch Build has not been built with fbgemm and qnnpack", ) -@pytest.mark.parametrize("model_fn", get_models_from_module(models.quantization)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization)) def test_quantized_classification_model(model_fn): set_rng_seed(0) defaults = { @@ -918,7 +914,7 @@ def test_quantized_classification_model(model_fn): torch.ao.quantization.convert(model, inplace=True) -@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading): model_name = model_fn.__name__ max_trainable = _model_tests_values[model_name]["max_trainable"] @@ -931,9 +927,9 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load @needs_cuda -@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small)) +@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow)) @pytest.mark.parametrize("scripted", (False, True)) -def test_raft(model_builder, scripted): +def test_raft(model_fn, scripted): torch.manual_seed(0) @@ -943,7 +939,7 @@ def test_raft(model_builder, scripted): # reduced to 1) corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2) - model = model_builder(corr_block=corr_block).eval().to("cuda") + model = model_fn(corr_block=corr_block).eval().to("cuda") if scripted: model = torch.jit.script(model) @@ -955,7 +951,7 @@ def test_raft(model_builder, scripted): flow_pred = preds[-1] # Tolerance is fairly high, but there are 2 * H * W outputs to check # The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different - _assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1) + _assert_expected(flow_pred, name=model_fn.__name__, atol=1e-2, rtol=1) if __name__ == "__main__": diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index eefb1669901..a33502fff6e 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,22 +1,22 @@ import pytest import test_models as TM import torch -import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo +from torchvision.prototype import models from common_utils import cpu_and_gpu, set_rng_seed -@pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime)) +@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo)) @pytest.mark.parametrize("model_mode", ("standard", "scripted")) @pytest.mark.parametrize("dev", cpu_and_gpu()) -def test_raft_stereo(model_builder, model_mode, dev): +def test_raft_stereo(model_fn, model_mode, dev): # A simple test to make sure the model can do forward pass and jit scriptable set_rng_seed(0) # Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output # get the idea from test_models.test_raft - corr_pyramid = raft_stereo.CorrPyramid1d(num_levels=2) - corr_block = raft_stereo.CorrBlock1d(num_levels=2, radius=2) - model = model_builder(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev) + corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2) + corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2) + model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev) if model_mode == "scripted": model = torch.jit.script(model) @@ -35,4 +35,4 @@ def test_raft_stereo(model_builder, model_mode, dev): ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}" # Test against expected file output - TM._assert_expected(depth_pred, name=model_builder.__name__, atol=1e-2, rtol=1e-2) + TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) From a75cfacb3c28f408a9ebdb1cd33e0627cdcd5aef Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 17:56:21 +0100 Subject: [PATCH 25/31] Fix linters --- test/test_extended_models.py | 1 - test/test_prototype_models.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index cf9ca54420d..9dd818b3ba5 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -1,4 +1,3 @@ -import importlib import os import pytest diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index a33502fff6e..56f7b9cb6ac 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,8 +1,8 @@ import pytest import test_models as TM import torch -from torchvision.prototype import models from common_utils import cpu_and_gpu, set_rng_seed +from torchvision.prototype import models @pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo)) From 7e9c2e737e25ff392ec053a5a05eb9447df1065f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 30 Jul 2022 09:57:51 +0100 Subject: [PATCH 26/31] Fix tests. --- test/test_backbone_utils.py | 21 ++++++------------- .../test_models_detection_negative_samples.py | 4 ++-- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index a2b2406441e..bd1174819ef 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -11,15 +11,6 @@ from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names -def get_available_models(): - # TODO add a registration mechanism to torchvision.models - return [ - k - for k, v in models.__dict__.items() - if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" - ] - - @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") @@ -135,10 +126,10 @@ def _get_return_nodes(self, model): eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)] return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) - @pytest.mark.parametrize("model_name", get_available_models()) + @pytest.mark.parametrize("model_name", models.list_models()) def test_build_fx_feature_extractor(self, model_name): set_rng_seed(0) - model = models.__dict__[model_name](**self.model_defaults).eval() + model = models.get_model(model_name, **self.model_defaults).eval() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) # Check that it works with both a list and dict for return nodes self._create_feature_extractor( @@ -172,9 +163,9 @@ def test_node_name_conventions(self): train_nodes, _ = get_graph_node_names(model) assert all(a == b for a, b in zip(train_nodes, test_module_nodes)) - @pytest.mark.parametrize("model_name", get_available_models()) + @pytest.mark.parametrize("model_name", models.list_models()) def test_forward_backward(self, model_name): - model = models.__dict__[model_name](**self.model_defaults).train() + model = models.get_model(model_name, **self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) model = self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes @@ -211,10 +202,10 @@ def test_feature_extraction_methods_equivalence(self): for k in ilg_out.keys(): assert ilg_out[k].equal(fgn_out[k]) - @pytest.mark.parametrize("model_name", get_available_models()) + @pytest.mark.parametrize("model_name", models.list_models()) def test_jit_forward_backward(self, model_name): set_rng_seed(0) - model = models.__dict__[model_name](**self.model_defaults).train() + model = models.get_model(model_name, **self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) model = self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 13db78d53fc..c91cfdf20a7 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -99,8 +99,8 @@ def test_assign_targets_to_proposals(self): ], ) def test_forward_negative_sample_frcnn(self, name): - model = torchvision.models.detection.__dict__[name]( - weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 + model = torchvision.models.get_model( + name, weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 ) images, targets = self._make_empty_sample() From 81a7e3ff109fbb8c8ba111543964a4d66e41feb8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 30 Jul 2022 10:17:56 +0100 Subject: [PATCH 27/31] Add beta annotation on docs. --- torchvision/models/_api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 4eecc389e4d..da92fc3a596 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -80,6 +80,8 @@ def get_weight(name: str) -> WeightsEnum: """ Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" + .. betastatus:: method + Args: name (str): The name of the weight enum entry. @@ -117,6 +119,8 @@ def get_model_weights(model: Union[Callable, str]) -> W: """ Retuns the weights enum class associated to the given model. + .. betastatus:: method + Args: name (callable or str): The model builder function or the name under which it is registered. @@ -182,6 +186,8 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: """ Returns a list with the names of registered models. + .. betastatus:: method + Args: module (ModuleType, optional): The module from which we want to extract the available models. @@ -207,6 +213,8 @@ def get_model(name: str, **config: Any) -> M: """ Gets the model name and configuration and returns an instantiated model. + .. betastatus:: method + Args: name (str): The name under which the model is registered. **config (Any): parameters passed to the model builder method. From 867f85f93f23f18c489e8b0f01bdd13bd7ea9e2d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 30 Jul 2022 10:25:58 +0100 Subject: [PATCH 28/31] Fix tests. --- test/test_backbone_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index bd1174819ef..4fba3c3d098 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -126,7 +126,7 @@ def _get_return_nodes(self, model): eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)] return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) - @pytest.mark.parametrize("model_name", models.list_models()) + @pytest.mark.parametrize("model_name", models.list_models(models)) def test_build_fx_feature_extractor(self, model_name): set_rng_seed(0) model = models.get_model(model_name, **self.model_defaults).eval() @@ -163,7 +163,7 @@ def test_node_name_conventions(self): train_nodes, _ = get_graph_node_names(model) assert all(a == b for a, b in zip(train_nodes, test_module_nodes)) - @pytest.mark.parametrize("model_name", models.list_models()) + @pytest.mark.parametrize("model_name", models.list_models(models)) def test_forward_backward(self, model_name): model = models.get_model(model_name, **self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) @@ -202,7 +202,7 @@ def test_feature_extraction_methods_equivalence(self): for k in ilg_out.keys(): assert ilg_out[k].equal(fgn_out[k]) - @pytest.mark.parametrize("model_name", models.list_models()) + @pytest.mark.parametrize("model_name", models.list_models(models)) def test_jit_forward_backward(self, model_name): set_rng_seed(0) model = models.get_model(model_name, **self.model_defaults).train() From 7ca12b9bf05bd27822d4a7b61bc21fbc9e0e60c7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Aug 2022 08:57:12 +0100 Subject: [PATCH 29/31] Apply changes from code-review. --- test/test_extended_models.py | 5 ++--- torchvision/models/_api.py | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 9dd818b3ba5..55259bb150d 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -52,15 +52,14 @@ def get_models_from_module(module): return [ v.__name__ for k, v in module.__dict__.items() - if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k not in models._api.__all__ + if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__ ] a = set(get_models_from_module(module)) b = set(x.replace("quantized_", "") for x in models.list_models(module)) assert len(b) > 0 - assert len(a - b) == 0 - assert len(b - a) == 0 + assert a == b @pytest.mark.parametrize( diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index da92fc3a596..4df988dcc9a 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -80,7 +80,7 @@ def get_weight(name: str) -> WeightsEnum: """ Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" - .. betastatus:: method + .. betastatus:: function Args: name (str): The name of the weight enum entry. @@ -119,7 +119,7 @@ def get_model_weights(model: Union[Callable, str]) -> W: """ Retuns the weights enum class associated to the given model. - .. betastatus:: method + .. betastatus:: function Args: name (callable or str): The model builder function or the name under which it is registered. @@ -186,7 +186,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: """ Returns a list with the names of registered models. - .. betastatus:: method + .. betastatus:: function Args: module (ModuleType, optional): The module from which we want to extract the available models. @@ -213,7 +213,7 @@ def get_model(name: str, **config: Any) -> M: """ Gets the model name and configuration and returns an instantiated model. - .. betastatus:: method + .. betastatus:: function Args: name (str): The name under which the model is registered. From 7c3a0ba1153c0c55b07193dc4e5c624536437a02 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Aug 2022 09:21:52 +0100 Subject: [PATCH 30/31] Adding documentation. --- docs/source/models.rst | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/source/models.rst b/docs/source/models.rst index 769c2d2721b..64c14a1e208 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -120,6 +120,43 @@ behavior, such as batch normalization. To switch between these modes, use # Set model to eval mode model.eval() +Model Registration Mechanism +---------------------------- + +.. betastatus:: registration mechanism + +As of v0.14, TorchVision offers a new model registration mechanism which allows retreaving models +and weights by their names. Here are a few examples on how to use them: + +.. code:: python + + # List available models + all_models = list_models() + classification_models = list_models(module=torchvision.models) + + # Initialize models + m1 = get_model("mobilenet_v3_large", weights=None) + m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT") + + # Fetch weights + weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT") + assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT + + weights_enum = get_model_weights("quantized_mobilenet_v3_large") + assert weights_enum == MobileNet_V3_Large_QuantizedWeights + + weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large) + assert weights_enum == weights_enum2 + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + get_model + get_model_weights + get_weight + list_models + Using models from Hub --------------------- From 2eae177fec84fbdf494f2b0593bd3626a0adae4f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Aug 2022 09:33:59 +0100 Subject: [PATCH 31/31] Fix docs. --- docs/source/models.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/models.rst b/docs/source/models.rst index 64c14a1e208..3cf52389e82 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -148,6 +148,9 @@ and weights by their names. Here are a few examples on how to use them: weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large) assert weights_enum == weights_enum2 +Here are the available public methods of the model registration mechanism: + +.. currentmodule:: torchvision.models .. autosummary:: :toctree: generated/ :template: function.rst