Skip to content

Conversation

@CaoE
Copy link

@CaoE CaoE commented Sep 15, 2021

  • modify the directory structure: moved the autocast files from torchvision/csrc/ops/autocast/ to torchvision/csrc/ops/autocast/cuda

  • add the cpu directory under the autocast directory;

  • register deform_conv2d, nms, ps_roi_align, ps_roi_pool, roi_align, and roi_pool to AutocastCPU.

@CaoE CaoE force-pushed the autocast branch 2 times, most recently from a810d89 to 72fd957 Compare September 15, 2021 07:00
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CaoE , I just took a brief look. For now I'm just curious, what is you use-case for supporting autocast on CPU?

test/test_ops.py Outdated

@pytest.mark.parametrize('x_dtype', (torch.float, torch.half))
@pytest.mark.parametrize('rois_dtype', (torch.float, torch.half))
def test_autocast_cpu(self, x_dtype, rois_dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating a new test, maybe we could just parametrize over the device with the cpu_and_gpu() function?
Since the context manager is device-dependent, we could just set it in the code, like

cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast
with cm():
    self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)

We could do the same for the rest of the newly introduced tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea, I will modify it like this.

@CaoE
Copy link
Author

CaoE commented Sep 15, 2021

Because we want to use deep learning models on cpu servers, but torchvision ops like nms will report errors when using BFloat16.

@CaoE CaoE force-pushed the autocast branch 3 times, most recently from ee6fa35 to 8367b5a Compare September 15, 2021 09:06
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CaoE , this LGTM provided that the tests go green. There is a minor linting issue for now, would you mind fixing it?

test/test_ops.py:524:17: E117 over-indented

I will let @fmassa take a look before merging, also possibly @datumbox since you're more familiar with the src structure etc.

test/test_ops.py Outdated
self.test_nms_cuda(iou=iou, dtype=dtype)
@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.half))
def test_autocast(self, device, iou, dtype):
def test_nms_cpu(iou, dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating this new one here, do you think we could just rely on test_nms_ref instead?
It doesn't accept a dtype parameter so we could add it if it's relevant, or alternatively we can just define test_fn as a partial function.

No strong opinion on this, we can also leave as is, even though it's a bit unfortunate that we have to define a new test_nms_cpu here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating this new one here just to check whether the input is converted to float type. Is your suggestion to use test_nms_ref instread of test_nms_cpu here? It's all ok for me. Thank you so much for your suggestion @NicolasHug

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_nms_ref instread of test_nms_cpu here?

Yes :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will modify it later.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes on the C++ overall look good to me. I would advise testing them on FBcode and updating buck before merging to ensure we won't break anything internally.

My only concern, and it's not linked to this PR, is the amount of duplicate and boilerplate code we are forced to add in the repo to handle the registrations.

@ezyang I know that a year ago you've been working on improving the op registration. Is there a better way to do it?

setup.py Outdated
source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu'))

source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp'))
source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cuda', '*.cpp'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Buck will need changes due to this. Might be worth bringing the PR in FBcode prior merging to ensure it does not break anything.

@fmassa
Copy link
Member

fmassa commented Sep 16, 2021

If I understand it correctly, the only change between the CPU and CUDA folders is:

  • the autocast passes the CPU device
  • the registration mechanism registers it to AutocastCPU instead of Autocast

Is there any other difference, or do we expect it to potentially diverge more over time?

test/test_ops.py Outdated
with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
def test_autocast(self, device, x_dtype, rois_dtype):
cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Someone should file a bug asking for torch.amp.autocast that accepts a device argument lol

@ezyang
Copy link
Contributor

ezyang commented Sep 16, 2021

@ezyang I know that a year ago you've been working on improving the op registration. Is there a better way to do it?

I think the most direct way to reduce boilerplate is if you can reuse the templates/macros that are used inside PyTorch core to setup autocasting. I haven't checked if you are doing unusual autocasting so that these aren't applicable. We don't have codegen support for autocasting so that's right out, and no one has written a boxed fallback for autocasting (maybe someone should!)

@CaoE CaoE force-pushed the autocast branch 2 times, most recently from 53c7e00 to 83293bf Compare September 17, 2021 02:26
@datumbox
Copy link
Contributor

@CaoE Are we good to merge?

@CaoE
Copy link
Author

CaoE commented Sep 17, 2021

@CaoE Are we good to merge?

yes, I have nothing else to commit.

@datumbox
Copy link
Contributor

Awesome, let us run some additional tests on the internal FB infra and we will merge right after.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking this as "changes needed" to avoid accidental merges prior testing it on FBcode.

@CaoE No further action is needed on your side, you are good to go. Just give us a bit more time to test things :)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as "request changes" so that we don't forget to answer my questions on the past post.
While we can have duplicate code here initially, it would be good for us to understand if we plan to keep those two sets of files practically identical or if further changes that can lead to divergence in the CPU / CUDA files are expected to happen.

This is important because we can think about follow-up work that will refactor this redundancy out, improving maintainance cost in the future, but only if we know that the changes will be kept to the ones I pointed out in my previous comment.

@CaoE
Copy link
Author

CaoE commented Oct 9, 2021

Hi @datumbox @NicolasHug @ezyang @fmassa , may I know when the PR can be merged, and whether it can catch into the torchvision branch for PyTorch 1.10 release? Thank you very much.

@datumbox
Copy link
Contributor

@CaoE Note that the PR has conflicts with main that need to be resolved. Some of them can be addressed as described here #4539 but I think it might be easier to fix manually.

Also could you please provide clarifications on this #4412 (comment)?

@fmassa
Copy link
Member

fmassa commented Oct 11, 2021

@CaoE we were waiting for your clarifications on my questions before moving forward merging this PR.

Also, the branch cut / freeze was on Friday, so it might be hard getting those changes into the 0.11 release.

@CaoE
Copy link
Author

CaoE commented Oct 12, 2021

@fmassa @datumbox Sorry for missing the questions and thank you for the detailed explanation.

If I understand it correctly, the only change between the CPU and CUDA folders is:

  • the autocast passes the CPU device
  • the registration mechanism registers it to AutocastCPU instead of Autocast

Yes.

Is there any other difference, or do we expect it to potentially diverge more over time?

  • There is no other difference.
  • If a new op is added to vision, it may also need to be registered to AutocastCPU and Autocast respectively. There should be no change in other parts except for the cpu and cuda subfolders of autocast and test_ops.py.

CaoE added 2 commits October 12, 2021 13:01
* modify the directory structure: moved the autocast files from torchvision/csrc/ops/autocast/ to torchvision/csrc/ops/autocast/cuda

* add the cpu directory under the autocast directory;

* register deform_conv2d, nms, ps_roi_align, ps_roi_pool, roi_align, and roi_pool to AutocastCPU.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants