Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
66ae344
Reducing memory footprint
yger Sep 29, 2025
11958d7
WIP
yger Sep 29, 2025
76fd5d1
WIP
yger Sep 29, 2025
2e1098a
WIP
yger Sep 29, 2025
a37b8f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
40b1f6c
Fixing tests
yger Sep 29, 2025
98ed633
Fixing tests
yger Sep 29, 2025
d7c2e89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
b844f3e
WIP
yger Sep 30, 2025
f8e3ba9
WIP
yger Sep 30, 2025
794102a
Merge branch 'memory_template_similarity' of github.com:yger/spikeint…
yger Sep 30, 2025
0aa76a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
9858fc6
WIP
yger Sep 30, 2025
b51432e
Merge branch 'memory_template_similarity' of github.com:yger/spikeint…
yger Sep 30, 2025
341d980
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
ecfee8d
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterface
yger Oct 3, 2025
4f6a7f1
WIP
yger Oct 3, 2025
76b9a7b
WIP
yger Oct 3, 2025
6a29e3f
Reducing memory footprint for large number of templates/channels
yger Oct 3, 2025
bb3421d
Merge branch 'memory_template_similarity'
yger Oct 3, 2025
5f0e02b
improve iterative_isosplit and remove warnings
samuelgarcia Oct 6, 2025
82223a9
"n_pca_features" 6 > 3
samuelgarcia Oct 6, 2025
50b4143
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 8, 2025
8ccee0e
Merge branch 'main' of github.com:spikeinterface/spikeinterface
yger Oct 9, 2025
426b61d
Merge branch 'main' of github.com:yger/spikeinterface
yger Oct 10, 2025
b76552a
iterative isosplit params
samuelgarcia Oct 13, 2025
22aa5cd
oups
samuelgarcia Oct 15, 2025
61a570e
wip
samuelgarcia Oct 15, 2025
d671acc
Merge branch 'SpikeInterface:main' into main
yger Oct 16, 2025
4ec8408
various try on iterative_isosplit
samuelgarcia Oct 21, 2025
19e77fa
fix git bug
samuelgarcia Oct 22, 2025
d4a2124
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 24, 2025
b6df1cb
Merge branch 'main' of github.com:spikeinterface/spikeinterface
yger Oct 26, 2025
8330277
improve isocut and tdc2
samuelgarcia Oct 27, 2025
9f9bddb
WIP
yger Oct 27, 2025
84aeb92
tdc2 improvement
samuelgarcia Oct 28, 2025
41b5d6b
WIP
yger Oct 28, 2025
ab470f5
WIP
yger Oct 28, 2025
936c31b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
1b5ef48
Alignment during merging
yger Oct 28, 2025
3524264
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into c…
yger Oct 28, 2025
8bff173
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f5869d7
WIP
yger Oct 28, 2025
17691f8
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into c…
yger Oct 28, 2025
27eb077
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
4026a07
WIP
yger Oct 28, 2025
903e85e
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into c…
yger Oct 28, 2025
c6f4708
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
356508d
fix nan in plot perf vs snr
samuelgarcia Oct 29, 2025
c53a887
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 29, 2025
3f22619
WIP
yger Oct 29, 2025
911ea27
WIP
yger Oct 29, 2025
04f4c09
WIP
yger Oct 29, 2025
15d64dc
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into m…
samuelgarcia Oct 30, 2025
88e3081
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 30, 2025
9ddc138
rename spitting_tools to tersplit_tools to avoid double file with sam…
samuelgarcia Oct 30, 2025
dce3b96
compute_similarity_with_templates_array returan lags always
samuelgarcia Oct 30, 2025
055176e
tdc2 params ajustement
samuelgarcia Oct 30, 2025
85eabb3
tdc sc versions
samuelgarcia Oct 31, 2025
e149f1a
More seed in isosplit to avoid test fails
samuelgarcia Nov 3, 2025
24a4195
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
1a8ec68
WIP
yger Nov 3, 2025
fbeb941
Merge pull request #25 from yger/w_before
samuelgarcia Nov 3, 2025
860538f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
172ce9f
comments complicated tests for isosplit that depend on seed
samuelgarcia Nov 3, 2025
072f6ad
Merge branch 'more_isosplit' of github.com:samuelgarcia/spikeinterfac…
samuelgarcia Nov 3, 2025
6dab522
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,9 @@ def _plot_performances_vs_metric(
.get_performance()[performance_name]
.to_numpy(dtype="float64")
)
all_xs.append(x)
all_ys.append(y)
mask = ~np.isnan(x) & ~np.isnan(y)
all_xs.append(x[mask])
all_ys.append(y[mask])

