diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 5b875a6ef20..39c05123333 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -28,5 +28,5 @@ def test_bbox_instance(data, format): assert isinstance(bboxes, torch.Tensor) assert bboxes.ndim == 2 and bboxes.shape[1] == 4 if isinstance(format, str): - format = datapoints.BoundingBoxFormat.from_str(format.upper()) + format = datapoints.BoundingBoxFormat[(format.upper())] assert bboxes.format == format diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 1dc46f8f21a..75e779f0b21 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -1,18 +1,18 @@ from __future__ import annotations +from enum import Enum from typing import Any, List, Optional, Sequence, Tuple, Union import torch -from torchvision._utils import StrEnum from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from ._datapoint import _FillTypeJIT, Datapoint -class BoundingBoxFormat(StrEnum): - XYXY = StrEnum.auto() - XYWH = StrEnum.auto() - CXCYWH = StrEnum.auto() +class BoundingBoxFormat(Enum): + XYXY = "XYXY" + XYWH = "XYWH" + CXCYWH = "CXCYWH" class BoundingBox(Datapoint): @@ -39,7 +39,7 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if isinstance(format, str): - format = BoundingBoxFormat.from_str(format.upper()) + format = BoundingBoxFormat[format.upper()] return cls._wrap(tensor, format=format, spatial_size=spatial_size)