Skip to content

Commit 00fe92a

Browse files
authored
Merge pull request #334 from ComputationalCryoEM/polar_2d_utest
tighted up polar2d unit test
2 parents 551aa32 + 754703d commit 00fe92a

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

tests/test_PolarBasis2D.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from unittest import TestCase
23

34
import numpy as np
@@ -6,12 +7,18 @@
67
from aspire.image import Image
78
from aspire.utils import complex_type, utest_tolerance
89
from aspire.utils.matlab_compat import m_reshape
10+
from aspire.utils.random import randn
11+
12+
logger = logging.getLogger(__name__)
913

1014

1115
class PolarBasis2DTestCase(TestCase):
1216
def setUp(self):
1317
self.dtype = np.float32
1418
self.basis = PolarBasis2D((8, 8), 4, 32, dtype=self.dtype)
19+
# Note, in practice we got a degenerate random array around 1%
20+
# of the time, so we fix a seed for the randn calls.
21+
self.seed = 8675309
1522

1623
def tearDown(self):
1724
pass
@@ -469,7 +476,7 @@ def testPolarBasis2DAdjoint(self):
469476
# The evaluate function should be the adjoint operator of evaluate_t.
470477
# Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have
471478
# (y, A*x) = (A^t*y, x) = (B*y, x)
472-
x = np.random.randn(self.basis.count).astype(self.dtype)
479+
x = randn(self.basis.count, seed=self.seed).astype(self.dtype)
473480

474481
x = m_reshape(x, (self.basis.nrad, self.basis.ntheta))
475482

@@ -483,14 +490,15 @@ def testPolarBasis2DAdjoint(self):
483490
x = m_reshape(x, (self.basis.nrad * self.basis.ntheta,))
484491

485492
x_t = self.basis.evaluate(x).asnumpy()
486-
y = np.random.randn(np.prod(self.basis.sz)).astype(self.dtype)
493+
y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype)
487494
y_t = self.basis.evaluate_t(
488495
Image(m_reshape(y, self.basis.sz)[np.newaxis, :])
489496
) # RCOPT
490-
self.assertTrue(
491-
np.isclose(
492-
np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz),))),
493-
np.dot(y_t, x),
494-
atol=utest_tolerance(self.dtype),
495-
)
497+
498+
lhs = np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz),)))
499+
rhs = np.real(np.dot(y_t, x))
500+
logging.debug(
501+
f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}"
496502
)
503+
504+
self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype)))

0 commit comments

Comments
 (0)