% rigidBodyCorrection() - Compute a robust rigid body template from good
%                         part of data, and applies it to elsewhere.
%                         The procesure follows these steps.
%                         1. Exclude channel if continuous >50% or total >75% data
%                            are NaN.
%                         2. Compute a robust template from data points
%                            with all LED markers.
%                         3. Fit the template where all LED markers are present
%                            (Lv.1 correction). 
%                         4. Fit the template where at least 3 LED markers
%                            are present (Lv.2 correction).
%                         5. Spline-interpolate where less than 3 LED
%                            markers are available, then fit the rigid body template
%                            (Lv.3 correction).
%                         6. Detect high-frequency spikes in overall data,
%                            spline-interpolate them, and fit the rigid-body template
%                            (Lv.4 correction).
%
% Use:
%   >> [correctedCoordinates, lessThanThreeLedMarkersIdx, correctionLevelIndicator, meanRmsError, rollPitchYaw] = rigidBodyCorrection(inputCoordinates)
%
% Inputs:
%   'inputCoordinates'           - 3xnxk for xyz coordinates, n channels, and k time points.
%
% Outputs:
%   'correctedCoordinates'       - 3xnxk for xyz coordinates, n channels, and k time points.
%   'lessThanThreeLedMarkersIdx' - 1xk data point indices where only less than 3
%                                  LED markers are available i.e., rigid-body fit was not
%                                  available for the original data; they are
%                                  spline-interpolated first then fit with the rigid-body.
%   'correctionLevelIndicator'   - 1xn point-by-point correction level indicator.
%                                  (Lv.1-Lv.4; see above for definition of the levels.)
%   'meanRmsError'               - (framesWith >= 3 markers)x(ch^2) fluctuations of
%                                  across-LED distances within a rigid-body. 
%   'rollPitchYaw'               - 3xchxn roll, pitch, yaw of the template rigid body, after correction.

% Author:
%    Makoto Miyakoshi. SCCN, INC, UCSD. mmiyakoshi@ucsd.edu
%
% History:
%    12/20/2017 Makoto. Use convhull() to compute area when only 3 points are available.
%    12/19/2017 Makoto. Replace distance matrix calculation with volume calculation using convhull().
%    12/18/2017 Makoto. When building a template, frames are rotated (how did it drop off...)
%    09/07/2017 Makoto. >50% NaN channel is excluded. 
%    08/10/2017 Makoto. Created.

% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

function [correctedCoordinates, lessThanThreeLedMarkersIdx, correctionLevelIndicator,...
                         meanRmsError, rollPitchYaw] = rigidBodyCorrection(inputCoordinates)

% Visualize the input data.
% figure; plot(reshape(inputCoordinates, [size(inputCoordinates,1)*size(inputCoordinates,2) size(inputCoordinates,3)])')               

% Create correction level indicator.
correctionLevelIndicator = zeros(size(inputCoordinates,3),1);                        
                        
% Replace zeros with NaN.
nanInputCoordinates = inputCoordinates;
nanInputCoordinates(nanInputCoordinates==0) = NaN;
    % figure; plot(reshape(nanInputCoordinates, [size(inputCoordinates,1)*size(inputCoordinates,2) size(inputCoordinates,3)])')               


    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% If all channels are available AND longer than 10 sec, proceed. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%   
srate = evalin('base', 'EEG.srate');
numNonNanChannels = sum(~isnan(squeeze(nanInputCoordinates(1,:,:))));
maxAvailableIdx = find(numNonNanChannels==max(numNonNanChannels));
if max(numNonNanChannels) == size(nanInputCoordinates,2) && length(maxAvailableIdx) > srate*10
    deadChannelFlag = 0;
    
    
    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% If channel(s) needs to be dropped out, drop them out and find second best channel combination that is longer than 30 sec available. %%% 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%      
else
    deadChannelFlag = 1;
    
    % If the max channel indices is longer than 10 sec, use max; if shorter than 30 sec, use max-1 to hopefully capture longer than 30 sec (no guarantee).
    maxAvailableIdx = find(numNonNanChannels==max(numNonNanChannels));
    if length(maxAvailableIdx) < srate*10
        maxAvailableIdx = find(numNonNanChannels==max(numNonNanChannels)-1);
    end
    
    % Find which combination of the channels has the most datapoints.
    maxAvailableMatrix = ~isnan(squeeze(nanInputCoordinates(1,:,maxAvailableIdx)));
    uniqueLabelForChannelCombination = bsxfun(@plus, double(maxAvailableMatrix), [1:size(maxAvailableMatrix,1)]');
    uniqueLabelForChannelCombination = sum(uniqueLabelForChannelCombination.^2);
    uniqueValues = unique(uniqueLabelForChannelCombination);
    numUniqueValues = histc(uniqueLabelForChannelCombination, uniqueValues);
    [~,maxIdx] = max(numUniqueValues);
    selectedRigidBodyIdx = find(uniqueLabelForChannelCombination==uniqueValues(maxIdx));
    selectedRigidBody = maxAvailableMatrix(:,selectedRigidBodyIdx);
    goodChannelIdx = find(selectedRigidBody(:,1));
    deadChannelIdx = setdiff(1:size(nanInputCoordinates,2), goodChannelIdx);
    
    originalInputCoordinates = inputCoordinates;
    inputCoordinates = originalInputCoordinates(:, goodChannelIdx, :);
    
    % Overwrite the initial input data with channel-rejected one.
    nanInputCoordinates = inputCoordinates;
    nanInputCoordinates(nanInputCoordinates==0) = NaN;
end

        % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % %%% If a channel is dead (only < 10 s available), exclude it. %%%
        % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % nanMask            = isnan(nanInputCoordinates);
        % numberOfNonNan     = sum(squeeze(~nanMask(1,:,:)),2);
        % numberOfNonNanList = zeros(length(numberOfNonNan),1);
        % srate = evalin('base', 'EEG.srate');
        % numberOfNonNanList(numberOfNonNan<srate*10) = 1;
        % if any(numberOfNonNanList)
        %     deadChannelFlag = 1;
        %     deadChannelIdx  = find(numberOfNonNanList);
        %     goodChannelIdx  = setdiff(1:size(inputCoordinates,2), deadChannelIdx);
        %     originalInputCoordinates = inputCoordinates;
        %     inputCoordinates = originalInputCoordinates(:, goodChannelIdx, :);
        %     
        %     % Replace zeros with NaN.
        %     nanInputCoordinates = inputCoordinates;
        %     nanInputCoordinates(nanInputCoordinates==0) = NaN;
        % end

        %     % Detect the longest NaN continuum .
        %     longestNanContinuum = zeros(size(nanMask,2),1);
        %     for chIdx = 1:length(longestNanContinuum)
        %         if numberOfNaN(chIdx)==0
        %             continue
        %         end
        %         currentCh = squeeze(nanMask(1,chIdx,:));
        %         currentCh([1 end]) = 0;
        %         currentChDiff      = diff(currentCh);
        %         maskOnset  = find(currentChDiff>0);
        %         maskOffset = find(currentChDiff<0);
        %         nanContinuumLength = maskOffset-maskOnset;
        %         longestNanContinuum(chIdx) = max(nanContinuumLength);
        %     end
        %     
        %     % Add total NaN check results.
        %     longestNanContinuum = longestNanContinuum + numberOfNanWeight;
        %     
        %     % Exclude a channel if it has continuous >50% or total >75% of the time NaN
        %     if any(longestNanContinuum>size(nanInputCoordinates,3)*0.5)
        %         deadChannelFlag = 1;
        %         deadChannelIdx  = find(longestNanContinuum>size(nanInputCoordinates,3)*0.5);
        %         goodChannelIdx  = setdiff(1:size(inputCoordinates,2), deadChannelIdx);
        %         originalInputCoordinates = inputCoordinates;
        %         inputCoordinates = originalInputCoordinates(:, goodChannelIdx, :);
        %         
        %         % Replace zeros with NaN.
        %         nanInputCoordinates = inputCoordinates;
        %         nanInputCoordinates(nanInputCoordinates==0) = NaN;
        %     else
        %         deadChannelFlag = 0;
        %     end
   


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Separate frames into 1) all LED markers are present, and 2) at least one LED channel has NaN. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
originalNanFrameIdx = squeeze(sum(sum(nanInputCoordinates,1),2));
allLedPresentIdx    = find(~isnan(originalNanFrameIdx));
someLedMissingIdx   = find( isnan(originalNanFrameIdx));

        % % If isempty(allLedPresentIdx)==true, interpolate NaN shorter than 1 second.
        % if isempty(allLedPresentIdx)
        %     interpolatedData = customInterpolation(nanInputCoordinates, 5);
        %         % figure; plot(reshape(nanInputCoordinates, [size(inputCoordinates,1)*size(inputCoordinates,2) size(inputCoordinates,3)])')               
        %         % figure; plot(reshape(tmp, [size(inputCoordinates,1)*size(inputCoordinates,2) size(inputCoordinates,3)])')               
        %     for chIdx = 1:size(interpolatedData,2)
        %         currentRawData   = squeeze(nanInputCoordinates(:,chIdx,:));
        %         currentFixedData = squeeze(interpolatedData(:,chIdx,:));
        %         
        %         % Calculate NaN continuum length.
        %         currentCh = squeeze(nanMask(1,chIdx,:));
        %         currentCh([1 end]) = 0;
        %         currentChDiff      = diff(currentCh);
        %         maskOnset  = find(currentChDiff>0);
        %         maskOffset = find(currentChDiff<0);
        %         nanContinuumLength = maskOffset-maskOnset;
        %         
        %         % Replace the NaN chunk only if it is < 1 s.
        %         EEG = evalin('base', 'EEG');
        %         noReplacingChunkIdx = nanContinuumLength>EEG.srate;
        %         for chunkIdx = 1:length(nanContinuumLength)
        %             if noReplacingChunkIdx(chunkIdx) == 0
        %                 currentRawData(:,maskOnset(chunkIdx):maskOffset(chunkIdx)) = currentFixedData(:,maskOnset(chunkIdx):maskOffset(chunkIdx));
        %             end
        %         end
        %                 % figure; plot(currentRawData')               
        % 
        %         
        %         % Put it back to original data
        %         nanInputCoordinates(:,chIdx,:) = repmat(currentRawData, [1 1 1]);
        %     end
        % end



disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 1. Build a rigid-body template. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')

% Compute volume/area of the rigid body/plane, and identify data points with > 2.5% and < 97.5% values. 
allLedPresentCoordinates = inputCoordinates(:,:,allLedPresentIdx);
volumeList = zeros(size(allLedPresentCoordinates,3),1);

if     size(allLedPresentCoordinates,2) <= 2
    error('Unable to use minimum 3 points.')
    
elseif size(allLedPresentCoordinates,2) == 3 % If only 3 points are available, calculate area.
    for n = 1:size(allLedPresentCoordinates,3)
        tmpFrame = allLedPresentCoordinates(:,:,n)';
        tmpFrame = bsxfun(@minus, tmpFrame, tmpFrame(3,:)); % Shift to the Original point.
        tmpFrame = tmpFrame(1:2,:)';
        [~, volumeList(n)] = convhull(tmpFrame);
    end
    volumeList = volumeList*(100^2); % Convert from meter^2 to cm^2.
    
else % If more than 4 points are available, calculate volume.
    for n = 1:size(allLedPresentCoordinates,3)
        [~, volumeList(n)] = convhull(allLedPresentCoordinates(:,:,n)');
    end
    volumeList = volumeList*(100^3); % Convert from meter^3 to cc.
end

cutoffValues = prctile(volumeList, [5 95]);
robustResultsIdx = find(volumeList>cutoffValues(1) & volumeList<cutoffValues(2));

    % This is an old algorithm. 12/18/2017.
    % Detects outliers in distance matrices.
    % distanceOutlierIdx = distanceMatrixOutlierDetection(inputCoordinates(:,:,allLedPresentIdx), 1);

% Build a template.
initialFrame = inputCoordinates(:,:,allLedPresentIdx(robustResultsIdx(1)));
initialFrame = bsxfun(@minus, initialFrame, mean(initialFrame,2));
fitResult    = zeros(size(initialFrame,1), size(initialFrame,2), length(robustResultsIdx));
tStart = tic;
for frameIdx = 1:length(robustResultsIdx)
    if mod(frameIdx,1000)==0
        tElapsed = toc(tStart);
        disp(sprintf('%d/%d(%.0f s left)...', frameIdx, length(allLedPresentIdx), round(tElapsed*(length(allLedPresentIdx)/1000 - frameIdx/1000))))
        tStart = tic;
    end
    
    fitResult(:,:,frameIdx) = rigidBodyTransformation(initialFrame, allLedPresentCoordinates(:,:,robustResultsIdx(frameIdx)));
end
rigidBodyTemplate = mean(fitResult,3);
        % meanVolume = mean(volumeList(robustResultsIdx));
        % stdVolue   = std(volumeList(robustResultsIdx));
        % [~, templateVolume] = convhull(rigidBodyTemplate');
        % templateVolume = templateVolume*(100^3)
        % 
        % tmpFrame = rigidBodyTemplate;
        % tmpFrame = bsxfun(@minus, tmpFrame, tmpFrame(3,:)); % Shift to the Original point.
        % tmpFrame = tmpFrame(1:2,:)';
        % [~,tmplateVolume] = convhull(tmpFrame);
        % tmplateVolume = tmplateVolume*(100^2); % Convert from meter^2 to cm^2.
disp(sprintf('Step 1/5 done!\n\n\n'))



disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 2. Fit the rigid body when all LED markers are present. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
meanRmsError                = zeros(size(inputCoordinates,3),1);
allLedPresentData           = zeros(size(rigidBodyTemplate,1), size(rigidBodyTemplate,2), length(allLedPresentIdx));
allLedPresentRmsError       = zeros(length(allLedPresentIdx),1);
allLedPresentRotationMatrix = zeros(3,3,length(allLedPresentIdx));
tStart = tic;
for frameIdx = 1:length(allLedPresentIdx)
    
    if mod(frameIdx,1000)==0
        tElapsed = toc(tStart);
        disp(sprintf('%d/%d(%.0f s left)...', frameIdx, length(allLedPresentIdx), round(tElapsed*(length(allLedPresentIdx)/1000 - frameIdx/1000))))
        tStart = tic;
    end
    currentData = inputCoordinates(:,:,allLedPresentIdx(frameIdx));
    [allLedPresentData(:,:,frameIdx), allLedPresentRmsError(frameIdx), allLedPresentRotationMatrix(:,:,frameIdx)] = rigidBodyTransformation(currentData, rigidBodyTemplate);
end
correctionLevelIndicator(allLedPresentIdx,1) = 1; % Level 1 correction: apply rigid body template where all LED markers are present.
meanRmsError(allLedPresentIdx)               = allLedPresentRmsError;
disp(sprintf('Step 2/5 done!\n\n\n'))



disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 3. Fit the rigid body when minimum 3 LED markers are present. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
someMarkersMissingData           = zeros(size(rigidBodyTemplate,1), size(rigidBodyTemplate,2), length(someLedMissingIdx));
someMarkersMissingRmsError       = zeros(length(someLedMissingIdx),1);
someMarkersMissingRotationMatrix = zeros(3,3,length(someLedMissingIdx));
tStart = tic;
for frameIdx = 1:length(someLedMissingIdx)
    
    if mod(frameIdx,1000)==0
        tElapsed = toc(tStart);
        disp(sprintf('%d/%d(%.0f s left)...', frameIdx, length(someLedMissingIdx), round(tElapsed*(length(someLedMissingIdx)/1000 - frameIdx/1000))))
        tStart = tic;
    end

    currentData = nanInputCoordinates(:,:,someLedMissingIdx(frameIdx));
    nanChannelIdx = find(isnan(sum(currentData)));
    numberOfSurvivedChannels = size(currentData,2)-length(nanChannelIdx);
    if numberOfSurvivedChannels >= 3
        availableChannelIdx = setdiff(1:size(currentData,2), nanChannelIdx);
        currentDataReduced  = currentData(:,availableChannelIdx);
        rigidBodyTemplateReduced = rigidBodyTemplate(:,availableChannelIdx);
        [rotatedDataReduced, someMarkersMissingRmsError(frameIdx), someMarkersMissingRotationMatrix(:,:,frameIdx), centroidOffset] = rigidBodyTransformation(currentDataReduced, rigidBodyTemplateReduced);
        templateFitData = someMarkersMissingRotationMatrix(:,:,frameIdx)*bsxfun(@minus, rigidBodyTemplate, mean(rigidBodyTemplate,2));
        templateFitData = bsxfun(@plus, templateFitData, centroidOffset);
        someMarkersMissingData(:,:,frameIdx) = templateFitData;
    else
        someMarkersMissingData(:,:,frameIdx) = currentData;
    end
end
correctionLevelIndicator(someLedMissingIdx,1) = 2; % Level 2 correction: some of the LED markers are missing but minimum 3 are still present.
meanRmsError(someLedMissingIdx)               = someMarkersMissingRmsError;

disp(sprintf('Step 3/5 done!\n\n\n'))
        % % Plot the head template
        % figure
        % subplot(2,3,1); plot(squeeze(nanInputCoordinates(1,1,someLedMissingIdx)))
        % subplot(2,3,2); plot(squeeze(nanInputCoordinates(2,1,someLedMissingIdx)))
        % subplot(2,3,3); plot(squeeze(nanInputCoordinates(3,1,someLedMissingIdx)))
        % subplot(2,3,4); plot(squeeze(someMarkersMissingData(1,1,:)))
        % subplot(2,3,5); plot(squeeze(someMarkersMissingData(2,1,:)))
        % subplot(2,3,6); plot(squeeze(someMarkersMissingData(3,1,:)))



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Combine all-marker-available results and some-marker-missing results. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
recoveredData                        = zeros(size(rigidBodyTemplate,1), size(rigidBodyTemplate,2), size(inputCoordinates,3));
recoveredData(:,:,allLedPresentIdx)  = allLedPresentData;
recoveredData(:,:,someLedMissingIdx) = someMarkersMissingData;
rotationMatrixForRecovered                        = zeros(3,3,size(inputCoordinates,3));
rotationMatrixForRecovered(:,:,allLedPresentIdx)  = allLedPresentRotationMatrix;
rotationMatrixForRecovered(:,:,someLedMissingIdx) = someMarkersMissingRotationMatrix;
        % % Check the difference
        % channId = 1;
        % figure
        % subplot(2,3,1)
        % plot(squeeze(inputCoordinates(1,channId,:)), 'r')
        % subplot(2,3,2)
        % plot(squeeze(inputCoordinates(2,channId,:)), 'r')
        % subplot(2,3,3)
        % plot(squeeze(inputCoordinates(3,channId,:)), 'r')
        % subplot(2,3,4)
        % plot(squeeze(recoveredData(1,channId,:)))
        % subplot(2,3,5)
        % plot(squeeze(recoveredData(2,channId,:)))
        % subplot(2,3,6)
        % plot(squeeze(recoveredData(3,channId,:)))


    
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 4. Spline-intepolate the data when less than 3 LED markers are present. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
nanDataframeIdx            = find(isnan(squeeze(sum(sum(recoveredData,1),2))));
nonNanDataframeIdx         = find(~isnan(squeeze(sum(sum(recoveredData,1),2))));
theLastValidPoint          = nonNanDataframeIdx(end); % The rest of data are continuous NaN.
lessThanThreeLedMarkersIdx = nanDataframeIdx(nanDataframeIdx<theLastValidPoint)';
interpolatedData           = nan(size(recoveredData));

% If the first datapoint is NaN, fill them with the first valid point.
if find(ismember(lessThanThreeLedMarkersIdx,1))
    diffDiffIdx = find(diff(diff(lessThanThreeLedMarkersIdx)));
    if isempty(diffDiffIdx)
        diffDiffIdx = length(lessThanThreeLedMarkersIdx)-1;
    end
    continuousNanIdx = 1:lessThanThreeLedMarkersIdx(diffDiffIdx(1)+1);
    recoveredData(:,:,continuousNanIdx) = repmat(recoveredData(:,:,diffDiffIdx(1)+2), [1 1 length(continuousNanIdx)]);
end

% Apply the custom spline interpolation with 2-Hz low-pass.
interpolatedData(:,:,1:theLastValidPoint) = customInterpolation(recoveredData(:,:,1:theLastValidPoint), 2);

% Apply rigid-body template to the interpolated data.
rigidBodyOnSpline             = zeros(size(rigidBodyTemplate,1), size(rigidBodyTemplate,2), length(lessThanThreeLedMarkersIdx));
rigidBodySplineRotationMatrix = zeros(3,3,length(lessThanThreeLedMarkersIdx));
for frameIdx = 1:length(lessThanThreeLedMarkersIdx)
    currentData = interpolatedData(:,:,lessThanThreeLedMarkersIdx(frameIdx));
   [rigidBodyOnSpline(:,:,frameIdx), ~, rigidBodySplineRotationMatrix(:,:,frameIdx)] = rigidBodyTransformation(currentData, rigidBodyTemplate);
end

% Store results.
interpolatedData(:,:,lessThanThreeLedMarkersIdx)                 = rigidBodyOnSpline;
rotationMatrixForRigidBodySpline                                 = rotationMatrixForRecovered;
rotationMatrixForRigidBodySpline(:,:,lessThanThreeLedMarkersIdx) = rigidBodySplineRotationMatrix;
correctionLevelIndicator(lessThanThreeLedMarkersIdx,1)           = 3; % Level 3 correction: Coordinates of some of LED markers are made up by spline interpolation.
disp(sprintf('Step 4/5 done!\n\n\n'))
        % figure; plot(reshape(interpolatedData, [3*5 66233])')
        % inputCoordinatesNaN = inputCoordinates;
        % nanMask = inputCoordinatesNaN==0;
        % inputCoordinatesNaN(nanMask) = NaN;
        % figure; plot(reshape(inputCoordinatesNaN, [3*5 66233])')


        
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 5. Spline-interpolate >2Hz spikes. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')

% Replace >2Hz spikes with NaN.
EEG = evalin('base', 'EEG');
spikeReplacedWithNaN = interpolatedData;
spikeReplacedWithNaN(:,:,1:theLastValidPoint) = customSpikeDetection(interpolatedData(:,:,1:theLastValidPoint), 2, 5, ceil(EEG.srate*0.1)); % dataWithSpikes, hpfCutoffHz, thresholdSd, pointSpreadInFrame
    % figure; plot(reshape(interpolatedData, [3*6 66233])')
    % figure; plot(reshape(spikeReplacedWithNaN, [3*6 66233])')

% Obtain NaN index.
spikeCorrectionIdx = find(isnan(squeeze(sum(sum(spikeReplacedWithNaN(:,:,1:theLastValidPoint),1),2))));

% Interpolate NaN with spline interpolation.
spikeCorrectedData = spikeReplacedWithNaN;
spikeCorrectedData(:,:,1:theLastValidPoint) = customInterpolation(spikeReplacedWithNaN(:,:,1:theLastValidPoint), 2); % dataWithNaN, lpfCutoffHz

        % figure; plot(reshape(spikeCorrectedData, [3*6 66233])')
        % figure; plot(reshape(interpolatedData, [3*6 66233])')
        % inputCoordinatesNaN = inputCoordinates;
        % nanMask = inputCoordinatesNaN==0;
        % inputCoordinatesNaN(nanMask) = NaN;
        % figure; plot(reshape(inputCoordinatesNaN, [3*5 66233])')

% Apply rigid-body template to the interpolated data.
rigidBodyOnSpikeCorrection      = zeros(size(rigidBodyTemplate,1), size(rigidBodyTemplate,2), length(spikeCorrectionIdx));
spikeCorrectionRotationMatrix   = zeros(3,3,length(spikeCorrectionIdx));
for frameIdx = 1:length(spikeCorrectionIdx)
    currentData = spikeCorrectedData(:,:,spikeCorrectionIdx(frameIdx));
   [rigidBodyOnSpikeCorrection(:,:,frameIdx), ~, spikeCorrectionRotationMatrix(:,:,frameIdx)] = rigidBodyTransformation(currentData, rigidBodyTemplate);
end
rotationMatrixForSpikeCorrected = rotationMatrixForRigidBodySpline;
spikeCorrectedData(:,:,spikeCorrectionIdx)              = rigidBodyOnSpikeCorrection;
rotationMatrixForSpikeCorrected(:,:,spikeCorrectionIdx) = spikeCorrectionRotationMatrix;
correctionLevelIndicator(spikeCorrectionIdx,1)          = 4; % Level 4 correction: Coordinates of all of LED markers are made up by spline interpolation.
disp(sprintf('Step 5/5 done!\n\n\n'))
        


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Compute the rotation time series with roll, pitch, yaw after correction. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if nargin == 4
    T = zeros(4,4);
    T(4,4) = 1;
    rollPitchYaw = zeros(theLastValidPoint,3);
    for n = 1:theLastValidPoint
        T(1:3,1:3) = rotationMatrixForSpikeCorrected(:,:,n);
        x = t2x(T,'rpy');
        rollPitchYaw(n,:) = x(5:7)';
    end
end



% Prepare the output valiables.
if deadChannelFlag == 1;
    correctedCoordinates = nan(size(originalInputCoordinates));
    correctedCoordinates(:,goodChannelIdx,:) = spikeCorrectedData;
else
    correctedCoordinates = spikeCorrectedData;
end
correctionLevelIndicator(theLastValidPoint+1:end) = 0;      
        % figure; plot(reshape(correctedCoordinates, [3*5 66233])')