% interpolateRigidBodyAcrossBlocks() - performs rigid body interpolation across blocks
%                                      wherever a subset of channels in a rigid body
%                                      are missing.
%
% Use  : ALLEEG = interpolateRigidBodyAcrossBlocks(ALLEEG, rigidBodyChannelIdx)
%
% Input: ALLEEG             -- within-subject all the blocks.
%                              Requires EEG.etc.rigidBodyFit which is
%                              4 [x,y,z,reliability] x ch x time tensor.
%        rigidBodyChannelIdx -- channel indices of the rigid body.
%
% Output: ALLEEG            -- within-subject all the blocks. Missing data (i.e., all NaN
%                              channels) in EEG.etc.audiomaze.phaseSpaceCorrected is interpolated.
%
% Author: Makoto Miyakoshi, SCCN, INC, UCSD.

% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

% History: 
% 01/29/2020 Makoto. Report how many channels are missing during the blocks.
% 01/24/2018 Makoto. Changed size(currentRigidBody,2) to nonNanChannel
% 01/19/2018 Makoto. Skipping the initial NaN frames to choose a template. 
% 12/28/2017 Makoto. Adding only missing channels.
% 12/22/2017 Makoto. Created.

function [ALLEEG, numMissingMarkers] = interpolateRigidBodyAcrossBlocks(ALLEEG, rigidBodyChannelIdx)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Obtain rigid bodies from all the blocks. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
allRigidBodies = zeros(3, length(rigidBodyChannelIdx), length(ALLEEG));
for eegIdx = 1:length(ALLEEG)
    currentRigidBody = double(ALLEEG(eegIdx).etc.audiomaze.phaseSpaceCorrected(1:3, rigidBodyChannelIdx, :));
    numNonNanChannel = max(sum(~isnan(squeeze(currentRigidBody(1,:,:))),1)); % Find max number of channels available at a time.
    [initialNanFrameIdx, endingNanFrameIdx] = findInitialAndEndingNanFrames(currentRigidBody, numNonNanChannel);
    goodFrameIdx               = setdiff(1:size(currentRigidBody,3), [initialNanFrameIdx endingNanFrameIdx]);
    allRigidBodies(:,:,eegIdx) = currentRigidBody(:,:,goodFrameIdx(1));
end

% Find the block that has the most available channels.
availabilityMask = ~isnan(squeeze(allRigidBodies(1,:,:)));
numMissingMarkers = length(find(availabilityMask==0));


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Exit if no missing marker detected. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if numMissingMarkers == 0
    return
end


% If missing marker is detected, proceed to the recovery stage.
[maxAvailable, rotationTemplateIdx] = max(sum(availabilityMask));
if maxAvailable <=2
   error('Less than 3 available.') 
end
rotationTemplate = allRigidBodies(:,:,rotationTemplateIdx);
rotationTemplateAvailabilityIdx = find(availabilityMask(:,rotationTemplateIdx));

% Rotate all the rigid bodies to the rotationTemplate.
rotatedRigidBodies = zeros(3, length(rigidBodyChannelIdx), length(ALLEEG));
for eegIdx = 1:length(ALLEEG)
    currentRigidBody = allRigidBodies(:,:,eegIdx);
    currentRigidBodyAvailabilityIdx = find(availabilityMask(:,eegIdx));
    commonChannelIdx = intersect(rotationTemplateAvailabilityIdx, currentRigidBodyAvailabilityIdx);
    
    [~, ~, rotationMatrix, centroidOffset] = rigidBodyTransformation(rotationTemplate(:,commonChannelIdx), currentRigidBody(:,commonChannelIdx));
    missingMarkerRecoveredData = rotationMatrix*bsxfun(@minus, currentRigidBody, mean(currentRigidBody(:,commonChannelIdx),2));
    missingMarkerRecoveredData = bsxfun(@plus, missingMarkerRecoveredData, centroidOffset);
    rotatedRigidBodies(:,:,eegIdx) = missingMarkerRecoveredData;
