Skip to content

llama models fail to torch.compile - errs out re: 2 compatible backends for target (cuda) ([<class 'c.CUDABackend'>, <class 'nvi.CUDABackend'>]). There should only be one. #55

@lessw2020

Description

@lessw2020

with latest nightly I am unable to compile the llama models (any size):

[rank0]:2024-02-12 10:00:40,026 - root - INFO - Compiling model llama with torch.compile...
[rank0]:[rank0]:[2024-02-12 10:00:42,377] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:NCCL version 2.19.3+cuda12.3
[rank0]:/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/lowering.py:1697: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/data/users/less/local/torchtrain/train.py", line 288, in <module>
[rank0]:[rank0]:     main(args)
[rank0]:[rank0]:   File "/data/users/less/local/torchtrain/train.py", line 161, in main
[rank0]:[rank0]:     pred = model(input_ids)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 454, in _fn
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 25, in inner
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward
[rank0]:[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/less/local/torchtrain/torchtrain/models/llama/model.py", line 485, in forward
[rank0]:[rank0]:     h, freqs_cis = self.embeddings(tokens)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 904, in catch_errors
[rank0]:[rank0]:     return callback(frame, cache_entry, hooks, frame_state, skip=1)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 769, in _convert_frame
[rank0]:[rank0]:     result = inner_convert(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert
[rank0]:[rank0]:     return _compile(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]:     return func(*args, **kwds)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 669, in _compile
[rank0]:[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]:     r = func(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 542, in compile_inner
[rank0]:[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
[rank0]:[rank0]:     transformations(instructions, code_options)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 163, in _fn
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 507, in transform
[rank0]:[rank0]:     tracer.run()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2130, in run
[rank0]:[rank0]:     super().run()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]:     and self.step()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]:     getattr(self, inst.opname)(inst)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1302, in STORE_ATTR
[rank0]:[rank0]:     return self.store_attr_graph_break(inst)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1326, in store_attr_graph_break
[rank0]:[rank0]:     self.output.compile_subgraph(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_subgraph
[rank0]:[rank0]:     self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]:     return func(*args, **kwds)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1102, in compile_and_call_fx_graph
[rank0]:[rank0]:     compiled_fn = self.call_user_compiler(gm)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]:     r = func(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1175, in call_user_compiler
[rank0]:[rank0]:     raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1156, in call_user_compiler
[rank0]:[rank0]:     compiled_fn = compiler_fn(gm, self.example_inputs())
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
[rank0]:[rank0]:     compiled_gm = compiler_fn(gm, example_inputs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/__init__.py", line 1730, in __call__
[rank0]:[rank0]:     return compile_fx(model_, inputs_, config_patches=self.config)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]:     return func(*args, **kwds)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1321, in compile_fx
[rank0]:[rank0]:     return aot_autograd(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 57, in compiler_fn
[rank0]:[rank0]:     cg = aot_module_simplified(gm, example_inputs, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 879, in aot_module_simplified
[rank0]:[rank0]:     compiled_fn = create_aot_dispatcher_function(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]:     r = func(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 604, in create_aot_dispatcher_function
[rank0]:[rank0]:     compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 434, in aot_wrapper_dedupe
[rank0]:[rank0]:     return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 639, in aot_wrapper_synthetic_base
[rank0]:[rank0]:     return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 297, in aot_dispatch_autograd
[rank0]:[rank0]:     compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]:     r = func(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1249, in fw_compiler_base
[rank0]:[rank0]:     return inner_compile(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
[rank0]:[rank0]:     inner_compiled_fn = compiler_fn(gm, example_inputs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/debug.py", line 304, in inner
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]:     return func(*args, **kwds)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]:     return func(*args, **kwds)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 423, in compile_fx_inner
[rank0]:[rank0]:     compiled_graph = fx_codegen_and_compile(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 689, in fx_codegen_and_compile
[rank0]:[rank0]:     compiled_fn = graph.compile_to_fn()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1224, in compile_to_fn
[rank0]:[rank0]:     return self.compile_to_module().call
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]:     r = func(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1176, in compile_to_module
[rank0]:[rank0]:     mod = PyCodeCache.load_by_key_path(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2053, in load_by_key_path
[rank0]:[rank0]:     exec(code, mod.__dict__, mod.__dict__)
[rank0]:[rank0]:   File "/tmp/torchinductor_less/rf/crfqzglf5slafezlc46mvfhf7rn5xkcdjgxfhjfyek4ffjh5mbpo.py", line 72, in <module>
[rank0]:[rank0]:     async_compile.wait(globals())
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2593, in wait
[rank0]:[rank0]:     scope[key] = result.result()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2400, in result
[rank0]:[rank0]:     self.future.result()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/concurrent/futures/_base.py", line 445, in result
[rank0]:[rank0]:     return self.__get_result()
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/triton/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
[rank0]:[rank0]:     raise self._exception
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: RuntimeError: 2 compatible backends for target (cuda) ([<class 'c.CUDABackend'>, <class 'nvi.CUDABackend'>]). There should only be one.
[rank0]:
[rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[rank0]:
[rank0]:
[rank0]:[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]:[rank0]:     import torch._dynamo
[rank0]:[rank0]:     torch._dynamo.config.suppress_errors = True
[rank0]:
[2024-02-12 10:00:55,905] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1864627 closing signal SIGTERM

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions