% Schreiner et al. 2024: Figure_4;

%% Ripple ERP

subjects     = {'Pmtl02'};


for iSub=1:numel(subjects)
            
    subID      = subjects{iSub};   
                
    cd(fullfile(datadir, [subjects{iSub}], sleep_folder))    
    
    load ([subjects{iSub} TMR_ripples]);
              
    cfg              = [];
    cfg.bsfilter     = 'yes';
    cfg.bsfreq       = [48 52];
    cfg.lpfilter     = 'yes';
    cfg.lpfreq       = 200;
   
    cfg.removemean   = 'no';        
    eeg              = ft_preprocessing(cfg, eeg);

    cfg              = [];
    cfg.channel      = {'Hipp*', 'Ento*', 'Para*'};
    cfg.latency      = [-.5 .5];
    eeg              = ft_selectdata(cfg, eeg);
   
    cfg              = [];
    erp              = ft_timelockanalysis(cfg, eeg);    
    rip{1,iSub}      = erp.avg;
    
end

    all_rip = vertcat(rip{:});
    mean_alli = mean(all_rip);
    SEM  = std (all_rip,1)/sqrt(size(all_rip,2));
 
    color_ = [0 0 0]; 
    hFig2 = figure();
    [hl,~] = boundedline(eeg.time,mean_alli, SEM, 'cmap',[0 0 0], 'alpha','transparency', 0.15);
    set(hl, 'linewidth', 1.25);
     
    yline (0,'--','LineWidth',1)
    ylim([-10 10])

    
%%  Power spectral density (PSD) across all detected SWRs

subjects     = {'Pmtl02'};


for iSub=1:numel(subjects)
            
    subID      = subjects{iSub};   
                
    cd(fullfile(datadir, [subjects{iSub}], sleep_folder))    
    
    load ([subjects{iSub} TMR_ripples]);
              
    cfg                    = [];
    cfg.bsfilter           = 'yes';
    cfg.bsfreq             = [48 52; 98 102; 148 152; 198 202];
    cfg.hpfilter           = 'yes';
    cfg.hpfreq             = .1;
    cfg.hpinstabilityfix   = 'reduce';
    cfg.lpfilter           = 'yes';
    cfg.lpfreq             = 300;
   
    cfg.removemean         = 'no';        
    eeg                    = ft_preprocessing(cfg, eeg);

    cfg              = [];
    cfg.channel      = {'Hipp*', 'Ento*', 'Para*'};
    cfg.latency      = [-.5 1.5];
    eeg              = ft_selectdata(cfg, eeg);
   
    %remove 1/f
    cfg            = [];
    cfg.derivative ='yes';
    eeg            =  ft_preprocessing(cfg, eeg);

    for ich = 1:size(eeg.label,1)
    
        cfg = [];
        cfg.output      = 'pow';          
        cfg.method      = 'mtmfft';
        cfg.taper       = 'hanning';      
        cfg.foilim      = [1 200];            
        psd_hann = ft_freqanalysis(cfg, eeg);

        all_chan(ich,:) =  psd_hann.powspctrm;
    end

    all_ripple_segm{iSub} = all_chan;
    
end


alli = vertcat(all_ripple_segm{:});
mm   = mean(alli);
plot(psd_hann.freq,mm)
SEM = std(alli)/sqrt(length(alli));

 
[hl,hp] = boundedline(freq,mm, SEM, 'cmap',[0,0,0], 'alpha','transparency', 0.15);
set(hl, 'linewidth', 1.25);
axis('tight')


%% TFR 

