1+ import logging
12from unittest import TestCase
23
34import numpy as np
67from aspire .image import Image
78from aspire .utils import complex_type , utest_tolerance
89from aspire .utils .matlab_compat import m_reshape
10+ from aspire .utils .random import randn
11+
12+ logger = logging .getLogger (__name__ )
913
1014
1115class PolarBasis2DTestCase (TestCase ):
1216 def setUp (self ):
1317 self .dtype = np .float32
1418 self .basis = PolarBasis2D ((8 , 8 ), 4 , 32 , dtype = self .dtype )
19+ # Note, in practice we got a degenerate random array around 1%
20+ # of the time, so we fix a seed for the randn calls.
21+ self .seed = 8675309
1522
1623 def tearDown (self ):
1724 pass
@@ -469,7 +476,7 @@ def testPolarBasis2DAdjoint(self):
469476 # The evaluate function should be the adjoint operator of evaluate_t.
470477 # Namely, if A = evaluate, B = evaluate_t, and B=A^t, we will have
471478 # (y, A*x) = (A^t*y, x) = (B*y, x)
472- x = np . random . randn (self .basis .count ).astype (self .dtype )
479+ x = randn (self .basis .count , seed = self . seed ).astype (self .dtype )
473480
474481 x = m_reshape (x , (self .basis .nrad , self .basis .ntheta ))
475482
@@ -483,14 +490,15 @@ def testPolarBasis2DAdjoint(self):
483490 x = m_reshape (x , (self .basis .nrad * self .basis .ntheta ,))
484491
485492 x_t = self .basis .evaluate (x ).asnumpy ()
486- y = np . random . randn (np .prod (self .basis .sz )).astype (self .dtype )
493+ y = randn (np .prod (self .basis .sz ), seed = self . seed ).astype (self .dtype )
487494 y_t = self .basis .evaluate_t (
488495 Image (m_reshape (y , self .basis .sz )[np .newaxis , :])
489496 ) # RCOPT
490- self .assertTrue (
491- np .isclose (
492- np .dot (y , m_reshape (x_t , (np .prod (self .basis .sz ),))),
493- np .dot (y_t , x ),
494- atol = utest_tolerance (self .dtype ),
495- )
497+
498+ lhs = np .dot (y , m_reshape (x_t , (np .prod (self .basis .sz ),)))
499+ rhs = np .real (np .dot (y_t , x ))
500+ logging .debug (
501+ f"lhs: { lhs } rhs: { rhs } absdiff: { np .abs (lhs - rhs )} atol: { utest_tolerance (self .dtype )} "
496502 )
503+
504+ self .assertTrue (np .isclose (lhs , rhs , atol = utest_tolerance (self .dtype )))
0 commit comments