Skip to content

FSDP + SP does not work with --compile #61

@tianyu-l

Description

@tianyu-l

FSDP + SP works fine when compile is off, but got the following error when compile is on:

error log SP=2 ./run_llama_train.sh + TRAINER_DIR=/home/lty/local/torchtrain + MODEL=llama + MODEL_CONF=debugmodel + NGPU=8 + PP=1 + SP=2 + DP=-1 + LOG_RANK=0 + CHECKPOINT_FOLDER= + CHECKPOINT_INTERVAL=5 + torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --steps 10 --model llama --model_conf debugmodel --pp_degree 1 --sp_degree 2 --dp_degree -1 --compile --checkpoint-folder= --checkpoint-interval=5 W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 ***************************************** W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0215 17:38:16.585000 140337690436736 torch/distributed/run.py:717 ***************************************** [rank0]:2024-02-15 17:38:20,132 - torchtrain.parallelisms - INFO - Building 2-D device mesh with ('dp', 'sp'), [4, 2] [rank0]:2024-02-15 17:38:28,308 - root - INFO - Building llama [rank0]:2024-02-15 17:38:28,325 - root - INFO - Reloaded SentencePiece model from ./torchtrain/datasets/tokenizer/tokenizer.model [rank0]:2024-02-15 17:38:28,325 - root - INFO - #words: 32000 - BOS ID: 1 - EOS ID: 2 [rank0]:2024-02-15 17:38:31,662 - root - INFO - Model fully initialized via reset_params [rank0]:2024-02-15 17:38:31,662 - root - INFO - Model built with: ModelArgs(dim=256, n_layers=2, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768) [rank0]:2024-02-15 17:38:31,662 - root - INFO - Model llama debugmodel size: 18,089,216 total parameters [rank0]:2024-02-15 17:38:31,663 - root - INFO - GPU memory usage: NVIDIA PG509-210 (0): 79.1537 GB capacity, 0.0 GB in-use, 0.0% in-use [rank0]:NCCL version 2.19.3+cuda12.0 [rank0]:2024-02-15 17:38:36,274 - root - INFO - Applied Sequence Parallelism to the model... [rank0]:2024-02-15 17:38:36,575 - root - INFO - Applied FSDP to the model... [rank0]:2024-02-15 17:38:36,579 - root - INFO - Gradient scaling not enabled. [rank0]:2024-02-15 17:38:36,579 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./torchtrain/outputs/tb/20240215-1738. [rank0]:2024-02-15 17:38:36,580 - root - INFO - Compiling model llama with torch.compile... [rank0]:2024-02-15 17:38:40,957 - root - INFO - Profiling active. Traces will be saved at ./torchtrain/outputs/profiling/traces [rank0]:[rank0]:W0215 17:38:41.362000 139938524181632 torch/_logging/_internal.py:873 [0/0] Profiler function will be ignored [rank0]:/home/lty/pytorch/torch/_inductor/lowering.py:1704: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:[rank0]: Traceback (most recent call last): [rank0]:[rank0]: File "/home/lty/torchtrain/train.py", line 349, in [rank0]:[rank0]: main(args) [rank0]:[rank0]: File "/home/lty/torchtrain/train.py", line 179, in main [rank0]:[rank0]: pred = model(input_ids) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/eval_frame.py", line 455, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/external_utils.py", line 25, in inner [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward [rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 482, in forward [rank0]:[rank0]: def forward(self, tokens: torch.Tensor): [rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 498, in torch_dynamo_resume_in_forward_at_493 [rank0]:[rank0]: h = layer(h, freqs_cis) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward [rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1529, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 912, in catch_errors [rank0]:[rank0]: return callback(frame, cache_entry, hooks, frame_state, skip=1) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 777, in _convert_frame [rank0]:[rank0]: result = inner_convert( [rank0]:[rank0]: ^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert [rank0]:[rank0]: return _compile( [rank0]:[rank0]: ^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/.conda/envs/pytorch-3.11/lib/python3.11/contextlib.py", line 81, in inner [rank0]:[rank0]: return func(*args, **kwds) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 669, in _compile [rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 250, in time_wrapper [rank0]:[rank0]: r = func(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 542, in compile_inner [rank0]:[rank0]: out_code = transform_code_object(code, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object [rank0]:[rank0]: transformations(instructions, code_options) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 163, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/convert_frame.py", line 507, in transform [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2130, in run [rank0]:[rank0]: super().run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1243, in CALL_FUNCTION_EX [rank0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 734, in call_function [rank0]:[rank0]: return self.func.call_function(tx, merged_args, merged_kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1392, in call_function [rank0]:[rank0]: ) = self.create_wrapped_node( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1204, in create_wrapped_node [rank0]:[rank0]: ) = speculate_subgraph( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 396, in speculate_subgraph [rank0]:[rank0]: output = f.call_function(tx, args, sub_kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/nn_module.py", line 716, in call_function [rank0]:[rank0]: return variables.UserFunctionVariable(fn, source=source).call_function( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1243, in CALL_FUNCTION_EX [rank0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/nn_module.py", line 716, in call_function [rank0]:[rank0]: return variables.UserFunctionVariable(fn, source=source).call_function( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward [rank0]:[rank0]: return getattr(self.realize(), name)(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 288, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/functions.py", line 89, in call_function [rank0]:[rank0]: return tx.inline_user_function_return( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2266, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2380, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 793, in run [rank0]:[rank0]: and self.step() [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 756, in step [rank0]:[rank0]: getattr(self, inst.opname)(inst) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 1785, in CALL [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/misc.py", line 547, in call_function [rank0]:[rank0]: return self.obj.call_method(tx, self.name, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 388, in call_method [rank0]:[rank0]: result = handler_method(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 730, in method_redistribute [rank0]:[rank0]: return wrap_fx_proxy( [rank0]:[rank0]: ^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/builder.py", line 1273, in wrap_fx_proxy [rank0]:[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/builder.py", line 1358, in wrap_fx_proxy_cls [rank0]:[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1683, in get_fake_value [rank0]:[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1629, in get_fake_value [rank0]:[rank0]: ret_val = wrap_fake_exception( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1165, in wrap_fake_exception [rank0]:[rank0]: return fn() [rank0]:[rank0]: ^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1630, in [rank0]:[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1750, in run_node [rank0]:[rank0]: raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/utils.py", line 1729, in run_node [rank0]:[rank0]: return node.target(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/_dynamo/variables/tensor.py", line 723, in redistribute_fn_with_prim_types [rank0]:[rank0]: return x.redistribute(*args_as_value, **kwargs_as_value) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/api.py", line 467, in redistribute [rank0]:[rank0]: return Redistribute.apply(self, device_mesh, placements) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/autograd/function.py", line 572, in apply [rank0]:[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/redistribute.py", line 263, in forward [rank0]:[rank0]: output = redistribute_local_tensor(local_tensor, current_spec, target_spec) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/redistribute.py", line 164, in redistribute_local_tensor [rank0]:[rank0]: transform_infos = _gen_transform_infos(current_spec, target_spec) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/placement_types.py", line 441, in __hash__ [rank0]:[rank0]: self._hash = self._hash_impl() [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/_tensor/placement_types.py", line 424, in _hash_impl [rank0]:[rank0]: return hash( [rank0]:[rank0]: ^^^^^ [rank0]:[rank0]: File "/home/lty/pytorch/torch/__init__.py", line 309, in __hash__ [rank0]:[rank0]: raise TypeError("unhashable type: non-singleton SymInt") [rank0]:[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function .redistribute_fn_with_prim_types at 0x7f45431c1b20>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), device_mesh=DeviceMesh([0, 1], mesh_dim_names=('sp',)), placements=(Shard(dim=0),)),), **{}): [rank0]:[rank0]: unhashable type: non-singleton SymInt [rank0]: [rank0]:[rank0]: from user code: [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward [rank0]:[rank0]: return self.checkpoint_fn( # type: ignore[misc] [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1538, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: File "/home/lty/torchtrain/torchtrain/models/llama/model.py", line 413, in forward [rank0]:[rank0]: h = x + self.attention(self.attention_norm(x), freqs_cis) [rank0]:[rank0]: File "/home/lty/pytorch/torch/nn/modules/module.py", line 1568, in _call_impl [rank0]:[rank0]: args_result = hook(self, args) [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 323, in [rank0]:[rank0]: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] [rank0]:[rank0]: File "/home/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 316, in _prepare_input_fn [rank0]:[rank0]: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) [rank0]: [rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information [rank0]: [rank0]: [rank0]:[rank0]: You can suppress this exception and fall back to eager by setting: [rank0]:[rank0]: import torch._dynamo [rank0]:[rank0]: torch._dynamo.config.suppress_errors = True [rank0]: W0215 17:39:06.601000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321633 closing signal SIGTERM W0215 17:39:06.602000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321634 closing signal SIGTERM W0215 17:39:06.603000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321636 closing signal SIGTERM W0215 17:39:06.604000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321637 closing signal SIGTERM W0215 17:39:06.605000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321638 closing signal SIGTERM W0215 17:39:06.606000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321639 closing signal SIGTERM W0215 17:39:06.608000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:694 Sending process 2321641 closing signal SIGTERM E0215 17:39:09.856000 140337690436736 torch/distributed/elastic/multiprocessing/api.py:669 failed (exitcode: 1) local_rank: 0 (pid: 2321629) of binary: /home/lty/.conda/envs/pytorch-3.11/bin/python Traceback (most recent call last): File "/home/lty/.conda/envs/pytorch-3.11/bin/torchrun", line 33, in sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/home/lty/pytorch/torch/distributed/run.py", line 834, in main run(args) File "/home/lty/pytorch/torch/distributed/run.py", line 825, in run elastic_launch( File "/home/lty/pytorch/torch/distributed/launcher/api.py", line 137, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/pytorch/torch/distributed/launcher/api.py", line 271, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ train.py FAILED ------------------------------------------------------------ Failures: ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2024-02-15_17:39:06 host : devgpu051.cln3.facebook.com rank : 0 (local_rank: 0) exitcode : 1 (pid: 2321629) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions