Skip to content

Commit 2806e3d

Browse files
Fourier-Bessel Mixin class (#599)
Fourier Bessel Mixin Class: fb.py
1 parent ef9f3bc commit 2806e3d

File tree

9 files changed

+68
-53
lines changed

9 files changed

+68
-53
lines changed

src/aspire/basis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .basis import Basis
55
from .steerable import SteerableBasis2D
6+
from .fb import FBBasisMixin
67

78
# isort: on
89

src/aspire/basis/basis.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import numpy as np
44
from scipy.sparse.linalg import LinearOperator, cg
55

6-
from aspire.basis.basis_utils import num_besselj_zeros
76
from aspire.image import Image
87
from aspire.utils import mdim_mat_fun_conj
9-
from aspire.utils.matlab_compat import m_reshape
108
from aspire.volume import Volume
119

1210
logger = logging.getLogger(__name__)
@@ -47,46 +45,6 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
4745

4846
self._build()
4947

50-
def _getfbzeros(self):
51-
"""
52-
Generate zeros of Bessel functions
53-
"""
54-
# get upper_bound of zeros of Bessel functions
55-
upper_bound = min(self.ell_max + 1, 2 * self.nres + 1)
56-
57-
# List of number of zeros
58-
n = []
59-
# List of zero values (each entry is an ndarray; all of possibly different lengths)
60-
zeros = []
61-
62-
# generate zeros of Bessel functions for each ell
63-
for ell in range(upper_bound):
64-
# for each ell, num_besselj_zeros returns the zeros of the
65-
# order ell Bessel function which are less than 2*pi*c*R = nres*pi/2,
66-
# the truncation rule for the Fourier-Bessel expansion
67-
_n, _zeros = num_besselj_zeros(
68-
ell + (self.ndim - 2) / 2, self.nres * np.pi / 2
69-
)
70-
if _n == 0:
71-
break
72-
else:
73-
n.append(_n)
74-
zeros.append(_zeros)
75-
76-
# get maximum number of ell
77-
self.ell_max = len(n) - 1
78-
79-
# set the maximum of k for each ell
80-
self.k_max = np.array(n, dtype=int)
81-
82-
max_num_zeros = max(len(z) for z in zeros)
83-
for i, z in enumerate(zeros):
84-
zeros[i] = np.hstack(
85-
(z, np.zeros(max_num_zeros - len(z), dtype=self.dtype))
86-
)
87-
88-
self.r0 = m_reshape(np.hstack(zeros), (-1, self.ell_max + 1)).astype(self.dtype)
89-
9048
def _build(self):
9149
"""
9250
Build the internal data structure to represent basis

src/aspire/basis/basis_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def besselj_zeros(nu, k):
247247
return z
248248

249249

250-
def num_besselj_zeros(ell, r):
250+
def all_besselj_zeros(ell, r):
251251
"""
252252
Compute the zeros of the order `ell` Bessel function which are less than `r`.
253253

src/aspire/basis/fb.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
3+
import numpy as np
4+
5+
from aspire.basis.basis_utils import all_besselj_zeros
6+
from aspire.utils.matlab_compat import m_reshape
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class FBBasisMixin(object):
12+
"""
13+
FBBasisMixin is a mixin implementing methods specific to Fourier-Bessel expansions,
14+
to be inherited by Fourier-Bessel subclasses of Basis.
15+
"""
16+
17+
def _calc_k_max(self):
18+
"""
19+
Generate zeros of Bessel functions
20+
"""
21+
# get upper_bound of zeros of Bessel functions
22+
upper_bound = min(self.ell_max + 1, 2 * self.nres + 1)
23+
24+
# List of number of zeros
25+
n = []
26+
# List of zero values (each entry is an ndarray; all of possibly different lengths)
27+
zeros = []
28+
29+
for ell in range(upper_bound):
30+
# for each ell, num_besselj_zeros returns the zeros of the
31+
# order ell Bessel function which are less than 2*pi*c*R = nres*pi/2,
32+
# the truncation rule for the Fourier-Bessel expansion
33+
if self.ndim == 2:
34+
bessel_order = ell
35+
elif self.ndim == 3:
36+
bessel_order = ell + 1 / 2
37+
_n, _zeros = all_besselj_zeros(bessel_order, self.nres * np.pi / 2)
38+
if _n == 0:
39+
break
40+
else:
41+
n.append(_n)
42+
zeros.append(_zeros)
43+
44+
# get maximum number of ell
45+
self.ell_max = len(n) - 1
46+
47+
# set the maximum of k for each ell
48+
self.k_max = np.array(n, dtype=int)
49+
50+
max_num_zeros = max(len(z) for z in zeros)
51+
for i, z in enumerate(zeros):
52+
zeros[i] = np.hstack(
53+
(z, np.zeros(max_num_zeros - len(z), dtype=self.dtype))
54+
)
55+
56+
self.r0 = m_reshape(np.hstack(zeros), (-1, self.ell_max + 1)).astype(self.dtype)

src/aspire/basis/fb_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy.special import jv
55

6-
from aspire.basis import SteerableBasis2D
6+
from aspire.basis import FBBasisMixin, SteerableBasis2D
77
from aspire.basis.basis_utils import unique_coords_nd
88
from aspire.image import Image
99
from aspire.utils import complex_type, real_type, roll_dim, unroll_dim
@@ -12,7 +12,7 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
class FBBasis2D(SteerableBasis2D):
15+
class FBBasis2D(SteerableBasis2D, FBBasisMixin):
1616
"""
1717
Define a derived class using the Fourier-Bessel basis for mapping 2D images
1818
@@ -54,7 +54,7 @@ def _build(self):
5454
)
5555

5656
# get upper bound of zeros, ells, and ks of Bessel functions
57-
self._getfbzeros()
57+
self._calc_k_max()
5858

5959
# calculate total number of basis functions
6060
self.count = self.k_max[0] + sum(2 * self.k_max[1:])

src/aspire/basis/fb_3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import numpy as np
44

5-
from aspire.basis import Basis
5+
from aspire.basis import Basis, FBBasisMixin
66
from aspire.basis.basis_utils import real_sph_harmonic, sph_bessel, unique_coords_nd
77
from aspire.utils import roll_dim, unroll_dim
88
from aspire.utils.matlab_compat import m_flatten, m_reshape
99

1010
logger = logging.getLogger(__name__)
1111

1212

13-
class FBBasis3D(Basis):
13+
class FBBasis3D(Basis, FBBasisMixin):
1414
"""
1515
Define a derived class for direct spherical Harmonics Bessel basis expanding 3D volumes
1616
@@ -49,7 +49,7 @@ def _build(self):
4949
)
5050

5151
# get upper bound of zeros, ells, and ks of Bessel functions
52-
self._getfbzeros()
52+
self._calc_k_max()
5353

5454
# calculate total number of basis functions
5555
self.count = sum(self.k_max * (2 * np.arange(0, self.ell_max + 1) + 1))

src/aspire/basis/ffb_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _build(self):
4444
self.n_theta = int((n_theta + np.mod(n_theta, 2)) / 2)
4545

4646
# get upper bound of zeros, ells, and ks of Bessel functions
47-
self._getfbzeros()
47+
self._calc_k_max()
4848

4949
# calculate total number of basis functions
5050
self.count = self.k_max[0] + sum(2 * self.k_max[1:])

src/aspire/basis/ffb_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _build(self):
3333
self.kcut = 0.5
3434

3535
# get upper bound of zeros, ells, and ks of Bessel functions
36-
self._getfbzeros()
36+
self._calc_k_max()
3737

3838
# calculate total number of basis functions
3939
self.count = sum(self.k_max * (2 * np.arange(0, self.ell_max + 1) + 1))

tests/test_basis_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import numpy as np
44

55
from aspire.basis.basis_utils import (
6+
all_besselj_zeros,
67
besselj_zeros,
78
lgwt,
89
norm_assoc_legendre,
9-
num_besselj_zeros,
1010
real_sph_harmonic,
1111
sph_bessel,
1212
unique_coords_nd,
@@ -41,7 +41,7 @@ def testBesselJZeros(self):
4141
)
4242

4343
def testNumBesselJZeros(self):
44-
n, zeros = num_besselj_zeros(10, 20)
44+
n, zeros = all_besselj_zeros(10, 20)
4545
self.assertEqual(2, n)
4646
self.assertTrue(np.allclose(zeros, [14.47550069, 18.43346367]))
4747

0 commit comments

Comments
 (0)