end

    %{

    figure
    for n = 1:12
        hold on
        scatter3(squeeze(rotatedRigidBodies(1,:,n))', squeeze(rotatedRigidBodies(2,:,n))', squeeze(rotatedRigidBodies(3,:,n))')
    end

    %}
    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Calculate the block-mean rigid body. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%    
blockMeanRigidBody = nanmean(rotatedRigidBodies,3);
blockMeanRigidBodyAvailabilityIdx = find(~isnan(squeeze(blockMeanRigidBody(1,:,:))));



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Fit the block-mean rigid-body template to those with missing channels (only for missing ones). %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for eegIdx = 1:length(ALLEEG)
    currentBlockAvailabilityIdx = find(availabilityMask(:,eegIdx));
    
    % Skip the block if its rigid body does not have missing channels.
    if length(currentBlockAvailabilityIdx) == length(blockMeanRigidBodyAvailabilityIdx)
        continue
    end
    
    % Obtain the current data.
    currentDataForCorrection  = double(ALLEEG(eegIdx).etc.audiomaze.phaseSpaceCorrected([1:3], rigidBodyChannelIdx, :));
    
    % Perform rigid-body fitting.
    rigidBodyInterpolatedData = zeros(size(blockMeanRigidBody,1), size(blockMeanRigidBody,2), size(currentDataForCorrection,3));
    tStart = tic;
    for frameIdx = 1:size(currentDataForCorrection,3)
        
        % Estimate process time.
        if mod(frameIdx,1000)==0
            tElapsed = toc(tStart);
            disp(sprintf('%d/%d(%.0f s left)...', frameIdx, size(currentDataForCorrection,3), round(tElapsed*(size(currentDataForCorrection,3)/1000 - frameIdx/1000))))
            tStart = tic;
        end
          
        % Obtain the current rigid body.
        currentRigidBody = currentDataForCorrection(:,:,frameIdx);
        
        % Skip frames with NaN data.
        if sum(isnan(currentRigidBody(1,:))) == size(currentRigidBody,2)
            continue
        end
        
        % Obtain availability index.
        currentRigidBodyAvailabilityIdx = find(~isnan(currentRigidBody(1,:)));
        commonChannelIdx    = intersect(blockMeanRigidBodyAvailabilityIdx, currentRigidBodyAvailabilityIdx);
        channelToBeFixedIdx = setdiff(blockMeanRigidBodyAvailabilityIdx, commonChannelIdx);
        
        % Ontain rotation matrix. Calculation error fixed by Makoto (01/29/2020)
        [fitResult, rmsError, rotationMatrix, centroidOffset] = rigidBodyTransformation(currentRigidBody(:,commonChannelIdx), blockMeanRigidBody(:,commonChannelIdx));
        
        % Fit the template to the current frame.
        identicallyRotatedData     = rotationMatrix*blockMeanRigidBody;
        transitionParameter        = mean(currentRigidBody(:,commonChannelIdx),2) - mean(identicallyRotatedData(:,commonChannelIdx),2);
        missingMarkerRecoveredData = bsxfun(@plus, identicallyRotatedData, transitionParameter);
        
                    % % Fit the block-mean template rigid body to the rigid body of the current frame.
                    % [~, ~, rotationMatrix, centroidOffset] = rigidBodyTransformation(currentRigidBody(:,commonChannelIdx), blockMeanRigidBody(:,commonChannelIdx));
                    % templateFitData = rotationMatrix*bsxfun(@minus, blockMeanRigidBody, mean(blockMeanRigidBody(:,commonChannelIdx),2));
                    % templateFitData = bsxfun(@plus, templateFitData, centroidOffset);
        
        % Add the missing channel taken from the fit block-mean template. 
        currentDataForCorrection(:,channelToBeFixedIdx,frameIdx) = missingMarkerRecoveredData(:,channelToBeFixedIdx);
    end
    
    % Put it back to ALLEEG
    ALLEEG(eegIdx).etc.audiomaze.phaseSpaceCorrected([1:3], rigidBodyChannelIdx, :) = currentDataForCorrection;
    ALLEEG(eegIdx).etc.audiomaze.phaseSpaceCorrected = single(ALLEEG(eegIdx).etc.audiomaze.phaseSpaceCorrected);
end