Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_auto_augment(self, transform, input):
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["rgb"], dtypes=[torch.float32])
fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
Expand Down
2 changes: 0 additions & 2 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
size = size or torch.randint(16, 33, (2,)).tolist()

if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
num_channels = {
features.ColorSpace.GRAYSCALE: 1,
features.ColorSpace.RGB: 3,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __new__(
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)

if isinstance(format, str):
format = BoundingBoxFormat[format]
format = BoundingBoxFormat.from_str(format.upper())

bounding_box._metadata.update(dict(format=format, image_size=image_size))

Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@


class ColorSpace(StrEnum):
OTHER = 0
GRAYSCALE = 1
RGB = 3
OTHER = StrEnum.auto()
GRAYSCALE = StrEnum.auto()
RGB = StrEnum.auto()


class Image(_Feature):
Expand All @@ -37,7 +37,7 @@ def __new__(
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace[color_space]
color_space = ColorSpace.from_str(color_space.upper())

image._metadata.update(dict(color_space=color_space))

Expand Down
11 changes: 3 additions & 8 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import sys
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Callable, Dict

from torchvision.prototype.utils._internal import StrEnum

from ..._internally_replaced_utils import load_state_dict_from_url


Expand Down Expand Up @@ -34,7 +35,7 @@ class Weights:
meta: Dict[str, Any]


class WeightsEnum(Enum):
class WeightsEnum(StrEnum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
Expand All @@ -58,12 +59,6 @@ def verify(cls, obj: Any) -> Any:
)
return obj

@classmethod
def from_str(cls, value: str) -> "WeightsEnum":
if value in cls.__members__:
return cls.__members__[value]
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress)

Expand Down
13 changes: 11 additions & 2 deletions torchvision/prototype/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,17 @@
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto

def __getitem__(self, item):
return super().__getitem__(item.upper() if isinstance(item, str) else item)
def from_str(self, member: str):
try:
return self[member]
except KeyError:
raise ValueError(
add_suggestion(
f"Unknown value '{member}' for {self.__name__}.",
word=member,
possibilities=list(self.__members__.keys()),
)
) from None


class StrEnum(enum.Enum, metaclass=StrEnumMeta):
Expand Down