% Schreiner et al. 2024: Figure_2;

% Figure_2a: Time-frequencyy representations of experimental and control cues
% Figure_2b: Classification of retrieval related EEG acticity during TMR
% Figure_2c: searchlight decoding in source space
% Figure_2d: Correlation between classification performance and TMR triggered power 


% Fieldtrip and MVPA-light toolbox need to be installed
%
%
% Input:
% dirRoot   root directory 'data' containing the folders
%           'EEG' and 'auxiliary'

function Fig_2(dirRoot)

% which analysis 
do_Fig2a   = 1;
do_Fig2b   = 0; 
do_Fig2c   = 0;
do_Fig2d   = 0; 

datadir    = fullfile(dirRoot,'/EEG/');
aux_dir    = fullfile(dirRoot,'/auxiliary/');

% start subject loop
subjects = {'phd03', 'phd05', 'phd06', 'phd07', 'phd08', 'phd12', 'phd13', 'phd15', 'phd16', 'phd17', 'phd20', 'phd21', 'phd22', 'phd23', 'phd24', 'phd25', 'phd26','phd28','phd29','phd30','phd32','phd35','phd36', 'phd39', 'phd40'};


if any (do_Fig2a)

    for iSub  = 1:length(subjects)


        % read data
        tmp_sleep         = load(fullfile(datadir,subjects{iSub},sprintf('%s_TMR',subjects{iSub})));
        data_trl          = tmp_sleep.data_trl;

        % re-reference to mastoids      
        cfg            = [];
        cfg.reref      = 'yes';
        cfg.refchannel = {'M1', 'M2'};    
        cfg.refmethod  = 'avg';
        data_trl       = ft_preprocessing(cfg, data_trl);

        cfg            = [];
        cfg.channel    = {'all', '-M1', '-M2'};
        data_trl       = ft_selectdata(cfg, data_trl);

        % define conditions (experimental sounds vs. control)
        trs_exp        = [];
        trs_control    = [];    

        for tr = 1 : length(data_trl.trial)
            if (data_trl.trialinfo{tr,1}.Display ~= 999)        
        trs_exp = cat(1,trs_exp,tr); 

            elseif (data_trl.trialinfo{tr,1}.Display ==999)... % 999 refers to control cues
        trs_control = cat(1,trs_control,tr);

            end
        end

        % now time frequency part
        cfg              = [];
        cfg.keeptrials   = 'yes';
        cfg.output       = 'pow';
        cfg.method       = 'mtmconvol';
        cfg.taper        = 'hanning';
        cfg.foi          = 1:1:50;     % frequency range
        cfg.t_ftimwin    = 5./cfg.foi; % 5 cycles per time window
        cfg.toi          = -1:0.05:4;  % time range   
        tfr_all          = ft_freqanalysis(cfg, data_trl);

        % zscore all trials
        startT  = nearest(tfr_all.time,-1);
        endT    = nearest(tfr_all.time, 4);

            pow  = tfr_all.powspctrm;
            powz = nan(size(pow));

            for ichan = 1:size(pow,2)
                for ifreq = 1:size(pow,3)
                    d = squeeze(pow(:,ichan,ifreq,startT:endT));
                    m = nanmean(d(:));
                    s = nanstd(d(:));

                    powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
                end
            end

        pow                 = [];
        tfr_all.powspctrm   = powz;                   
        clear powz;    

        % segment data according to conditions
        cfg = []; cfg.trials = trs_control;tfr_control = ft_selectdata(cfg,tfr_all);
        cfg = []; cfg.trials = trs_exp;tfr_exp         = ft_selectdata(cfg,tfr_all);    

        % average over single trials
        tfr_control.dimord    = 'chan_freq_time';
        tfr_exp.dimord        = 'chan_freq_time';
        tfr_control.powspctrm = squeeze(nanmean(tfr_control.powspctrm,1));
        tfr_exp.powspctrm     = squeeze(nanmean(tfr_exp.powspctrm,1));

        % collect all data in one structure   
        ALL_TFR_control{iSub} = tfr_control;
        ALL_TFR_exp{iSub}     = tfr_exp;   

    end


        % stats  
        cfg = [];
        cfg.method      = 'triangulation';
        cfg.elec        = 'elec_ant64_fin.mat';
        neighbours      = ft_prepare_neighbours(cfg);    

        cfg = [];
        cfg.latency          = [-0.5 1.5];    
        cfg.avgovertime      = 'no'; 
        cfg.spmversion       = 'spm12';
        cfg.channel          = {'all'};

        cfg.frequency        = [2 25];                                    
        cfg.method           = 'montecarlo';
        cfg.statistic        = 'ft_statfun_depsamplesT';
        cfg.correctm         = 'cluster';
        cfg.clusteralpha     = 0.05;
        cfg.clusterstatistic = 'maxsum';
        cfg.minnbchan        = 3;
        cfg.tail             = 0;
        cfg.clustertail      = 0;
        cfg.alpha            = 0.05;
        cfg.numrandomization = 1000;
        cfg.computecritval   = 'yes';
        cfg.correcttail      = 'alpha';

        %specifies with which sensors other sensors can form clusters
        cfg.neighbours = neighbours;

        subj = size (subjects,2);  
        design = zeros(2,2*subj);
        for i = 1:subj
          design(1,i) = i;
        end
        for i = 1:subj
          design(1,subj+i) = i;
        end
        design(2,1:subj)        = 1;
        design(2,subj+1:2*subj) = 2;

        cfg.design   = design;
        cfg.uvar     = 1;
        cfg.ivar     = 2;
        Fieldtripstat = ft_freqstatistics(cfg,ALL_TFR_exp{:}, ALL_TFR_control{:});
        length(find(Fieldtripstat.mask))

        stat = Fieldtripstat;

        % plot stats
        addpath ('/Volumes/Seagate Portable Drive/dir_thal_sleep_data/scripts/auxiliary/')

        stat.stat = stat.stat.*stat.mask; % mask data
        t_sums = squeeze(sum(stat.stat));

        figure;
        pcolor(stat.time,[2:1:25], t_sums); 
        shading interp;%lighting phong;
        axis xy; 
        ll = xlabel('time [seconds]', 'Fontsize', 17);
        yy = ylabel('Frequency [Hz]', 'Fontsize', 17);
        set(gca,'layer','top')
        set(gca,'TickDir','in', 'Linewidth',2);

        caxis([0 100]);
        set(gcf, 'Color', 'w')
        colorbar;
        hcb = colorbar;

        set(gca,'FontSize',14)
        ax = gca;
        ax.XTick = [-0.49 0 0.5 1 1.49];
        ax.XTickLabel = [{'-0.5'} '0' '0.5' '1' '1.5'];

        ax.YTick = [5 10 15 20 25];

        set(gcf, 'Position', [100, 400, 570, 200]);

        stats_time = nearest(stat.time,cfg.latency(1)):nearest(stat.time,cfg.latency(2));
        stats_freq = nearest(stat.freq,cfg.frequency(1)):nearest(stat.freq,cfg.frequency(2));

        sigmap           = zeros(numel(stat.freq),numel(stat.time));
        sigval           = (t_sums);
        sigmap(stats_freq,stats_time) = sigval;
        sigmap           = abs(sigmap);
        sigmap(sigmap>0) = 1;

        plot_contour(stat.time,stat.freq,sigmap)
        xline(0,'--', 'linewidth',2, 'color','k');

    end


    if any (do_Fig2b)

    TOI_wake  = [-0.5 1];
    TOI_sleep = [-0.5 1.5]; 

    classi_timeaxis_train = TOI_wake(1):0.01:TOI_wake(2);
    classi_timeaxis_test  = TOI_sleep(1):0.01:TOI_sleep(2);
    acc                   = cell(1,numel(subjects));

    % enter subject loop
    for iSub=1:numel(subjects)

        tmp_wake2         = load(fullfile(datadir,subjects{iSub},sprintf('%s_ret1.mat',subjects{iSub}))); % pre-sleep retrieval 
        data_wake2        = tmp_wake2.eeg;
        tmp_wake2         = [];

        tmp_wake1         = load(fullfile(datadir,subjects{iSub},sprintf('%s_ret2.mat',subjects{iSub}))); % post-sleep retrieval
        data_wake1        = tmp_wake1.eeg;
        tmp_wake1         = [];

        tmp_sleep         = load(fullfile(datadir,subjects{iSub},sprintf('%s_TMR.mat',subjects{iSub})));
        data_sleep        = tmp_sleep.data_trl;
        tmp_sleep         = [];

        % check which head directions were cued
        for ii = 1:3
            hd_react(ii) = data_sleep.trialinfo{ii,1}.Display;  
        end

        hd_react = sort(hd_react);
        hd_react = (hd_react(1:2));
        clear data_trl

        % re-reference
        cfg            = [];
        cfg.reref      = 'yes';
        cfg.refchannel = {'all'}; 
        cfg.refmethod  = 'avg';
        cfg.removemean = 'no';  
        cfg.demean     = 'yes';
        data_wake1     = ft_preprocessing(cfg, data_wake1);
        data_wake2     = ft_preprocessing(cfg, data_wake2);

        % add to trialinfo whether head-direction was cued before

        for jj = 1 :size (data_wake1.trialinfo,1)        
            if ismember(data_wake1.trialinfo{jj, 1}.head_angle, hd_react) ==1
               data_wake1.trialinfo{jj, 1}.React = 1;
            else
               data_wake1.trialinfo{jj, 1}.React = 0;
            end
        end

        for jjj = 1 :size (data_wake2.trialinfo,1)        
            if ismember(data_wake2.trialinfo{jjj, 1}.head_angle, hd_react) ==1
               data_wake2.trialinfo{jjj, 1}.React = 1;
            else
               data_wake2.trialinfo{jjj, 1}.React = 0;
            end
        end 

        % bring data into better format and normalize (z-score)
        cfg                    = [];
        cfg.channel            = 'eeg'; 
        cfg.hpfilter           = 'yes';
        cfg.hpfreq             = 0.1;
        cfg.lpfilter           = 'yes';
        cfg.lpfreq             = 40;
        cfg.hpinstabilityfix   = 'reduce';    
        cfg.demean             = 'yes';           
        cfg.removemean         = 'no';        
        data_trl_ret1          = ft_preprocessing(cfg, data_wake1);
        data_trl_ret2          = ft_preprocessing(cfg, data_wake2);

        cfg                    = [];
        cfg.keeptrials         = 'yes';
        cfg.removemean         = 'no';     
        data_trl_ret1          = ft_timelockanalysis(cfg,data_trl_ret1);
        dat_trial_ret1         = data_trl_ret1.trial;

        data_trl_ret2          = ft_timelockanalysis(cfg,data_trl_ret2);
        dat_trial_ret2         = data_trl_ret2.trial;

        % z-score
        preprocess_param            = mv_get_preprocess_param('zscore');
        [~, dat_trial_norm_ret1 ]   = mv_preprocess_zscore(preprocess_param, dat_trial_ret1 );        
        data_trl_ret1.trial         = dat_trial_norm_ret1;

        preprocess_param            = mv_get_preprocess_param('zscore');
        [~, dat_trial_norm_ret2 ]   = mv_preprocess_zscore(preprocess_param, dat_trial_ret2 );        
        data_trl_ret2.trial         = dat_trial_norm_ret2;    

        % reduce data to experimental conditions (left + right react) ret1
        trs_exp_ret1 = [];   
        for tr = 1 : size(data_trl_ret1.trial,1)
            if (data_trl_ret1.trialinfo{tr,1}.React ~= 0)
        trs_exp_ret1 = cat(1,trs_exp_ret1,tr);  
            end
        end 

        % ret2 
        trs_exp_ret2    = [];   
        for tr = 1 : size(data_trl_ret2.trial,1)
            if (data_trl_ret2.trialinfo{tr,1}.React ~= 0)
        trs_exp_ret2 = cat(1,trs_exp_ret2,tr);  
            end
        end 

        % allocate data to conditions
        cfg             = []; 
        cfg.trials      = trs_exp_ret1; 
        data_trl_ret1   = ft_redefinetrial(cfg,data_trl_ret1);     

        cfg             = []; 
        cfg.trials      = trs_exp_ret2; 
        data_trl_ret2   = ft_redefinetrial(cfg,data_trl_ret2);  

        % appenddata 
        cfg       = [];
        data_wake = ft_appenddata(cfg,data_trl_ret1,data_trl_ret2); 

        % get rid of some channels and re-reference TMR data
        cfg            = [];
        cfg.channel    = {'all','-M1', '-M2'};
        data_sleep     = ft_selectdata(cfg, data_sleep);    

        cfg            = [];
        cfg.resamplefs = 200;   
        data_sleep     = ft_resampledata(cfg, data_sleep);   

        cfg            = [];
        cfg.reref      = 'yes';
        cfg.refchannel = {'all'}; % for detection we should use mastoids
        cfg.refmethod  = 'avg';
        cfg.removemean = 'no';  
        cfg.demean     = 'yes';
        data_sleep     = ft_preprocessing(cfg, data_sleep);

        %% bring data into better format and normalize (z-score)
        cfg                    = [];
        cfg.channel            = 'eeg'; 
        cfg.hpfilter           = 'yes';
        cfg.hpfreq             = 0.1;
        cfg.hpinstabilityfix   = 'reduce';
        cfg.lpfilter           = 'yes';
        cfg.lpfreq             = 40;
        cfg.removemean         = 'no';        
        data_sleep             = ft_preprocessing(cfg, data_sleep);

        cfg                    = [];
        cfg.keeptrials         = 'yes';
        cfg.removemean         = 'no';     
        data_sleep             = ft_timelockanalysis(cfg,data_sleep);
        dat_sleep              = data_sleep.trial;

        % z-score
        preprocess_param       = mv_get_preprocess_param('zscore');
        [~, dat_trial_norm ]   = mv_preprocess_zscore(preprocess_param, dat_sleep);        
        data_sleep.trial       = dat_trial_norm;

        % reduce data to experimental conditions (left & right)
        trs_exp       = [];   
        for tr = 1 : size(data_sleep.trial,1)
            if (data_sleep.trialinfo{tr,1}.Display ~= 999)
        trs_exp = cat(1,trs_exp,tr);  
            end
        end 

        % allocate data to conditions
        cfg             = []; 
        cfg.trials      = trs_exp; 
        data_sleep     = ft_redefinetrial(cfg,data_sleep);  

        % append wake & sleep                
        cfg              = [];
        data_app         = ft_appenddata (cfg,data_wake, data_sleep);

        % run pca on appended data and reduce ranks
        cfg              = [];
        cfg.method       = 'pca';
        cfg.updatesens   = 'no';
        cfg.numcomponent = 30;
        comp             = ft_componentanalysis(cfg, data_app);

        cfg              = [];
        cfg.updatesens   = 'no';
        cfg.demean       = 'no';    
        cfg.component    = comp.label(31:end);
        data_app         = ft_rejectcomponent(cfg, comp);        

        % identify wake data
        cfg              = [];
        cfg.trials       = 1:numel(data_wake.trialinfo);
        data_all_wake     = ft_selectdata(cfg,data_app);

        % identify sleep part
        cfg              = [];
        cfg.trials       = numel(data_wake.trialinfo)+1:numel(data_app.trialinfo);
        data_all_sleep   = ft_selectdata(cfg,data_app);

        % temporal smoothing (running average)        
        for itrial = 1:numel(data_all_wake.trial)
            data_all_wake.trial{itrial} = smoothdata(data_all_wake.trial{itrial},2,'movmean',0.15/(1/data_all_wake.fsample));
        end

        for itrial = 1:numel(data_all_sleep.trial)
            data_all_sleep.trial{itrial} = smoothdata(data_all_sleep.trial{itrial},2,'movmean',0.15/(1/data_all_sleep.fsample));
        end

        % prepare data for classification
        cfg             = [];
        cfg.keeptrials  = 'yes';
        cfg.removemean  = 'no'; 
        tmp             = ft_timelockanalysis(cfg,data_all_wake);
        dat_wake        = tmp.trial;

        % class selection (2 classes)
        trlinfo  = cell2mat(data_all_wake.trialinfo);    
        category = [trlinfo.head_angle];       
        behav    = [trlinfo.headang_acc];

        % Training (define classes)
        left    = [1,2];
        right   = [3,4];
        correct = 1;

        train_sel1       = ismember(category,left) & ismember(behav,correct);
        train_sel2       = ismember(category,right)& ismember(behav,correct);

        % Training Labels
        train_dat        = cat(1,dat_wake(train_sel1,:,:),dat_wake(train_sel2,:,:));
        classcode_train  = cat(1,1*ones(sum(train_sel1),1),2*ones(sum(train_sel2),1));

        % now the same for sleep data    
        % bring data into better format
        cfg             = [];
        cfg.keeptrials  = 'yes';
        cfg.removemean  = 'no';
        tmp             = ft_timelockanalysis(cfg,data_all_sleep);
        dat_sleep       = tmp.trial;

        % class selection (2 classes)    
        trlinfo  = cell2mat(data_all_sleep.trialinfo);    
        category = [trlinfo.Display];       

        % Testing data (define classes))
        left       = [1,2];
        right      = [3,4];

        test_sel1  = ismember(category,left);
        test_sel2  = ismember(category,right);    

        % Test Labels
        test_dat       = cat(1,dat_sleep(test_sel1,:,:),dat_sleep(test_sel2,:,:));
        classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));

        % time-info
        timeaxis_samples_train = nearest(data_all_wake.time{1},TOI_wake(1)):round(0.01*data_all_wake.fsample):nearest(data_all_wake.time{1},TOI_wake(2));
        timeaxis_samples_test  = nearest(data_all_sleep.time{1},TOI_sleep(1)):round(0.01*data_all_sleep.fsample):nearest(data_all_sleep.time{1},TOI_sleep(2));

        test_train_dat         = train_dat(:,:,timeaxis_samples_train);
        test_test_dat          = test_dat(:,:,timeaxis_samples_test);

        % classification
        cfg             = [];
        cfg.classifier  = 'lda';
        cfg.metric      = 'auc';

        [perf, ~]  = mv_classify_timextime(cfg, test_train_dat, classcode_train, test_test_dat, classcode_test);        
        acc{iSub}       = perf;

    end

        % stats
        accmat = [];
        cnt=0;
        subject_include = ones(1,numel(acc));

        for isubject=1:numel(acc)
                cnt=cnt+1;
                accmat(cnt,:,:) = acc{isubject}-.5;            
        end

        cstat = {};

        cstat{1}.label               = {'Channels'};
        cstat{1}.time                = classi_timeaxis_test;
        cstat{1}.freq                = classi_timeaxis_train;

        cstat{1}.individual(:,1,:,:) = accmat;
        cstat{1}.dimord              = 'subj_chan_freq_time';
        cstat{1}.avg                 = squeeze(mean(cstat{1}.individual));

        cstat{2}                     = cstat{1};
        cstat{2}.individual(:,1,:,:) = zeros(size(accmat));
        cstat{2}.avg                 = squeeze(mean(cstat{2}.individual));

        cfg                     = [];
        cfg.spmversion          = 'spm12';
        cfg.latency             = [-.5 1.5];%TOI_sleep;
        cfg.frequency           = [-.5 1];  %TOI_wake;
        cfg.statistic           = 'depsamplesT';
        cfg.method              = 'montecarlo';
        cfg.correctm            = 'cluster';
        cfg.alpha               = .05;
        cfg.clusteralpha        = .05;
        cfg.tail                = 0;
        cfg.neighbours          = [];
        cfg.minnbchan           = 0;
        cfg.computecritval      = 'yes';
        cfg.numrandomization    = 1000;

        cfg.clusterstatistic    = 'maxsum';
        cfg.clustertail         = cfg.tail;
        cfg.parameter           = 'individual';

        design = zeros(2,2*numel(acc));
        design(1,:)= [1:numel(acc), 1:numel(acc)];
        design(2,:)= [ones(1,numel(acc)), 2*ones(1,numel(acc))];

        cfg.design  = design;
        cfg.uvar    = 1;
        cfg.ivar    = 2;

        % run stats
        [Fieldtripstats] = ft_freqstatistics(cfg, cstat{:});
        length(find(Fieldtripstats.mask))

        % plot it

        %define time
        start_test = find (classi_timeaxis_test == cfg.latency(1));
        fin_test   = find (classi_timeaxis_test == cfg.latency(2));

        start_train = find (classi_timeaxis_test == cfg.frequency(1));
        fin_train   = find (classi_timeaxis_test == cfg.frequency(2));

        %prepare contour plot
        stats_time = nearest(cstat{1}.time,cfg.latency(1)):nearest(cstat{1}.time,cfg.latency(2));
        stats_freq = nearest(cstat{1}.freq,cfg.frequency(1)):nearest(cstat{1}.freq,cfg.frequency(2));

        sigmap    = zeros(numel(Fieldtripstats.freq),numel(Fieldtripstats.time));
        sigval    = Fieldtripstats.mask;
        sigmap(stats_freq,stats_time) = sigval;

        % plot it
        figure;
        imagesc(classi_timeaxis_test(start_test:fin_test),classi_timeaxis_train(start_train:fin_train),squeeze(Fieldtripstats.stat));

        colorbar
        axis xy
        xlabel('time TMR (sec)','FontSize', 16)
        ylabel('time retrieval (sec)','FontSize', 16)
        set(gca,'tickdir','out')
        caxis([-3 3])

        h = colorbar();
        ylabel(h, 't-values','FontSize', 16);
        set(gca,'TickDir','in');
        set(gca,'linewidth',1.5)

        a = get(gca,'XTickLabel');
        set(gca,'XTickLabel',a, 'fontsize',16)

        a = get(gca,'YTickLabel');    
        set(gca,'YTickLabel',a,'fontsize',16)

        plot_contour(classi_timeaxis_test(start_test:fin_test),classi_timeaxis_train(start_train:fin_train),sigmap,'k-',2)
        xline(0, '--', 'linewidth',1.5)
        yline(0, '--', 'linewidth',1.5)

    end   
    
  if any (do_Fig2c)
      
    % load all relevant files for source analysis

    data_trl = load(fullfile(datadir,subjects{1},sprintf('%s_TMR',subjects{iSub})));
    
    vol             = load(fullfile(aux_dir,'vol.mat'));
    lf              = load(fullfile(aux_dir,'lf.mat'));
    mri             = load(fullfile(aux_dir,'standard_mri.mat'));
    template_source = load(fullfile(aux_dir,'template_source.mat'));

    elec  = ft_read_sens('elec_ant64.mat');   
    elec  = ft_convert_units(elec, 'cm');

    atlas = ft_read_atlas('ROI_MNI_V4.nii');
    atlas = ft_convert_units(atlas,'cm');
    mri   = ft_convert_units(mri,'cm');

    % define neighbour structure for source-space
    grid_inside = lf.pos(lf.inside,:,:);

    % bring raw (segmented)data into shape 
    cfg  = [];
    data = ft_timelockanalysis(cfg, data_trl);

    % fake data structure for source
    data.time         = template_source.time;
    data.avg          = template_source.avg;
    data.elec.chanpos = grid_inside;

    % assign labels to voxels
    label             = {};
    for cc            = 1 : size((template_source.avg),1); label{cc,1} = [num2str(cc)]; end
    data.label        = label;
    data.elec.label   = label;

    % compute neighbours for voxels
    cfg               = [];
    cfg.method        = 'distance';
    cfg.neighbourdist = 2; 
    SL_neighbours     = ft_prepare_neighbours(cfg, data);

    clear data
    clear template_source

    % time of interest
    TOI_wake                  = [-0.5 1] ;
    TOI_sleep                 = [-0.5 1.5];

    classi_timeaxis_train     = TOI_wake(1):0.01:TOI_wake(2);
    classi_timeaxis_test      = TOI_sleep(1):0.01:TOI_sleep(2);
    acc                       = cell(1,numel(subjects));


    for iSub=1:numel(subjects)

        tmp_wake1    = load(fullfile(datadir_wake,subjects{iSub},sprintf('%s_ret1.mat',subjects{iSub}))); % pre-sleep retrieval
        data_wake1   = tmp_wake1.eeg;
        tmp_wake1    = [];

        tmp_wake2    = load(fullfile(datadir_wake,subjects{iSub},sprintf('%s_ret2.mat',subjects{iSub}))); % post-sleep retrieval
        data_wake2   = tmp_wake2.eeg;
        tmp_wake2    = [];   

        tmp_sleep         = load(fullfile(datadir,subjects{iSub},sprintf('%s_TMR.mat',subjects{iSub})));
        data_sleep        = tmp_sleep.data_trl;
        tmp_sleep         = [];        
        
             
        % check which head directions were cued
        for ii = 1:3
            hd_react(ii) = data_trl.trialinfo{ii,1}.Display;  
        end

        hd_react = sort(hd_react);
        hd_react = (hd_react(1:2));
  

        cfg                    = [];
        cfg.reref              = 'yes';
        cfg.refchannel         = {'all'}; 
        cfg.refmethod          = 'avg';
        cfg.removemean         = 'no';  
        cfg.demean             = 'yes';
        cfg.hpfilter           = 'yes';
        cfg.hpfreq             = 0.1;
        cfg.hpinstabilityfix   = 'reduce';    
        cfg.lpfilter           = 'yes';
        cfg.lpfreq             = 40;

        data_wake1             = ft_preprocessing(cfg, data_wake1);
        data_wake2             = ft_preprocessing(cfg, data_wake2);

        % add to trialinfo whether head-direction was cued before

        for jj = 1 :size (data_wake1.trialinfo,1)        
            if ismember(data_wake1.trialinfo{jj, 1}.head_angle, hd_react) ==1
               data_wake1.trialinfo{jj, 1}.React = 1;
            else
               data_wake1.trialinfo{jj, 1}.React = 0;
            end
        end

        for jjj = 1 :size (data_wake2.trialinfo,1)        
            if ismember(data_wake2.trialinfo{jjj, 1}.head_angle, hd_react) ==1
               data_wake2.trialinfo{jjj, 1}.React = 1;
            else
               data_wake2.trialinfo{jjj, 1}.React = 0;
            end
        end 

        % reduce data to experimental conditions 
        % ret1
        trs_exp_ret1 = [];   
        for tr = 1 : size(data_wake1.trial,2)
            if (data_wake1.trialinfo{tr,1}.React ~= 0)
        trs_exp_ret1 = cat(1,trs_exp_ret1,tr);  
            end
        end 

        % ret2 
        trs_exp_ret2    = [];   
        for tr = 1 : size(data_wake2.trial,2)
            if (data_wake2.trialinfo{tr,1}.React ~= 0)
        trs_exp_ret2 = cat(1,trs_exp_ret2,tr);  
            end
        end 

        % allocate data to conditions
        cfg             = []; 
        cfg.trials      = trs_exp_ret1; 
        data_wake1      = ft_redefinetrial(cfg,data_wake1);     

        cfg             = []; 
        cfg.trials      = trs_exp_ret2; 
        data_wake2      = ft_redefinetrial(cfg,data_wake2);  

        % bring data into better format and normalize (z-score)
        cfg             = [];
        cfg.keeptrials  = 'yes';
        cfg.removemean  = 'no';     
        data_trl_ret1   = ft_timelockanalysis(cfg,data_wake1);
        dat_trial_ret1  = data_trl_ret1.trial;

        data_trl_ret2   = ft_timelockanalysis(cfg,data_wake2);
        dat_trial_ret2  = data_trl_ret2.trial;

        % z-score
        preprocess_param            = mv_get_preprocess_param('zscore');
        [~, dat_trial_norm_ret1 ]   = mv_preprocess_zscore(preprocess_param, dat_trial_ret1 );        
        data_trl_ret1.trial         = dat_trial_norm_ret1;

        preprocess_param            = mv_get_preprocess_param('zscore');
        [~, dat_trial_norm_ret2 ]   = mv_preprocess_zscore(preprocess_param, dat_trial_ret2 );        
        data_trl_ret2.trial         = dat_trial_norm_ret2;    

        % append data 
        cfg       = [];
        data_wake = ft_appenddata(cfg, data_trl_ret1,data_trl_ret2);  

        % get rid of some channels and re-reference
        cfg            = [];
        cfg.channel    = {'all','-M1', '-M2'};
        data_sleep     = ft_selectdata(cfg, data_sleep);    

        cfg            = [];
        cfg.resamplefs = 200;   
        data_sleep     = ft_resampledata(cfg, data_sleep);   

        cfg                    = [];
        cfg.reref              = 'yes';
        cfg.refchannel         = {'all'}; 
        cfg.refmethod          = 'avg';
        cfg.removemean         = 'no';  
        cfg.demean             = 'yes';
        cfg.hpfilter           = 'yes';
        cfg.hpfreq             = 0.1;
        cfg.hpinstabilityfix   = 'reduce';
        cfg.lpfilter           = 'yes';
        cfg.lpfreq             = 40;
        data_sleep             = ft_preprocessing(cfg, data_sleep);

        % reduce data to experimental conditions (left & right; control sounds not needed)
        trs_exp       = [];   
        for tr = 1 : size(data_sleep.trial,2)
            if (data_sleep.trialinfo{tr,1}.Display ~= 999)
        trs_exp = cat(1,trs_exp,tr);  
            end
        end 

        % allocate data to conditions
        cfg                   = []; 
        cfg.trials            = trs_exp; 
        data_sleep            = ft_redefinetrial(cfg,data_sleep); 

        % bring data into better format and normalize (z-score)
        cfg                   = [];
        cfg.keeptrials        = 'yes';
        cfg.removemean        = 'no';     
        data_sleep            = ft_timelockanalysis(cfg,data_sleep);
        dat_sleep             = data_sleep.trial;

        % z-score
        preprocess_param      = mv_get_preprocess_param('zscore');
        [~, dat_trial_norm ]  = mv_preprocess_zscore(preprocess_param, dat_sleep);        
        data_sleep.trial      = dat_trial_norm;
        clear dat_trial_norm

        % append wake & sleep                
        cfg              = [];
        data_app         = ft_appenddata (cfg,data_wake, data_sleep);

        % run pca on appended data and reduce ranks
        cfg              = [];
        cfg.method       = 'pca';
        cfg.updatesens   = 'no';
        cfg.numcomponent = 30;
        cfg.demean       = 'no';
        comp             = ft_componentanalysis(cfg, data_app);

        cfg              = [];
        cfg.updatesens   = 'no';
        cfg.component    = comp.label(31:end);
        cfg.demean       = 'no';    
        data_app         = ft_rejectcomponent(cfg, comp);        

        % identify wake data
        cfg              = [];
        cfg.trials       = 1:numel(data_wake.trialinfo);
        data_all_wake    = ft_selectdata(cfg,data_app);

        % identify sleep part
        cfg              = [];
        cfg.trials       = numel(data_wake.trialinfo)+1:numel(data_app.trialinfo);
        data_all_sleep   = ft_selectdata(cfg,data_app);


        % temporal smoothing (running average)        
        for itrial = 1:numel(data_all_wake.trial)
            data_all_wake.trial{itrial} = smoothdata(data_all_wake.trial{itrial},2,'movmean',0.15/(1/data_all_wake.fsample));
        end

        for itrial = 1:numel(data_all_sleep.trial)
            data_all_sleep.trial{itrial} = smoothdata(data_all_sleep.trial{itrial},2,'movmean',0.15/(1/data_all_sleep.fsample));
        end

        %% now source
        cfg                  = [];
        cfg.covariance       = 'yes';
        cfg.covariancewindow = 'all';
        cfg.vartrllength     = 2;
        data2filt            = ft_timelockanalysis(cfg, data_all_wake);
        data2filt.elec       = elec;

        cfg                  = [];
        cfg.method           = 'lcmv';
        cfg.lcmv.keepfilter  = 'yes';
        cfg.grid             = lf;       
        cfg.headmodel        = vol;
        cfg.lcmv.fixedori    = 'yes';
        cfg.lcmv.lambda      = '0%';
        source_ret_all       = ft_sourceanalysis(cfg, data2filt);
        data2filt            = [];

        filtermat = cell2mat(source_ret_all.avg.filter(source_ret_all.inside)); 

        source_ret            = [];
        source_ret.sampleinfo = data_all_wake.sampleinfo; % transfer sample information
        source_ret.time       = data_all_wake.time;       % transfer time information
        source_ret.trialinfo  = data_all_wake.trialinfo;  % transfer trial information
        source_ret.fsample    = data_all_wake.fsample;    % set sampling rate

        % create labels for each virtual electrode
        label = {};
        for cc = 1 : numel(find(source_ret_all.inside)); label{cc,1} = [num2str(cc)]; end
        source_ret.label = label;

        % for each trial, apply filters to the recorded data
        for jj = 1 : numel(data_all_wake.trial); source_ret.trial{1,jj} = filtermat*data_all_wake.trial{1,jj}; end
        clear filtermat;

        % bring into nicer format
        cfg             = [];
        cfg.keeptrials  = 'yes';
        tmp             = ft_timelockanalysis(cfg,source_ret);
        dat_wake        = tmp.trial;
        tmp             = [];
        source_ret_all  = [];     

        % class selection (2 classes)
        trlinfo  = cell2mat(source_ret.trialinfo);    
        category = [trlinfo.head_angle];       
        behav    = [trlinfo.headang_acc];

        % training
        left     = [1,2];
        right    = [3,4];
        correct  = 1;

        train_sel1  = ismember(category,left) & ismember(behav,correct);
        train_sel2  = ismember(category,right)& ismember(behav,correct);
        clear behav

        % training Labels
        train_dat        = cat(1,dat_wake(train_sel1,:,:),dat_wake(train_sel2,:,:));
        classcode_train  = cat(1,1*ones(sum(train_sel1),1),2*ones(sum(train_sel2),1));

        % now the same for sleep data    
        cfg                  = [];
        cfg.covariance       = 'yes';
        cfg.covariancewindow = 'all';
        cfg.vartrllength     = 2;
        data2filt            = ft_timelockanalysis(cfg, data_all_sleep);
        data2filt.elec       = elec;

        cfg                  = [];
        cfg.method           = 'lcmv';
        cfg.lcmv.keepfilter  = 'yes';
        cfg.grid             = lf;       
        cfg.headmodel        = vol;
        cfg.lcmv.fixedori    = 'yes';
        cfg.lcmv.lambda      = '0%';
        source_tmr_all       = ft_sourceanalysis(cfg, data2filt);

        filtermat = cell2mat(source_tmr_all.avg.filter(source_tmr_all.inside)); 

        source_tmr            = [];
        source_tmr.sampleinfo = data_all_sleep.sampleinfo; % transfer sample information
        source_tmr.time       = data_all_sleep.time;       % transfer time information
        source_tmr.trialinfo  = data_all_sleep.trialinfo;  % transfer trial information
        source_tmr.fsample    = data_all_sleep.fsample;    % set sampling rate

        % create labels for each virtual electrode
        label = {};
        for cc = 1 : numel(find(source_tmr_all.inside)); label{cc,1} = [num2str(cc)]; end
        source_tmr.label = label;

        % for each trial, apply filters to the recorded data
        for jj = 1 : numel(data_all_sleep.trial); source_tmr.trial{1,jj} = filtermat*data_all_sleep.trial{1,jj}; end
        clear filtermat;
        data_all_sleep = [];

        % bring into nicer format
        cfg             = [];
        cfg.keeptrials  = 'yes';
        cfg.removemean  = 'no';
        tmp             = ft_timelockanalysis(cfg,source_tmr);
        dat_sleep       = tmp.trial;
        tmp             = [];
        source_tmr_all  = [];       

        % class selection (2 classes)    
        trlinfo  = cell2mat(source_tmr.trialinfo);    
        category = [trlinfo.Display];       

        % Testing data 
        left  = [1,2];
        right = [3,4];

        test_sel1  = ismember(category,left);
        test_sel2  = ismember(category,right);    

        % Test Labels
        test_dat       = cat(1,dat_sleep(test_sel1,:,:),dat_sleep(test_sel2,:,:));
        classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));

        % time-info
        timeaxis_samples_train = nearest(source_ret.time{1},TOI_wake(1)):round(0.01*source_ret.fsample):nearest(source_ret.time{1},TOI_wake(2));
        timeaxis_samples_test  = nearest(source_tmr.time{1},TOI_sleep(1)):round(0.01*source_tmr.fsample):nearest(source_tmr.time{1},TOI_sleep(2));

        test_train_dat         = train_dat(:,:,timeaxis_samples_train);
        test_test_dat          = test_dat(:,:,timeaxis_samples_test);

        labelz = source_tmr.label;

        % classification    
        perfmat = nan(numel(labelz),numel(classi_timeaxis_train),numel(classi_timeaxis_test));

        cfg             = [];
        cfg.classifier  = 'lda';
        cfg.metric      = 'auc';
        cfg.k           = 5;

        for ichannel = 1:numel(labelz) % searchlight across voxels 

            these_channels  = cat(1,SL_neighbours(ichannel).label,SL_neighbours(ichannel).neighblabel);
            chansel         = ismember(labelz,these_channels);

            [perf, ~]  = mv_classify_timextime(cfg,test_train_dat(:,chansel,:),classcode_train,test_test_dat(:,chansel,:), classcode_test);
            perfmat(ichannel,:,:) = perf;
        end

        data_SL{iSub}.perfmat = perfmat;

    end   
  end
  
  if any (do_Fig2d)

    load ('/Volumes/Seagate Portable Drive/HD_publication/sharing/data/auxiliary/ga_TFR.mat')   % load tfr group data
    load ('/Volumes/Seagate Portable Drive/HD_publication/sharing/data/auxiliary/stat_TFR.mat') % load tfr stats 

    load ('/Volumes/Seagate Portable Drive/HD_publication/sharing/data/auxiliary/stat_decoding.mat')   % load decoding group data
    load ('/Volumes/Seagate Portable Drive/HD_publication/sharing/data/auxiliary/accmat_decoding.mat') % load decoding stats

    mask              = Fieldtripstat_TFR.mask  ;
    ga_red            = ga_exp.powspctrm;
    brain_TFR         = mean(ga_red(:,mask==1),2); % average across sign. bits

    accmat2 = accmat(:,1:151,1:201); % adjust time
    brain_decod = mean(accmat2(:,Fieldtripstats.posclusterslabelmat==1),2); % average across sign. bits
    brain_decod = brain_decod+0.5;

    type = 'Spearman';
    [i,p] = corr (brain_TFR, brain_decod , 'type', type);

    % plot    
    color = [0,0,0];
    h     = figure ('Color', [1 1 1]);
    s1    = plot(brain_TFR,brain_decod, 'o');
    set(s1, 'MarkerSize', 10, 'color',[0,0,0],'LineWidth', 2, 'MarkerFaceColor',[1,0,0]);

    scatter(brain_TFR,brain_decod,[],color, 'o', 'filled','SizeData',300);

    % add regression line
    hold on
    l = lsline ;
    set(l,'LineWidth', 1,'color',[0,0,0])

    X         = l.XData;
    Y         = (l.YData) ;
    slope     = diff(Y) ./ diff(X) ;
    intercept = Y(1) - X(1) * slope;

    close (h)
    h2 = figure ('Color', [1 1 1]);

    m  = slope; b = intercept; x = 0.01:0.01:0.21;
    h3 = plot(x, m*x+b, 'LineWidth',6, 'color',color);
    ylim([0.465 0.58])
    xlim([0 .22])%-30 15
    h3.Color(4) = 0.9;
    hold on

    scatter(brain_TFR,brain_decod,[],color, 'o', 'filled','SizeData',300);

    alpha(.75)
    xticks([0.05 0.1 0.15 0.2])
    xticklabels({'0.05' '0.1' '0.15' '0.2'})

    yticks([0.48 0.5 0.52 0.54 0.56 0.58])
    yticklabels({'0.48' '0.5' '0.52' '0.54' '0.56' '0.58'})

    xlabel('TFR power', 'FontSize', 20)
    ylabel('mean react', 'FontSize', 20)

    set(gca, 'FontSize', 20, 'Box','off')
    set(gcf,'Position',[100 100 500 400])

    set(gca,'TickDir','in')
    set(gca,'TickLength',[0.01, 0.04])
    ax = gca;
    set(gca,'linewidth',2)    
       
  end
  
  
end




