Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,10 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:

mask_maxc_neq_r = maxc != r
mask_maxc_eq_g = maxc == g
mask_maxc_neq_g = ~mask_maxc_eq_g

hr = (bc - gc).mul_(~mask_maxc_neq_r)
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hr = bc.sub_(gc).mul_(~mask_maxc_neq_r)
hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_()))
Comment on lines +212 to +214
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the order of operations allows us to do more in-place ops.

In particular once hg is estimated, we can do an in-place on bc during the hr estimation. Then in the estimation of hb we can do inplace on gc and the logical masks.


h = hr.add_(hg).add_(hb)
h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
Expand All @@ -221,14 +220,16 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:

def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
h6 = h.mul(6)
i = torch.floor(h6)
f = h6 - i
f = h6.sub_(i)
Comment on lines +223 to +225
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor optimizations based on @pmeier's finding that mul is preferable to * for numbers. Also h6 can be modified in-place as it's not reused.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, we only get benefits if we can eliminate a tensor division, which is not the case here.

i = i.to(dtype=torch.int32)

p = (v * (1.0 - s)).clamp_(0.0, 1.0)
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
sxf = s * f
one_minus_s = 1.0 - s
q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0)
t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0)
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
Comment on lines +228 to +232
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again we reorder to be able to do more in-place ops. We also expand the math ops to reuse components.

We precompute s*f which is used in q and t estimation. We also do that for 1-s. Then we estimate first q, so that we canl ater modify sxf inplace on the t estimation. Finally one_minus_s can be in-place modified in the estimation of p.

i.remainder_(6)

mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
Expand All @@ -238,7 +239,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)

return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary casting of mask. It's possible to do a multiplication with bools as we did previously for mask_maxc_eq_g etc.



def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def convert_format_bounding_box(
if new_format == old_format:
return bounding_box

# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the highest priority as we don't do such conversions internally but it might be good to offer those 2 and stop doing 2 conversions on the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with that as the number of formats is low. If that changes in the future, we maybe need to walk back or only partially implement 1-to-1 conversions for all formats.

if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
elif old_format == BoundingBoxFormat.CXCYWH:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest.mock
from typing import Any, Dict, Tuple, Union

import numpy as np
Expand All @@ -20,6 +19,8 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:

@torch.jit.unused
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
import unittest.mock
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a drive by change to avoid the hard dependency on unittest. @pmeier said offline that we can clean up many methods that are no longer used. He is going to do this on a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will send a PR soon.


with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]

Expand Down