if with_sigmoid_fit:
max_snr = max(np.max(x) for x in all_xs)
Expand Down
19 changes: 13 additions & 6 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _merge_extension_data(
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids, new_sorting_analyzer.channel_ids
)

new_similarity = compute_similarity_with_templates_array(
new_similarity, _ = compute_similarity_with_templates_array(
new_templates_array,
all_templates_array,
method=self.params["method"],
Expand Down Expand Up @@ -146,7 +146,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids
)

new_similarity = compute_similarity_with_templates_array(
new_similarity, _ = compute_similarity_with_templates_array(
new_templates_array,
all_templates_array,
method=self.params["method"],
Expand Down Expand Up @@ -188,7 +188,7 @@ def _run(self, verbose=False):
self.sorting_analyzer, return_in_uV=self.sorting_analyzer.return_in_uV
)
sparsity = self.sorting_analyzer.sparsity
similarity = compute_similarity_with_templates_array(
similarity, _ = compute_similarity_with_templates_array(
templates_array,
templates_array,
method=self.params["method"],
Expand Down Expand Up @@ -393,7 +393,13 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
templates_array,
other_templates_array,
method,
support="union",
num_shifts=0,
sparsity=None,
other_sparsity=None,
):

if method == "cosine_similarity":
Expand Down Expand Up @@ -432,10 +438,11 @@ def compute_similarity_with_templates_array(
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support
)

lags = np.argmin(distances, axis=0) - num_shifts
distances = np.min(distances, axis=0)
similarity = 1 - distances

return similarity
return similarity, lags


def compute_template_similarity_by_pair(
Expand All @@ -445,7 +452,7 @@ def compute_template_similarity_by_pair(
templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_in_uV=True)
sparsity_1 = sorting_analyzer_1.sparsity
sparsity_2 = sorting_analyzer_2.sparsity
similarity = compute_similarity_with_templates_array(
similarity, _ = compute_similarity_with_templates_array(
templates_array_1,
templates_array_2,
method=method,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def test_compute_similarity_with_templates_array(params):
templates_array = rng.random(size=(2, 20, 5))
other_templates_array = rng.random(size=(4, 20, 5))

similarity = compute_similarity_with_templates_array(templates_array, other_templates_array, **params)
similarity, lags = compute_similarity_with_templates_array(templates_array, other_templates_array, **params)
print(similarity.shape)
print(lags)


pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")
Expand Down Expand Up @@ -141,5 +142,5 @@ def test_equal_results_numba(params):
test.cache_folder = Path("./cache_folder")
test.test_extension(params=dict(method="l2"))

# params = dict(method="cosine", num_shifts=8)
# test_compute_similarity_with_templates_array(params)
params = dict(method="cosine", num_shifts=8)
test_compute_similarity_with_templates_array(params)
97 changes: 62 additions & 35 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
_set_optimal_chunk_size,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core import compute_sparsity


class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10},
"general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 20},
"whitening": {"mode": "local", "regularize": False},
"detection": {
"method": "matched_filtering",
Expand All @@ -38,8 +38,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"motion_correction": {"preset": "dredge_fast"},
"merging": {"max_distance_um": 50},
"clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()},
"cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None},
"matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()},
"apply_preprocessing": True,
"apply_whitening": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"chunk_preprocessing": {"memory_limit": None},
"multi_units_only": False,
Expand Down Expand Up @@ -85,7 +87,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):

@classmethod
def get_sorter_version(cls):
return "2025.09"
return "2025.10"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -114,30 +116,50 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
num_channels = recording.get_num_channels()
ms_before = params["general"].get("ms_before", 0.5)
ms_after = params["general"].get("ms_after", 1.5)
radius_um = params["general"].get("radius_um", 100)
radius_um = params["general"].get("radius_um", 100.0)
detect_threshold = params["detection"]["method_kwargs"].get("detect_threshold", 5)
peak_sign = params["detection"].get("peak_sign", "neg")
deterministic = params["deterministic_peaks_detection"]
debug = params["debug"]
seed = params["seed"]
apply_preprocessing = params["apply_preprocessing"]
apply_whitening = params["apply_whitening"]
apply_motion_correction = params["apply_motion_correction"]
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after))

## First, we are filtering the data
filtering_params = params["filtering"].copy()
if apply_preprocessing:
if verbose:
print("Preprocessing the recording (bandpass filtering + CMR + whitening)")
if apply_whitening:
print("Preprocessing the recording (bandpass filtering + CMR + whitening)")
else:
print("Preprocessing the recording (bandpass filtering + CMR)")
recording_f = bandpass_filter(recording, **filtering_params, dtype="float32")
if num_channels > 1:
if num_channels >= 32:
recording_f = common_reference(recording_f)
else:
if verbose:
print("Skipping preprocessing (whitening only)")
recording_f = recording
recording_f.annotate(is_filtered=True)

if apply_whitening:
## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["seed"] = params["seed"]
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
if num_channels == 1:
whitening_kwargs["regularize"] = False
if whitening_kwargs["regularize"]:
whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"}
whitening_kwargs["apply_mean"] = True
recording_w = whiten(recording_f, **whitening_kwargs)
else:
recording_w = recording_f

valid_geometry = check_probe_for_drift_correction(recording_f)
if apply_motion_correction:
if not valid_geometry:
Expand All @@ -151,27 +173,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
motion_correction_kwargs = params["motion_correction"].copy()
motion_correction_kwargs.update({"folder": motion_folder})
noise_levels = get_noise_levels(
recording_f, return_in_uV=False, random_slices_kwargs={"seed": seed}, **job_kwargs
recording_w, return_in_uV=False, random_slices_kwargs={"seed": seed}, **job_kwargs
)
motion_correction_kwargs["detect_kwargs"] = {"noise_levels": noise_levels}
recording_f = correct_motion(recording_f, **motion_correction_kwargs, **job_kwargs)
recording_w = correct_motion(recording_w, **motion_correction_kwargs, **job_kwargs)
else:
motion_folder = None

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["seed"] = params["seed"]
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
if num_channels == 1:
whitening_kwargs["regularize"] = False
if whitening_kwargs["regularize"]:
whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"}
whitening_kwargs["apply_mean"] = True

recording_w = whiten(recording_f, **whitening_kwargs)

noise_levels = get_noise_levels(
recording_w, return_in_uV=False, random_slices_kwargs={"seed": seed}, **job_kwargs
)
Expand Down Expand Up @@ -325,18 +333,33 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if not clustering_from_svd:
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording

templates = get_templates_from_peaks_and_recording(
dense_templates = get_templates_from_peaks_and_recording(
recording_w,
selected_peaks,
peak_labels,
ms_before,
ms_after,
job_kwargs=job_kwargs,
)

sparsity = compute_sparsity(dense_templates, method="radius", radius_um=radius_um)
threshold = params["cleaning"].get("sparsify_threshold", None)
if threshold is not None:
sparsity_snr = compute_sparsity(
dense_templates,
method="snr",
amplitude_mode="peak_to_peak",
noise_levels=noise_levels,
threshold=threshold,
)
sparsity.mask = sparsity.mask & sparsity_snr.mask

templates = dense_templates.to_sparse(sparsity)

else:
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd

templates, _ = get_templates_from_peaks_and_svd(
dense_templates, new_sparse_mask = get_templates_from_peaks_and_svd(
recording_w,
selected_peaks,
peak_labels,
Expand All @@ -348,16 +371,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
operator="median",
)
# this release the peak_svd memmap file
templates = dense_templates.to_sparse(new_sparse_mask)

del more_outs

templates = clean_templates(
templates,
noise_levels=noise_levels,
min_snr=detect_threshold,
max_jitter_ms=0.1,
remove_empty=True,
)
cleaning_kwargs = params.get("cleaning", {}).copy()
cleaning_kwargs["noise_levels"] = noise_levels
cleaning_kwargs["remove_empty"] = True
templates = clean_templates(templates, **cleaning_kwargs)

if verbose:
print("Kept %d clean clusters" % len(templates.unit_ids))
Expand Down Expand Up @@ -416,7 +437,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

if sorting.get_non_empty_unit_ids().size > 0:
final_analyzer = final_cleaning_circus(
recording_w, sorting, templates, job_kwargs=job_kwargs, **merging_params
recording_w,
sorting,
templates,
noise_levels=noise_levels,
job_kwargs=job_kwargs,
**merging_params,
)
final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer")

Expand Down Expand Up @@ -451,14 +477,15 @@ def final_cleaning_circus(
max_distance_um=50,
template_diff_thresh=np.arange(0.05, 0.5, 0.05),
debug_folder=None,
job_kwargs=None,
noise_levels=None,
job_kwargs=dict(),
):

from spikeinterface.sortingcomponents.tools import create_sorting_analyzer_with_existing_templates
from spikeinterface.curation.auto_merge import auto_merge_units

# First we compute the needed extensions
analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates)
analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates, noise_levels=noise_levels)
analyzer.compute("unit_locations", method="center_of_mass", **job_kwargs)
analyzer.compute("template_similarity", **similarity_kwargs)

Expand Down
33 changes: 20 additions & 13 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
},
"filtering": {
"freq_min": 150.0,
"freq_max": 5000.0,
"freq_max": 6000.0,
"ftype": "bessel",
"filter_order": 2,
},
"detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"svd": {"n_components": 4},
"svd": {"n_components": 10},
"clustering": {
"recursive_depth": 5,
"recursive_depth": 3,
},
"templates": {
"ms_before": 2.0,
Expand Down Expand Up @@ -84,7 +84,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):

@classmethod
def get_sorter_version(cls):
return "2025.09"
return "2025.10"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -130,7 +130,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print("Done correct_motion()")

recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32")
recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20.0, dtype="float32")
if apply_cmr:
recording = common_reference(recording)

Expand Down Expand Up @@ -197,6 +197,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`"
# )

# whitenning do not improve in tdc2
# recording_w = whiten(recording, mode="global")

unit_ids, clustering_label, more_outs = find_clusters_from_peaks(
recording,
peaks,
Expand All @@ -206,9 +209,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs=job_kwargs,
)

# peak_shifts = extra_out["peak_shifts"]
# new_peaks = peaks.copy()
# new_peaks["sample_index"] -= peak_shifts
new_peaks = peaks

mask = clustering_label >= 0
Expand Down Expand Up @@ -252,16 +252,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
is_in_uV=False,
)

# sparsity is a mix between radius and
sparsity_threshold = params["templates"]["sparsity_threshold"]
sparsity = compute_sparsity(
templates_dense, method="snr", noise_levels=noise_levels, threshold=sparsity_threshold
radius_um = params["waveforms"]["radius_um"]
sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um)
sparsity_snr = compute_sparsity(
templates_dense,
method="snr",
amplitude_mode="peak_to_peak",
noise_levels=noise_levels,
threshold=sparsity_threshold,
)
sparsity.mask = sparsity.mask & sparsity_snr.mask
templates = templates_dense.to_sparse(sparsity)
# templates = remove_empty_templates(templates)

templates = clean_templates(
templates_dense,
sparsify_threshold=params["templates"]["sparsity_threshold"],
templates,
sparsify_threshold=None,
noise_levels=noise_levels,
min_snr=params["templates"]["min_snr"],
max_jitter_ms=None,
Expand Down
Loading