Skip to content

Support DistributedDataParallel and DataParallel, and publish Python package #30

@yoshitomo-matsubara

Description

@yoshitomo-matsubara

First of all, thank you for the great package!

1. Support DistributedDataParallel and DataParallel

I'm working on large-scale experiments that takes pretty long for training, and wondering if this framework can support DataParallel and DistributedDataParallel.

The current example/train.py looks like supporting Dataparallel as CustomDataParallel, but returned the following error

Traceback (most recent call last):
  File "examples/train.py", line 369, in <module>
    main(sys.argv[1:])
  File "examples/train.py", line 348, in main
    args.clip_max_norm,
  File "examples/train.py", line 159, in train_one_epoch
    out_net = model(d)
  File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 160, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in replicate
    return replicate(module, device_ids, not torch.is_grad_enabled())
  File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 140, in replicate
    param_idx = param_indices[param]
KeyError: Parameter containing:
tensor([[[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]],

        [[-10.,   0.,  10.]]], device='cuda:0', requires_grad=True)

(pipenv run python examples/train.py --data ./dataset/ --batch-size 4 --cuda on a machine with 3 GPUs)

When commenting out these two lines https://github.com/InterDigitalInc/CompressAI/blob/master/examples/train.py#L333-L334 , it looks working well

/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported.
  warnings.warn("Setting attributes on ParameterList is not supported.")
Train epoch 0: [0/5000 (0%)]	Loss: 183.278 |	MSE loss: 0.278 |	Bpp loss: 2.70 |	Aux loss: 5276.71
Train epoch 0: [40/5000 (1%)]	Loss: 65.175 |	MSE loss: 0.096 |	Bpp loss: 2.70 |	Aux loss: 5273.95
Train epoch 0: [80/5000 (2%)]	Loss: 35.178 |	MSE loss: 0.050 |	Bpp loss: 2.69 |	Aux loss: 5271.21
Train epoch 0: [120/5000 (2%)]	Loss: 36.634 |	MSE loss: 0.052 |	Bpp loss: 2.68 |	Aux loss: 5268.45
Train epoch 0: [160/5000 (3%)]	Loss: 26.010 |	MSE loss: 0.036 |	Bpp loss: 2.68 |	Aux loss: 5265.67
...

Could you please fix the issue and also support DistributedDataParallel?
If you need more examples to identify the components causing this issue, let me know. I have a few more examples (error messages) for both DataParallel and DistributedDataParallel with different network architectures (containing CompressionModel).

2. Publish Python package

It would be much more useful if you can publish this framework as a Python package so that we can install it with pip install compressai

Thank you!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions