Skip to content
Closed
2 changes: 1 addition & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from .vision_transformer import *
from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_weight
from ._api import get_weight, list_models, load_model
56 changes: 55 additions & 1 deletion torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import sys
from dataclasses import dataclass, fields
from inspect import signature
from typing import Any, Callable, cast, Dict, Mapping
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar

from torch import nn

from torchvision._utils import StrEnum

Expand Down Expand Up @@ -140,3 +143,54 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
)

return cast(WeightsEnum, weights_enum)


M = TypeVar("M", bound=Type[nn.Module])

BUILTIN_MODELS = {}


def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
key = name if name is not None else fn.__name__
if key in BUILTIN_MODELS:
raise ValueError(f"An entry is already registered under the name '{key}'.")
BUILTIN_MODELS[key] = fn
return fn

return wrapper


def list_models(module: Optional[ModuleType] = None) -> List[str]:
"""
Returns a list with the names of registered models.

Args:
module (ModuleType, optional): The module from which we want to extract the available models.

Returns:
models (list): A list with the names of available models.
"""
models = [
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
]
return sorted(models)


def load_model(name: str, **config: Any) -> M:
"""
Gets the model name and configuration and returns an instantiated model.

Args:
name (str): The name under which the model is registered.
**config (Any): parameters passed to the model builder method.

Returns:
model (nn.Module): The initialized model.
"""
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
except KeyError:
raise ValueError(f"Unknown model {name}")
return fn(**config)
3 changes: 2 additions & 1 deletion torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface

Expand Down Expand Up @@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
"""AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
Expand Down
4 changes: 3 additions & 1 deletion torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface

Expand Down Expand Up @@ -371,6 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
def mobilenet_v3_large(
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
Expand Down Expand Up @@ -401,6 +402,7 @@ def mobilenet_v3_large(
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)


@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
def mobilenet_v3_small(
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..mobilenetv3 import (
Expand Down Expand Up @@ -184,6 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_QNNPACK_V1


@register_model(name="quantized_mobilenet_v3_large")
@handle_legacy_interface(
weights=(
"pretrained",
Expand Down