Skip to content

Commit 51cc7a8

Browse files
tgaddairteddykokerSeanNarenBorda
authored
Horovod: fixed early stopping and added metrics aggregation (#3775)
* Fixed early stopping for Horovod * Refactored to sync_dist_if_available * Bump min Horovod version to support hvd.is_initialized * Changelog * Added back change for Horovod * Removed redundant checks for initialization * Implement metrics gathering for Horovod * Added test for EvalResult * Renamed ddp_sync_on_step -> dist_sync_on_step * Added metric test for Horovod * Added option pass callable allgather function to metric base class * Added dist_sync_fn * Fixed calls to private _sync_dist * Fixed Horovod test * Added sync_tensor to the distributed backend * Skip Windows * Insert test path * Removed redundant import * Updated drone * Unset HOROVOD_GPU_ALLREDUCE * Unset * No cache dir * No uninstall * Unset variables * Uninstall Horovod during initialization * Replaced more references to ddp_sync_on_step * Fixed imports * Fixed attribute * Added back default * Lint * Added back docstring * Made gather_all_tensors default * Added whitespace * Update tests/models/test_horovod.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/metrics/metric.py Co-authored-by: Jirka Borovec <[email protected]> * Update CHANGELOG.md Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent e81707b commit 51cc7a8

22 files changed

+324
-65
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458))
3131

3232

33+
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
34+
35+
3336
### Changed
3437

3538

pytorch_lightning/accelerators/accelerator.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
import math
1616
from enum import Enum
17-
from typing import Any, Optional
17+
from typing import Any, Optional, Union
1818

1919
import torch
2020

@@ -30,6 +30,12 @@
3030
except ImportError:
3131
amp = None
3232

33+
if torch.distributed.is_available():
34+
from torch.distributed import ReduceOp
35+
else:
36+
class ReduceOp:
37+
SUM = None
38+
3339
EPSILON = 1e-6
3440
EPSILON_FP16 = 1e-5
3541

@@ -209,6 +215,22 @@ def init_ddp_connection(
209215
torch_backend, rank=global_rank, world_size=world_size
210216
)
211217

218+
def sync_tensor(self,
219+
tensor: Union[torch.Tensor],
220+
group: Optional[Any] = None,
221+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
222+
"""
223+
Function to reduce a tensor from several distributed processes to one aggregated tensor.
224+
Args:
225+
tensor: the tensor to sync and reduce
226+
group: the process group to gather results from. Defaults to all processes (world)
227+
reduce_op: the reduction operation. Defaults to sum.
228+
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
229+
Return:
230+
reduced value
231+
"""
232+
raise NotImplementedError()
233+
212234
def __getstate__(self):
213235
return {
214236
'trainer': self.trainer,

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@
1818
import sys
1919
from os.path import abspath
2020
from time import sleep
21-
from typing import Optional, List
21+
from typing import Any, Optional, List, Union
2222

2323
import numpy as np
2424

2525
from pytorch_lightning import _logger as log
26-
from pytorch_lightning.accelerators.accelerator import Accelerator
26+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2727
from pytorch_lightning.core.lightning import LightningModule
2828
from pytorch_lightning.distributed.dist import LightningDistributed
2929
from pytorch_lightning.utilities import AMPType
3030
from pytorch_lightning.utilities.distributed import find_free_network_port
3131
from pytorch_lightning.utilities.distributed import rank_zero_only
32+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
3233
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3334
from pytorch_lightning.utilities.seed import seed_everything
3435
from torch.nn.parallel import DistributedDataParallel
@@ -298,3 +299,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
298299
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
299300

300301
return model
302+
303+
def sync_tensor(self,
304+
tensor: Union[torch.Tensor],
305+
group: Optional[Any] = None,
306+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
307+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15-
from typing import List, Optional
15+
from typing import Any, List, Optional, Union
1616

1717
import torch
1818
import torch.distributed as torch_distrib
1919
import torch.distributed as dist
2020
from torch.nn.parallel import DistributedDataParallel
2121

2222
from pytorch_lightning import _logger as log
23-
from pytorch_lightning.accelerators.accelerator import Accelerator
23+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.utilities import AMPType
2626
from pytorch_lightning.utilities.distributed import rank_zero_only
27+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2728
from pytorch_lightning.distributed.dist import LightningDistributed
2829

2930

@@ -199,3 +200,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
199200
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
200201

201202
return model
203+
204+
def sync_tensor(self,
205+
tensor: Union[torch.Tensor],
206+
group: Optional[Any] = None,
207+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
208+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15-
from typing import List, Optional
15+
from typing import Any, List, Optional, Union
1616

1717
import torch
1818
import torch.distributed as torch_distrib
@@ -21,11 +21,11 @@
2121
from torch.nn.parallel import DistributedDataParallel
2222

2323
from pytorch_lightning import _logger as log
24-
from pytorch_lightning.accelerators.accelerator import Accelerator
24+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2525
from pytorch_lightning.core.lightning import LightningModule
2626
from pytorch_lightning.utilities import AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
28-
from pytorch_lightning.utilities.distributed import find_free_network_port
28+
from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available
2929
from pytorch_lightning.distributed.dist import LightningDistributed
3030

3131
try:
@@ -229,3 +229,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
229229
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
230230

231231
return model
232+
233+
def sync_tensor(self,
234+
tensor: Union[torch.Tensor],
235+
group: Optional[Any] = None,
236+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
237+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15-
from typing import List, Optional
15+
from typing import Any, List, Optional, Union
1616

1717
import torch
1818
import torch.distributed as torch_distrib
1919
import torch.distributed as dist
2020
from torch.nn.parallel import DistributedDataParallel
2121

2222
from pytorch_lightning import _logger as log
23-
from pytorch_lightning.accelerators.accelerator import Accelerator
23+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.utilities import AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only
28+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2829

2930
try:
3031
from hydra.utils import to_absolute_path, get_original_cwd
@@ -198,3 +199,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
198199
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
199200

200201
return model
202+
203+
def sync_tensor(self,
204+
tensor: Union[torch.Tensor],
205+
group: Optional[Any] = None,
206+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
207+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_slurm_accelerator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15-
from typing import List
15+
from typing import Any, List, Optional, Union
1616

1717
import torch
1818
import torch.distributed as torch_distrib
1919
import torch.distributed as dist
2020
from torch.nn.parallel import DistributedDataParallel
2121

2222
from pytorch_lightning import _logger as log
23-
from pytorch_lightning.accelerators.accelerator import Accelerator
23+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.utilities import AMPType
27-
from pytorch_lightning.utilities.distributed import rank_zero_only
27+
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2828
from pytorch_lightning.utilities.seed import seed_everything
2929

3030
try:
@@ -205,3 +205,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
205205
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
206206

207207
return model
208+
209+
def sync_tensor(self,
210+
tensor: Union[torch.Tensor],
211+
group: Optional[Any] = None,
212+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
213+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License
1414
import os
1515
import re
16-
from typing import List, Optional
16+
from typing import Any, List, Optional, Union
1717

1818
import torch
1919
import torch.multiprocessing as mp
@@ -22,11 +22,12 @@
2222
from torch.nn.parallel import DistributedDataParallel
2323

2424
from pytorch_lightning import _logger as log
25-
from pytorch_lightning.accelerators.accelerator import Accelerator
25+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2626
from pytorch_lightning.core.lightning import LightningModule
2727
from pytorch_lightning.utilities import AMPType
2828
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
2929
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port
30+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
3031
from pytorch_lightning.utilities.seed import seed_everything
3132
from pytorch_lightning.distributed.dist import LightningDistributed
3233

@@ -254,3 +255,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
254255
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
255256

256257
return model
258+
259+
def sync_tensor(self,
260+
tensor: Union[torch.Tensor],
261+
group: Optional[Any] = None,
262+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
263+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import os
15-
from typing import List, Optional
15+
from typing import Any, List, Optional, Union
1616

1717
import torch
1818
import torch.distributed as torch_distrib
1919
import torch.distributed as dist
2020
from torch.nn.parallel import DistributedDataParallel
2121

2222
from pytorch_lightning import _logger as log
23-
from pytorch_lightning.accelerators.accelerator import Accelerator
23+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.utilities import AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only
28+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2829

2930

3031
try:
@@ -201,3 +202,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
201202
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
202203

203204
return model
205+
206+
def sync_tensor(self,
207+
tensor: Union[torch.Tensor],
208+
group: Optional[Any] = None,
209+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
210+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/horovod_accelerator.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import ExitStack
15-
from typing import Optional
15+
from typing import Any, Optional, Union
1616

1717
import torch
1818
from torch.optim.lr_scheduler import _LRScheduler
1919

20-
from pytorch_lightning.accelerators.accelerator import Accelerator
20+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2121
from pytorch_lightning.utilities import AMPType
2222
from pytorch_lightning.utilities.distributed import rank_zero_only
2323

@@ -161,3 +161,41 @@ def barrier(self, name: Optional[str] = None):
161161
def broadcast(self, obj, src=0):
162162
obj = hvd.broadcast_object(obj, src)
163163
return obj
164+
165+
def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
166+
if group is not None:
167+
raise ValueError(
168+
"Horovod does not support allgather using a subcommunicator at this time. "
169+
"Unset `group`."
170+
)
171+
172+
if len(result.shape) == 0:
173+
# Convert scalars to single dimension tensors
174+
result = result.reshape(1)
175+
176+
# sync and gather all
177+
hvd.join()
178+
gathered = hvd.allgather(result)
179+
gathered_result = list(gathered.split(1, dim=0))
180+
return gathered_result
181+
182+
def sync_tensor(self,
183+
tensor: Union[torch.Tensor],
184+
group: Optional[Any] = None,
185+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
186+
if group is not None:
187+
raise ValueError(
188+
"Horovod does not support allreduce using a subcommunicator at this time. "
189+
"Unset `group`."
190+
)
191+
192+
if reduce_op is None or reduce_op == "sum":
193+
reduce_op = hvd.Sum
194+
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
195+
reduce_op = hvd.Average
196+
else:
197+
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
198+
199+
# sync all processes before reduction
200+
hvd.join()
201+
return hvd.allreduce(tensor, op=reduce_op)

0 commit comments

Comments
 (0)