@@ -12,7 +12,7 @@ class TestSessionDisplacementGenerator:
1212 """
1313 This class tests the `generate_session_displacement_recordings` that
1414 returns a recordings / sorting in which the units are shifted
15- across sessions. This is acheived by shifting the unit locations
15+ across sessions. This is achieved by shifting the unit locations
1616 in both (x, y) on the generated templates that are used in
1717 `InjectTemplatesRecording()`.
1818 """
@@ -136,7 +136,7 @@ def test_recordings_length(self, options):
136136 for rec , expected_rec_length in zip (output_recordings , options ["kwargs" ]["recording_durations" ]):
137137 assert rec .get_total_duration () == expected_rec_length
138138
139- def test_spike_times_across_recordings (self , options ):
139+ def test_spike_times_and_firing_rates_across_recordings (self , options ):
140140 """
141141 Check the randomisation of spike times across recordings.
142142 When a seed is set, this is passed to `generate_sorting`
@@ -146,14 +146,17 @@ def test_spike_times_across_recordings(self, options):
146146 """
147147 options ["kwargs" ]["recording_durations" ] = (10 ,) * options ["num_recs" ]
148148
149- output_sortings_same = generate_session_displacement_recordings (** options ["kwargs" ])[1 ]
149+ output_sortings_same , extra_outputs_same = generate_session_displacement_recordings (** options ["kwargs" ])[1 : 3 ]
150150
151151 options ["kwargs" ]["seed" ] = None
152- output_sortings_different = generate_session_displacement_recordings (** options ["kwargs" ])[1 ]
152+ output_sortings_different , extra_outputs_different = generate_session_displacement_recordings (
153+ ** options ["kwargs" ]
154+ )[1 :3 ]
153155
154156 for unit_idx in range (options ["kwargs" ]["num_units" ]):
155157 for rec_idx in range (1 , options ["num_recs" ]):
156158
159+ # Exact spike times are not preserved when seed is None
157160 assert np .array_equal (
158161 output_sortings_same [0 ].get_unit_spike_train (unit_idx ),
159162 output_sortings_same [rec_idx ].get_unit_spike_train (unit_idx ),
@@ -162,6 +165,15 @@ def test_spike_times_across_recordings(self, options):
162165 output_sortings_different [0 ].get_unit_spike_train (unit_idx ),
163166 output_sortings_different [rec_idx ].get_unit_spike_train (unit_idx ),
164167 )
168+ # Firing rates should always be preserved.
169+ assert np .array_equal (
170+ extra_outputs_same ["firing_rates" ][0 ][unit_idx ],
171+ extra_outputs_same ["firing_rates" ][rec_idx ][unit_idx ],
172+ )
173+ assert np .array_equal (
174+ extra_outputs_different ["firing_rates" ][0 ][unit_idx ],
175+ extra_outputs_different ["firing_rates" ][rec_idx ][unit_idx ],
176+ )
165177
166178 @pytest .mark .parametrize ("dim_idx" , [0 , 1 ])
167179 def test_x_y_shift_non_rigid (self , options , dim_idx ):
@@ -271,32 +283,70 @@ def test_displacement_with_peak_detection(self, options):
271283 assert np .isclose (new_pos , first_pos + y_shift , rtol = 0 , atol = options ["y_bin_um" ])
272284
273285 def test_amplitude_scalings (self , options ):
274-
286+ """
287+ Test that the templates are scaled by the passed scaling factors
288+ in the specified order. The order can be in the passed order,
289+ in the order of highest-to-lowest firing unit, or in the order
290+ of (amplitude * firing_rate) (highest to lowest unit).
291+ """
292+ # Setup arguments to create an unshifted set of recordings
293+ # where the templates are to be scaled with `true_scalings`
275294 options ["kwargs" ]["recording_durations" ] = (10 , 10 )
276295 options ["kwargs" ]["recording_shifts" ] = ((0 , 0 ), (0 , 0 ))
277296 options ["kwargs" ]["num_units" ] == 5 ,
278297
298+ true_scalings = np .array ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 ])
299+
279300 recording_amplitude_scalings = {
280301 "method" : "by_passed_order" ,
281- "scalings" : (np .ones (5 ), np . array ([ 0.1 , 0.2 , 0.3 , 0.4 , 0.5 ]) ),
302+ "scalings" : (np .ones (5 ), true_scalings ),
282303 }
283304
284305 _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
285306 ** options ["kwargs" ],
286307 recording_amplitude_scalings = recording_amplitude_scalings ,
287308 )
288- breakpoint ()
289- first , second = extra_outputs ["templates_array_moved" ] # TODO: own function
290- first_min = np .min (np .min (first , axis = 2 ), axis = 1 )
291- second_min = np .min (np .min (second , axis = 2 ), axis = 1 )
292- scales = second_min / first_min
293309
294- assert np .allclose (scales , shifts )
310+ # Check that the unit templates are scaled in the order
311+ # the scalings were passed.
312+ test_scalings = self ._calculate_scalings_from_output (extra_outputs )
313+ assert np .allclose (test_scalings , true_scalings )
314+
315+ # Now run, again applying the scalings in the order of
316+ # unit firing rates (highest to lowest).
317+ firing_rates = np .array ([5 , 4 , 3 , 2 , 1 ])
318+ generate_sorting_kwargs = dict (firing_rates = firing_rates , refractory_period_ms = 4.0 )
319+ recording_amplitude_scalings ["method" ] = "by_firing_rate"
320+ _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
321+ ** options ["kwargs" ],
322+ recording_amplitude_scalings = recording_amplitude_scalings ,
323+ generate_sorting_kwargs = generate_sorting_kwargs ,
324+ )
325+
326+ test_scalings = self ._calculate_scalings_from_output (extra_outputs )
327+ assert np .allclose (test_scalings , true_scalings [np .argsort (firing_rates )])
295328
296- # TODO: scale based on recording output
297- # check scaled by amplitude.
329+ # Finally, run again applying the scalings in the order of
330+ # unit amplitude * firing_rate
331+ recording_amplitude_scalings ["method" ] = "by_amplitude_and_firing_rate" # TODO: method -> order
332+ amplitudes = np .min (np .min (extra_outputs ["templates_array_moved" ][0 ], axis = 2 ), axis = 1 )
333+ firing_rate_by_amplitude = np .argsort (amplitudes * firing_rates )
298334
299- breakpoint ()
335+ _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
336+ ** options ["kwargs" ],
337+ recording_amplitude_scalings = recording_amplitude_scalings ,
338+ generate_sorting_kwargs = generate_sorting_kwargs ,
339+ )
340+
341+ test_scalings = self ._calculate_scalings_from_output (extra_outputs )
342+ assert np .allclose (test_scalings , true_scalings [firing_rate_by_amplitude ])
343+
344+ def _calculate_scalings_from_output (self , extra_outputs ):
345+ first , second = extra_outputs ["templates_array_moved" ]
346+ first_min = np .min (np .min (first , axis = 2 ), axis = 1 )
347+ second_min = np .min (np .min (second , axis = 2 ), axis = 1 )
348+ test_scalings = second_min / first_min
349+ return test_scalings
300350
301351 def test_metadata (self , options ):
302352 """
@@ -339,7 +389,7 @@ def test_same_as_generate_ground_truth_recording(self):
339389 generate_probe_kwargs = None
340390 generate_unit_locations_kwargs = dict ()
341391 generate_templates_kwargs = dict (ms_before = 1.5 , ms_after = 3 )
342- generate_sorting_kwargs = dict ()
392+ generate_sorting_kwargs = dict (firing_rates = 1 )
343393 generate_noise_kwargs = dict ()
344394 seed = 42
345395
0 commit comments