@@ -15,7 +15,7 @@ class Basis:
1515 Define a base class for expanding 2D particle images and 3D structure volumes
1616
1717 """
18- def __init__ (self , size , ell_max = None ):
18+ def __init__ (self , size , ell_max = None , dtype = np . float64 ):
1919 """
2020 Initialize an object for the base of basis class
2121
@@ -36,6 +36,9 @@ def __init__(self, size, ell_max=None):
3636 self .count = 0
3737 self .ell_max = ell_max
3838 self .ndim = ndim
39+ self .dtype = dtype
40+ if self .dtype != np .float64 :
41+ raise NotImplementedError ("Currently only implemented for default double (np.float64) type" )
3942
4043 self ._build ()
4144
@@ -149,46 +152,46 @@ def mat_evaluate_t(self, X):
149152 """
150153 return mdim_mat_fun_conj (X , len (self .sz ), 1 , self .evaluate_t )
151154
152- def expand (self , v ):
155+ def expand (self , x ):
153156 """
154- Expand array in basis
157+ Obtain coefficients in the basis from those in standard coordinate basis
155158
156- This is a similar function to ` evaluate_t` but with more accuracy by
157- using the cg optimizing of linear equation, Ax=b.
159+ This is a similar function to evaluate_t but with more accuracy by using
160+ the cg optimizing of linear equation, Ax=b.
158161
159- If `v` is a matrix of size `basis.ct`-by-..., `B` is the change-of-basis
160- matrix of this basis, and `x` is a matrix of size `self.sz`-by-...,
161- the function calculates v = (B' * B)^(-1) * B' * x, where the rows
162- of `B` and columns of `x` are read as vectorized arrays.
163-
164- :param v: An array whose first few dimensions are to be expanded in this basis.
165- These dimensions must equal `self.sz`.
166- :return: The coefficients of `v` expanded in this basis. If more than
167- one array of size `self.sz` is found in `v`, the second and higher
168- dimensions of the return value correspond to those higher dimensions of `v`.
162+ :param x: An array whose first two or three dimensions are to be expanded
163+ the desired basis. These dimensions must equal `self.sz`.
164+ :return : The coefficients of `v` expanded in the desired basis.
165+ The first dimension of `v` is with size of `count` and the
166+ second and higher dimensions of the return value correspond to
167+ those higher dimensions of `x`.
169168
170169 """
171- ensure (v .shape [:self .ndim ] == self .sz , f'First { self .ndim } dimensions of v must match { self .sz } .' )
170+ # ensure the first dimensions with size of self.sz
171+ x , sz_roll = unroll_dim (x , self .ndim + 1 )
172+ ensure (x .shape [:self .ndim ] == self .sz ,
173+ f'First { self .ndim } dimensions of x must match { self .sz } .' )
172174
173- v , sz_roll = unroll_dim (v , self .ndim + 1 )
174- b = self .evaluate_t (v )
175- operator = LinearOperator (
176- shape = (self .count , self .count ),
177- matvec = lambda x : self .evaluate_t (self .evaluate (x ))
178- )
175+ operator = LinearOperator (shape = (self .count , self .count ),
176+ matvec = lambda v : self .evaluate_t (self .evaluate (v )))
179177
180178 # TODO: (from MATLAB implementation) - Check that this tolerance make sense for multiple columns in v
181- tol = 10 * np .finfo (v .dtype ).eps
179+ tol = 10 * np .finfo (x .dtype ).eps
182180 logger .info ('Expanding array in basis' )
183- v , info = cg (operator , b , tol = tol )
184181
185- v = v [..., np .newaxis ]
182+ # number of image samples
183+ n_data = np .size (x , self .ndim )
184+ v = np .zeros ((self .count , n_data ), dtype = x .dtype )
186185
187- if info != 0 :
188- raise RuntimeError ('Unable to converge!' )
186+ for isample in range (0 , n_data ):
187+ b = self .evaluate_t (x [..., isample ])
188+ # TODO: need check the initial condition x0 can improve the results or not.
189+ v [..., isample ], info = cg (operator , b , tol = tol )
190+ if info != 0 :
191+ raise RuntimeError ('Unable to converge!' )
189192
193+ # return v coefficients with the first dimension of self.count
190194 v = roll_dim (v , sz_roll )
191-
192195 return v
193196
194197 def expand_t (self , v ):
0 commit comments