Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,12 @@ class SplitLookupFunction_Dense_Op
using torch::autograd::Variable;

auto grad_output = grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {

// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
} else if (!grad_output.is_contiguous()) {
grad_output = grad_output.contiguous();
}

Expand Down Expand Up @@ -324,12 +328,12 @@ class SplitNoBagLookupFunction_Dense_Op
using torch::autograd::Variable;

auto grad_output = grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
} else if (!grad_output.is_contiguous()) {
grad_output = grad_output.contiguous();
}

auto grad_dev_weights =
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op :
using torch::autograd::Variable;

auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
} else if (!grad_output.is_contiguous()) {
grad_output = grad_output.contiguous();
}

{% if not nobag %}
Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,9 +1324,9 @@ def test_backward_dense(
rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5,
)
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous()
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
torch.testing.assert_close(
cc.weights.grad,
Expand Down Expand Up @@ -1584,9 +1584,9 @@ def test_backward_sgd( # noqa C901
else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))
)
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous()
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
if use_cache:
cc.flush()
Expand Down Expand Up @@ -1817,7 +1817,7 @@ def execute_backward_adagrad_( # noqa C901
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
cc.flush()
split_optimizer_states = [s for (s,) in cc.split_optimizer_states()]
Expand Down Expand Up @@ -2637,7 +2637,7 @@ def execute_backward_optimizers_( # noqa C901
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
cc.flush()

Expand Down