-
Notifications
You must be signed in to change notification settings - Fork 603
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
Summary
- I am importing and using the torchtitan Llama4 MoE module to write tests for the torchao model conversion API for float8 MoE training (see test file in float8 moe training conversion API prototype ao#2275)
- For the non-converted, bfloat16 MoE module I am getting an error in the backward pass (see below).
- When I run a
out = torch._grouped_mm(...); out.sum().backward()in isolation, I don't get the error, so I think it must have something to do with the surrounding routing, shuffle, etc implementations?
torch._grouped_mm backward working in isolation in PDB:
(Pdb) test_out = torch._grouped_mm(x.reshape(-1, x.shape[-1]), model.experts.w1, of
fs=torch.tensor([4096, 8192], dtype=torch.int32, device="cuda"))
(Pdb) test_out.sum().backward()
(Pdb) Error:
...
# backward pass
> ref_out.sum().backward()
test/prototype/scaled_grouped_mm/test_moe_training_conversion.py:68:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../pytorch/torch/_tensor.py:648: in backward
torch.autograd.backward(
../pytorch/torch/autograd/__init__.py:354: in backward
_engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
t_outputs = (tensor(-56360960., device='cuda:0', dtype=torch.bfloat16,
grad_fn=<SumBackward0>),)
args = ((tensor(1., device='cuda:0', dtype=torch.bfloat16),), False, False, ())
kwargs = {'accumulate_grad': True, 'allow_unreachable': True}
attach_logging_hooks = False
def _engine_run_backward(
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor, ...]:
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
try:
> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
E RuntimeError: Tensor should have a contiguous dimension and not be self-overlapping, got [0, 0, 0] for strides and [1, 8192, 4096] for sizes
../pytorch/torch/autograd/graph.py:829: RuntimeErrorVersions
- torch (built from source): 2.8.0a0+git48807d5
- torchtitan: local install with latest changes pulled
- torchao: local install): 0.11.0+gitbf8ab460
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working