diff --git a/ignite/metrics/gan/fid.py b/ignite/metrics/gan/fid.py index cea5264061e2..d124e5e9e398 100644 --- a/ignite/metrics/gan/fid.py +++ b/ignite/metrics/gan/fid.py @@ -31,13 +31,13 @@ def fid_score( except ImportError: raise ModuleNotFoundError("fid_score requires scipy to be installed.") - mu1, mu2 = mu1.cpu(), mu2.cpu() - sigma1, sigma2 = sigma1.cpu(), sigma2.cpu() + mu1, mu2 = mu1.detach().cpu(), mu2.detach().cpu() + sigma1, sigma2 = sigma1.detach().cpu(), sigma2.detach().cpu() diff = mu1 - mu2 # Product might be almost singular - covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False) + covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False) # Numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):