% Schreiner et al. 2024: 
% Figure_1C: Classification of head orientation-related activity during retrieval 
% Fieldtrip and MVPA-light toolbox need to be installed
%
%
% Input:
% dirRoot   root directory 'data' containing the folders
%           'EEG' and 'auxiliary'

function Fig_1(dirRoot)

datadir    = fullfile(dirRoot,'/EEG/');

% start subject loop
subjects = {'phd03', 'phd05', 'phd06', 'phd07', 'phd08', 'phd12', 'phd13', 'phd15', 'phd16', 'phd17', 'phd20', 'phd21', 'phd22', 'phd23', 'phd24', 'phd25', 'phd26','phd28','phd29','phd30','phd32','phd35','phd36', 'phd39', 'phd40'};

TOI                 = [-1 3];
acc                 = cell(1,numel(subjects));
classi_timeaxis     = TOI(1):0.01:TOI(2);

for iSub  = 1:length(subjects)
    
    % read sleep data and check which head orientations were cued (decoding will be done on cued condition)
    tmp_sleep         = load(fullfile(datadir,subjects{iSub},sprintf('%s_TMR',subjects{iSub})));
    data_trl          = tmp_sleep.data_trl;
        
    hd_react = nan(1,3);
    for ii = 1:3
        hd_react(ii) = data_trl.trialinfo{ii,1}.Display;  
    end
    hd_react = sort(hd_react);
    hd_react = (hd_react(1:2));
    clear data_trl
           
    tmp_ret1       = load(fullfile(datadir,subjects{iSub},sprintf('%s_ret1',subjects{iSub})));
    data_trl_ret1  = tmp_ret1.eeg;
    
    tmp_ret2       = load(fullfile(datadir,subjects{iSub},sprintf('%s_ret2',subjects{iSub})));
    data_trl_ret2  = tmp_ret2.eeg;
        
    cfg            = [];
    cfg.reref      = 'yes';
    cfg.refchannel = {'all'}; 
    cfg.refmethod  = 'avg';
    cfg.removemean = 'no';  
    cfg.demean     = 'yes';
    data_trl_ret1  = ft_preprocessing(cfg, data_trl_ret1);
    data_trl_ret2  = ft_preprocessing(cfg, data_trl_ret2);
    
    % add to trialinfo whether head-direction was cued before (ret1)
    
    for jj = 1 :size (data_trl_ret1.trialinfo,1)        
        if ismember(data_trl_ret1.trialinfo{jj, 1}.head_angle, hd_react) ==1
           data_trl_ret1.trialinfo{jj, 1}.React = 1;
        else
           data_trl_ret1.trialinfo{jj, 1}.React = 0;
        end
    end

    % add to trialinfo whether head-direction was cued before (ret2)  
    for jj = 1 :size (data_trl_ret2.trialinfo,1)        
        if ismember(data_trl_ret2.trialinfo{jj, 1}.head_angle, hd_react) ==1
           data_trl_ret2.trialinfo{jj, 1}.React = 1;
        else
           data_trl_ret2.trialinfo{jj, 1}.React = 0;
        end
    end    
        
    % bring data into better format and normalize (z-score)
    cfg                    = [];
    cfg.channel            = 'eeg'; 
    cfg.demean             = 'yes';           
    cfg.removemean         = 'no';      
    cfg.lpfilter           = 'yes';
    cfg.lpfreq             = 40;
    data_trl_ret1          = ft_preprocessing(cfg, data_trl_ret1);
    data_trl_ret2          = ft_preprocessing(cfg, data_trl_ret2);

    cfg                    = [];
    cfg.keeptrials         = 'yes';
    cfg.removemean         = 'no';     
    data_trl_ret1          = ft_timelockanalysis(cfg,data_trl_ret1);
    dat_trial_ret1         = data_trl_ret1.trial;
    
    data_trl_ret2          = ft_timelockanalysis(cfg,data_trl_ret2);
    dat_trial_ret2         = data_trl_ret2.trial;
    
    % z-score
    preprocess_param            = mv_get_preprocess_param('zscore');
    [~, dat_trial_norm_ret1 ]   = mv_preprocess_zscore(preprocess_param, dat_trial_ret1 );        
    data_trl_ret1.trial         = dat_trial_norm_ret1;

    preprocess_param            = mv_get_preprocess_param('zscore');
    [~, dat_trial_norm_ret2 ]   = mv_preprocess_zscore(preprocess_param, dat_trial_ret2 );        
    data_trl_ret2.trial         = dat_trial_norm_ret2;    
    
    % reduce data to cued conditions 
    % ret1
    trs_exp_ret1 = [];   
    for tr = 1 : size(data_trl_ret1.trial,1)
        if (data_trl_ret1.trialinfo{tr,1}.React ~= 0)
    trs_exp_ret1 = cat(1,trs_exp_ret1,tr);  
        end
    end 
          
    % ret2 
    trs_exp_ret2    = [];   
    for tr = 1 : size(data_trl_ret2.trial,1)
        if (data_trl_ret2.trialinfo{tr,1}.React ~= 0)
    trs_exp_ret2 = cat(1,trs_exp_ret2,tr);  
        end
    end 
    
    cfg             = []; 
    cfg.trials      = trs_exp_ret1; 
    data_trl_ret1   = ft_redefinetrial(cfg,data_trl_ret1);     
        
    cfg             = []; 
    cfg.trials      = trs_exp_ret2; 
    data_trl_ret2   = ft_redefinetrial(cfg,data_trl_ret2);      
    
    % append ret1 and ret2
    cfg             = []; 
    data_trl        = ft_appenddata(cfg,data_trl_ret1, data_trl_ret2);     

    % run pca and reduce ranks
    cfg              = [];
    cfg.method       = 'pca';
    cfg.updatesens   = 'no';
    cfg.numcomponent = 30;
    comp             = ft_componentanalysis(cfg, data_trl);

    cfg              = [];
    cfg.updatesens   = 'no';
    cfg.component    = comp.label(31:end);
    data_trl         = ft_rejectcomponent(cfg, comp);
    
    % temporal smoothing (running average)        
    for itrial = 1:numel(data_trl.trial)
        data_trl.trial{itrial} = smoothdata(data_trl.trial{itrial},2,'movmean',0.15/(1/data_trl.fsample)); %
    end

    % bring data into better format
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no';     
    tmp             = ft_timelockanalysis(cfg,data_trl);
    dat             = tmp.trial;

    % class selection (2 classes)    
    trlinfo  = cell2mat(data_trl.trialinfo);    
    category = [trlinfo.head_angle];       
    behav    = [trlinfo.headang_acc];
        
    left     = [1,2];
    right    = [3,4];
    correct  = 1;
    
    % take only remembered items into account
    sel1  = ismember(category,left) & ismember(behav,correct);
    sel2  = ismember(category,right)& ismember(behav,correct);

    trlnums{iSub} = [sum(sel1) sum(sel2)];

    dat4class   = cat(1,dat(sel1,:,:),dat(sel2,:,:));
    classcode   = cat(1,1*ones(sum(sel1),1),2*ones(sum(sel2),1));
    samples     = nearest(data_trl.time{1},TOI(1)):0.01*data_trl.fsample:nearest(data_trl.time{1},TOI(2));
    dat4class   = dat4class(:,:,samples);       

    % classification
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'auc';
    cfg.repeat      = 5;
    cfg.k           = 5;    
    [perf, ~]       = mv_classify_across_time(cfg,dat4class,classcode);
    acc{iSub}       = perf;   
       
