From 30ad4b721b19b964939a6e610bd32160fbd6ec30 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Wed, 20 Jul 2022 10:48:03 -0700 Subject: [PATCH] refactor grad output non-contiguous handler Summary: This is a follow-up on D37951520 (https://github.com/pytorch/FBGEMM/commit/5a15342f874adb7dad908dfcf318c02f618778a8) - Minor clean-up and refactoring for non-contiguous grad output. - Add more comments. - Add unit test coverage TODO: add the 16 alignment unit test coverage. Differential Revision: D37988742 fbshipit-source-id: e0f6a565ec22fcf9a1135053d708dca1b33688db --- .../codegen/embedding_backward_dense_host.cpp | 16 ++++++++++------ .../embedding_backward_split_host_template.cpp | 8 ++++---- .../test/split_table_batched_embeddings_test.py | 12 ++++++------ 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index 1cb28e77d4..7117db7509 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -174,8 +174,12 @@ class SplitLookupFunction_Dense_Op using torch::autograd::Variable; auto grad_output = grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { + + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { grad_output = grad_output.contiguous(); } @@ -324,12 +328,12 @@ class SplitNoBagLookupFunction_Dense_Op using torch::autograd::Variable; auto grad_output = grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { - grad_output = grad_output.contiguous(); - } + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { + grad_output = grad_output.contiguous(); } auto grad_dev_weights = diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index 44867e4612..70fcd1f547 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -274,12 +274,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : using torch::autograd::Variable; auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { - grad_output = grad_output.contiguous(); - } + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { + grad_output = grad_output.contiguous(); } {% if not nobag %} diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index d153a9170d..f6bde3d7cf 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -1324,9 +1324,9 @@ def test_backward_dense( rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5, ) if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous() + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) torch.testing.assert_close( cc.weights.grad, @@ -1584,9 +1584,9 @@ def test_backward_sgd( # noqa C901 else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)) ) if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous() + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) if use_cache: cc.flush() @@ -1817,7 +1817,7 @@ def execute_backward_adagrad_( # noqa C901 if do_pooling: goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) cc.flush() split_optimizer_states = [s for (s,) in cc.split_optimizer_states()] @@ -2637,7 +2637,7 @@ def execute_backward_optimizers_( # noqa C901 if do_pooling: goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) cc.flush()