% Schreiner et al. 2024: Figure_3;

%% Classification during retrieval

datadir     = fullfile(mother,'/wake/');
TMRdir      = fullfile(mother,'/sleep/');

subjects    = {'Pmtl02'};
sleep_EEG   = '_TMR';       
wake_EEG    = '_ret.mat'; % retrieval data (pre- and post-sleep) 

% time settings
TOI             = [-.5 1];
acc             = cell(1,numel(subjects));
classi_timeaxis = TOI(1):0.01:TOI(2);

subjects  = {'Pmtl02'};


for iSub  = 1:length(subjects)
    
    subID      = subjects{iSub};   
        
    % read sleep data for head orientations
    cd(fullfile(TMRdir, [subjects{iSub}]))  
    load ([subjects{iSub} sleep_EEG]);
    
    % check which head directions were cued
    for ii = 1:6
        hd_react(ii) = eeg_TMR.trialinfo{ii,1}.Display;  
    end
    
    hd_react = unique(hd_react);
    hd_react = sort(hd_react);
    hd_react = (hd_react(1:2));
    clear data_clean     
    
    % load retrieval data 
    cd(fullfile(datadir, [subjects{iSub}]))    
    load ([subjects{iSub} wake_EEG]);
    
    % downsample   
    cfg            = [];
    cfg.resamplefs = 200;
    cfg.demean     = 'yes';
    data_trl_ret   = ft_resampledata(cfg, eeg);   
     
    % add to trialinfo whether head-direction was cued before (ret1)   
    for jj = 1 :size (data_trl_ret.trialinfo,1)        
        if ismember(data_trl_ret.trialinfo{jj, 1}.head_angle, hd_react) ==1
           data_trl_ret.trialinfo{jj, 1}.React = 0;
        else
           data_trl_ret.trialinfo{jj, 1}.React = 1;
        end
    end    
        
    % bring data into better format and normalize (z-score)
    cfg                    = [];
    cfg.bsfilter           = 'yes';
    cfg.bsfreq             = [48 52];
    cfg.demean             = 'yes';           
    cfg.removemean         = 'no';      
    cfg.hpfilter           = 'yes';
    cfg.hpfreq             = 0.1; 
    cfg.hpinstabilityfix   = 'reduce'; 
    cfg.lpfilter           = 'yes';
    cfg.lpfreq             = 40;
    data_trl_ret           = ft_preprocessing(cfg, data_trl_ret);

    cfg                    = [];
    cfg.keeptrials         = 'yes';
    cfg.removemean         = 'no';     
    data_trl_ret           = ft_timelockanalysis(cfg,data_trl_ret);
    dat_trial_ret          = data_trl_ret.trial;
    
    % z-score
    preprocess_param        = mv_get_preprocess_param('zscore');
    [~, dat_trial_norm_ret] = mv_preprocess_zscore(preprocess_param, dat_trial_ret);        
    data_trl_ret.trial      = dat_trial_norm_ret;    
    
    % reduce data to experimental conditions (left + right cued) 
    trs_exp_ret    = [];   
    for tr = 1 : size(data_trl_ret.trial,1)
        if (data_trl_ret.trialinfo{tr,1}.React ~= 0)
    trs_exp_ret = cat(1,trs_exp_ret,tr);  
        end
    end 
    
    % allocate data to conditions
    cfg             = []; 
    cfg.trials      = trs_exp_ret; 
    data_trl_ret    = ft_redefinetrial(cfg,data_trl_ret);      
    
    cfg             = []; 
    data_trl        = ft_appenddata(cfg,data_trl_ret);     

    % run pca and reduce ranks
    cfg              = [];
    cfg.method       = 'pca';
    cfg.updatesens   = 'no';
    cfg.numcomponent = 30;
    comp             = ft_componentanalysis(cfg, data_trl);

    cfg              = [];
    cfg.updatesens   = 'no';
    cfg.component    = comp.label(31:end);
    data_trl         = ft_rejectcomponent(cfg, comp);
    
    % temporal smoothing (running average)        
    for itrial = 1:numel(data_trl.trial)
        data_trl.trial{itrial} = smoothdata(data_trl.trial{itrial},2,'movmean',0.15/(1/data_trl.fsample)); %
    end

    % bring data into better format
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no';     
    tmp             = ft_timelockanalysis(cfg,data_trl);
    dat             = tmp.trial;
    tmp             = [];

    % class selection (2 classes)    
    trlinfo  = cell2mat(data_trl.trialinfo);    
    category = [trlinfo.head_angle];       
    behav    = [trlinfo.headang_acc];
    
    left     = [1,2];
    right    = [3,4];
    correct  = 1;
    
    sel1  = ismember(category,left) & ismember(behav,correct);
    sel2  = ismember(category,right)& ismember(behav,correct);

    dat4class     = cat(1,dat(sel1,:,:),dat(sel2,:,:));
    classcode     = cat(1,1*ones(sum(sel1),1),2*ones(sum(sel2),1));
    samples       = nearest(data_trl.time{1},TOI(1)):0.01*data_trl.fsample:nearest(data_trl.time{1},TOI(2));
    dat4class     = dat4class(:,:,samples);    
    
    % classification
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'auc';
    cfg.repeat      = 15;
    cfg.k           = 5;    
    [perf, result]  = mv_classify_across_time(cfg,dat4class,classcode);
    acc{iSub}       = perf;   
    
