@@ -47,67 +47,10 @@ def device(self) -> torch.device:
4747 return device
4848
4949 def to (self , * args : Any , ** kwargs : Any ) -> Self : # type: ignore[valid-type]
50- """Moves and/or casts the parameters and buffers.
51-
52- This can be called as
53- .. function:: to(device=None, dtype=None, non_blocking=False)
54- .. function:: to(dtype, non_blocking=False)
55- .. function:: to(tensor, non_blocking=False)
56- Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
57- floating point desired :attr:`dtype` s. In addition, this method will
58- only cast the floating point parameters and buffers to :attr:`dtype`
59- (if given). The integral parameters and buffers will be moved
60- :attr:`device`, if that is given, but with dtypes unchanged. When
61- :attr:`non_blocking` is set, it tries to convert/move asynchronously
62- with respect to the host if possible, e.g., moving CPU Tensors with
63- pinned memory to CUDA devices.
64- See below for examples.
65-
66- Note:
67- This method modifies the module in-place.
68-
69- Args:
70- device: the desired device of the parameters
71- and buffers in this module
72- dtype: the desired floating point type of
73- the floating point parameters and buffers in this module
74- tensor: Tensor whose dtype and device are the desired
75- dtype and device for all parameters and buffers in this module
76-
77- Returns:
78- Module: self
79-
80- Example::
81- >>> from torch import Tensor
82- >>> class ExampleModule(_DeviceDtypeModuleMixin):
83- ... def __init__(self, weight: Tensor):
84- ... super().__init__()
85- ... self.register_buffer('weight', weight)
86- >>> _ = torch.manual_seed(0)
87- >>> module = ExampleModule(torch.rand(3, 4))
88- >>> module.weight #doctest: +ELLIPSIS
89- tensor([[...]])
90- >>> module.to(torch.double)
91- ExampleModule()
92- >>> module.weight #doctest: +ELLIPSIS
93- tensor([[...]], dtype=torch.float64)
94- >>> cpu = torch.device('cpu')
95- >>> module.to(cpu, dtype=torch.half, non_blocking=True)
96- ExampleModule()
97- >>> module.weight #doctest: +ELLIPSIS
98- tensor([[...]], dtype=torch.float16)
99- >>> module.to(cpu)
100- ExampleModule()
101- >>> module.weight #doctest: +ELLIPSIS
102- tensor([[...]], dtype=torch.float16)
103- >>> module.device
104- device(type='cpu')
105- >>> module.dtype
106- torch.float16
107- """
108- # there is diff nb vars in PT 1.5
109- out = torch ._C ._nn ._parse_to (* args , ** kwargs )
110- self .__update_properties (device = out [0 ], dtype = out [1 ])
50+ """See :meth:`torch.nn.Module.to`."""
51+ # this converts `str` device to `torch.device`
52+ device , dtype = torch ._C ._nn ._parse_to (* args , ** kwargs )[:2 ]
53+ self .__update_properties (device = device , dtype = dtype )
11154 return super ().to (* args , ** kwargs )
11255
11356 def cuda (self , device : Optional [Union [torch .device , int ]] = None ) -> Self : # type: ignore[valid-type]
@@ -130,50 +73,27 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
13073 return super ().cuda (device = device )
13174
13275 def cpu (self ) -> Self : # type: ignore[valid-type]
133- """Moves all model parameters and buffers to the CPU.
134-
135- Returns:
136- Module: self
137- """
76+ """See :meth:`torch.nn.Module.cpu`."""
13877 self .__update_properties (device = torch .device ("cpu" ))
13978 return super ().cpu ()
14079
14180 def type (self , dst_type : Union [str , torch .dtype ]) -> Self : # type: ignore[valid-type]
142- """Casts all parameters and buffers to :attr:`dst_type`.
143-
144- Arguments:
145- dst_type (type or string): the desired type
146-
147- Returns:
148- Module: self
149- """
81+ """See :meth:`torch.nn.Module.type`."""
15082 self .__update_properties (dtype = dst_type )
15183 return super ().type (dst_type = dst_type )
15284
15385 def float (self ) -> Self : # type: ignore[valid-type]
154- """Casts all floating point parameters and buffers to ``float`` datatype.
155-
156- Returns:
157- Module: self
158- """
86+ """See :meth:`torch.nn.Module.float`."""
15987 self .__update_properties (dtype = torch .float )
16088 return super ().float ()
16189
16290 def double (self ) -> Self : # type: ignore[valid-type]
163- """Casts all floating point parameters and buffers to ``double`` datatype.
164-
165- Returns:
166- Module: self
167- """
91+ """See :meth:`torch.nn.Module.double`."""
16892 self .__update_properties (dtype = torch .double )
16993 return super ().double ()
17094
17195 def half (self ) -> Self : # type: ignore[valid-type]
172- """Casts all floating point parameters and buffers to ``half`` datatype.
173-
174- Returns:
175- Module: self
176- """
96+ """See :meth:`torch.nn.Module.half`."""
17797 self .__update_properties (dtype = torch .half )
17898 return super ().half ()
17999
0 commit comments