3535def correct_inter_session_displacement (
3636 recordings_list : list [BaseRecording ],
3737 existing_motion_info : Optional [list [Dict ]] = None ,
38+ keep_channels_constant = False ,
3839 detect_kwargs = {}, # TODO: make non-mutable (same for motion.py)
3940 select_kwargs = {},
4041 localize_peaks_kwargs = {},
@@ -44,10 +45,10 @@ def correct_inter_session_displacement(
4445 from spikeinterface .sortingcomponents .peak_detection import detect_peaks , detect_peak_methods
4546 from spikeinterface .sortingcomponents .peak_selection import select_peaks
4647 from spikeinterface .sortingcomponents .peak_localization import localize_peaks , localize_peak_methods
47- from spikeinterface .sortingcomponents .motion_estimation import estimate_motion
48- from spikeinterface .sortingcomponents .motion_interpolation import InterpolateMotionRecording
48+ from spikeinterface .sortingcomponents .motion . motion_estimation import estimate_motion
49+ from spikeinterface .sortingcomponents .motion . motion_interpolation import InterpolateMotionRecording
4950 from spikeinterface .core .node_pipeline import ExtractDenseWaveforms , run_node_pipeline
50- from spikeinterface .sortingcomponents .motion_utils import Motion
51+ from spikeinterface .sortingcomponents .motion . motion_utils import Motion , get_spatial_windows
5152
5253 # TODO: do not accept multi-segment recordings.
5354 # TODO: check all recordings have the same probe dimensions!
@@ -101,12 +102,13 @@ def correct_inter_session_displacement(
101102 peaks_list = [info ["peaks" ] for info in existing_motion_info ]
102103 peak_locations_list = [info ["peak_locations" ] for info in existing_motion_info ]
103104
104- from spikeinterface .sortingcomponents .motion_estimation import make_2d_motion_histogram , make_3d_motion_histograms
105+ from spikeinterface .sortingcomponents .motion . motion_utils import make_2d_motion_histogram , make_3d_motion_histograms
105106
106107 # make motion histogram
107108 motion_histogram_dim = "2D" # "2D" or "3D", for now only handle 2D case
108109
109110 motion_histogram_list = []
111+ all_temporal_bin_edges = [] # TODO: fix naming
110112
111113 bin_um = 2 # TODO: critial paraneter. easier to take no binning and gaus smooth?
112114
@@ -125,13 +127,13 @@ def correct_inter_session_displacement(
125127 peak_locations ,
126128 weight_with_amplitude = False ,
127129 direction = "y" ,
128- bin_duration_s = recording .get_duration (segment_index = 0 ), # 1.0,
130+ bin_s = recording .get_duration (segment_index = 0 ), # 1.0,
129131 bin_um = bin_um ,
130- margin_um = 50 ,
132+ hist_margin_um = 50 ,
131133 spatial_bin_edges = None ,
132134 )
133135 else :
134- assert NotImplementedError
136+ assert NotImplementedError # TODO: might be old API pre-dredge
135137 motion_histogram = make_3d_motion_histograms (
136138 recording ,
137139 peaks ,
@@ -146,8 +148,8 @@ def correct_inter_session_displacement(
146148 )
147149 motion_histogram_list .append (motion_histogram [0 ].squeeze ())
148150 # store bin edges
149- temporal_bin_edges = motion_histogram [1 ]
150- spatial_bin_edges = motion_histogram [2 ]
151+ all_temporal_bin_edges . append ( motion_histogram [1 ])
152+ spatial_bin_edges_um = motion_histogram [2 ] # should be same across all recordings
151153
152154 # Do some checks on temporal and spatial bin edges that they are all the same?
153155 # TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
@@ -183,6 +185,12 @@ def correct_inter_session_displacement(
183185 # TODO: think will need to make this negative
184186 shifts [i ] = (midpoint - np .argmax (conv )) * bin_um # # TODO: the bin spacing is super important for resoltuion
185187
188+ # half
189+ # TODO: need to figure out interpolation to the center point, weird;y
190+ # the below does not work
191+ # shifts[0] = (shifts[1] / 2)
192+ # shifts[1] = (shifts[1] / 2) * -1
193+ # print("SHIFTS", shifts)
186194 # TODO: handle only the 2D case for now
187195 # TODO: do multi-session optimisation
188196
@@ -196,16 +204,37 @@ def correct_inter_session_displacement(
196204 for i , recording in enumerate (recordings_list ):
197205
198206 # TODO: direct copy, use 'get_window' from motion machinery
199- bin_centers = spatial_bin_edges [:- 1 ] + bin_um / 2.0
200- n = bin_centers .size
201- non_rigid_windows = [np .ones (n , dtype = "float64" )]
202- middle = (spatial_bin_edges [0 ] + spatial_bin_edges [- 1 ]) / 2.0
203- non_rigid_window_centers = np .array ([middle ])
204-
205- motion_array = shifts [i ] # TODO: this is the rigid case!
207+ if False :
208+ bin_centers = spatial_bin_edges [:- 1 ] + bin_um / 2.0
209+ n = bin_centers .size
210+ non_rigid_windows = [np .ones (n , dtype = "float64" )]
211+ middle = (spatial_bin_edges [0 ] + spatial_bin_edges [- 1 ]) / 2.0
212+ non_rigid_window_centers = np .array ([middle ])
213+
214+ dim = 1 # ["x", "y", "z"].index(direction)
215+ contact_depths = recording .get_channel_locations ()[:, dim ]
216+ spatial_bin_centers = 0.5 * (spatial_bin_edges_um [1 :] + spatial_bin_edges_um [:- 1 ])
217+
218+ _ , window_centers = get_spatial_windows (
219+ contact_depths , spatial_bin_centers , rigid = True # TODO: handle non-rigid case
220+ )
221+ # win_shape=win_shape, TODO: handle defaults better
222+ # win_step_um=win_step_um,
223+ # win_scale_um=win_scale_um,
224+ # win_margin_um=win_margin_um,
225+ # zero_threshold=1e-5,
226+
227+ # if shifts[i] == 0:
228+ ## all_recording_corrected.append(recording) # TODO
229+ # continue
230+ temporal_bin_edges = all_temporal_bin_edges [i ]
206231 temporal_bins = 0.5 * (temporal_bin_edges [1 :] + temporal_bin_edges [:- 1 ])
232+
233+ motion_array = np .zeros ((temporal_bins .size , window_centers .size )) # TODO: check this is the expected shape
234+ motion_array [:, :] = shifts [i ] # TODO: this is the rigid case!
235+
207236 motion = Motion (
208- [np . atleast_2d ( motion_array ) ], [temporal_bins ], non_rigid_window_centers , direction = "y"
237+ [motion_array ], [temporal_bins ], window_centers , direction = "y"
209238 ) # will be same for all except for shifts
210239 all_motion_info .append (motion ) # not certain on this
211240
@@ -225,4 +254,15 @@ def correct_inter_session_displacement(
225254 "all_motion_histograms" : motion_histogram_list , # TODO: naming
226255 "all_shifts" : shifts ,
227256 }
257+
258+ if keep_channels_constant :
259+ # TODO: use set
260+ import functools
261+
262+ common_channels = functools .reduce (
263+ np .intersect1d , [recording .channel_ids for recording in all_recording_corrected ]
264+ )
265+
266+ all_recording_corrected = [recording .channel_slice (common_channels ) for recording in all_recording_corrected ]
267+
228268 return all_recording_corrected , displacement_info # TODO: output more stuff later e.g. the Motion object
0 commit comments