for iSub  = 1:length(subjects)
    
    subID      = subjects{iSub};
    
    % load data
    cd(fullfile(datadir, [subjects{iSub}], TMR_folder));
    load ([subjects{iSub} sleep_EEG]);
                
    trs_exp        = [];
    trs_control    = [];
    
    for tr = 1 : size(data_trl.trial,2)
        if (data_trl.trialinfo{tr,1}.Display ~= 999)...        
    trs_exp = cat(1,trs_exp,tr); 
    
        elseif (data_trl.trialinfo{tr,1}.Display ==999)...
    trs_control = cat(1,trs_control,tr);
    
        end
    end
        
    cfg      = [];
    cfg.trials = trs_exp;
    dat_exp  = ft_selectdata(cfg, data_trl);

    cfg       = [];
    cfg.trials= trs_control;
    dat_cont  = ft_selectdata(cfg, data_trl);
    
    cfg              = [];
    cfg.keeptrials   = 'yes';
    cfg.output       = 'pow';
    cfg.method       = 'mtmconvol';
    cfg.taper        = 'hanning';
    cfg.foi          = [2:1:40 44:4:100 108:8:200];    
    cfg.t_ftimwin    = 5./cfg.foi; 
    cfg.pad          = 16;
    cfg.toi          = -1:0.05:4;     
    tfr_all          = ft_freqanalysis(cfg, data_trl);
    
    % zscore all trials
    startT  = nearest(tfr_all.time,-1);
    endT    = nearest(tfr_all.time, 4);
    
        pow  = tfr_all.powspctrm;
        powz = nan(size(pow));
        
        for ichan = 1:size(pow,2)
            for ifreq = 1:size(pow,3)
                d = squeeze(pow(:,ichan,ifreq,startT:endT));
                m = nanmean(d(:));
                s = nanstd(d(:));
                
                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
            end
        end
        
        pow                 = [];
        tfr_all.powspctrm   = powz;                   
        clear powz;  
                   
    % segment data according to conditions
    cfg = []; cfg.trials = trs_control;tfr_control = ft_selectdata(cfg,tfr_all);
    cfg = []; cfg.trials = trs_exp;tfr_exp         = ft_selectdata(cfg,tfr_all);    
        
    % average over single trials
    tfr_control.dimord    = 'chan_freq_time';
    tfr_exp.dimord        = 'chan_freq_time';
 
    tfr_control.powspctrm = squeeze(nanmean(tfr_control.powspctrm,1));
    tfr_exp.powspctrm     = squeeze(nanmean(tfr_exp.powspctrm,1));

    All_tfr_exp{iSub}         = tfr_exp;
    All_tfr_control{iSub}     = tfr_control;
      
end

%%

for ii = 1:size(All_tfr,2)
    
    cfg               = [];
    cfg.channel       = {'Front*','Pari*', 'Temp*'};
    tfr_exp           = ft_selectdata(cfg, All_tfr_exp{ii});
    tfr_cont          = ft_selectdata(cfg, All_tfr_control{ii});
    
    exp_tfr{ii} = tfr_exp.powspctrm;
    con_tfr{ii} = tfr_cont.powspctrm;
    
end

% Remove empty cells
exp_tfr = exp_tfr_front(~cellfun('isempty',exp_tfr));
con_tfr = con_tfr_front(~cellfun('isempty',con_tfr));

for ii = 1:size(all_tfr,2)
    
    all_tfr{ii} = permute(all_tfr{1, ii}  ,[2 1 3]);
    exp_tfr{ii} = permute(exp_tfr{1, ii}  ,[2 1 3]);
    con_tfr{ii} = permute(con_tfr{1, ii}  ,[2 1 3]);
               
end

% append contacts
exp_tfr     = horzcat(exp_tfr{:});
exp_tfr     = permute(exp_tfr,[2 1 3]);

con_tfr     = horzcat(con_tfr{:});
con_tfr     = permute(con_tfr,[2 1 3]);

% build template
cfg         = [];
cfg.channel = All_tfr_all{1, 1}.label(1);
temp        = ft_selectdata(cfg, All_tfr_all{1, 1});

field = 'trialinfo';
temp  = rmfield(temp,field);

for ii =  1:size(exp_tfr,1)
    temp.label     = {'frontal'};
    temp.powspctrm = exp_tfr(ii,:,:); 
    exp_tfr{ii}    = temp;
end

for ii =  1:size(con_tfr,1)
    temp.label     = {'frontal'};
    temp.powspctrm = con_tfr(ii,:,:); 
    con_tfr{ii}    = temp;
end

% do grand average
cfg    = [];
cfg.keepindividual  = 'yes';
ga_exp = ft_freqgrandaverage(cfg, exp_tfr{:});
ga_con = ft_freqgrandaverage(cfg, con_tfr{:});
    
% stats
cfg                     = [];
cfg.spmversion          = 'spm12';
cfg.latency             = [-.5 1.5];
cfg.frequency           = [3 25];
cfg.statistic           = 'depsamplesT';
cfg.method              = 'montecarlo';
cfg.correctm            = 'cluster';
cfg.alpha               = .05;
cfg.clusteralpha        = .05;
cfg.tail                = 0;
cfg.correcttail         = 'alpha';
cfg.neighbours          = [];
cfg.minnbchan           = 0;
cfg.computecritval      = 'yes';
cfg.numrandomization    = 1000;

cfg.clusterstatistic    = 'maxsum';
cfg.clustertail         = cfg.tail;
cfg.parameter           = 'powspctrm';

