diff --git a/examples/example_experimental_data.py b/examples/example_experimental_data.py index 89feb1b..d60749f 100644 --- a/examples/example_experimental_data.py +++ b/examples/example_experimental_data.py @@ -9,14 +9,34 @@ The available denoising methods are "nordic", "mp-pca", "hybrid-pca", "opt-fro", "opt-nuc" and "opt-op". """ -from patch_denoise.simulation.phantom import mr_shepp_logan_t2_star, g_factor_map -from patch_denoise.simulation.activations import add_frames -from patch_denoise.simulation.noise import add_temporal_gaussian_noise +import nibabel as nib +from patch_denoise.space_time.lowrank import OptimalSVDDenoiser +import timeit # %% # Setup the parameters for the simulation and noise -SHAPE = (64, 64, 64) -N_FRAMES = 200 +# SHAPE = (64, 64, 64) +# N_FRAMES = 200 -NOISE_LEVEL = 2 +# NOISE_LEVEL = 2 + +base_path = "/data/parietal/store2/data/ibc/" +#input_path = base_path + "3mm/sub-01/ses-00/func/wrdcsub-01_ses-00_task-ArchiSocial_dir-ap_bold.nii.gz" +input_path = base_path + "sourcedata/sub-01/ses-00/func/sub-01_ses-00_task-ArchiSocial_dir-ap_bold.nii.gz" +output_path = "/scratch/ymzayek/retreat_data/output.nii" + +img = nib.load(input_path) + +print(f"Data shape is {img.shape} with affine \n{img.affine}") + +patch_shape = (11, 11, 11) +patch_overlap = (5) + +# initialize denoiser +optimal_llr = OptimalSVDDenoiser(patch_shape, patch_overlap) + +# denoise image +time_start = timeit.default_timer() +denoised = optimal_llr.denoise(img.get_fdata(), engine="gpu", batch_size=100) +print(timeit.default_timer() - time_start) \ No newline at end of file diff --git a/src/patch_denoise/space_time/base.py b/src/patch_denoise/space_time/base.py index 3b0a660..9e11d76 100644 --- a/src/patch_denoise/space_time/base.py +++ b/src/patch_denoise/space_time/base.py @@ -3,10 +3,11 @@ import logging import numpy as np from tqdm.auto import tqdm +import cupy as cp from .._docs import fill_doc -from .utils import get_patch_locs +from .utils import get_patch_locs, get_patches_gpu @fill_doc @@ -33,7 +34,13 @@ def __init__(self, patch_shape, patch_overlap, recombination="weighted"): self.input_denoising_kwargs = dict() @fill_doc - def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None): + def denoise( + self, + input_data, + mask=None, + mask_threshold=50, + progbar=None, + ): """Denoise the input_data, according to mask. Patches are extracted sequentially and process by the implemented @@ -129,8 +136,111 @@ def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None): noise_std_estimate[patch_slice] += noise_var # the top left corner of the patch is used as id for the patch. rank_map[patch_center_img] = maxidx + if progbar: + progbar.update() + + # Averaging the overlapping pixels. + # this is only required for averaging recombinations. + if self.recombination in ["average", "weighted"]: + output_data /= patchs_weight[..., None] + noise_std_estimate /= patchs_weight + + output_data[~process_mask] = 0 + + return output_data, patchs_weight, noise_std_estimate, rank_map + + def denoise_gpu( + self, + input_data, + mask=None, + mask_threshold=50, + progbar=None, + batch_size=None, + ): + data_shape = input_data.shape + output_data = np.zeros_like(input_data) + rank_map = np.zeros(data_shape[:-1], dtype=np.int32) + # Create Default mask + if mask is None: + process_mask = np.full(data_shape[:-1], True) + else: + process_mask = np.copy(mask) + + patch_shape, patch_overlap = self.__get_patch_param(data_shape) + patch_size = np.prod(patch_shape) + + if self.recombination == "center": + patch_center = ( + *(slice(ps // 2, ps // 2 + 1) for ps in patch_shape), + slice(None, None, None), + ) + patchs_weight = np.zeros(data_shape[:-1], np.float32) + noise_std_estimate = np.zeros(data_shape[:-1], dtype=np.float32) + + # discard useless patches + patch_locs = get_patch_locs(patch_shape, patch_overlap, data_shape[:-1]) + get_it = np.zeros(len(patch_locs), dtype=bool) + + patch_slices = [] + for i, patch_tl in enumerate(patch_locs): + patch_slice = tuple( + slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) + ) + if 100 * np.sum(process_mask[patch_slice]) / patch_size > mask_threshold: + get_it[i] = True + patch_slices.append(patch_slice) + + logging.info(f"Denoise {100 * np.sum(get_it) / len(patch_locs):.2f}% patches") + patch_locs = np.ascontiguousarray(patch_locs[get_it]) + + if progbar is None: + progbar = tqdm(total=len(patch_locs)) + elif progbar is not False: + progbar.reset(total=len(patch_locs)) + + patches = get_patches_gpu(input_data, patch_shape, patch_overlap) + patches[np.isnan(patches)] = np.mean(patches) + + patches_denoise, patches_maxidx, noise_var = self._patch_processing_gpu( + patches, + patch_slices=patch_slices, + batch_size=batch_size, + **self.input_denoising_kwargs, + ) + patches_denoise = cp.asnumpy(patches_denoise) + patches_maxidx = cp.asnumpy(patches_maxidx) + for patch_tl, p_denoise, maxidx in zip(patch_locs, patches_denoise, patches_maxidx): + #breakpoint() + patch_slice = tuple( + slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) + ) + process_mask[patch_slice] = 1 + p_denoise = np.reshape(p_denoise, (*patch_shape, -1)) + patch_center_img = tuple( + ptl + ps // 2 for ptl, ps in zip(patch_tl, patch_shape) + ) + if self.recombination == "center": + output_data[patch_center_img] = p_denoise[patch_center] + noise_std_estimate[patch_center_img] += noise_var + elif self.recombination == "weighted": + theta = 1 / (2 + maxidx) + output_data[patch_slice] += p_denoise * theta + patchs_weight[patch_slice] += theta + elif self.recombination == "average": + output_data[patch_slice] += p_denoise + patchs_weight[patch_slice] += 1 + else: + raise ValueError( + "recombination must be one of 'weighted', 'average', " + "'center'." + ) + if not np.isnan(noise_var): + noise_std_estimate[patch_slice] += noise_var + # the top left corner of the patch is used as id for the patch. + rank_map[patch_center_img] = maxidx if progbar: progbar.update() + # Averaging the overlapping pixels. # this is only required for averaging recombinations. if self.recombination in ["average", "weighted"]: @@ -140,6 +250,7 @@ def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None): output_data[~process_mask] = 0 return output_data, patchs_weight, noise_std_estimate, rank_map + @abc.abstractmethod def _patch_processing(self, patch, patch_slice=None, **kwargs): diff --git a/src/patch_denoise/space_time/lowrank.py b/src/patch_denoise/space_time/lowrank.py index 69d6b26..b8ebb15 100644 --- a/src/patch_denoise/space_time/lowrank.py +++ b/src/patch_denoise/space_time/lowrank.py @@ -4,6 +4,7 @@ import numpy as np from scipy.linalg import svd from scipy.optimize import minimize +import cupy as cp from .base import BaseSpaceTimeDenoiser from .utils import ( @@ -11,6 +12,7 @@ eig_synthesis, marshenko_pastur_median, svd_analysis, + svd_analysis_gpu, svd_synthesis, ) from .._docs import fill_doc @@ -320,6 +322,8 @@ def denoise( noise_std=None, eps_marshenko_pastur=1e-7, progbar=None, + engine="cpu", + batch_size=None, ): """ Optimal thresholing denoising method. @@ -364,7 +368,20 @@ def denoise( else: self.input_denoising_kwargs["var_apriori"] = noise_std**2 - return super().denoise(input_data, mask, mask_threshold, progbar=progbar) + if engine == "cpu": + return super().denoise( + input_data, mask, mask_threshold, progbar=progbar, + ) + elif engine == "gpu": + return super().denoise_gpu( + input_data, + mask, + mask_threshold, + progbar=progbar, + batch_size=batch_size, + ) + else: + raise ValueError(f"Unknown engine: {engine}. Use 'cpu' or 'gpu'.") def _patch_processing( self, @@ -396,6 +413,58 @@ def _patch_processing( return p_new, maxidx, np.NaN + def _patch_processing_gpu( + self, + patches, + patch_slices=None, + shrink_func=None, + mp_median=None, + var_apriori=None, + batch_size=None, + ): + if batch_size is None: + batch_size = patches.shape[0] + u_vec, s_values, v_vec, p_tmean = svd_analysis_gpu( + patches, batch_size=batch_size + ) + if var_apriori is not None: + #sigma = cp.empty((batch_size, m, m), dtype=cp.float64) + for patch_slice in patch_slices: + sigma = np.mean(np.sqrt(var_apriori[patch_slice])) + else: + sigma = cp.median( + s_values, axis=1 + ) / cp.sqrt(patches.shape[-1] * mp_median) + + scale_factor = (cp.sqrt(patches.shape[-1]) * sigma)[..., None] + thresh_s_values = scale_factor * shrink_func( + s_values / scale_factor, + beta=patches.shape[-1] / patches.shape[-2], + ) + thresh_s_values[cp.isnan(thresh_s_values)] = 0 + + # Check all batches to see if they have any values above 0 + check_any = cp.any(thresh_s_values, axis=1) + indices_true = cp.nonzero(check_any)[0] + indices_false = cp.nonzero(~check_any)[0] + maxidx = cp.zeros(thresh_s_values.shape[0]) + p_new = cp.zeros(patches.shape) + + if len(indices_true) > 0: + # Get values at nonzero indices and get the max index for each + thresh_s_values_t = thresh_s_values[indices_true, :] + for i in indices_true: + maxidx[i] = cp.max(cp.array(cp.nonzero(thresh_s_values_t[i]))) + 1 + p_new[i] = ( + u_vec[i, :, :maxidx[i]] @ ( + thresh_s_values_t[i, :maxidx[i], None] * v_vec[i, :maxidx[i], :] + ) + ) + p_tmean[i, :] + if len(indices_false) > 0: + for i in indices_false: + maxidx[i] = 0 + p_new[i] = cp.zeros_like(patches[i]) + p_tmean[i, :] + def _sure_atn_cost(X, method, sing_vals, gamma, sigma=None, tau=None): """ diff --git a/src/patch_denoise/space_time/utils.py b/src/patch_denoise/space_time/utils.py index 0ad13aa..9d9ad72 100644 --- a/src/patch_denoise/space_time/utils.py +++ b/src/patch_denoise/space_time/utils.py @@ -2,6 +2,7 @@ import numpy as np from scipy.integrate import quad from scipy.linalg import eigh, svd +import cupy as cp def svd_analysis(input_data): @@ -19,14 +20,47 @@ def svd_analysis(input_data): ------- u_vec, s_vals, v_vec, mean """ + # TODO benchmark svd vs svds and order of data. mean = np.mean(input_data, axis=0) data_centered = input_data - mean - # TODO benchmark svd vs svds and order of data. u_vec, s_vals, v_vec = svd(data_centered, full_matrices=False) return u_vec, s_vals, v_vec, mean +def svd_analysis_gpu(input_data, batch_size): + total_samples = input_data.shape[0] + num_batches = int(np.ceil(total_samples/ batch_size)) + adjusted_batch_size = total_samples // num_batches + last_batch_size = total_samples % adjusted_batch_size + + # Initialize arrays to store the results + # input_data shape is (total patches, patch size, time) + m = input_data.shape[1] + n = input_data.shape[2] + U_batched = cp.empty((total_samples, m, n), dtype=cp.float64) + S_batched = cp.empty((total_samples, min(m, n)), dtype=cp.float64) + V_batched = cp.empty((total_samples, n, n), dtype=cp.float64) + mean_batched = cp.empty((total_samples, n), dtype=cp.float64) + + # Compute SVD for each matrix in the batch + for i in range(num_batches): + print(i) + start_idx = i * adjusted_batch_size + end_idx = start_idx + adjusted_batch_size if i < num_batches - 1 else start_idx + last_batch_size + idx = slice(start_idx, end_idx) + mean = cp.mean(input_data[idx], axis=1, keepdims=True) + data_centered = cp.asarray(input_data[idx] - mean) + u_vec, s_vals, v_vec = cp.linalg.svd( + data_centered, full_matrices=False + ) + U_batched[idx] = u_vec + S_batched[idx] = s_vals + V_batched[idx] = v_vec + mean_batched[idx] = cp.asarray(cp.squeeze(mean)) + return U_batched, S_batched, V_batched, mean_batched + + def svd_synthesis(u_vec, s_vals, v_vec, mean, idx): """ Reconstruct ``X = (U @ (S * V)) + M`` with only the max_idx greatest component. @@ -197,6 +231,43 @@ def get_patch_locs(p_shape, p_ovl, v_shape): return patch_locs.reshape(-1, len(p_shape)) +def get_patches_gpu(input_data, patch_shape, patch_overlap): + """Extract all the patches from a volume. + + Returns + ------- + numpy.ndarray + All the patches in shape (patches, patch size, time). + """ + patch_size = np.prod(patch_shape) + + # Pad the data + input_data = cp.asarray(input_data) + + c, h, w, t_s = input_data.shape + kc, kh, kw = patch_shape # kernel size + sc, sh, sw = np.repeat( + patch_shape[0] - patch_overlap[0], len(patch_shape) + ) + needed_c = int((cp.ceil((c - kc) / sc + 1) - ((c - kc) / sc + 1)) * kc) + needed_h = int((cp.ceil((h - kh) / sh + 1) - ((h - kh) / sh + 1)) * kh) + needed_w = int((cp.ceil((w - kw) / sw + 1) - ((w - kw) / sw + 1)) * kw) + + input_data_padded = cp.pad( + input_data, ((0, needed_c), (0, needed_h), (0, needed_w), (0, 0) + ), mode='edge') + + step = patch_shape[0] - patch_overlap[0] + patches = cp.lib.stride_tricks.sliding_window_view( + input_data_padded, patch_shape, axis=(0, 1, 2) + )[::step, ::step, ::step] + + patches = patches.transpose((0, 1, 2, 4, 5, 6, 3)) + patches = patches.reshape((np.prod(patches.shape[:3]), patch_size, t_s)) + + return cp.asnumpy(patches) + + def estimate_noise(noise_sequence, block_size=1): """Estimate the temporal noise standard deviation of a noise only sequence.""" volume_shape = noise_sequence.shape[:-1]