Skip to content

Commit 9be9769

Browse files
authored
Merge pull request #298 from ComputationalCryoEM/replace_fft
includ centered fft functions in facade modules and replace FFT functions in other submodules
2 parents 0c62236 + 94925ad commit 9be9769

File tree

12 files changed

+209
-109
lines changed

12 files changed

+209
-109
lines changed

src/aspire/basis/ffb_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import numpy as np
44
from numpy import pi
5-
from scipy.fftpack import fft, ifft
65
from scipy.special import jv
76

87
from aspire.basis import FBBasis2D
98
from aspire.basis.basis_utils import lgwt
109
from aspire.image import Image
1110
from aspire.nufft import anufft, nufft
11+
from aspire.numeric import fft, xp
1212
from aspire.utils import complex_type
1313
from aspire.utils.matlab_compat import m_reshape
1414

@@ -165,7 +165,7 @@ def evaluate(self, v):
165165
ind_pos = ind_pos + 2 * self.k_max[ell]
166166

167167
# 1D inverse FFT in the degree of polar angle
168-
pf = 2 * pi * ifft(pf, axis=1, overwrite_x=True)
168+
pf = 2 * pi * xp.asnumpy(fft.ifft(xp.asarray(pf), axis=1))
169169

170170
# Only need "positive" frequencies.
171171
hsize = int(np.size(pf, 1) / 2)
@@ -235,7 +235,7 @@ def evaluate_t(self, x):
235235
)
236236

237237
# 1D FFT on the angular dimension for each concentric circle
238-
pf = 2 * pi / (2 * n_theta) * fft(pf, 2 * n_theta, 2)
238+
pf = 2 * pi / (2 * n_theta) * xp.asnumpy(fft.fft(xp.asarray(pf)))
239239

240240
# This only makes it easier to slice the array later.
241241
v = np.zeros((n_images, self.count), dtype=x.dtype)

src/aspire/basis/fpswf_2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import numpy as np
44
from numpy import pi
55
from numpy.linalg import lstsq
6-
from scipy.fftpack import fft
76
from scipy.optimize import least_squares
87
from scipy.special import jn
98

109
from aspire.basis.basis_utils import lgwt, t_x_mat, t_x_mat_dot
1110
from aspire.basis.pswf_2d import PSWFBasis2D
1211
from aspire.nufft import nufft
12+
from aspire.numeric import fft, xp
1313
from aspire.utils import complex_type
1414

1515
logger = logging.getLogger(__name__)
@@ -400,7 +400,7 @@ def _pswf_integration(self, images_nufft):
400400
:,
401401
]
402402
curr_r_mat = np.concatenate((curr_r_mat, np.conj(curr_r_mat)))
403-
fft_plan = fft(curr_r_mat, curr_r_mat.shape[0], axis=0)
403+
fft_plan = xp.asnumpy(fft.fft(xp.asarray(curr_r_mat), axis=0))
404404
angular_eval = fft_plan * self.quad_rule_radial_wts[i]
405405

406406
r_n_eval_mat[i, :, :] = np.tile(

src/aspire/image/image.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
import mrcfile
44
import numpy as np
5-
from scipy.fftpack import fft2, ifft2, ifftshift
65
from scipy.interpolate import RegularGridInterpolator
76
from scipy.linalg import lstsq
87

98
import aspire.volume
109
from aspire.nufft import anufft
11-
from aspire.utils import anorm, ensure
10+
from aspire.numeric import fft, xp
11+
from aspire.utils import ensure
1212
from aspire.utils.coor_trans import grid_2d
13-
from aspire.utils.fft import centered_fft2, centered_ifft2
1413
from aspire.utils.matlab_compat import m_reshape
14+
from aspire.utils.matrix import anorm
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -46,7 +46,9 @@ def _im_translate2(im, shifts):
4646
raise ValueError("The number of shifts must be 1 or match the number of images")
4747

4848
resolution = im.res
49-
grid = np.fft.ifftshift(np.ceil(np.arange(-resolution / 2, resolution / 2)))
49+
grid = xp.asnumpy(
50+
fft.ifftshift(xp.asarray(np.ceil(np.arange(-resolution / 2, resolution / 2))))
51+
)
5052
om_y, om_x = np.meshgrid(grid, grid)
5153
phase_shifts = np.einsum("ij, k -> ijk", om_x, shifts[:, 0]) + np.einsum(
5254
"ij, k -> ijk", om_y, shifts[:, 1]
@@ -56,9 +58,9 @@ def _im_translate2(im, shifts):
5658
phase_shifts /= resolution
5759

5860
mult_f = np.exp(-2 * np.pi * 1j * phase_shifts)
59-
im_f = np.fft.fft2(im.asnumpy())
61+
im_f = xp.asnumpy(fft.fft2(xp.asarray(im.asnumpy())))
6062
im_translated_f = im_f * mult_f
61-
im_translated = np.real(np.fft.ifft2(im_translated_f))
63+
im_translated = np.real(xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f))))
6264