design = zeros(2,2*(size(ga_exp.powspctrm,1)));
design(1,:)= [1:size(ga_exp.powspctrm,1), 1:size(ga_exp.powspctrm,1)];
design(2,:)= [ones(1,size(ga_exp.powspctrm,1)), 2*ones(1,size(ga_exp.powspctrm,1))];

cfg.design  = design;
cfg.uvar    = 1;
cfg.ivar    = 2;

[Fieldtripstats] = ft_freqstatistics(cfg, ga_exp, ga_con);

% plot
stat.stat = Fieldtripstats.stat.*Fieldtripstats.mask; % mask data
t_val     = squeeze((stat.stat));

figure;
pcolor(Fieldtripstats.time,Fieldtripstats.freq, t_val); 
shading interp;
axis xy; 
ll = xlabel('time [seconds]', 'Fontsize', 17);
yy = ylabel('Frequency [Hz]', 'Fontsize', 17);
set(gca,'layer','top')
set(gca,'TickDir','in', 'Linewidth',2);


%% Classification during TMR

% define whether to run classification on high or low SO-spindle activity trials
do_highpow         = 1;
do_lowpow          = 0;


datadir    = fullfile(mother,'/wake/');
TMRdir     = fullfile(mother,'/sleep/');

sleep_EEG  = '_TMR';     % TMR data
wake_EEG   = '_ret.mat'; % retrieval data (pre- and post-sleep)

high_trls  = '_trls_TFR_highPow'; % trial indices for high SO-spindle activity trials 
low_trls   = '_trls_TFR_lowPow';  % trial indices for low SO-spindle activity trials

subjects   = {'Pmtl02'};

% time of interest
TOI_wake              = [-0.5 1] ;
TOI_sleep             = [-0.5 1.5];

classi_timeaxis_train = TOI_wake(1):0.01:TOI_wake(2);
classi_timeaxis_test  = TOI_sleep(1):0.01:TOI_sleep(2);
acc                   = cell(1,numel(subjects));

% subject loop

