% rigidBodyCorrection() - Compute a robust rigid body template from good
%                         part of data, and applies it to elsewhere.
%                         The procesure follows these steps.
%                         1. Compute a robust template from data points
%                            with maximum number of available LED markers.
%                         2. Fit the template wherever at least 3 LED markers
%                            are present.
%                         3. Linear-interpolate frames with less than 3 LED
%                            markers are available, low-pass filter them,
%                            spline-interpolate them, then fit the rigid
%                            body template.
%
% Use:
%   >> [correctedCoordinates, lessThanThreeMarkerIdx, rollPitchYaw] = rigidBodyCorrection(inputCoordinates)
%
% Inputs:
%   'inputCoordinates'       - 3xnxk for xyz coordinates, n channels, and k time points.
%   'lowPassFilterCutoffHz'  - 1x1 for cutoff frequency for low-pass filtering
%                              in applying spline interpolation after linear interpolation[Hz] 
%
% Outputs:
%   'correctedCoordinates'   - 3xnxk for xyz coordinates, n channels, and k time points.
%   'lessThanThreeMarkerIdx' - 1xk data point indices where only less than 3
%                              LED markers are available i.e., rigid-body fit was not
%                              available for the original data; they are
%                              spline-interpolated first then fit with the rigid-body.
%   'rollPitchYaw'           - [OPTIONAL] 3xchxn roll, pitch, yaw of the
%                              template rigid body after correction.

% Author:
%    Makoto Miyakoshi. SCCN, INC, UCSD. mmiyakoshi@ucsd.edu
%
% History:
%    02/04/2020 Makoto. Further fix made on the spline part. Re-calculated.
%    01/28/2020 Makoto. Calculation error fixed. Now max error measured by p0400_rigidBodySimulationTest is 6.3804e-07.
%    01/15/2018 Makoto. Optimized the entire process. Removed some of the methods implemented previously.
%    12/20/2017 Makoto. Use convhull() to compute area when only 3 points are available.
%    12/19/2017 Makoto. Replace distance matrix calculation with volume calculation using convhull().
%    12/18/2017 Makoto. When building a template, frames are rotated (how did it drop off...)
%    09/07/2017 Makoto. >50% NaN channel is excluded. 
%    08/10/2017 Makoto. Created.

% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

function [correctedCoordinates, lessThanThreeMarkerIdx, rollPitchYaw] = rigidBodyCorrection(inputCoordinates, lowPassFilterCutoffHz)

