11import logging
22
3- import finufftpy
3+ import finufft
44import numpy as np
55
66from aspire .nufft import Plan
1010
1111
1212class FinufftPlan (Plan ):
13- def __init__ (self , sz , fourier_pts , epsilon = 1e-15 , ntransforms = 1 , ** kwargs ):
13+ def __init__ (self , sz , fourier_pts , epsilon = 1e-8 , ntransforms = 1 , ** kwargs ):
1414 """
1515 A plan for non-uniform FFT in 2D or 3D.
1616
17- :param sz: A tuple indicating the geometry of the signal
18- :param fourier_pts: The points in Fourier space where the Fourier transform is to be calculated,
19- arranged as a dimension-by-K array. These need to be in the range [-pi, pi] in each dimension.
20- :param epsilon: The desired precision of the NUFFT
21- :param ntransforms: Optional integer indicating if you would like to compute a batch of `ntransforms`
22- transforms. Implies vol_f.shape is (..., `ntransforms`). Defaults to 0 which disables batching.
17+ :param sz: A tuple indicating the geometry of the signal.
18+ :param fourier_pts: The points in Fourier space where the Fourier
19+ transform is to be calculated, arranged as a dimension-by-K array.
20+ These need to be in the range [-pi, pi] in each dimension.
21+ :param epsilon: The desired precision of the NUFFT.
22+ :param ntransforms: Optional integer indicating if you would like
23+ to compute a batch of `ntransforms`.
24+ transforms. Implies vol_f.shape is (`ntransforms`, ...).
2325 """
2426
2527 self .ntransforms = ntransforms
26- manystr = ""
27- if self .ntransforms > 1 :
28- manystr = "many"
2928
3029 self .sz = sz
3130 self .dim = len (sz )
3231
3332 self .dtype = fourier_pts .dtype
3433
35- # TODO: Currently/historically finufftpy is hardcoded as doubles only (inside the binding layer).
36- # This has been changed in their GuruV2 work,
37- # and an updated package should be released and integrated very soon.
38- # The following casting business is only to facilitate the transition.
39- # I don't want to send one precision in and get a different one out.
40- # We have enough code that does that already.
41- # (Potentially anything that used this, for example).
42- # I would error, but I know a bunch of code ASPIRE wants to work,
43- # the cov2d tutorial for example, would fail and require hacks to run
44- # if I was strict about dtypes rn.
45- # I would rather deal with that in other, more targeted PRs.
46- # This preserves the legacy behavior of admitting singles,
47- # but I will correct it slightly, to return the precision given as input.
48- # This approach should contain the hacks here to a single place on the edge,
49- # instead of spread through the code. ASPIRE code should focus on being
50- # internally consistent.
51- # Admittedly not ideal, but ignoring these problems wasn't sustainable.
52-
53- self .cast_output = False
54- if self .dtype != np .float64 :
55- logger .debug (
56- "This version of finufftpy is hardcoded to doubles internally"
57- " casting input to doubles, results cast back to singles."
58- )
59- self .cast_output = True
60- self .dtype = np .float64
61-
6234 self .complex_dtype = complex_type (self .dtype )
6335
64- # TODO: Things get messed up unless we ensure a 'C' ordering here - investigate why
65- self .fourier_pts = np .asarray (
66- np .mod (fourier_pts + np .pi , 2 * np .pi ) - np .pi , order = "C" , dtype = self .dtype
36+ self .fourier_pts = np .ascontiguousarray (
37+ np .mod (fourier_pts + np .pi , 2 * np .pi ) - np .pi
6738 )
39+
6840 self .num_pts = fourier_pts .shape [1 ]
69- self .epsilon = epsilon
7041
71- # Get a handle on the appropriate 1d/2d/3d forward transform function in finufftpy
72- self .transform_function = getattr (finufftpy , f"nufft{ self .dim } d2{ manystr } " )
42+ self .epsilon = max (epsilon , np .finfo (self .dtype ).eps )
43+ if self .epsilon != epsilon :
44+ logger .debug (
45+ f"FinufftPlan adjusted eps={ self .epsilon } " f" from requested { epsilon } ."
46+ )
47+
48+ self ._transform_plan = finufft .Plan (
49+ nufft_type = 2 ,
50+ n_modes_or_dim = self .sz ,
51+ eps = self .epsilon ,
52+ n_trans = self .ntransforms ,
53+ dtype = self .dtype ,
54+ )
55+
56+ self ._adjoint_plan = finufft .Plan (
57+ nufft_type = 1 ,
58+ n_modes_or_dim = self .sz ,
59+ eps = self .epsilon ,
60+ n_trans = self .ntransforms ,
61+ dtype = self .dtype ,
62+ )
7363
74- # Get a handle on the appropriate 1d/2d/3d adjoint function in finufftpy
75- self .adjoint_function = getattr ( finufftpy , f"nufft { self .dim } d1 { manystr } " )
64+ self . _transform_plan . setpts ( * self . fourier_pts )
65+ self ._adjoint_plan . setpts ( * self .fourier_pts )
7666
7767 def transform (self , signal ):
7868 """
7969 Compute the NUFFT transform using this plan instance.
8070
8171 :param signal: Signal to be transformed. For a single transform,
8272 this should be a a 1, 2, or 3D array matching the plan `sz`.
83- For a batch, signal should have shape `(*sz, ntransforms )`.
73+ For a batch, signal should have shape `(ntransforms, *sz )`.
8474
8575 :returns: Transformed signal of shape `num_pts` or
8676 `(ntransforms, num_pts)`.
8777 """
8878
89- sig_shape = signal .shape
90- res_shape = self .num_pts
79+ sig_frame_shape = signal .shape
9180 # Note, there is a corner case for ntransforms == 1.
9281 if self .ntransforms > 1 or (
9382 self .ntransforms == 1 and len (signal .shape ) == self .dim + 1
@@ -103,37 +92,18 @@ def transform(self, signal):
10392 f" should match ntransforms { self .ntransforms } ." ,
10493 )
10594
106- sig_shape = signal .shape [1 :]
107- res_shape = (self .ntransforms , self .num_pts )
95+ sig_frame_shape = signal .shape [1 :]
96+
97+ # finufft expects signal.ndim == dim for ntransforms = 1.
98+ if self .ntransforms == 1 :
99+ signal = signal .reshape (self .sz )
108100
109101 ensure (
110- sig_shape == self .sz ,
102+ sig_frame_shape == self .sz ,
111103 f"Signal frame to be transformed must have shape { self .sz } " ,
112104 )
113105
114- epsilon = max (self .epsilon , np .finfo (signal .dtype ).eps )
115-
116- # Forward transform functions in finufftpy have signatures of the form:
117- # (x, y, z, c, isign, eps, f, ...)
118- # (x, y c, isign, eps, f, ...)
119- # (x, c, isign, eps, f, ...)
120- # Where f is a Fortran-order ndarray of the appropriate dimensions
121- # We form these function signatures here by tuple-unpacking
122-
123- result = np .zeros (res_shape , dtype = self .complex_dtype )
124-
125- result_code = self .transform_function (
126- * self .fourier_pts ,
127- result ,
128- - 1 ,
129- epsilon ,
130- signal .T , # RCOPT, currently F ordered, should change in gv2
131- )
132-
133- if result_code != 0 :
134- raise RuntimeError (f"FINufft transform failed. Result code { result_code } " )
135- if self .cast_output :
136- result = result .astype (np .complex64 )
106+ result = self ._transform_plan .execute (signal )
137107
138108 return result
139109
@@ -145,19 +115,9 @@ def adjoint(self, signal):
145115 this should be a a 1D array of len `num_pts`.
146116 For a batch, signal should have shape `(ntransforms, num_pts)`.
147117
148- :returns: Transformed signal `(sz)` or `(sz, ntransforms )`.
118+ :returns: Transformed signal `(sz)` or `(ntransforms, sz )`.
149119 """
150120
151- epsilon = max (self .epsilon , np .finfo (signal .dtype ).eps )
152-
153- # Adjoint functions in finufftpy have signatures of the form:
154- # (x, y, z, c, isign, eps, ms, mt, mu, f, ...)
155- # (x, y c, isign, eps, ms, mt f, ...)
156- # (x, c, isign, eps, ms, f, ...)
157- # Where f is a Fortran-order ndarray of the appropriate dimensions
158- # We form these function signatures here by tuple-unpacking
159-
160- res_shape = self .sz
161121 # Note, there is a corner case for ntransforms == 1.
162122 if self .ntransforms > 1 or (self .ntransforms == 1 and len (signal .shape ) == 2 ):
163123 ensure (
@@ -170,31 +130,11 @@ def adjoint(self, signal):
170130 "For multiple transforms, signal stack length"
171131 f" should match ntransforms { self .ntransforms } ." ,
172132 )
173- res_shape = (
174- self .ntransforms ,
175- * self .sz ,
176- )
177133
178- result = np .zeros (res_shape , dtype = self .complex_dtype )
179-
180- # FINUFFT is F order at this time. The bindings
181- # will pickup the fact `signal` is C_Contiguous,
182- # and transpose the data; we just need to transpose
183- # the indices. I think the next release addresses this.
184- # Note in the 2020 hackathon this was changed directly in FFB,
185- # which worked because GPU arrays just need the pointer anyway...
186- # This is a quirk of this version of FINUFFT, and
187- # so probably belongs here at the edge,
188- # away from other implementations.
189- signal = signal .reshape (signal .shape [::- 1 ])
190-
191- result_code = self .adjoint_function (
192- * self .fourier_pts , signal , 1 , epsilon , * self .sz , result
193- )
194- if result_code != 0 :
195- raise RuntimeError (f"FINufft adjoint failed. Result code { result_code } " )
134+ # finufft is expecting flat array for 1D case.
135+ if self .ntransforms == 1 :
136+ signal = signal .reshape (self .num_pts )
196137
197- if self .cast_output :
198- result = result .astype (np .complex64 )
138+ result = self ._adjoint_plan .execute (signal )
199139
200140 return result
0 commit comments