for iSub=1:numel(subjects)
            
    subID      = subjects{iSub};
 
    % load data
    cd(fullfile(datadir, [subjects{iSub}]))    
    load ([subjects{iSub} wake_EEG]);
       
    cd(fullfile(TMRdir, [subjects{iSub}]))  
    load ([subjects{iSub} sleep_EEG]);
            
    % check which head directions were cued
    for ii = 1:20
        hd_react(ii) = eeg_TMR.trialinfo{ii,1}.Display;  
    end
    
    hd_react = unique(hd_react);
    hd_react = sort(hd_react);
    hd_react = (hd_react(1:2));
    clear data_trl
            
    % resample      
    cfg            = [];
    cfg.resamplefs = 200;   
    eeg            = ft_resampledata(cfg, eeg);   
    
    cfg         = [];
    cfg.latency = [-1 4];
    eeg         = ft_selectdata(cfg, eeg); 
       
    % add to trialinfo of retrieval whether head-direction was cued before   
    for jj = 1 :size (eeg.trialinfo,1)        
        if ismember(eeg.trialinfo{jj, 1}.head_angle, hd_react) ==1
           eeg.trialinfo{jj, 1}.React = 1;
        else
           eeg.trialinfo{jj, 1}.React = 0;
        end
    end
          
    % bring data into better format and normalize (z-score)
    cfg                    = [];
    cfg.hpfilter           = 'yes';
    cfg.hpfreq             = 0.1;
    cfg.lpfilter           = 'yes';
    cfg.lpfreq             = 40;
    cfg.hpinstabilityfix   = 'reduce';    
    cfg.bsfilter           = 'yes';
    cfg.bsfreq             = [48 52];
    cfg.demean             = 'yes';           
    cfg.removemean         = 'no';        
    data_trl_ret1          = ft_preprocessing(cfg, eeg);

    cfg                    = [];
    cfg.keeptrials         = 'yes';
    cfg.removemean         = 'no';     
    data_trl_ret1          = ft_timelockanalysis(cfg,data_trl_ret1);
    dat_trial_ret1         = data_trl_ret1.trial;
    
    % z-score
    preprocess_param           = mv_get_preprocess_param('zscore');
    [~, dat_trial_norm_ret1 ]  = mv_preprocess_zscore(preprocess_param, dat_trial_ret1 );        
    data_trl_ret1.trial        = dat_trial_norm_ret1;  
    
    % reduce data to experimental conditions (left + right react)
    trs_exp_ret1 = [];   
    for tr = 1 : size(data_trl_ret1.trial,1)
        if (data_trl_ret1.trialinfo{tr,1}.React ~= 0)
    trs_exp_ret1 = cat(1,trs_exp_ret1,tr);  
        end
    end 
          
    % allocate data to conditions
    cfg             = []; 
    cfg.trials      = trs_exp_ret1; 
    data_trl_ret1   = ft_redefinetrial(cfg,data_trl_ret1);     
            
    % bring into format
    cfg       = [];
    data_wake = ft_appenddata(cfg,data_trl_ret1); 
        
    % do the same for sleep data               
    cfg            = [];
    cfg.resamplefs = 200;   
    eeg_TMR        = ft_resampledata(cfg, eeg_TMR);    
    
    % bring data into better format and normalize (z-score)
    cfg                    = [];
    cfg.bsfilter           = 'no';
    cfg.bsfreq             = [48 52];
    cfg.hpfilter           = 'yes';
    cfg.hpfreq             = 0.1;
    cfg.hpinstabilityfix   = 'reduce';
    cfg.lpfilter           = 'yes';
    cfg.lpfreq             = 40;
    cfg.removemean         = 'no';        
    eeg_TMR                = ft_preprocessing(cfg, eeg_TMR);

    cfg                    = [];
    cfg.keeptrials         = 'yes';
    cfg.removemean         = 'no';     
    eeg_TMR                = ft_timelockanalysis(cfg,eeg_TMR);
    dat_sleep              = eeg_TMR.trial;

    % z-score
    preprocess_param       = mv_get_preprocess_param('zscore');
    [~, dat_trial_norm ]   = mv_preprocess_zscore(preprocess_param, dat_sleep);        
    eeg_TMR.trial          = dat_trial_norm;

    % reduce data to experimental conditions (left & right)
    trs_exp       = [];   
    for tr = 1 : size(eeg_TMR.trial,1)
        if (eeg_TMR.trialinfo{tr,1}.Display ~= 999)
    trs_exp = cat(1,trs_exp,tr);  
        end
    end 
    
    % allocate data to conditions
    cfg             = []; 
    cfg.trials      = trs_exp; 
    eeg_TMR      = ft_redefinetrial(cfg,eeg_TMR);  
    
    % go for low or high power trials (SO-spindle range) 
    if any (do_highpow)       
       load ([subjects{iSub} high_trls]);
       cfg = [];
       cfg.trials = ind_trls_high;
       eeg_TMR = ft_selectdata(cfg, eeg_TMR);       
    end
       
    if any (do_lowpow)       
       load ([subjects{iSub} low_trls]);
       cfg = [];
       cfg.trials = ind_trls_low;
       eeg_TMR = ft_selectdata(cfg, eeg_TMR);       
    end
        
    % prepare and append wake & sleep                
    cfg         = [];
    eeg_TMR     = ft_appenddata (cfg, eeg_TMR);

    cfg         = [];
    cfg.channel = eeg_TMR.label;
    data_wake   = ft_selectdata(cfg, data_wake);
    
    cfg         = [];
    cfg.channel = data_wake.label;
    eeg_TMR     = ft_selectdata(cfg, eeg_TMR);
    
    cfg         = [];
    data_app    = ft_appenddata (cfg, data_wake, eeg_TMR);   
            
    % run pca on appended data and reduce ranks
    cfg              = [];
    cfg.method       = 'pca';
    cfg.updatesens   = 'no';
    cfg.numcomponent = 30; 
    comp             = ft_componentanalysis(cfg, data_app);

    cfg              = [];
    cfg.updatesens   = 'no';
    cfg.component    = comp.label(31:end);
    data_app         = ft_rejectcomponent(cfg, comp);       

    % identify wake data
    cfg              = [];
    cfg.trials       = 1:numel(data_wake.trialinfo);
    data_all_wake    = ft_selectdata(cfg,data_app);

    % identify sleep part
    cfg              = [];
    cfg.trials       = numel(data_wake.trialinfo)+1:numel(data_app.trialinfo);
    data_all_sleep   = ft_selectdata(cfg,data_app);
              
    % temporal smoothing (running average)        
    for itrial = 1:numel(data_all_wake.trial)
        data_all_wake.trial{itrial} = smoothdata(data_all_wake.trial{itrial},2,'movmean',0.15/(1/data_all_wake.fsample));
    end

    for itrial = 1:numel(data_all_sleep.trial)
        data_all_sleep.trial{itrial} = smoothdata(data_all_sleep.trial{itrial},2,'movmean',0.15/(1/data_all_sleep.fsample));
    end
    
    % prepare data for classification
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no'; 
    tmp             = ft_timelockanalysis(cfg,data_all_wake);
    dat_wake        = tmp.trial;
    tmp             = [];
    
    % class selection (2 classes)
    trlinfo  = cell2mat(data_all_wake.trialinfo);    
    category = [trlinfo.head_angle];       
    behav    = [trlinfo.headang_acc];
    
    % Training (define classes)
    left        = [1,2];
    right       = [3,4];
    correct     = 1;
    
    train_sel1  = ismember(category,left) & ismember(behav,correct);
    train_sel2  = ismember(category,right)& ismember(behav,correct);

    % Training Labels
    train_dat        = cat(1,dat_wake(train_sel1,:,:),dat_wake(train_sel2,:,:));
    classcode_train  = cat(1,1*ones(sum(train_sel1),1),2*ones(sum(train_sel2),1));
        
    trialinfo     = [];
    category      = [];
    
    % now the same for sleep data    
    % bring data into better format
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no';
    tmp             = ft_timelockanalysis(cfg,data_all_sleep);
    dat_sleep       = tmp.trial;
    tmp             = [];
    
    % class selection (2 classes)    
    trlinfo  = cell2mat(data_all_sleep.trialinfo);    
    category = [trlinfo.Display];       
 
    % Testing data (define classes))
    left       = [1,2];
    right      = [3,4];
    
    test_sel1  = ismember(category,left); 
    test_sel2  = ismember(category,right);  
   
    % Test Labels
    
    test_dat       = cat(1,dat_sleep(test_sel1,:,:),dat_sleep(test_sel2,:,:));
    classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));
       
    trialinfo = [];
    category  = [];
    
    % time-info
    timeaxis_samples_train = nearest(data_all_wake.time{1},TOI_wake(1)):round(0.01*data_all_wake.fsample):nearest(data_all_wake.time{1},TOI_wake(2));
    timeaxis_samples_test  = nearest(data_all_sleep.time{1},TOI_sleep(1)):round(0.01*data_all_sleep.fsample):nearest(data_all_sleep.time{1},TOI_sleep(2));
        
    test_train_dat         = train_dat(:,:,timeaxis_samples_train);
    test_test_dat          = test_dat(:,:,timeaxis_samples_test);
    
    % classification
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'auc';

    [perf, result]  = mv_classify_timextime(cfg, test_train_dat, classcode_train, test_test_dat, classcode_test);        
    acc{iSub}       = perf;
    
