22
33import mrcfile
44import numpy as np
5- from scipy .fftpack import fft2 , ifft2 , ifftshift
65from scipy .interpolate import RegularGridInterpolator
76from scipy .linalg import lstsq
87
98import aspire .volume
109from aspire .nufft import anufft
11- from aspire .utils import anorm , ensure
10+ from aspire .numeric import fft , xp
11+ from aspire .utils import ensure
1212from aspire .utils .coor_trans import grid_2d
13- from aspire .utils .fft import centered_fft2 , centered_ifft2
1413from aspire .utils .matlab_compat import m_reshape
14+ from aspire .utils .matrix import anorm
1515
1616logger = logging .getLogger (__name__ )
1717
@@ -46,7 +46,9 @@ def _im_translate2(im, shifts):
4646 raise ValueError ("The number of shifts must be 1 or match the number of images" )
4747
4848 resolution = im .res
49- grid = np .fft .ifftshift (np .ceil (np .arange (- resolution / 2 , resolution / 2 )))
49+ grid = xp .asnumpy (
50+ fft .ifftshift (xp .asarray (np .ceil (np .arange (- resolution / 2 , resolution / 2 ))))
51+ )
5052 om_y , om_x = np .meshgrid (grid , grid )
5153 phase_shifts = np .einsum ("ij, k -> ijk" , om_x , shifts [:, 0 ]) + np .einsum (
5254 "ij, k -> ijk" , om_y , shifts [:, 1 ]
@@ -56,9 +58,9 @@ def _im_translate2(im, shifts):
5658 phase_shifts /= resolution
5759
5860 mult_f = np .exp (- 2 * np .pi * 1j * phase_shifts )
59- im_f = np . fft .fft2 (im .asnumpy ())
61+ im_f = xp . asnumpy ( fft .fft2 (xp . asarray ( im .asnumpy ()) ))
6062 im_translated_f = im_f * mult_f
61- im_translated = np .real (np . fft .ifft2 (im_translated_f ))
63+ im_translated = np .real (xp . asnumpy ( fft .ifft2 (xp . asarray ( im_translated_f )) ))
6264
6365 return Image (im_translated )
6466
@@ -200,7 +202,10 @@ def downsample(self, ds_res):
200202 mask = (np .abs (grid ["x" ]) < ds_res / self .res ) & (
201203 np .abs (grid ["y" ]) < ds_res / self .res
202204 )
203- im = np .real (centered_ifft2 (centered_fft2 (self .data ) * mask ))
205+ im_shifted = fft .centered_ifft2 (
206+ fft .centered_fft2 (xp .asarray (self .data )) * xp .asarray (mask )
207+ )
208+ im = np .real (xp .asnumpy (im_shifted ))
204209
205210 for s in range (im_ds .shape [0 ]):
206211 interpolator = RegularGridInterpolator (
@@ -219,12 +224,13 @@ def filter(self, filter):
219224 """
220225 filter_values = filter .evaluate_grid (self .res )
221226
222- im_f = centered_fft2 (self .data )
227+ im_f = xp .asnumpy (fft .centered_fft2 (xp .asarray (self .data )))
228+
223229 if im_f .ndim > filter_values .ndim :
224230 im_f *= filter_values
225231 else :
226232 im_f = filter_values * im_f
227- im = centered_ifft2 (im_f )
233+ im = xp . asnumpy ( fft . centered_ifft2 (xp . asarray ( im_f )) )
228234 im = np .real (im )
229235
230236 return Image (im )
@@ -263,13 +269,11 @@ def _im_translate(self, shifts):
263269 shifts = shifts .astype (self .dtype )
264270
265271 L = self .res
266- im_f = fft2 (im , axes = (1 , 2 ))
267- grid_1d = (
268- ifftshift (np .ceil (np .arange (- L / 2 , L / 2 , dtype = self .dtype )))
269- * 2
270- * np .pi
271- / L
272+ im_f = xp .asnumpy (fft .fft2 (xp .asarray (im )))
273+ grid_shifted = fft .ifftshift (
274+ xp .asarray (np .ceil (np .arange (- L / 2 , L / 2 , dtype = self .dtype )))
272275 )
276+ grid_1d = xp .asnumpy (grid_shifted ) * 2 * np .pi / L
273277 om_x , om_y = np .meshgrid (grid_1d , grid_1d , indexing = "ij" )
274278
275279 phase_shifts_x = - shifts [:, 0 ].reshape ((n_shifts , 1 , 1 ))
@@ -281,7 +285,7 @@ def _im_translate(self, shifts):
281285 )
282286 mult_f = np .exp (- 1j * phase_shifts )
283287 im_translated_f = im_f * mult_f
284- im_translated = ifft2 (im_translated_f , axes = ( 1 , 2 ))
288+ im_translated = xp . asnumpy ( fft . ifft2 (xp . asarray ( im_translated_f ) ))
285289 im_translated = np .real (im_translated )
286290
287291 return Image (im_translated )
@@ -316,7 +320,7 @@ def backproject(self, rot_matrices):
316320 pts_rot = np .moveaxis (pts_rot , 1 , 2 )
317321 pts_rot = m_reshape (pts_rot , (3 , - 1 ))
318322
319- im_f = centered_fft2 (self .data ) / (L ** 2 )
323+ im_f = xp . asnumpy ( fft . centered_fft2 (xp . asarray ( self .data )) ) / (L ** 2 )
320324 if L % 2 == 0 :
321325 im_f [:, 0 , :] = 0
322326 im_f [:, :, 0 ] = 0
0 commit comments