Skip to content

Cannot run FSDP2 with low bit optim from AO #1189

@nighting0le01

Description

@nighting0le01

Cannot run FSDP2 with low bit optim from AO

[rank7]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank7]:   File "<frozen runpy>", line 88, in _run_code
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 226, in <module>
[rank7]:     main()
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 218, in main
[rank7]:     trainer.fit(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank7]:     return trainer_fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank7]:     self._run(model, ckpt_path=ckpt_path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank7]:     results = self._run_stage()
[rank7]:               ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank7]:     self.fit_loop.run()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank7]:     self.advance()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank7]:     self.epoch_loop.run(self._data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank7]:     self.advance(data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 269, in advance
[rank7]:     call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 218, in _call_callback_hooks
[rank7]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 316, in on_train_batch_end
[rank7]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank7]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 715, in _save_none_monitor_checkpoint
[rank7]:     self._save_checkpoint(trainer, filepath)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
[rank7]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1365, in save_checkpoint
[rank7]:     self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/model_parallel.py", line 321, in save_checkpoint
[rank7]:     _distributed_checkpoint_save(converted_state, path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 867, in _distributed_checkpoint_save
[rank7]:     save(converted_state, checkpoint_id=path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 429, in inner_func
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 152, in save
[rank7]:     return _save_state_dict(
[rank7]:            ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 316, in _save_state_dict
[rank7]:     central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 191, in reduce_scatter
[rank7]:     raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 303, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 101, in create_local_plan
[rank7]:     plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 399, in create_default_local_save_plan
[rank7]:     requests += _create_write_items(fqn, obj)
[rank7]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 222, in _create_write_items
[rank7]:     return object.__create_write_items__(fqn, object)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 598, in __create_write_items__
[rank7]:     return [_create_write_items_for_dtensor(fqn, object)]
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 86, in _create_write_items_for_dtensor
[rank7]:     properties=TensorProperties.create_from_tensor(tensor.to_local()),
[rank7]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/metadata.py", line 108, in create_from_tensor
[rank7]:     pin_memory=tensor.is_pinned(),
[rank7]:                ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 377, in _dispatch__torch_function__
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 393, in _dispatch__torch_dispatch__
[rank7]:     raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
[rank7]: NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), kwarg_types={}

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions