Skip to content

Commit 6dd7298

Browse files
add single device training
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 96bb8a4 commit 6dd7298

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Any, Union
2+
3+
import torch
4+
from torch._C import device
5+
6+
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
7+
8+
9+
class SingleDevicePlugin(TrainingTypePlugin):
10+
def __init__(self, device: torch.device) -> bool:
11+
super().__init__()
12+
self.device: torch.device = device
13+
14+
@property
15+
def on_tpu(self) -> bool:
16+
return False
17+
18+
@property
19+
def on_gpu(self) -> bool:
20+
return self.device.type == "cuda" and torch.cuda.is_available()
21+
22+
def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
23+
return output
24+
25+
@property
26+
def root_device(self) -> torch.device:
27+
return self.device
28+
29+
def model_to_device(self) -> None:
30+
if self.on_gpu:
31+
torch.cuda.set_device(self.root_device)
32+
33+
self._model.to(self.root_device)
34+
35+
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
36+
self._model = model
37+
self.model_to_device()
38+
return self.model
39+
40+
@property
41+
def is_global_zero(self) -> bool:
42+
return True
43+
44+
def barrier(self, *args, **kwargs) -> None:
45+
pass
46+
47+
def broadcast(self, obj: object, src: int = 0) -> object:
48+
return obj

0 commit comments

Comments
 (0)