% 02/14/2020 Makoto. Updated. Save the first Torso.
% 02/10/2020 Makoto. Renewed to use neural network.
% 01/19/2018 Makoto. Used.
% 01/03/2018 Makoto. Modified.

% Perform arm and knee correction.
ALLEEG = [];
sdThreshold = 1.5;
downsamplingFactor = 40; % 480/40 = 12 Hz.
addpath /data/mobi/Audiomaze/makoto/code/phaseSpaceCorrectionTools

% Define rigid body LED marker indices.
headIdx   = [1:7];
torsoIdx  = [8:11 21 27];
handIdx   = [16:20];
footR_Idx = [23:26];
footL_Idx = [29:32];
armIdx    = [12:15];
kneeR_Idx = [21 22 23];
kneeL_Idx = [27 28 29];
nonRigidBodyIdx = [12:15 22 28];

allSets         = dir('/data/mobi/Audiomaze/makoto/p04_interpolateAcrossSubjects/*.set');
allSetNames     = {allSets.name}';
subjNames       = cellfun(@(x) x(1:7), allSetNames, 'uniformoutput', false);
uniqueSubjNames = unique(subjNames);

for subjIdx = 14 %:length(uniqueSubjNames) % 819_C3 (idx85) has no right knee. 829 has one NaN in xyz which is real zero in coordinate.
    
    currentSubjName     = uniqueSubjNames{subjIdx};
    currentSubjBlockIdx = find(strcmp(subjNames, currentSubjName)); 
    
    % Obtain phaseSpaceCorrected data from all the blocks of the same subject.
    phaseSpaceData = [];
    for blockIdx = 1:length(currentSubjBlockIdx)
        EEG = pop_loadset('filename', allSets(currentSubjBlockIdx(blockIdx)).name, 'filepath', '/data/mobi/Audiomaze/makoto/p04_interpolateAcrossSubjects', 'loadmode', 'info');
        [ALLEEG, EEG, CURRENTSET] = eeg_store( ALLEEG, EEG, 0 );
        phaseSpaceData = cat(3, phaseSpaceData, EEG.etc.audiomaze.phaseSpaceCorrected);
    end
    
    % Convert zeros into NaN.
    phaseSpaceData(phaseSpaceData==0) = NaN;
    
    % Replace zeros with NaNs if all xyz are zeros. (02/10/2020 Makoto)
    torsoData = phaseSpaceData(1:3, torsoIdx, :);
        % % Test.
        % torsoData(:,3, 1234560)=0;
        % torsoData(:,4, 1234565)=0;
        % torsoData(:,5, 1234570)=0;
    zeroMask = squeeze(sum(torsoData==0));
    zeroXyzIdx = find(zeroMask==3);
    [chIdx, timeIdx] = ind2sub(size(zeroMask), zeroXyzIdx);
    torsoData(:,chIdx,timeIdx) = NaN;

    % Detect all-NaN torso frame indices to remove.
    tmp = reshape(torsoData, [size(torsoData,1)*size(torsoData,2), size(torsoData,3)]);
    tmp2 = isnan(tmp);
    tmp3 = sum(tmp2);
    allNanTorsoIdx = find(tmp3==18);
    %nanChannelVec = squeeze(sum(isnan(phaseSpaceData(1,:,:)),2));
    %allNanIdx     = find(nanChannelVec == size(phaseSpaceData,2));
    nonNanFrameIdx = setdiff(1:size(torsoData,3), allNanTorsoIdx);
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Rotate all the frames into the first frame using torso rigid body rotation. %%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    wholeBodyRotated   = nan(3, size(phaseSpaceData,2), size(phaseSpaceData,3));
    rotationMatrixList = nan(3,3,size(torsoData,3));
    referenceTorso     = torsoData(:,:,nonNanFrameIdx(1));
    for goodFrameIdxIdx = 1:length(nonNanFrameIdx)
        frameIdx = nonNanFrameIdx(goodFrameIdxIdx);
        
        % If genuine zero was present in one of xyz, replace the NaN with zero.
        currentTorsoData = torsoData(:,:,frameIdx);
        if any(isnan(currentTorsoData(:)))
            nanMask = isnan(currentTorsoData);
            xzyNanCheck = sum(nanMask) == 3;
            if any(xzyNanCheck)
                error('All-NaN frames found in Torso.')
            else
                currentTorsoData(nanMask) = 0;
            end
        end
        
        [~, ~, rotationMatrix] = rigidBodyTransformation(referenceTorso, currentTorsoData);
        identicallyRotatedData = rotationMatrix*currentTorsoData;
        transitionParameter    = mean(referenceTorso,2) - mean(identicallyRotatedData,2);
        wholeBodyRotatedNotTransitioned = rotationMatrix*phaseSpaceData(1:3,:,frameIdx);
        wholeBodyRotatedAndTransitioned = bsxfun(@plus, wholeBodyRotatedNotTransitioned, transitionParameter);
        wholeBodyRotated(:,:,  frameIdx) = wholeBodyRotatedAndTransitioned;
        rotationMatrixList(:,:,frameIdx) = rotationMatrix;
    end
    
    % vis_mocapMovie(wholeBodyRotated, 500)

    
    
    %%%%%%%%%%%%%%%%%%%%%%%%
    %%% Lerning the arm. %%%
    %%%%%%%%%%%%%%%%%%%%%%%%
    
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Interpolate short interruptions (within 1 s). %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        armShortInterruptionFixed = customInterpolationShortWindow(wholeBodyRotated(:,[12:15],:), 480, 5, 480); % Mocap data sampling rate = 480 Hz.

        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Evaluate arm marker intervals. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        armLengthMatrix = [squeeze(sqrt(sum((wholeBodyRotated(:,8,:) -armShortInterruptionFixed(:,1,:)).^2)))';
                           squeeze(sqrt(sum((armShortInterruptionFixed(:,1,:)-armShortInterruptionFixed(:,2,:)).^2)))';
                           squeeze(sqrt(sum((armShortInterruptionFixed(:,2,:)-armShortInterruptionFixed(:,3,:)).^2)))';
                           squeeze(sqrt(sum((armShortInterruptionFixed(:,3,:)-armShortInterruptionFixed(:,4,:)).^2)))';
                           squeeze(sqrt(sum((armShortInterruptionFixed(:,4,:)-mean(wholeBodyRotated(:,16:20,:),2)).^2)))'];

        armSumVector = squeeze(sum(sum(armShortInterruptionFixed,1),2))';
        nonNanIdx    = find(~isnan(armSumVector));

            % Calculate how many single or multiple missing markers present.
            wholeBodyNanMatrix = isnan(squeeze(wholeBodyRotated(1,:,:)));
            allNanIdx          = find(sum(wholeBodyNanMatrix==size(wholeBodyNanMatrix,1)));
            armNanMatrix = isnan(squeeze(armShortInterruptionFixed(1,:,:)));
            oneMarkerMissing    = find(sum(armNanMatrix)==1);
            twoMarkersMissing   = find(sum(armNanMatrix)==2);
            threeMarkersMissing = find(sum(armNanMatrix)==3);
            fourMarkersMissing  = find(sum(armNanMatrix)==4);

        % Find unbiased mean and SD.
        goodArmLengthMask = zeros(size(armLengthMatrix));
        for lengthIdx = 1:size(armLengthMatrix,1)
            nonNanLengthVector  = armLengthMatrix(lengthIdx,nonNanIdx);
            goodLengthThreshold = prctile(nonNanLengthVector, 95);
            goodLengthIdx       = find(nonNanLengthVector<goodLengthThreshold);
            goodLengthData      = nonNanLengthVector(goodLengthIdx);
            goodLengthMean      = mean(goodLengthData);
            goodLengthStd       = std(goodLengthData);
            nonNanGoodIdxIdx    = find(nonNanLengthVector<(goodLengthMean+sdThreshold*goodLengthStd));
            goodArmLengthMask(lengthIdx, nonNanIdx(nonNanGoodIdxIdx)) = 1;
        end
        completeMarkerIdx = find(sum(goodArmLengthMask)==size(armLengthMatrix,1));
        clear good*
        
        
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Another level of cleaning. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        goodArmXyz = armShortInterruptionFixed(:,:,completeMarkerIdx);
        zscoreMask = zeros(size(goodArmXyz,2), size(goodArmXyz,3));
        for chIdx = 1:size(goodArmXyz,2)
            zscoredXyz          = zscore(squeeze(goodArmXyz(:,chIdx,:)), [], 2);
            goodZscoreMask      = zscoredXyz>-4 & zscoredXyz<4;
            goodZscoreMask      = sum(goodZscoreMask)==3;
            zscoreMask(chIdx,:) = goodZscoreMask;
        end
        completeMarkerIdxIdx = find(sum(zscoreMask)==4);
        completeMarkerIdx = completeMarkerIdx(completeMarkerIdxIdx);

        
        %%%%%%%%%%%%%%%%%%%%%%%%
        %%% Downsample data. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%
        % Prepare data.
        downsamplingIdx = 1:downsamplingFactor:length(completeMarkerIdx);
        %armMarkers      = wholeBodyRotated(:, armIdx, completeMarkerIdx);
        armMarkers      = armShortInterruptionFixed(:,:,completeMarkerIdx);
        handCentroid    = mean(wholeBodyRotated(:, handIdx, completeMarkerIdx),2);
        armHandCombined = cat(2, armMarkers, handCentroid);
        armHandCombined = armHandCombined(:,:,downsamplingIdx);
        clear armMarkers handCentroid

        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Train neural networks. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % Single-channel recovery.
        
        disp('1/17...'); net_12ch = customNeuralNetwork(armHandCombined(:,[2 3 4 5],:), armHandCombined(:,1,:));
        disp('2/17...'); net_13ch = customNeuralNetwork(armHandCombined(:,[1 3 4 5],:), armHandCombined(:,2,:));
        disp('3/17...'); net_14ch = customNeuralNetwork(armHandCombined(:,[1 2 4 5],:), armHandCombined(:,3,:));
        disp('4/17...'); net_15ch = customNeuralNetwork(armHandCombined(:,[1 2 3 5],:), armHandCombined(:,4,:));

        % Two-channel recovery.
        disp('5/17...'); net_12_13ch = customNeuralNetwork(armHandCombined(:,[3 4 5],:), armHandCombined(:,[1 2],:));
        disp('6/17...'); net_12_14ch = customNeuralNetwork(armHandCombined(:,[2 4 5],:), armHandCombined(:,[1 3],:));
        disp('7/17...'); net_12_15ch = customNeuralNetwork(armHandCombined(:,[2 3 5],:), armHandCombined(:,[1 4],:));
        disp('8/17...'); net_13_14ch = customNeuralNetwork(armHandCombined(:,[1 4 5],:), armHandCombined(:,[2 3],:));
        disp('9/17...'); net_13_15ch = customNeuralNetwork(armHandCombined(:,[1 3 5],:), armHandCombined(:,[2 4],:));
        disp('10/17...'); net_14_15ch = customNeuralNetwork(armHandCombined(:,[1 2 5],:), armHandCombined(:,[3 4],:));

        % Three-channel recovery.
        disp('11/17...'); net_12_13_14ch = customNeuralNetwork(armHandCombined(:,[4 5],:), armHandCombined(:,[1 2 3],:));
        disp('12/17...'); net_12_13_15ch = customNeuralNetwork(armHandCombined(:,[3 5],:), armHandCombined(:,[1 2 4],:));
        disp('13/17...'); net_12_14_15ch = customNeuralNetwork(armHandCombined(:,[2 5],:), armHandCombined(:,[1 3 4],:));
        disp('14/17...'); net_13_14_15ch = customNeuralNetwork(armHandCombined(:,[1 5],:), armHandCombined(:,[2 3 4],:));

        % All-channel recovery.
        disp('15/17...'); net_12_13_14_15ch = customNeuralNetwork(armHandCombined(:,[5],:), armHandCombined(:,[1 2 3 4],:));
    
        
        
        
%%        
    %%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Lerning the knees. %%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%
    
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Interpolate short interruptions (within 1 s). %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        rightKneeShortInterruptionFixed = customInterpolationShortWindow(wholeBodyRotated(:, 22, :), 480, 5, 480); % Mocap data sampling rate = 480 Hz.
        leftKneeShortInterruptionFixed  = customInterpolationShortWindow(wholeBodyRotated(:, 28, :), 480, 5, 480); % Mocap data sampling rate = 480 Hz.

        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Evaluate knee marker intervals. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        rightKneeLengthMatrix = [squeeze(sqrt(sum((wholeBodyRotated(:,21,:)-rightKneeShortInterruptionFixed).^2)))';
                                 squeeze(sqrt(sum((mean(wholeBodyRotated(:,23:26,:),2)-rightKneeShortInterruptionFixed).^2)))'];   
        leftKneeLengthMatrix  = [squeeze(sqrt(sum((wholeBodyRotated(:,27,:)-leftKneeShortInterruptionFixed).^2)))';
                                 squeeze(sqrt(sum((mean(wholeBodyRotated(:,29:32,:),2)-leftKneeShortInterruptionFixed).^2)))'];        
        rightLegLength = sum(rightKneeLengthMatrix);                    
        leftLegLength  = sum(leftKneeLengthMatrix);                    
        rightLegNonNanIdx = find(~isnan(rightLegLength));
        leftLegNonNanIdx  = find(~isnan(leftLegLength));
        
        % Find unbiased mean and SD for the right leg.
        goodRightLegLengthMask = zeros(size(rightLegLength));
        nonNanLengthVector  = rightLegLength(rightLegNonNanIdx);
        goodLengthThreshold = prctile(nonNanLengthVector, 95);
        goodLengthIdx       = find(nonNanLengthVector<goodLengthThreshold);
        goodLengthData      = nonNanLengthVector(goodLengthIdx);
        goodLengthMean      = mean(goodLengthData);
        goodLengthStd       = std(goodLengthData);
        nonNanGoodIdxIdx    = find(nonNanLengthVector<(goodLengthMean+sdThreshold*goodLengthStd));
        goodRightLegLengthMask(rightLegNonNanIdx(nonNanGoodIdxIdx)) = 1;
        goodRightLegLengthIdx = find(goodRightLegLengthMask);
        
        % Find unbiased mean and SD for the left leg.
        goodLeftLegLengthMask = zeros(size(leftLegLength));
        nonNanLengthVector  = leftLegLength(leftLegNonNanIdx);
        goodLengthThreshold = prctile(nonNanLengthVector, 95);
        goodLengthIdx       = find(nonNanLengthVector<goodLengthThreshold);
        goodLengthData      = nonNanLengthVector(goodLengthIdx);
        goodLengthMean      = mean(goodLengthData);
        goodLengthStd       = std(goodLengthData);
        nonNanGoodIdxIdx    = find(nonNanLengthVector<(goodLengthMean+sdThreshold*goodLengthStd));
        goodLeftLegLengthMask(leftLegNonNanIdx(nonNanGoodIdxIdx)) = 1;  
        goodLeftLegLengthIdx = find(goodLeftLegLengthMask);

        
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Another level of cleaning. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        goodKneeXyz = rightKneeShortInterruptionFixed(:,:,goodRightLegLengthIdx);
        zscoreMask = zeros(size(goodKneeXyz,2), size(goodKneeXyz,3));
        for chIdx = 1:size(goodKneeXyz,2)
            zscoredXyz          = zscore(squeeze(goodKneeXyz(:,chIdx,:)), [], 2);
            goodZscoreMask      = zscoredXyz>-4 & zscoredXyz<4;
            goodZscoreMask      = sum(goodZscoreMask)==3;
            zscoreMask(chIdx,:) = goodZscoreMask;
        end
        goodXyzIdx = find(zscoreMask==1);
        goodRightLegLengthIdx = goodRightLegLengthIdx(goodXyzIdx);
        
        goodKneeXyz = leftKneeShortInterruptionFixed(:,:,goodLeftLegLengthIdx);
        zscoreMask = zeros(size(goodKneeXyz,2), size(goodKneeXyz,3));
        for chIdx = 1:size(goodKneeXyz,2)
            zscoredXyz          = zscore(squeeze(goodKneeXyz(:,chIdx,:)), [], 2);
            goodZscoreMask      = zscoredXyz>-4 & zscoredXyz<4;
            goodZscoreMask      = sum(goodZscoreMask)==3;
            zscoreMask(chIdx,:) = goodZscoreMask;
        end
        goodXyzIdx = find(zscoreMask==1);
        goodLeftLegLengthIdx = goodLeftLegLengthIdx(goodXyzIdx);       
        
        
        
        
        %%%%%%%%%%%%%%%%%%%%%%%%
        %%% Downsample data. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%
        % Prepare data.
        rightKneeDownsamplingIdx  = 1:downsamplingFactor:length(goodRightLegLengthIdx);
        leftKneeDownsamplingIdx   = 1:downsamplingFactor:length(goodLeftLegLengthIdx);
        rightKneeMarkers = rightKneeShortInterruptionFixed(:,:,goodRightLegLengthIdx(rightKneeDownsamplingIdx));
        leftKneeMarkers  = leftKneeShortInterruptionFixed( :,:,goodLeftLegLengthIdx(leftKneeDownsamplingIdx));
        footR_Centroid  = mean(wholeBodyRotated(:, footR_Idx, :),2);
        footL_Centroid  = mean(wholeBodyRotated(:, footL_Idx, :),2);
        
        
        
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%% Train neural networks. %%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        if ~isempty(rightKneeDownsamplingIdx) % For Subj819.
            rightKneeCombined = cat(2, rightKneeMarkers, footR_Centroid(:,:,goodRightLegLengthIdx(rightKneeDownsamplingIdx)), footL_Centroid(:,:,goodRightLegLengthIdx(rightKneeDownsamplingIdx)));
            disp('16/17...'); net_22ch = customNeuralNetwork(rightKneeCombined(:,[2 3],:), rightKneeCombined(:,1,:));
        end
                
        leftKneeCombined  = cat(2, leftKneeMarkers,  footR_Centroid(:,:,goodLeftLegLengthIdx(leftKneeDownsamplingIdx)), footL_Centroid(:,:,goodLeftLegLengthIdx(leftKneeDownsamplingIdx)));
        disp('17/17...'); net_28ch = customNeuralNetwork(leftKneeCombined(:,[2 3],:), leftKneeCombined(:,1,:));

        
        
        
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Store the learned neural networks. %%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    for eegIdx = 1:length(ALLEEG)
        EEG = ALLEEG(1,eegIdx);
        EEG.etc.audiomaze.neuralNetworks.referenceTorso    = referenceTorso;
        EEG.etc.audiomaze.neuralNetworks.net_12ch          = net_12ch;
        EEG.etc.audiomaze.neuralNetworks.net_13ch          = net_13ch;
        EEG.etc.audiomaze.neuralNetworks.net_14ch          = net_14ch;
        EEG.etc.audiomaze.neuralNetworks.net_15ch          = net_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_13ch       = net_12_13ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_14ch       = net_12_14ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_15ch       = net_12_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_13_14ch       = net_13_14ch;
        EEG.etc.audiomaze.neuralNetworks.net_13_15ch       = net_13_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_14_15ch       = net_14_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_13_14ch    = net_12_13_14ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_13_15ch    = net_12_13_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_14_15ch    = net_12_14_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_13_14_15ch    = net_13_14_15ch;
        EEG.etc.audiomaze.neuralNetworks.net_12_13_14_15ch = net_12_13_14_15ch;
        if ~isempty(rightKneeDownsamplingIdx) % For Subj819.
            EEG.etc.audiomaze.neuralNetworks.net_22ch = net_22ch;
        end
        EEG.etc.audiomaze.neuralNetworks.net_28ch          = net_28ch;
        
        pop_saveset(EEG, 'filename', EEG.filename, 'filepath', '/data/mobi/Audiomaze/makoto/p05_neuralNetworkForNonRigidBodies');
    end
    ALLEEG = [];
end