% customInterpolation() - Linear-interpolate NaN, low-pass filter data,
%                         re-import NaN, then spline-interpolate data again.
%                         This is to avoid introducing high-frequency
%                         spikes by directly applying spline interpolation
%                         to glitchy data. 'EEG' should be present in base
%                         workspace to obtain EEG.srate.
%
% Use:
%   >> nanInterpolated = customInterpolation(dataWithNaN, lpfCutoffHz)
%
% Inputs:
%   'dataWithNaN'     - 3xjxk double. 3 is PhaseSpace xyz, j is channels,
%                       and k is data points. Data with NaN.
%   'lpfCutoffHz'     - 1x1 double [Hz]. TBW is same as the cutoff frequency.
%
% Outputs:
%   'nanInterpolated' - 3xjxk double. NaNs are interpolated.

% Author:
%    Makoto Miyakoshi. SCCN, INC, UCSD. mmiyakoshi@ucsd.edu
%
% History: 
%    09/09/2017 Makoto. First data point NaN addressed.
%    08/25/2017 Makoto. Created.

% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

function nanInterpolated = customInterpolation(dataWithNaN, lpfCutoffHz)

% Visualize the input data.
% figure; plot(reshape(dataWithNaN, [size(dataWithNaN,1)*size(dataWithNaN,2) size(dataWithNaN,3)])')               

% If data are multi-channel, reshape data into two-demensional.
if     length(size(dataWithNaN)) == 2
    inputData = dataWithNaN;
elseif length(size(dataWithNaN)) == 3
    inputData = reshape(dataWithNaN, [size(dataWithNaN,1)*size(dataWithNaN,2), size(dataWithNaN,3)]);
end

% Generate data frame index for interpolation.
dataFrameIdx = 1:size(inputData,2);

% Calculate filter order.
srate       = evalin('base', 'EEG.srate');
filterOrder = pop_firwsord('hamming', srate, lpfCutoffHz);

% Interpolate each channel.
nanInterpolated = nan(size(inputData));
for chIdx = 1:size(inputData,1)
    
    % Obtain the current single-channel time series.
    currentData = inputData(chIdx,:);
    
    % If the current chanel is all NaN, skip it.
    nanMask          = isnan(currentData);
    interpolationIdx = find(nanMask);
    if length(interpolationIdx) == length(currentData)
        nanInterpolated(chIdx,:) = currentData;
        continue
    end
    
    %{
    
    % This is for testing.
    currentData(1:20) = NaN;
    currentData(end-20:end) = NaN;
    
    %}
    
    % Obtain data starting and ending with non-NaN data.
    [initialNanFrameIdx, endingNanFrameIdx] = findInitialAndEndingNanFrames(currentData,1);
    goodFrameIdx = setdiff(1:length(dataWithNaN), [initialNanFrameIdx endingNanFrameIdx]);
    currentDataShort = currentData(goodFrameIdx);
    
    % Re-define nanMask
    nanMask = isnan(currentDataShort);
    
    % Linear-interpolate NaN using interp1. This is more robust against high-frequency noise than spline interpolation.
    linearInterpolatedData          = currentDataShort;
    linearInterpolatedData(nanMask) = interp1(dataFrameIdx(~nanMask), currentDataShort(~nanMask), dataFrameIdx(nanMask), 'linear');
    
    % Low-pass filter the linear-interpolated data. Use dummy EEG structure to use pop_eegfiltnew().
    lpfData       = eeg_emptyset();
    lpfData.data  = linearInterpolatedData;
    lpfData.srate = evalin('base', 'EEG.srate');
    lpfData.pnts  = length(linearInterpolatedData);
    lpfData       = pop_eegfiltnew(lpfData, 0, lpfCutoffHz, filterOrder);
    lpfData       = lpfData.data;
    
    % Spline-interpolate NaN using interp1. After low-pass filtering, it works well without creating high-frequency noise.
    splineInterpolatedData          = currentDataShort;
    splineInterpolatedData(nanMask) = interp1(dataFrameIdx(~nanMask), lpfData(~nanMask), dataFrameIdx(nanMask), 'spline');
            % figure; plot(linearInterpolatedData, 'r'); hold on; plot(splineInterpolatedData)

    % Store the result.
    nanInterpolated(chIdx,goodFrameIdx) = splineInterpolatedData;
end

% If data are multi-channel, reshape data back into three-demensional.
if     length(size(dataWithNaN)) == 2
    nanInterpolated = nanInterpolated;
elseif length(size(dataWithNaN)) == 3
    nanInterpolated = reshape(nanInterpolated, size(dataWithNaN));
end