end

%%  stats 
% bring into better shape 
        
    accmat      	= [];
    cnt             =0;

    for isubject=1:numel(acc)
            cnt=cnt+1;
            accmat(cnt,:) = acc{isubject};
    end

    cstat                        = {};
    cstat{1}.label               = {'Channels'};
    cstat{1}.time                = classi_timeaxis;

    cstat{1}.individual(:,1,:)   = accmat;
    cstat{1}.dimord              = 'subj_chan_time';
    cstat{1}.avg                 = squeeze(mean(cstat{1}.individual));

    cstat{2}                     = cstat{1};
    cstat{2}.individual(:,1,:)   = .5*ones(size(accmat));
    cstat{2}.avg                 = squeeze(mean(cstat{2}.individual));

    % run stats        
    cfg                     = [];
    cfg.latency             = [-.5 3];%[TOI(1) TOI(2)];
    cfg.spmversion          = 'spm12';
    cfg.channel             = 'all';
    cfg.statistic           = 'depsamplesT';
    cfg.method              = 'montecarlo'; 
    cfg.correctm            = 'cluster'; 
    cfg.alpha               = .05;
    cfg.clusteralpha        = .05;
    cfg.tail                = 0;
    cfg.correcttail         = 'alpha'; 
    cfg.neighbours          = [];
    cfg.minnbchan           = 0;
    cfg.computecritval      = 'yes';

    cfg.numrandomization    = 1000;
    cfg.clusterstatistic    = 'maxsum'; 
    cfg.clustertail         = cfg.tail;
    cfg.parameter           = 'individual';

    nSub = size(accmat,1);
    % set up design matrix
    design = zeros(2,2*nSub);
    for i = 1:nSub
        design(1,i) = i;
    end
    for i = 1:nSub
        design(1,nSub+i) = i;
    end
    design(2,1:nSub)        = 1;
    design(2,nSub+1:2*nSub) = 2;

    cfg.design  = design;
    cfg.uvar    = 1;
    cfg.ivar    = 2;

    [Fieldtripstats] = ft_timelockstatistics(cfg, cstat{:});
    length(find(Fieldtripstats.mask))
    
    % plot
    
    start_test = find (classi_timeaxis == cfg.latency(1));
    fin_test   = find (classi_timeaxis == cfg.latency(2));
        
    m    = nanmean(accmat(:,start_test:fin_test));
    s    = nanstd(accmat(:,start_test:fin_test))./sqrt(size(accmat(:,start_test:fin_test),1));
    
    boundedline(classi_timeaxis(start_test:fin_test),m,s,'cmap', [0,0,0], 'alpha');
    hold on
    plot(classi_timeaxis(start_test:fin_test),m,'color', [0,0,0],'linewidth',4);

    ylim([0.455 0.64])
    hold on 

    stats_time = nearest(cstat{1}.time,cfg.latency(1)):nearest(cstat{1}.time,cfg.latency(2));

    sigline    = nan(1,numel(cstat{1}.time));
    sigline(stats_time(Fieldtripstats.mask==1)) = .49;

    plot(cstat{1}.time,sigline,'color',[0.6 0.6 0.6],'linewidth',6);

    
    
end