% Visualize the input data.
% figure; plot(reshape(inputCoordinates, [size(inputCoordinates,1)*size(inputCoordinates,2) size(inputCoordinates,3)])')               

% Replace zeros with NaN.
nanInputCoordinates = double(inputCoordinates);
nanInputCoordinates(nanInputCoordinates==0) = NaN;
    % figure; plot(reshape(nanInputCoordinates, [size(inputCoordinates,1)*size(inputCoordinates,2) size(inputCoordinates,3)])')               

% Detect initial and ending nan chunks which are not interpolatable. 
[initialNanFrameIdx, endingNanFrameIdx] = findInitialAndEndingNanFrames(nanInputCoordinates, 3);

% Obtain data starting and ending with recoverble data.
goodFrameIdx = setdiff(1:size(nanInputCoordinates,3), [initialNanFrameIdx endingNanFrameIdx]);
nanInputCoordinatesShort = nanInputCoordinates(:,:,goodFrameIdx);



disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 1. Create a rigid-body template. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
% Create a seed template.
% Note that even if the members of the maxChannelsIdx are different
% (i.e., ABDE and ABCE), the later nanmedian will recover ABCDE.
numNonNanChannels = sum(~isnan(squeeze(nanInputCoordinatesShort(1,:,:))));
maxChannelsIdx    = find(numNonNanChannels==max(numNonNanChannels));
templateToFit     = nanInputCoordinatesShort(:,:,maxChannelsIdx(1));
templateToFit     = bsxfun(@minus, templateToFit, nanmean(templateToFit,2));
seedTemplates     = nan(size(templateToFit,1), size(templateToFit,2), length(maxChannelsIdx));
disp('Calculating seed template...')
tStart = tic;
for frameIdx = 1:length(maxChannelsIdx)

    % Report time.
    if mod(frameIdx,1000)==0
        tElapsed = toc(tStart);
        disp(sprintf('%d/%d(%.0f s left)...', frameIdx, length(maxChannelsIdx), round(tElapsed*(length(maxChannelsIdx )/1000 - frameIdx/1000))))
        tStart = tic;
    end
    
    % Obtain the current rigid body.
    currentRigidBody = nanInputCoordinatesShort(:,:,maxChannelsIdx (frameIdx));
    
    % Obtain common channel index.
    commonChannelIdx = intersect(find(logical(~isnan(templateToFit(1,:)))), find(logical(~isnan(currentRigidBody(1,:)))));
    
    % Perform short-window interpolation (an idea back-imorted from p06)
    armShortInterruptionFixed  = customInterpolationShortWindow(wholeBodyRotated(:,[12:15],:), 480, 5, 480); % Mocap data sampling rate = 480 Hz.

    
    
    % Skip this frame if commonChannelIdx<3.
    if length(commonChannelIdx)<3
        continue
    end
    
    % Fit the current frame to the template.
    [fitResult, rmsError, rotationMatrix, centroidOffset] = rigidBodyTransformation(templateToFit(:,commonChannelIdx), currentRigidBody(:,commonChannelIdx));
    templateFitData = rotationMatrix*bsxfun(@minus, currentRigidBody, mean(currentRigidBody(:,commonChannelIdx),2));
    templateFitData = bsxfun(@plus, templateFitData, centroidOffset);
    seedTemplates(:,:,frameIdx) = templateFitData;
end

% It seems sufficient just to take median.
rigidBodyTemplate = nanmedian(seedTemplates,3);
rigidBodyTemplate = bsxfun(@minus, rigidBodyTemplate, nanmean(rigidBodyTemplate,2));
rigidBodyTemplateAvailableChannels = find(~isnan(rigidBodyTemplate(1,:)));


    % distanceFromMedian = squeeze(sqrt(sum(bsxfun(@minus, seedTemplate, seedTemplateMedian).^2)));
    % thresholdDistance  = prctile(distanceFromMedian, 90, 2);
    % distanceMask       = distanceFromMedian>repmat(thresholdDistance, [1 size(distanceFromMedian,2)]);
    % seedTemplateMasked = seedTemplate;
    % seedTemplateMasked(repmat(distanceMask, [3 1 1])) = NaN;
    % tmp = nanmean(seedTemplateMasked,3);

disp(sprintf('Step 1/3 done!\n\n\n'))

        %{

        % Plot rigid body template.
        figure
        hold on
        scatter3(squeeze(seedTemplates(3,1,:)), squeeze(seedTemplates(1,1,:)), squeeze(seedTemplates(2,1,:)), 8, [1 0 0])
        scatter3(squeeze(seedTemplates(3,2,:)), squeeze(seedTemplates(1,2,:)), squeeze(seedTemplates(2,2,:)), 8, [0 1 0])
        scatter3(squeeze(seedTemplates(3,3,:)), squeeze(seedTemplates(1,3,:)), squeeze(seedTemplates(2,3,:)), 8, [0 0 1])
        scatter3(squeeze(seedTemplates(3,4,:)), squeeze(seedTemplates(1,4,:)), squeeze(seedTemplates(2,4,:)), 8, [0 0 0])
        scatter3(squeeze(seedTemplates(3,5,:)), squeeze(seedTemplates(1,5,:)), squeeze(seedTemplates(2,5,:)), 8, [0.8 0.8 1])
        scatter3(squeeze(seedTemplates(3,6,:)), squeeze(seedTemplates(1,6,:)), squeeze(seedTemplates(2,6,:)), 8, [0.8 0.8 1])
        %scatter3(squeeze(seedTemplates(3,7,:)), squeeze(seedTemplates(1,7,:)), squeeze(seedTemplates(2,7,:)), 8, [0.8 0.8 1])
        scatter3(templateToFit(3,:),   templateToFit(1,:),   templateToFit(2,:),   40, [1 0 0], 'fill')
        
        %}


disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 2. Fit the rigid body template when >=3 markers are present. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
moreThanThreeMarkerIdx  = find(numNonNanChannels >= 3);
moreThanThreeMarkerData = nanInputCoordinatesShort(:,:,moreThanThreeMarkerIdx);
tStart = tic;
for frameIdx = 1:length(moreThanThreeMarkerIdx)
    
    if mod(frameIdx,1000)==0
        tElapsed = toc(tStart);
        disp(sprintf('%d/%d(%.0f s left)...', frameIdx, length(moreThanThreeMarkerIdx), round(tElapsed*(length(moreThanThreeMarkerIdx)/1000 - frameIdx/1000))))
        tStart = tic;
    end

    % Obtain the common channel available.
    currentData         = moreThanThreeMarkerData(:,:,frameIdx);
    nanChannelIdx       = find(isnan(sum(currentData)));
    availableChannelIdx = setdiff(1:size(currentData,2), nanChannelIdx);
    commonChannelIdx    = intersect(availableChannelIdx, rigidBodyTemplateAvailableChannels);

    % Skip the frame if <3 channels available.
    if length(commonChannelIdx)<3
        continue
    end
    
    % Ontain rotation matrix. Calculation error fixed by Makoto (01/28/2020)
    [fitResult, rmsError, rotationMatrix] = rigidBodyTransformation(currentData(:,commonChannelIdx), rigidBodyTemplate(:,commonChannelIdx));
    
    % Fit the template to the current frame.
    identicallyRotatedData = rotationMatrix*rigidBodyTemplate;
    transitionParameter = mean(currentData(:,commonChannelIdx),2) - mean(identicallyRotatedData(:,commonChannelIdx),2);
    templateFitData     = bsxfun(@plus, identicallyRotatedData, transitionParameter);
    moreThanThreeMarkerData(:,:,frameIdx) = templateFitData;
    
    %     % Ontain rotation matrix.
    %     [~, ~, rotationMatrix, centroidOffset] = rigidBodyTransformation(currentData(:,commonChannelIdx), rigidBodyTemplate(:,commonChannelIdx));
    %
    %     % Fit the template to the current frame.
    %     templateFitData = rotationMatrix*rigidBodyTemplate;
    %     templateFitData = bsxfun(@plus, templateFitData, centroidOffset);
    %     moreThanThreeMarkerData(:,:,frameIdx) = templateFitData;
    %
    %     % Fit the template to the current frame.
    %     identicallyRotatedData = rotationMatrix*rigidBodyTemplate;
    %     transitionParameter = mean(currentData(:,commonChannelIdx),2) - mean(rigidBodyTemplate(:,commonChannelIdx),2);
    %     templateFitData     = bsxfun(@plus, identicallyRotatedData, transitionParameter);
    %     moreThanThreeMarkerData(:,:,frameIdx) = templateFitData;
end
correctedCoordinatesShort = nanInputCoordinatesShort;
correctedCoordinatesShort(:,:,moreThanThreeMarkerIdx) = moreThanThreeMarkerData;
disp(sprintf('Step 2/3 done!\n\n\n'))

        %{

        % Plot rigid body template.
        colors = jet(7);
        figure
        hold on
        scatter3(squeeze(moreThanThreeMarkerData(3,1,:)), squeeze(moreThanThreeMarkerData(1,1,:)), squeeze(moreThanThreeMarkerData(2,1,:)), 8, colors(1,:))
        scatter3(squeeze(moreThanThreeMarkerData(3,2,:)), squeeze(moreThanThreeMarkerData(1,2,:)), squeeze(moreThanThreeMarkerData(2,2,:)), 8, colors(2,:))
        scatter3(squeeze(moreThanThreeMarkerData(3,3,:)), squeeze(moreThanThreeMarkerData(1,3,:)), squeeze(moreThanThreeMarkerData(2,3,:)), 8, colors(3,:))
        scatter3(squeeze(moreThanThreeMarkerData(3,4,:)), squeeze(moreThanThreeMarkerData(1,4,:)), squeeze(moreThanThreeMarkerData(2,4,:)), 8, colors(4,:))
        scatter3(squeeze(moreThanThreeMarkerData(3,5,:)), squeeze(moreThanThreeMarkerData(1,5,:)), squeeze(moreThanThreeMarkerData(2,5,:)), 8, colors(5,:))
        scatter3(squeeze(moreThanThreeMarkerData(3,6,:)), squeeze(moreThanThreeMarkerData(1,6,:)), squeeze(moreThanThreeMarkerData(2,6,:)), 8, colors(6,:))
        scatter3(squeeze(moreThanThreeMarkerData(3,7,:)), squeeze(moreThanThreeMarkerData(1,7,:)), squeeze(moreThanThreeMarkerData(2,7,:)), 8, colors(7,:))
        
        %}



disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
disp('%%% 3. Fit the rigid body template when <3 markers are present. %%%')
disp('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')

    % % This is for testing.
    % correctedCoordinates2 = correctedCoordinates;
    % %correctedCoordinates = correctedCoordinates2;
    % correctedCoordinates(:,:,1:100)       = nan;    
    % correctedCoordinates(:,:,end-100:end) = nan;
    % numNonNanChannels = sum(~isnan(squeeze(correctedCoordinates(1,:,:))));

lessThanThreeMarkerIdx = find(numNonNanChannels < 3);

% Apply the custom spline interpolation with the specified low-pass filter.
interpolatedData = customInterpolation(correctedCoordinatesShort, lowPassFilterCutoffHz);

% Apply rigid-body template to the interpolated data.
rigidBodyOnSpline = nan(size(rigidBodyTemplate,1), size(rigidBodyTemplate,2), length(lessThanThreeMarkerIdx));
for frameIdx = 1:length(lessThanThreeMarkerIdx)
    currentData = interpolatedData(:,:,lessThanThreeMarkerIdx(frameIdx));
    
    % Obtain the current frame to fix.
    nanChannelIdx       = find(isnan(sum(currentData)));
    availableChannelIdx = setdiff(1:size(currentData,2), nanChannelIdx);
    commonChannelIdx    = intersect(availableChannelIdx, rigidBodyTemplateAvailableChannels);

    % Skip the frame if <3 channels available.
    if length(commonChannelIdx)<3
        continue
    end
    
    % Obtain rotation matrix. Fixed by Makoto. 02/04/2020.
    [~, ~, rotationMatrix] = rigidBodyTransformation(currentData(:,commonChannelIdx), rigidBodyTemplate(:,commonChannelIdx));
    
    % Fit the template to the current frame.
    identicallyRotatedData = rotationMatrix*rigidBodyTemplate;
    transitionParameter = mean(currentData(:,commonChannelIdx),2) - mean(identicallyRotatedData(:,commonChannelIdx),2);
    templateFitData     = bsxfun(@plus, identicallyRotatedData, transitionParameter);
    rigidBodyOnSpline(:,:,frameIdx) = templateFitData;
        %     % Fit the template to the current frame.
        %     templateFitData = rotationMatrix*rigidBodyTemplate;
        %     templateFitData = bsxfun(@plus, templateFitData, centroidOffset);
        %     rigidBodyOnSpline(:,:,frameIdx) = templateFitData;
end
correctedCoordinatesShort(:,:,lessThanThreeMarkerIdx) = rigidBodyOnSpline;
disp(sprintf('Step 3/3 done!\n\n\n'))



%%%%%%%%%%%%%%%%%%%%%%%
%%% Prepare output. %%%
%%%%%%%%%%%%%%%%%%%%%%%
correctedCoordinates = nan(size(inputCoordinates));
correctedCoordinates(:,:,goodFrameIdx) = correctedCoordinatesShort;
if any(initialNanFrameIdx)
    lessThanThreeMarkerIdx = cat(2, initialNanFrameIdx, lessThanThreeMarkerIdx+initialNanFrameIdx(end));
end


        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% OPTION: Compute the rotation time series with roll, pitch, yaw after correction. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        if nargin == 3
            T = zeros(4,4);
            T(4,4) = 1;
            rollPitchYaw = zeros(theLastValidPoint,3);
            for n = 1:theLastValidPoint
                T(1:3,1:3) = rotationMatrixForSpikeCorrected(:,:,n);
                x = t2x(T,'rpy');
                rollPitchYaw(n,:) = x(5:7)';
            end
        end