Skip to content

Commit d595829

Browse files
committed
tighten up types in polar2d adjoint test and fix random seed
1 parent cc6afc5 commit d595829

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

tests/test_PolarBasis2D.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from unittest import TestCase
23

34
import numpy as np
@@ -6,12 +7,16 @@
67
from aspire.image import Image
78
from aspire.utils import complex_type, utest_tolerance
89
from aspire.utils.matlab_compat import m_reshape
10+
from aspire.utils.random import randn
11+
12+
logger = logging.getLogger(__name__)
913

1014

1115
class PolarBasis2DTestCase(TestCase):
1216
def setUp(self):
1317
self.dtype = np.float32
1418
self.basis = PolarBasis2D((8, 8), 4, 32, dtype=self.dtype)
19+
self.seed = 8675309
1520

1621
def tearDown(self):
1722
pass
@@ -469,7 +474,7 @@ def testPolarBasis2DAdjoint(self):
469474
# The evaluate function should be the adjoint operator of evaluate_t.
470475
# Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have
471476
# (y, A*x) = (A^t*y, x) = (B*y, x)
472-
x = np.random.randn(self.basis.count).astype(self.dtype)
477+
x = randn(self.basis.count, seed=self.seed).astype(self.dtype)
473478

474479
x = m_reshape(x, (self.basis.nrad, self.basis.ntheta))
475480

@@ -483,14 +488,15 @@ def testPolarBasis2DAdjoint(self):
483488
x = m_reshape(x, (self.basis.nrad * self.basis.ntheta,))
484489

485490
x_t = self.basis.evaluate(x).asnumpy()
486-
y = np.random.randn(np.prod(self.basis.sz)).astype(self.dtype)
491+
y = randn(np.prod(self.basis.sz), seed=self.seed).astype(self.dtype)
487492
y_t = self.basis.evaluate_t(
488493
Image(m_reshape(y, self.basis.sz)[np.newaxis, :])
489494
) # RCOPT
490-
self.assertTrue(
491-
np.isclose(
492-
np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz),))),
493-
np.dot(y_t, x),
494-
atol=utest_tolerance(self.dtype),
495-
)
495+
496+
lhs = np.dot(y, m_reshape(x_t, (np.prod(self.basis.sz),)))
497+
rhs = np.real(np.dot(y_t, x))
498+
logging.debug(
499+
f"lhs: {lhs} rhs: {rhs} absdiff: {np.abs(lhs-rhs)} atol: {utest_tolerance(self.dtype)}"
496500
)
501+
502+
self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype)))

0 commit comments

Comments
 (0)