diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 74cd0d415fc..d27e6066fe3 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -929,8 +929,10 @@ def _gen_affine_grid( d = 0.5 base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device) - base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow)) - base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1)) + x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) base_grid[..., 2].fill_(1) rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device) @@ -1065,8 +1067,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, d = 0.5 base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) - base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow)) - base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1)) + x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) base_grid[..., 2].fill_(1) rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)