Skip to content

Commit addc43a

Browse files
correct modeling_ddpm
1 parent f9a4532 commit addc43a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

models/vision/ddpm/modeling_ddpm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ def __init__(self, unet, noise_scheduler):
2727
super().__init__()
2828
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
2929

30-
def __call__(self, generator=None, torch_device=None):
31-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
30+
def __call__(self, batch_size=1, generator=None, torch_device=None):
31+
if torch_device is None:
32+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
3233

3334
self.unet.to(torch_device)
3435
# 1. Sample gaussian noise

0 commit comments

Comments
 (0)