Skip to content

Commit b44dbe0

Browse files
committed
ruff
1 parent d2fa322 commit b44dbe0

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,17 @@ def _(func, types, args, kwargs):
394394
########
395395
# Test #
396396
########
397-
def test():
397+
def main():
398398
from torchao.utils import benchmark_model
399399

400+
class M(torch.nn.Module):
401+
def __init__(self) -> None:
402+
super().__init__()
403+
self.linear = torch.nn.Linear(1024, 128)
404+
405+
def forward(self, x: torch.Tensor) -> torch.Tensor:
406+
return self.linear(x)
407+
400408
m = M()
401409
example_inputs = (100 * torch.randn(512, 1024),)
402410
NUM_WARMUPS = 10
@@ -431,4 +439,4 @@ def test():
431439
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))
432440

433441
if __name__ == "__main__":
434-
test()
442+
main()

tutorials/developer_api_guide/my_trainable_tensor_subclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def from_float(
6161
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)
6262

6363
class _ToMyTrainableDTypeTensor(torch.autograd.Function):
64-
"""
64+
"""
6565
Differentiable constructor for `MyTrainableDTypeTensor`.
6666
"""
6767

@@ -163,8 +163,8 @@ def _(func, types, args, kwargs):
163163
########
164164

165165
class M(torch.nn.Module):
166-
def __init__(self, *args, **kwargs) -> None:
167-
super().__init__(*args, **kwargs)
166+
def __init__(self) -> None:
167+
super().__init__()
168168
self.linear = torch.nn.Linear(512, 1024, bias=False)
169169

170170
def forward(self, x: torch.Tensor) -> torch.Tensor:

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
145145
########
146146
# Test #
147147
########
148-
if __name__ == "__main__":
148+
def main():
149149
# To make sure different ranks create the same module
150150
torch.manual_seed(5)
151151

@@ -192,3 +192,6 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
192192
print("torch.compile works!")
193193

194194
dist.destroy_process_group()
195+
196+
if __name__ == "__main__":
197+
main()

0 commit comments

Comments
 (0)