Skip to content

Error in backward pass when importing and using Llama4 MoE layer #1265

@danielvegamyhre

Description

@danielvegamyhre

Bug description

Summary

  • I am importing and using the torchtitan Llama4 MoE module to write tests for the torchao model conversion API for float8 MoE training (see test file in float8 moe training conversion API prototype ao#2275)
  • For the non-converted, bfloat16 MoE module I am getting an error in the backward pass (see below).
  • When I run a out = torch._grouped_mm(...); out.sum().backward() in isolation, I don't get the error, so I think it must have something to do with the surrounding routing, shuffle, etc implementations?

torch._grouped_mm backward working in isolation in PDB:

(Pdb) test_out = torch._grouped_mm(x.reshape(-1, x.shape[-1]), model.experts.w1, of
fs=torch.tensor([4096, 8192], dtype=torch.int32, device="cuda"))
(Pdb) test_out.sum().backward()
(Pdb) 

Error:

...
        # backward pass
>       ref_out.sum().backward()

test/prototype/scaled_grouped_mm/test_moe_training_conversion.py:68: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../pytorch/torch/_tensor.py:648: in backward
    torch.autograd.backward(
../pytorch/torch/autograd/__init__.py:354: in backward
    _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

t_outputs = (tensor(-56360960., device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SumBackward0>),)
args = ((tensor(1., device='cuda:0', dtype=torch.bfloat16),), False, False, ())
kwargs = {'accumulate_grad': True, 'allow_unreachable': True}
attach_logging_hooks = False

    def _engine_run_backward(
        t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
        *args: Any,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, ...]:
        attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
        if attach_logging_hooks:
            unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
        try:
>           return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                t_outputs, *args, **kwargs
            )  # Calls into the C++ engine to run the backward pass
E           RuntimeError: Tensor should have a contiguous dimension and not be self-overlapping, got [0, 0, 0] for strides and [1, 8192, 4096] for sizes

../pytorch/torch/autograd/graph.py:829: RuntimeError

Versions

  • torch (built from source): 2.8.0a0+git48807d5
  • torchtitan: local install with latest changes pulled
  • torchao: local install): 0.11.0+gitbf8ab460

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions