|
1 | 1 | """ |
2 | 2 | Changing default device |
3 | 3 | ======================= |
4 | | -It is common to want to write PyTorch code in a device agnostic way, |
| 4 | +
|
| 5 | +It is common practice to write PyTorch code in a device-agnostic way, |
5 | 6 | and then switch between CPU and CUDA depending on what hardware is available. |
6 | | -Traditionally, to do this you might have used if-statements and cuda() calls |
| 7 | +Typically, to do this you might have used if-statements and cuda() calls |
7 | 8 | to do this: |
8 | | -""" |
9 | 9 |
|
| 10 | +""" |
10 | 11 | import torch |
11 | 12 |
|
12 | 13 | USE_CUDA = False |
|
21 | 22 | inp = torch.randn(128, 20, device=device) |
22 | 23 | print(mod(inp).device) |
23 | 24 |
|
| 25 | +################################################################### |
24 | 26 | # PyTorch now also has a context manager which can take care of the |
25 | | -# device transfer automatically. |
| 27 | +# device transfer automatically. Here is an example: |
26 | 28 |
|
27 | 29 | with torch.device('cuda'): |
28 | 30 | mod = torch.nn.Linear(20, 30) |
29 | 31 | print(mod.weight.device) |
30 | 32 | print(mod(torch.randn(128, 20)).device) |
31 | 33 |
|
32 | | -# You can also set it globally |
| 34 | +######################################### |
| 35 | +# You can also set it globally like this: |
33 | 36 |
|
34 | 37 | torch.set_default_device('cuda') |
35 | 38 |
|
36 | 39 | mod = torch.nn.Linear(20, 30) |
37 | 40 | print(mod.weight.device) |
38 | 41 | print(mod(torch.randn(128, 20)).device) |
39 | 42 |
|
| 43 | +################################################################ |
40 | 44 | # This function imposes a slight performance cost on every Python |
41 | | -# call to the torch API (not just factory functions). If this |
| 45 | +# call to the torch API (not just factory functions). If this |
42 | 46 | # is causing problems for you, please comment on |
43 | | -# https://github.com/pytorch/pytorch/issues/92701 |
| 47 | +# `this issue <https://github.com/pytorch/pytorch/issues/92701>`__ |
0 commit comments