% samePostureInterpolation() ----  performs data interpolation on bad
%                                  datapoints in phaseSpace data using
%                                  'same-posture' frames i.e., same kinematic
%                                  chains across time.
%                                  1) Identify the xyz coordinates of the
%                                     reference channel of the frames in
%                                     question.
%                                  2) Find other frames, which are clean,
%                                     in which the reference
%                                     channel shows similar xyz coordinates.
%                                  3) Compute weighted mean (by distance to
%                                     the reference channel position) of
%                                     the corresponding other channel
%                                     positions.
%                                  4) Correct the noise channel and frame
%                                     by using the weighted mean computed above. 
%
% Use  : interpolated = samePostureInterpolation(dataToBeInterpolated, dataToBeReferenced, frameMaskToBeInterpolated)
%
% Input: dataToBeInterpolated      -- 3xnxm tensor corresponding to
%                                     xzy coordinates of the phase
%                                     space data, n channels, m frames.
%        dataToBeReferenced        -- 3x1xm tensor corresponding to 
%                                     xyz coordinates of the reference channel
%                                     distance to which serves as cost
%                                     function to be minimized, a single
%                                     channel index that is usually assumed to
%                                     be a part of a rigid body, m frames.
%        frameMaskToBeInterpolated -- nxm logical mask which corresponds to
%                                     n channels, m frames.
%
% Output: interpolated             -- 3xnxm tensor corresponding to
%                                    xzy coordinates of the phase
%                                    space data, n channels, m frames.
%
% Author: Makoto Miyakoshi, SCCN, INC, UCSD.

% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

% History: 
%   01/15/2018 Makoto. Created.

function interpolated = samePostureInterpolation(dataToBeInterpolated, dataToBeReferenced, frameMaskToBeInterpolated, goodFrameIdx, lpfCutoffHz)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Replace outliers with 'same-posture (i.e., kinematic chain) interpolation. %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
interpolated  = dataToBeInterpolated;
badFrameIdx   = find(sum(frameMaskToBeInterpolated,1));
goodFrameData = dataToBeReferenced(:,:,goodFrameIdx);
    
    % Low-pass filter the linear-interpolated data. Use dummy EEG structure to use pop_eegfiltnew().
    lpfData       = eeg_emptyset();
    lpfData.data  = reshape(dataToBeInterpolated, [size(dataToBeInterpolated,1)*size(dataToBeInterpolated,2) size(dataToBeInterpolated,3)]);
    lpfData.srate = evalin('base', 'EEG.srate');
    lpfData.pnts  = length(lpfData.data);
    lpfData       = pop_eegfiltnew(lpfData, 0, lpfCutoffHz);
    dataToBeInterpolated = reshape(lpfData.data, size(dataToBeInterpolated));

    lpfData       = eeg_emptyset();
    lpfData.data  = squeeze(dataToBeReferenced);
    lpfData.srate = evalin('base', 'EEG.srate');
    lpfData.pnts  = length(lpfData.data);
    lpfData       = pop_eegfiltnew(lpfData, 0, lpfCutoffHz);
    dataToBeReferenced = reshape(lpfData.data, size(dataToBeReferenced));

