Skip to content

Conversation

@datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 24, 2022

This PR:

  • Avoids the (std == 0).any() idiom which caused synchronization between CPU and GPU (50% improvement in CUDA)
  • Minimizes memory writes by avoiding cloning (25% improvement on CPU)

The combined result is:

Float mean/std:
[------------- Normalize cpu torch.float32 -------------]
                     |  normalize stable  |  normalize v2
1 threads: ----------------------------------------------
      (3, 400, 400)  |        364         |      264     
6 threads: ----------------------------------------------
      (3, 400, 400)  |        497         |      351     

Times are in microseconds (us).

[------------- Normalize cuda torch.float32 ------------]
                     |  normalize stable  |  normalize v2
1 threads: ----------------------------------------------
      (3, 400, 400)  |        118         |      55.6    
6 threads: ----------------------------------------------
      (3, 400, 400)  |        118         |      55.6    

Times are in microseconds (us).


List mean/std:
[------------- Normalize cpu torch.float32 -------------]
                     |  normalize stable  |  normalize v2
1 threads: ----------------------------------------------
      (3, 400, 400)  |        378         |      271     
6 threads: ----------------------------------------------
      (3, 400, 400)  |        513         |      360     

Times are in microseconds (us).

[------------- Normalize cuda torch.float32 ------------]
                     |  normalize stable  |  normalize v2
1 threads: ----------------------------------------------
      (3, 400, 400)  |        116         |      61.6    
6 threads: ----------------------------------------------
      (3, 400, 400)  |        116         |      61.6    

Times are in microseconds (us).

Modified benchmark script from here

cc @vfdev-5 @bjuncek @pmeier

Copy link
Contributor

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at this earlier and saw no possible optimizations. It seems you have better eyes 😛

LGTM, if CI is green.

f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}"
)

if (isinstance(std, (tuple, list)) and not all(std)) or std == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the first part of the check? What input would fail isinstance(std, (tuple, list))? Do we actually allow scalars here? Otherwise, this should be sufficient

Suggested change
if (isinstance(std, (tuple, list)) and not all(std)) or std == 0:
if not all(std):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually allow scalars. It's not visible due to the JIT-script types but if you pass mean=0.5, std=0.5 it works. So I'm keeping this for BC and provide separate benchmarks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh 🙄 We need to update the tests since they currently don't check scalars:

_NORMALIZE_MEANS_STDS = [
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
]
def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS,
):
yield ArgsKwargs(image_loader, mean=mean, std=std)

Will send a PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I also had to rewrite the check because JIT couldn't understand the assertions were correct in one line... This version seems to pass. I've updated the benchmarks and we are still good.

Comment on lines +33 to +36
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also looking into this earlier and one thing I asked myself, is when would this branch not trigger? The tensor should always have one dimensions unless we allow scalars. See above for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is purely for broadcasting in case someone passes lists, not scalars. Aka [0.5, 0.5, 0.5]. This is needed else, the following div/sub fails.

@datumbox datumbox merged commit 788ad12 into pytorch:main Oct 24, 2022
@datumbox datumbox deleted the prototype/normalize branch October 24, 2022 14:01
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2022
Summary:
* Avoid GPU-CPU sync on Normalize

* Further optimizations.

* Apply code review changes.

* Fixing JIT.

* linter fix

Reviewed By: YosuaMichael

Differential Revision: D40722904

fbshipit-source-id: e452d89a42b34be852e3125d25756b3f598e50f4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants