Skip to content

Commit d975ab9

Browse files
committed
replace scipy calls with interop wrapper
1 parent b2536f1 commit d975ab9

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

src/aspire/basis/basis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22

33
import numpy as np
4-
from scipy.sparse.linalg import LinearOperator, cg
4+
from scipy.sparse.linalg import LinearOperator
55

66
from aspire.image import Image
7+
from aspire.numeric.scipy import cg
78
from aspire.utils import mdim_mat_fun_conj
89
from aspire.volume import Volume
910

@@ -580,7 +581,7 @@ def expand(self, x, tol=None, atol=0):
580581
for isample in range(0, n_data):
581582
b = self.evaluate_t(self._cls(x[isample])).asnumpy().T
582583
# TODO: need check the initial condition x0 can improve the results or not.
583-
v[isample], info = cg(operator, b, tol=tol, atol=atol)
584+
v[isample], info = cg(operator, b, rtol=tol, atol=atol)
584585
if info != 0:
585586
raise RuntimeError(f"Unable to converge! cg info={info}")
586587

src/aspire/covariance/covar.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from functools import partial
33

44
import numpy as np
5-
import scipy.sparse.linalg
65
from scipy.linalg import norm
76
from scipy.sparse.linalg import LinearOperator
87

98
from aspire.nufft import anufft
109
from aspire.numeric import fft
10+
from aspire.numeric.scipy import cg
1111
from aspire.operators import evaluate_src_filters_on_grid
1212
from aspire.reconstruction import Estimator, FourierKernel, MeanEstimator
1313
from aspire.utils import (
@@ -127,9 +127,7 @@ def cb(xk):
127127
f"Delta {norm(b_coef - self.apply_kernel(xk, packed=True))} (target {target_residual})"
128128
)
129129

130-
x, info = scipy.sparse.linalg.cg(
131-
operator, b_coef, M=M, callback=cb, tol=tol, atol=0
132-
)
130+
x, info = cg(operator, b_coef, M=M, callback=cb, rtol=tol, atol=0)
133131

134132
if info != 0:
135133
raise RuntimeError("Unable to converge!")

src/aspire/numeric/scipy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
Utility wrappers for scipy methods.
33
"""
44

5+
import scipy
6+
from packaging.version import Version
57

6-
from scipy.sparse.linalg import cg
78

8-
def cg(*args,**kwargs):
9+
def cg(*args, **kwargs):
910
"""
1011
Supports scipy cg before and after 1.14.0.
1112
"""
1213

1314
# older scipy cg interface uses `tol` instead of `rtol`
1415
if Version(scipy.__version__) < Version("1.14.0"):
1516
kwargs["tol"] = kwargs.pop("rtol", None)
16-
return cg(*args,**kwargs)
17+
return scipy.sparse.linalg.cg(*args, **kwargs)

src/aspire/reconstruction/mean.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import numpy as np
55
from scipy.linalg import norm
6-
from scipy.sparse.linalg import LinearOperator, cg
6+
from scipy.sparse.linalg import LinearOperator
77

88
from aspire import config
99
from aspire.basis import Coef
1010
from aspire.nufft import anufft
1111
from aspire.numeric import fft
12+
from aspire.numeric.scipy import cg
1213
from aspire.operators import evaluate_src_filters_on_grid
1314
from aspire.reconstruction import Estimator, FourierKernel, FourierKernelMatrix
1415
from aspire.volume import Volume, rotated_grids
@@ -251,7 +252,7 @@ def cb(xk):
251252
x0=x0,
252253
M=M,
253254
callback=cb,
254-
tol=tol,
255+
rtol=tol,
255256
atol=0,
256257
maxiter=self.maxiter,
257258
)

0 commit comments

Comments
 (0)