end

    % plot it

    %define time
    start_test = find (classi_timeaxis_test == cfg.latency(1));
    fin_test   = find (classi_timeaxis_test == cfg.latency(2));

    start_train = find (classi_timeaxis_test == cfg.frequency(1));
    fin_train   = find (classi_timeaxis_test == cfg.frequency(2));

    %prepare contour plot
    stats_time = nearest(cstat{1}.time,cfg.latency(1)):nearest(cstat{1}.time,cfg.latency(2));
    stats_freq = nearest(cstat{1}.freq,cfg.frequency(1)):nearest(cstat{1}.freq,cfg.frequency(2));

    sigmap    = zeros(numel(Fieldtripstats.freq),numel(Fieldtripstats.time));
    sigval    = Fieldtripstats.mask;
    sigmap(stats_freq,stats_time) = sigval;

    figure;
    imagesc(classi_timeaxis_test(start_test:fin_test),classi_timeaxis_train(start_train:fin_train),squeeze(Fieldtripstats.stat));

    colorbar
    axis xy
    xlabel('time TMR (sec)','FontSize', 16)
    ylabel('time retrieval (sec)','FontSize', 16)
    set(gca,'tickdir','out')
    caxis([-3 3])

    h = colorbar();
    ylabel(h, 't-values','FontSize', 16);
    set(gca,'TickDir','in');
    set(gca,'linewidth',1.5)

    a = get(gca,'XTickLabel');
    set(gca,'XTickLabel',a, 'fontsize',16)

    a = get(gca,'YTickLabel');    
    set(gca,'YTickLabel',a,'fontsize',16)

    plot_contour(classi_timeaxis_test(start_test:fin_test),classi_timeaxis_train(start_train:fin_train),sigmap,'k-',2)
    xline(0, '--', 'linewidth',1.5)
    yline(0, '--', 'linewidth',1.5)