end


%% Ripple ERP

subjects     = {'Pmtl02'};


for iSub=1:numel(subjects)
            
    subID      = subjects{iSub};   
                
    cd(fullfile(datadir, [subjects{iSub}], wake_folder))    
    
    load ([subjects{iSub} wake_ripples]);
              
    cfg              = [];
    cfg.bsfilter     = 'yes';
    cfg.bsfreq       = [48 52];
    cfg.lpfilter     = 'yes';
    cfg.lpfreq       = 200;
   
    cfg.removemean   = 'no';        
    eeg              = ft_preprocessing(cfg, eeg);

    cfg              = [];
    cfg.channel      = {'Hipp*', 'Ento*', 'Para*'};
    cfg.latency      = [-.5 .5];
    eeg              = ft_selectdata(cfg, eeg);
   
    cfg              = [];
    erp              = ft_timelockanalysis(cfg, eeg);    
    rip{1,iSub}      = erp.avg;
    
end

    all_rip = vertcat(rip{:});
    mean_alli = mean(all_rip);
    SEM  = std (all_rip,1)/sqrt(size(all_rip,2));
 
    color_ = [0 0 0]; 
    hFig2 = figure();
    [hl,hp] = boundedline(eeg.time,mean_alli, SEM, 'cmap',[0 0 0], 'alpha','transparency', 0.15);
    set(hl, 'linewidth', 1.25);
     
    yline (0,'--','LineWidth',1)
    ylim([-10 10])

    
%% Ripple TFR

subjects     = {'Pmtl02'};


for iSub=1:numel(subjects)
            
    subID      = subjects{iSub};   
                
    cd(fullfile(datadir, [subjects{iSub}], wake_folder))    
    
    load ([subjects{iSub} wake_ripples]);
              
    cfg              = [];
    cfg.bsfilter     = 'yes';
    cfg.bsfreq       = [48 52];
    cfg.lpfilter     = 'yes';
    cfg.lpfreq       = 200;
   
    cfg.removemean   = 'no';        
    eeg              = ft_preprocessing(cfg, eeg);

    cfg              = [];
    cfg.channel      = {'Hipp*', 'Ento*', 'Para*'};
    cfg.latency      = [-1 1];
    eeg              = ft_selectdata(cfg, eeg);
   

    for irip = 1:size(eeg.label,1)

        cfg              = [];
        cfg.keeptrials   = 'yes';
        cfg.output       = 'pow';
        cfg.channel      = eeg.label(irip);
        cfg.method       = 'mtmconvol';
        cfg.taper        = 'hanning'; 
        cfg.foi          = [1:1:200];    
        cfg.tapsmofrq    = 0.5 *cfg.foi; 

        cfg.t_ftimwin    = 5./cfg.foi; 
        cfg.pad          = 16;
        cfg.toi          = -1:0.01:1;  
        tfr_all_rip      = ft_freqanalysis(cfg, eeg);

        startT  = nearest(tfr_all_rip.time,-1);
        endT    = nearest(tfr_all_rip.time, 1);
        
            pow  = tfr_all_rip.powspctrm;
            powz = nan(size(pow));
            
            for ichan = 1:size(pow,2)
                for ifreq = 1:size(pow,3)
                    d = squeeze(pow(:,ichan,ifreq,startT:endT));
                    m = nanmean(d(:));
                    s = nanstd(d(:));
                    
                    powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
                end
            end
            
        pow                    = [];
        tfr_all_rip.powspctrm  = powz;                   
        clear powz;  

        tfr_all_rip.dimord    = 'chan_freq_time';
        temp(1,:,:)           = squeeze(nanmean(tfr_all_rip.powspctrm,1));
        tfr_all_rip.powspctrm = temp;
        
        cfg                   = [];
        cfg.latency           = [-.25 .25];  % time range  
        tfr_all_rip           = ft_selectdata(cfg, tfr_all_rip);
        tfr_all_rip.label     = {'ripple'};

        tfr_all_ripsi{irip}   =  tfr_all_rip;
             
    end

    all_rip{iSub}   = tfr_all_ripsi;

end

    alli_r = []; 
    alli_r = horzcat(all_rip{:});

    cfg = [];
    cfg.toilim   = [-.1 .1];
    ga  = ft_freqgrandaverage(cfg, alli_r{:});

    h = pcolor(ga.time,ga.freq,squeeze(ga.powspctrm));
    shading interp ;%
    lighting phong;
    axis xy; 
    set(h, 'EdgeColor', 'none');
    caxis([0 2])
    colorbar

    set(gcf, 'Position', [100, 400, 570, 200]);
