@@ -88,24 +88,7 @@ def is_iterable(obj):
8888class TestCase (unittest .TestCase ):
8989 precision = 1e-5
9090
91- def assertExpected (self , output , subname = None , prec = None , strip_suffix = None ):
92- r"""
93- Test that a python value matches the recorded contents of a file
94- derived from the name of this test and subname. The value must be
95- pickable with `torch.save`. This file
96- is placed in the 'expect' directory in the same directory
97- as the test script. You can automatically update the recorded test
98- output using --accept.
99-
100- If you call this multiple times in a single function, you must
101- give a unique subname each time.
102-
103- strip_suffix allows different tests that expect similar numerics, e.g.
104- "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
105- test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
106- strip_suffix="_cpu", and they would both use a data file name based on
107- "test_xyz".
108- """
91+ def _get_expected_file (self , subname = None , strip_suffix = None ):
10992 def remove_prefix_suffix (text , prefix , suffix ):
11093 if text .startswith (prefix ):
11194 text = text [len (prefix ):]
@@ -128,33 +111,41 @@ def remove_prefix_suffix(text, prefix, suffix):
128111 subname_output = " ({})" .format (subname )
129112 expected_file += "_expect.pkl"
130113
131- def accept_output (update_type ):
132- print ("Accepting {} for {}{}:\n \n {}" .format (update_type , munged_id , subname_output , output ))
114+ if not ACCEPT and not os .path .exists (expected_file ):
115+ raise RuntimeError (
116+ ("No expect file exists for {}{}; to accept the current output, run:\n "
117+ "python {} {} --accept" ).format (munged_id , subname_output , __main__ .__file__ , munged_id ))
118+
119+ return expected_file
120+
121+ def assertExpected (self , output , subname = None , prec = None , strip_suffix = None ):
122+ r"""
123+ Test that a python value matches the recorded contents of a file
124+ derived from the name of this test and subname. The value must be
125+ pickable with `torch.save`. This file
126+ is placed in the 'expect' directory in the same directory
127+ as the test script. You can automatically update the recorded test
128+ output using --accept.
129+
130+ If you call this multiple times in a single function, you must
131+ give a unique subname each time.
132+
133+ strip_suffix allows different tests that expect similar numerics, e.g.
134+ "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
135+ test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
136+ strip_suffix="_cpu", and they would both use a data file name based on
137+ "test_xyz".
138+ """
139+ expected_file = self ._get_expected_file (subname , strip_suffix )
140+
141+ if ACCEPT :
142+ print ("Accepting updated output for {}:\n \n {}" .format (os .path .basename (expected_file ), output ))
133143 torch .save (output , expected_file )
134144 MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
135145 binary_size = os .path .getsize (expected_file )
136146 self .assertTrue (binary_size <= MAX_PICKLE_SIZE )
137-
138- try :
139- expected = torch .load (expected_file )
140- except IOError as e :
141- if e .errno != errno .ENOENT :
142- raise
143- elif ACCEPT :
144- accept_output ("output" )
145- return
146- else :
147- raise RuntimeError (
148- ("I got this output for {}{}:\n \n {}\n \n "
149- "No expect file exists; to accept the current output, run:\n "
150- "python {} {} --accept" ).format (munged_id , subname_output , output , __main__ .__file__ , munged_id ))
151-
152- if ACCEPT :
153- try :
154- self .assertEqual (output , expected , prec = prec )
155- except Exception :
156- accept_output ("updated output" )
157147 else :
148+ expected = torch .load (expected_file )
158149 self .assertEqual (output , expected , prec = prec )
159150
160151 def assertEqual (self , x , y , prec = None , message = '' , allow_inf = False ):
0 commit comments