for badFrameIdxIdx = 1:length(badFrameIdx)
    
    if mod(badFrameIdxIdx, 1000) == 0
        timeElapsed = toc;
        disp(sprintf('Same-posture kinematic chain interpolation on %d/%d frames processed (%.0f sec left)...', badFrameIdxIdx, length(badFrameIdx), timeElapsed/1000*(length(badFrameIdx)-badFrameIdxIdx)))
        tic
    end
    
    % Select the current frame to fix.
    currentFrameIdx      = badFrameIdx(badFrameIdxIdx);
    currentBadChannelIdx = find(frameMaskToBeInterpolated(:,currentFrameIdx));
    
    % Compute distance between the reference channel position of the current frame and that of the all the good frames.
    distanceInMeter1 = squeeze(sqrt(sum(bsxfun(@minus, goodFrameData, dataToBeReferenced(:,:,currentFrameIdx)).^2)));
    %distanceInMeter1 = sqrt(sum((squeeze(goodFrameData)-squeeze(dataToBeReferenced(:,:,currentFrameIdx))).^2));
    %distanceInMeter2 = nansum(squeeze(sqrt(sum(bsxfun(@minus, maskedDataToBeInterpolated(:,:,goodFrameIdx), maskedDataToBeInterpolated(:,:,currentFrameIdx)).^2))));

    % Find 'same posture frames' that has sufficiently small difference in sum distance (i.e., within 1%-tile) than the frame in question.
    criticalValue1      = prctile(distanceInMeter1, 1);
    samePostureFrameIdx = find(distanceInMeter1<=criticalValue1);

    %criticalValue2   = prctile(distanceInMeter2, 1);
    %samePostureFrameIdx2 = find(distanceInMeter2<=criticalValue2);
    %samePostureFrameIdx  = intersect(samePostureFrameIdx1,samePostureFrameIdx2);
    
    %{
    
    % Visualize the markers.
    figure
    hold on
    colors = jet(4);
    scatter3(dataToBeInterpolated(3,1,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(1,1,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(2,1,goodFrameIdx(samePostureFrameIdx)), 18, colors(1,:), 'fill');
    scatter3(dataToBeInterpolated(3,2,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(1,2,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(2,2,goodFrameIdx(samePostureFrameIdx)), 18, colors(2,:), 'fill');
    scatter3(dataToBeInterpolated(3,3,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(1,3,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(2,3,goodFrameIdx(samePostureFrameIdx)), 18, colors(3,:), 'fill');
    scatter3(dataToBeInterpolated(3,4,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(1,4,goodFrameIdx(samePostureFrameIdx)), dataToBeInterpolated(2,4,goodFrameIdx(samePostureFrameIdx)), 18, colors(4,:), 'fill');
    
    % Visualize the markers.
    figure
    hold on
    colors = jet(4);
    scatter3(dataToBeInterpolated(3,1,goodFrameIdx), dataToBeInterpolated(1,1,goodFrameIdx), dataToBeInterpolated(2,1,goodFrameIdx), 18, colors(1,:), 'fill');
    scatter3(dataToBeInterpolated(3,2,goodFrameIdx), dataToBeInterpolated(1,2,goodFrameIdx), dataToBeInterpolated(2,2,goodFrameIdx), 18, colors(2,:), 'fill');
    scatter3(dataToBeInterpolated(3,3,goodFrameIdx), dataToBeInterpolated(1,3,goodFrameIdx), dataToBeInterpolated(2,3,goodFrameIdx), 18, colors(3,:), 'fill');
    scatter3(dataToBeInterpolated(3,4,goodFrameIdx), dataToBeInterpolated(1,4,goodFrameIdx), dataToBeInterpolated(2,4,goodFrameIdx), 18, colors(4,:), 'fill');
    
    %}
    
    % Compute the weight list used in averaging the 'same posture frames'.
    distanceListInMeter = distanceInMeter1(samePostureFrameIdx);
    if any(find(distanceListInMeter==0))
        distanceListInMeter(distanceListInMeter==0) = eps;
    end
    weightList = (1./distanceListInMeter)/sum(1./distanceListInMeter);
    weightList = reshape(weightList, [1 1 length(weightList)]);
    
    % Obtain the 'same posture frames'.
    samePostureFrames = dataToBeInterpolated(:,:,goodFrameIdx(samePostureFrameIdx));
    %medianFrames   = median(sameHandFrames,3);
    
    % Apply the weights.
    samePostureFramesWeighted = bsxfun(@times, samePostureFrames, weightList);
    
    % Sum all the weighted data points.
    interpolatedFrame = sum(samePostureFramesWeighted,3);
    
    %{
    
    % Make a comparison
    figure
    colors = jet(4);
    for chIdx = 1:4
        subplot(1,2,1)
        scatter3(dataToBeInterpolated(3,chIdx,currentFrameIdx), dataToBeInterpolated(1,chIdx,currentFrameIdx), dataToBeInterpolated(2,chIdx,currentFrameIdx), 24, colors(chIdx,:), 'fill');
        hold on
        
        subplot(1,2,2)
        scatter3(interpolatedFrame(3,chIdx,:), interpolatedFrame(1,chIdx,:), interpolatedFrame(2,chIdx,:), 24, colors(chIdx,:), 'fill');
        hold on
    end
    colormap('jet')
    
    %}
    
    % Replace bad points with synthesized ones..
    interpolated(:, currentBadChannelIdx, currentFrameIdx) = interpolatedFrame(:, currentBadChannelIdx);
end