6365
return Image(im_translated)
6466

@@ -200,7 +202,10 @@ def downsample(self, ds_res):
200202
mask = (np.abs(grid["x"]) < ds_res / self.res) & (
201203
np.abs(grid["y"]) < ds_res / self.res
202204
)
203-
im = np.real(centered_ifft2(centered_fft2(self.data) * mask))
205+
im_shifted = fft.centered_ifft2(
206+
fft.centered_fft2(xp.asarray(self.data)) * xp.asarray(mask)
207+
)
208+
im = np.real(xp.asnumpy(im_shifted))
204209

205210
for s in range(im_ds.shape[0]):
206211
interpolator = RegularGridInterpolator(
@@ -219,12 +224,13 @@ def filter(self, filter):
219224
"""
220225
filter_values = filter.evaluate_grid(self.res)
221226

222-
im_f = centered_fft2(self.data)
227+
im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self.data)))
228+
223229
if im_f.ndim > filter_values.ndim:
224230
im_f *= filter_values
225231
else:
226232
im_f = filter_values * im_f
227-
im = centered_ifft2(im_f)
233+
im = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f)))
228234
im = np.real(im)
229235

230236
return Image(im)
@@ -263,13 +269,11 @@ def _im_translate(self, shifts):
263269
shifts = shifts.astype(self.dtype)
264270

265271
L = self.res
266-
im_f = fft2(im, axes=(1, 2))
267-
grid_1d = (
268-
ifftshift(np.ceil(np.arange(-L / 2, L / 2, dtype=self.dtype)))
269-
* 2
270-
* np.pi
271-
/ L
272+
im_f = xp.asnumpy(fft.fft2(xp.asarray(im)))
273+
grid_shifted = fft.ifftshift(
274+
xp.asarray(np.ceil(np.arange(-L / 2, L / 2, dtype=self.dtype)))
272275
)
276+
grid_1d = xp.asnumpy(grid_shifted) * 2 * np.pi / L
273277
om_x, om_y = np.meshgrid(grid_1d, grid_1d, indexing="ij")
274278

275279
phase_shifts_x = -shifts[:, 0].reshape((n_shifts, 1, 1))
@@ -281,7 +285,7 @@ def _im_translate(self, shifts):
281285
)
282286
mult_f = np.exp(-1j * phase_shifts)
283287
im_translated_f = im_f * mult_f
284-
im_translated = ifft2(im_translated_f, axes=(1, 2))
288+
im_translated = xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f)))
285289
im_translated = np.real(im_translated)
286290

287291
return Image(im_translated)
@@ -316,7 +320,7 @@ def backproject(self, rot_matrices):
316320
pts_rot = np.moveaxis(pts_rot, 1, 2)
317321
pts_rot = m_reshape(pts_rot, (3, -1))
318322

319-
im_f = centered_fft2(self.data) / (L ** 2)
323+
im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self.data))) / (L ** 2)
320324
if L % 2 == 0:
321325
im_f[:, 0, :] = 0
322326
im_f[:, :, 0] = 0

src/aspire/image/preprocess.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import math
33

44
import numpy as np
5-
from scipy.fftpack import fft, fft2, fftn, fftshift, ifft, ifft2, ifftn, ifftshift
65
from scipy.special import erf
76

7+
from aspire.numeric import fft, xp
88
from aspire.utils import ensure
99

1010
logger = logging.getLogger(__name__)
@@ -148,23 +148,36 @@ def downsample(insamples, szout, mask=None):
148148
# stack of one dimension objects
149149

150150
for idata in range(ndata):
151-
insamples_fft = crop_pad(fftshift(fft(insamples[idata])), L_out) * mask
152-
outsamples[idata] = np.real(ifft(ifftshift(insamples_fft)) * (L_out / L_in))
151+
insamples_shifted = fft.fftshift(fft.fft(xp.asarray(insamples[idata])))
152+
insamples_fft = crop_pad(insamples_shifted, L_out) * mask
153+
154+
outsamples_shifted = fft.ifft(fft.ifftshift(xp.asarray(insamples_fft)))
155+
outsamples[idata] = np.real(xp.asnumpy(outsamples_shifted) * (L_out / L_in))
153156

154157
elif insamples.ndim == 3:
155158
# stack of two dimension objects
156159
for idata in range(ndata):
157-
insamples_fft = crop_pad(fftshift(fft2(insamples[idata])), L_out) * mask
160+
insamples_shifted = fft.fftshift(fft.fft2(xp.asarray(insamples[idata])))
161+
insamples_fft = crop_pad(insamples_shifted, L_out) * mask
162+
163+
outsamples_shifted = fft.ifft2(fft.ifftshift(xp.asarray(insamples_fft)))
158164
outsamples[idata] = np.real(
159-
ifft2(ifftshift(insamples_fft)) * (L_out ** 2 / L_in ** 2)
165+
xp.asnumpy(outsamples_shifted) * (L_out ** 2 / L_in ** 2)
160166
)
161167

162168
elif insamples.ndim == 4:
163169
# stack of three dimension objects
164170
for idata in range(ndata):
165-
insamples_fft = crop_pad(fftshift(fftn(insamples[idata])), L_out) * mask
171+
insamples_shifted = fft.fftshift(
172+
fft.fftn(xp.asarray(insamples[idata]), axes=(0, 1, 2))
173+
)
174+
insamples_fft = crop_pad(insamples_shifted, L_out) * mask
175+
176+
outsamples_shifted = fft.ifftn(
177+
fft.ifftshift(xp.asarray(insamples_fft)), axes=(0, 1, 2)
178+
)
166179
outsamples[idata] = np.real(
167-
ifftn(ifftshift(insamples_fft)) * (L_out ** 3 / L_in ** 3)
180+
xp.asnumpy(outsamples_shifted) * (L_out ** 3 / L_in ** 3)
168181
)
169182

170183
else:

src/aspire/noise/noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44

5+
from aspire.numeric import fft, xp
56
from aspire.operators import ArrayFilter, ScalarFilter
67
from aspire.utils.coor_trans import grid_2d
7-
from aspire.utils.fft import centered_fft2
88

99
logger = logging.getLogger(__name__)
1010

@@ -122,7 +122,7 @@ def estimate_noise_psd(self):
122122

123123
_denominator = self.n * np.sum(mask)
124124
mean_est += np.sum(images_masked) / _denominator
125-
im_masked_f = centered_fft2(images_masked)
125+
im_masked_f = xp.asnumpy(fft.centered_fft2(xp.asarray(images_masked)))
126126
noise_psd_est += np.sum(np.abs(im_masked_f ** 2), axis=0) / _denominator
127127

128128
mid = self.L // 2

src/aspire/numeric/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
xp = NumericClass()
1313

1414

15-
def fft_class(which):
15+
def fft_object(which):
1616
if which == "pyfftw":
1717
from .pyfftw_fft import PyfftwFFT as FFTClass
1818
elif which == "cupy":
@@ -21,7 +21,7 @@ def fft_class(which):
2121
from .scipy_fft import ScipyFFT as FFTClass
2222
else:
2323
raise RuntimeError(f"Invalid selection for fft class: {which}")
24-
return FFTClass
24+
return FFTClass()
2525

2626

27-
fft = fft_class(config.common.fft)
27+
fft = fft_object(config.common.fft)

src/aspire/numeric/base_fft.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
class FFT:
2+
"""
3+
Define a customized interface for FFT functions
4+
5+
To make consistent among Pyfftw, Scipyfft and cupy fft,
6+
not all arguments are included.
7+
"""
8+
9+
def fft(self, x, axis=-1, workers=-1):
10+
raise NotImplementedError("subclasses must implement this")
11+
12+
def ifft(self, x, axis=-1, workers=-1):
13+
raise NotImplementedError("subclasses must implement this")
14+
15+
def fft2(self, x, axes=(-2, -1), workers=-1):
16+
raise NotImplementedError("subclasses must implement this")
17+
18+
def ifft2(self, x, axes=(-2, -1), workers=-1):
19+
raise NotImplementedError("subclasses must implement this")
20+
21+
def fftn(self, x, axes=None, workers=-1):
22+
raise NotImplementedError("subclasses must implement this")
23+
24+
def ifftn(self, x, axes=None, workers=-1):
25+
raise NotImplementedError("subclasses must implement this")
26+
27+
def fftshift(self, x, axes=None):
28+
raise NotImplementedError("subclasses must implement this")
29+
30+
def ifftshift(self, x, axes=None):
31+
raise NotImplementedError("subclasses must implement this")
32+
33+
def centered_ifft(self, x, axis=-1, workers=-1):
34+
x = self.ifftshift(x, axes=axis)
35+
x = self.ifft(x, axis=axis, workers=workers)
36+
x = self.fftshift(x, axes=axis)
37+
return x
38+
39+
def centered_fft(self, x, axis=-1, workers=-1):
40+
x = self.ifftshift(x, axes=axis)
41+
x = self.fft(x, axis=axis, workers=workers)
42+
x = self.fftshift(x, axes=axis)
43+
return x
44+
45+
def centered_ifft2(self, x, axes=(-2, -1), workers=-1):
46+
x = self.ifftshift(x, axes=axes)
47+
x = self.ifft2(x, axes=axes, workers=workers)
48+
x = self.fftshift(x, axes=axes)
49+
return x
50+
51+
def centered_fft2(self, x, axes=(-2, -1), workers=-1):
52+
x = self.ifftshift(x, axes=axes)
53+
x = self.fft2(x, axes=axes, workers=workers)
54+
x = self.fftshift(x, axes=axes)
55+
return x
56+
57+
def centered_ifftn(self, x, axes=None, workers=-1):
58+
x = self.ifftshift(x, axes=axes)
59+
x = self.ifftn(x, axes=axes, workers=workers)
60+
x = self.fftshift(x, axes=axes)
61+
return x
62+
63+
def centered_fftn(self, x, axes=None, workers=-1):
64+
x = self.ifftshift(x, axes=axes)
65+
x = self.fftn(x, axes=axes, workers=workers)
66+
x = self.fftshift(x, axes=axes)
67+
return x

src/aspire/numeric/cupy_fft.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,35 @@
11
import cupy as cp
22

3+
from aspire.numeric.base_fft import FFT
34

4-
class CupyFFT:
5+
6+
class CupyFFT(FFT):
57
"""
68
Define a unified wrapper class for Cupy FFT functions
79
810
To be consistent with Scipy and Pyfftw, not all arguments are included.
911
"""
1012

11-
@staticmethod
12-
def fft(x, axis=-1, workers=-1):
13+
def fft(self, x, axis=-1, workers=-1):
1314
return cp.fft.fft(x, axis=axis)
1415

15-
@staticmethod
16-
def ifft(x, axis=-1, workers=-1):
16+
def ifft(self, x, axis=-1, workers=-1):
1717
return cp.fft.ifft(x, axis=axis)
1818

19-
@staticmethod
20-
def fft2(x, axes=(-2, -1), workers=-1):
19+
def fft2(self, x, axes=(-2, -1), workers=-1):
2120
return cp.fft.fft2(x, axes=axes)
2221

23-
@staticmethod
24-
def ifft2(x, axes=(-2, -1), workers=-1):
22+
def ifft2(self, x, axes=(-2, -1), workers=-1):
2523
return cp.fft.ifft2(x, axes=axes)
2624

27-
@staticmethod
28-
def fftn(x, axes=None, workers=-1):
25+
def fftn(self, x, axes=None, workers=-1):
2926
return cp.fft.fftn(x, axes=axes)
3027

31-
@staticmethod
32-
def ifftn(x, axes=None, workers=-1):
33-
return cp.fft.ifft2(x, axes=axes)
28+
def ifftn(self, x, axes=None, workers=-1):
29+
return cp.fft.ifftn(x, axes=axes)
3430

35-
@staticmethod
36-
def fftshift(x, axes=None):
31+
def fftshift(self, x, axes=None):
3732
return cp.fft.fftshift(x, axes=axes)
3833

39-
@staticmethod
40-
def ifftshift(x, axes=None):
34+
def ifftshift(self, x, axes=None):
4135
return cp.fft.ifftshift(x, axes=axes)

0 commit comments

Comments
 (0)