@@ -19,6 +19,13 @@ class BoundingBox(_Feature):
1919 format : BoundingBoxFormat
2020 image_size : Tuple [int , int ]
2121
22+ @classmethod
23+ def _wrap (cls , tensor : torch .Tensor , * , format : BoundingBoxFormat , image_size : Tuple [int , int ]) -> BoundingBox :
24+ bounding_box = tensor .as_subclass (cls )
25+ bounding_box .format = format
26+ bounding_box .image_size = image_size
27+ return bounding_box
28+
2229 def __new__ (
2330 cls ,
2431 data : Any ,
@@ -29,52 +36,46 @@ def __new__(
2936 device : Optional [Union [torch .device , str , int ]] = None ,
3037 requires_grad : bool = False ,
3138 ) -> BoundingBox :
32- bounding_box = super (). __new__ ( cls , data , dtype = dtype , device = device , requires_grad = requires_grad )
39+ tensor = cls . _to_tensor ( data , dtype = dtype , device = device , requires_grad = requires_grad )
3340
3441 if isinstance (format , str ):
3542 format = BoundingBoxFormat .from_str (format .upper ())
36- bounding_box .format = format
37-
38- bounding_box .image_size = image_size
3943
40- return bounding_box
41-
42- def __repr__ (self , * , tensor_contents : Any = None ) -> str : # type: ignore[override]
43- return self ._make_repr (format = self .format , image_size = self .image_size )
44+ return cls ._wrap (tensor , format = format , image_size = image_size )
4445
4546 @classmethod
46- def new_like (
47+ def wrap_like (
4748 cls ,
4849 other : BoundingBox ,
49- data : Any ,
50+ tensor : torch . Tensor ,
5051 * ,
51- format : Optional [Union [ BoundingBoxFormat , str ] ] = None ,
52+ format : Optional [BoundingBoxFormat ] = None ,
5253 image_size : Optional [Tuple [int , int ]] = None ,
53- ** kwargs : Any ,
5454 ) -> BoundingBox :
55- return super ().new_like (
56- other ,
57- data ,
55+ return cls ._wrap (
56+ tensor ,
5857 format = format if format is not None else other .format ,
5958 image_size = image_size if image_size is not None else other .image_size ,
60- ** kwargs ,
6159 )
6260
61+ def __repr__ (self , * , tensor_contents : Any = None ) -> str : # type: ignore[override]
62+ return self ._make_repr (format = self .format , image_size = self .image_size )
63+
6364 def to_format (self , format : Union [str , BoundingBoxFormat ]) -> BoundingBox :
6465 if isinstance (format , str ):
6566 format = BoundingBoxFormat .from_str (format .upper ())
6667
67- return BoundingBox .new_like (
68+ return BoundingBox .wrap_like (
6869 self , self ._F .convert_format_bounding_box (self , old_format = self .format , new_format = format ), format = format
6970 )
7071
7172 def horizontal_flip (self ) -> BoundingBox :
7273 output = self ._F .horizontal_flip_bounding_box (self , format = self .format , image_size = self .image_size )
73- return BoundingBox .new_like (self , output )
74+ return BoundingBox .wrap_like (self , output )
7475
7576 def vertical_flip (self ) -> BoundingBox :
7677 output = self ._F .vertical_flip_bounding_box (self , format = self .format , image_size = self .image_size )
77- return BoundingBox .new_like (self , output )
78+ return BoundingBox .wrap_like (self , output )
7879
7980 def resize ( # type: ignore[override]
8081 self ,
@@ -84,19 +85,19 @@ def resize( # type: ignore[override]
8485 antialias : bool = False ,
8586 ) -> BoundingBox :
8687 output , image_size = self ._F .resize_bounding_box (self , image_size = self .image_size , size = size , max_size = max_size )
87- return BoundingBox .new_like (self , output , image_size = image_size )
88+ return BoundingBox .wrap_like (self , output , image_size = image_size )
8889
8990 def crop (self , top : int , left : int , height : int , width : int ) -> BoundingBox :
9091 output , image_size = self ._F .crop_bounding_box (
9192 self , self .format , top = top , left = left , height = height , width = width
9293 )
93- return BoundingBox .new_like (self , output , image_size = image_size )
94+ return BoundingBox .wrap_like (self , output , image_size = image_size )
9495
9596 def center_crop (self , output_size : List [int ]) -> BoundingBox :
9697 output , image_size = self ._F .center_crop_bounding_box (
9798 self , format = self .format , image_size = self .image_size , output_size = output_size
9899 )
99- return BoundingBox .new_like (self , output , image_size = image_size )
100+ return BoundingBox .wrap_like (self , output , image_size = image_size )
100101
101102 def resized_crop (
102103 self ,
@@ -109,7 +110,7 @@ def resized_crop(
109110 antialias : bool = False ,
110111 ) -> BoundingBox :
111112 output , image_size = self ._F .resized_crop_bounding_box (self , self .format , top , left , height , width , size = size )
112- return BoundingBox .new_like (self , output , image_size = image_size )
113+ return BoundingBox .wrap_like (self , output , image_size = image_size )
113114
114115 def pad (
115116 self ,
@@ -120,7 +121,7 @@ def pad(
120121 output , image_size = self ._F .pad_bounding_box (
121122 self , format = self .format , image_size = self .image_size , padding = padding , padding_mode = padding_mode
122123 )
123- return BoundingBox .new_like (self , output , image_size = image_size )
124+ return BoundingBox .wrap_like (self , output , image_size = image_size )
124125
125126 def rotate (
126127 self ,
@@ -133,7 +134,7 @@ def rotate(
133134 output , image_size = self ._F .rotate_bounding_box (
134135 self , format = self .format , image_size = self .image_size , angle = angle , expand = expand , center = center
135136 )
136- return BoundingBox .new_like (self , output , image_size = image_size )
137+ return BoundingBox .wrap_like (self , output , image_size = image_size )
137138
138139 def affine (
139140 self ,
@@ -155,7 +156,7 @@ def affine(
155156 shear = shear ,
156157 center = center ,
157158 )
158- return BoundingBox .new_like (self , output , dtype = output . dtype )
159+ return BoundingBox .wrap_like (self , output )
159160
160161 def perspective (
161162 self ,
@@ -164,7 +165,7 @@ def perspective(
164165 fill : FillTypeJIT = None ,
165166 ) -> BoundingBox :
166167 output = self ._F .perspective_bounding_box (self , self .format , perspective_coeffs )
167- return BoundingBox .new_like (self , output , dtype = output . dtype )
168+ return BoundingBox .wrap_like (self , output )
168169
169170 def elastic (
170171 self ,
@@ -173,4 +174,4 @@ def elastic(
173174 fill : FillTypeJIT = None ,
174175 ) -> BoundingBox :
175176 output = self ._F .elastic_bounding_box (self , self .format , displacement )
176- return BoundingBox .new_like (self , output , dtype = output . dtype )
177+ return BoundingBox .wrap_like (self , output )
0 commit comments