% Bullon Tarraso et al. 2025
% Code for figure 2d

ft_defaults
dirRoot = 'data';

subjects    = {'P02' 'P05' 'P07' 'P09' 'P11' 'P12' 'P13' 'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

load('auxiliary/Bham-64CH-Neighbours.mat')
load('auxiliary\Bham-64CH-Lay.mat')

locking = "cue";
 
% Config
remove_artifacts = false;
use_hilbert = false;
nbin = 60;
numPerms = 1000;
freqRange = [8, 20];
avg_window = 1;

%% RUN

for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));
          
    %% Load data and split
    % Load rem and nrem TFR        
    % Load segmentated phase retrieval data
    data_phase_obj = dataloader.load_segm_phase(dirRoot, dirSubject, 1, use_hilbert, remove_artifacts);
    data_phase_sce = dataloader.load_segm_phase(dirRoot, dirSubject, 2, use_hilbert, remove_artifacts);

    TFR_obj = load(fullfile(dirSubject(1).folder, dirSubject(1).name, "TFR_all")).TFR_all;
    TFR_rem_obj = load(fullfile(dirSubject(1).folder, dirSubject(1).name, "TFR_rem")).TFR_rem;
    TFR_nrem_obj = load(fullfile(dirSubject(1).folder, dirSubject(1).name, "TFR_nrem")).TFR_nrem;

    TFR_sce = load(fullfile(dirSubject(2).folder, dirSubject(2).name, "TFR_all")).TFR_all;
    TFR_rem_sce = load(fullfile(dirSubject(2).folder, dirSubject(2).name, "TFR_rem")).TFR_rem;
    TFR_nrem_sce = load(fullfile(dirSubject(2).folder, dirSubject(2).name, "TFR_nrem")).TFR_nrem;     
    
    cfg = [];
    TFR_all = ft_appendfreq(cfg, TFR_obj, TFR_sce);
    TFR_rem = ft_appendfreq(cfg, TFR_rem_obj, TFR_rem_sce);
    TFR_nrem = ft_appendfreq(cfg, TFR_nrem_obj, TFR_nrem_sce);

    % Concatenate phase data
    trialinfo_all          = [data_phase_obj.trialinfo; data_phase_sce.trialinfo];
    cfg                    = [];
    data_all               = ft_appenddata(cfg, data_phase_obj, data_phase_sce);
    data_all.fsample       = 100;
    data_all.trialinfo     = trialinfo_all; 
    
    % class selection (2 classes)
    trlinfo  = cell2mat(data_all.trialinfo);   
    accuracy = [trlinfo.ExemplarAccuracy];   % 1 [correct], 0 [wrong] 
    oldnew   = [trlinfo.OldNew];
    
    cfg = [];
    cfg.trials = accuracy==1 & oldnew == 1; 
    rem_phases      = ft_selectdata(cfg, data_all);  
    cfg.trials = accuracy==0 & oldnew == 1; 
    nrem_phases      = ft_selectdata(cfg, data_all);
    
    %% Prepare power data
    % Normalize all TFRs
    cfg             = [];
    cfg.latency     = [-1 0]; % Use one second before cue onset
    
    TFR_norm        = ft_selectdata(cfg,TFR_all);
    pow_norm = TFR_norm.powspctrm;  

    pow  = TFR_all.powspctrm;
    powz = nan(size(pow));

    % zscore per channel and frequency separately
    for ichan = 1:size(pow,2)
        for ifreq = 1:size(pow,3)
            %for itime =1:size(pow,4)
                d = squeeze(pow_norm(:,ichan,ifreq,:));
                m = nanmean(d(:));
                s = nanstd(d(:));

                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
           %end
        end
    end
    
    TFR_all.powspctrm = powz;
    
    % Normalize remembered trials
    TFR_norm        = ft_selectdata(cfg,TFR_rem);
    pow_norm = TFR_norm.powspctrm;  

    pow  = TFR_rem.powspctrm;
    powz = nan(size(pow));

    for ichan = 1:size(pow,2)
        for ifreq = 1:size(pow,3)
            %for itime =1:size(pow,4)
                d = squeeze(pow_norm(:,ichan,ifreq,:));
                m = nanmean(d(:));
                s = nanstd(d(:));

                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
           %end
        end
    end

    TFR_rem.powspctrm = powz;
        
    % Normalize not remembered trials
    TFR_norm        = ft_selectdata(cfg,TFR_nrem);
    pow_norm = TFR_norm.powspctrm;  

    pow  = TFR_nrem.powspctrm;
    powz = nan(size(pow));

    for ichan = 1:size(pow,2)
        for ifreq = 1:size(pow,3)
            %for itime =1:size(pow,4)
                d = squeeze(pow_norm(:,ichan,ifreq,:));
                m = nanmean(d(:));
                s = nanstd(d(:));

                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
            %end
        end
    end

    TFR_nrem.powspctrm = powz;
    
    %% Average over frequencies
    cfg = [];
    cfg.frequency   = freqRange;
    % cfg.avgoverchan = 'yes';
    cfg.avgoverfreq = 'yes';
    cfg.nanmean = 'yes';
    pow_subject_all = ft_selectdata(cfg, TFR_all);
    pow_subject_rem = ft_selectdata(cfg, TFR_rem);
    pow_subject_nrem = ft_selectdata(cfg, TFR_nrem);

    pow_subject_all.avg = squeeze(pow_subject_all.powspctrm);
    pow_subject_all.dimord = 'rpt_time'; 
    pow_subject_all = rmfield(pow_subject_all, ["freq", "powspctrm"]);
    pow_subject_rem.avg = squeeze(pow_subject_rem.powspctrm);
    pow_subject_rem.dimord = 'rpt_time'; 
    pow_subject_rem = rmfield(pow_subject_rem, ["freq", "powspctrm"]);
    pow_subject_nrem.avg = squeeze(pow_subject_nrem.powspctrm);
    pow_subject_nrem.dimord = 'rpt_time';
    pow_subject_nrem = rmfield(pow_subject_nrem, ["freq", "powspctrm"]); 
    
    %% Reduce everything to the same time and sampling rate
    % Get TOI
    cfg = [];
    cfg.latency     = [-0.5 3];

    pow_subject_all = ft_selectdata(cfg, pow_subject_all);
    pow_subject_rem = ft_selectdata(cfg, pow_subject_rem);
    pow_subject_nrem = ft_selectdata(cfg, pow_subject_nrem);
    phases_all = ft_selectdata(cfg, data_all);
    phases_rem = ft_selectdata(cfg, rem_phases);
    phases_nrem = ft_selectdata(cfg, nrem_phases);
    
    cfg = [];
    cfg.time = num2cell(repmat(pow_subject_all.time, size(phases_all.time,2), 1),2);
    phases_all = ft_resampledata(cfg, phases_all);
    cfg.time = num2cell(repmat(pow_subject_rem.time, size(phases_rem.time,2), 1),2);
    phases_rem = ft_resampledata(cfg, phases_rem);
    cfg.time = num2cell(repmat(pow_subject_nrem.time, size(phases_nrem.time,2), 1),2);
    phases_nrem = ft_resampledata(cfg, phases_nrem);
    
    %% Get respiratory modulation for alpha power
    winsize = 2*pi/nbin;
    for j=1:nbin
        position(j) = -pi+(j-1)*winsize;
    end
    phases_rem_cat = cat(1,phases_rem.trial{:});
    phases_nrem_cat = cat(1,phases_nrem.trial{:});
    phases_all_cat2 = cat(1,phases_all.trial{:});
    phases_all_cat = cat(1, phases_rem_cat, phases_nrem_cat);
    nchans = size(pow_subject_all.label,1);
    
    % Get binned power for all
    for j=1:nbin
        I = (phases_all_cat2 <  position(j)+winsize) & (phases_all_cat2 >=  position(j));
        for ichan = 1:nchans
            pow_chan = pow_subject_all.avg(:,ichan,:);
            pow_all_bin(isubject, ichan, j)=nanmean(pow_chan(I), 'all');
        end
    end
    
    parfor k = 1:numPerms
        trials_shuff = Shuffle(phases_all_cat2,2);
        trials_shuff = trials_shuff(1:size(phases_all_cat2,1),:);
        
        if ~strcmp(locking, "cue")
            shiftVec = randi(2*size(trials_shuff,2), size(trials_shuff,1),1);
            trials_shuff = circshift(trials_shuff, shiftVec);
        end
        
        for j=1:nbin
            I = (trials_shuff <  position(j)+winsize) & (trials_shuff >=  position(j));
            for ichan = 1:nchans
                pow_chan = pow_subject_all.avg(:,ichan,:);
                pow_all_bin_null(isubject, ichan, j, k)=nanmean(pow_chan(I), 'all');
            end
        end
    end
    
    % Get binned power for remembered trials
    for j=1:nbin
        I = (phases_rem_cat <  position(j)+winsize) & (phases_rem_cat >=  position(j));
        for ichan = 1:nchans
            pow_chan = pow_subject_rem.avg(:,ichan,:);
            pow_rem_bin(isubject, ichan, j)=nanmean(pow_chan(I), 'all');
        end
    end
    
    parfor k = 1:numPerms
        trials_shuff = Shuffle(phases_all_cat,2);
        trials_shuff = trials_shuff(1:size(phases_rem_cat,1),:);
        
        for j=1:nbin
            I = (trials_shuff <  position(j)+winsize) & (trials_shuff >=  position(j));
            for ichan = 1:nchans
                pow_chan = pow_subject_rem.avg(:,ichan,:);
                pow_rem_bin_null(isubject, ichan, j, k)=nanmean(pow_chan(I), 'all');
            end
        end
    end
    
    % Get binned power for not remembered trials
    for j=1:nbin
        I = (phases_nrem_cat <  position(j)+winsize) & (phases_nrem_cat >=  position(j));
        for ichan = 1:nchans
            pow_chan = pow_subject_nrem.avg(:,ichan,:);
            pow_nrem_bin(isubject, ichan, j)=nanmean(pow_chan(I), 'all');
        end
    end
    
    parfor k = 1:numPerms
        trials_shuff = Shuffle(phases_all_cat,2);
        trials_shuff = trials_shuff(1:size(phases_nrem_cat,1),:);
        
        for j=1:nbin
            I = (trials_shuff <  position(j)+winsize) & (trials_shuff >=  position(j));
            for ichan = 1:nchans
                pow_chan = pow_subject_nrem.avg(:,ichan,:);
                pow_nrem_bin_null(isubject, ichan, j, k)=nanmean(pow_chan(I), 'all');
            end
        end
    end
end

%% Prepare ft data for permutation stats

temp        = TFR_all;
temp = rmfield(temp, ["freq", "dimord", "powspctrm"]);

temp.time   = position;
temp_shuff  = temp;
temp.fsample = mean(diff(temp.time));

tmp_pow_shuff = squeeze(mean(pow_all_bin_null,4));
tmp_rem_pow_shuff = squeeze(mean(pow_rem_bin_null,4));
tmp_nrem_pow_shuff = squeeze(mean(pow_nrem_bin_null,4));

for i = 1:18
    temp.avg = math.smooth_circular(squeeze(pow_all_bin(i,:,:)), avg_window, 2);
    temp_shuff.avg = math.smooth_circular(squeeze(tmp_pow_shuff(i,:,:)), avg_window, 2);

    all_pow{i} = temp;
    all_pow_shuff{i} = temp_shuff;
    
    temp.avg = math.smooth_circular(squeeze(pow_rem_bin(i,:,:)), avg_window, 2);
    temp_shuff.avg = math.smooth_circular(squeeze(tmp_rem_pow_shuff(i,:,:)), avg_window, 2);

    rem_pow{i} = temp;
    rem_pow_shuff{i} = temp_shuff;
    
    temp.avg = math.smooth_circular(squeeze(pow_nrem_bin(i,:,:)), avg_window, 2);
    temp_shuff.avg = math.smooth_circular(squeeze(tmp_nrem_pow_shuff(i,:,:)), avg_window, 2);

    nrem_pow{i} = temp;
    nrem_pow_shuff{i} = temp_shuff;
    
end

%% Run permutation test

% prepare stats
cfg                     = [];
cfg.spmversion          = 'spm12';
cfg.minnbchan           =  2;
cfg.channel             = 'all';
cfg.statistic           = 'depsamplesT';
cfg.method              = 'montecarlo'; 
cfg.correctm            = 'cluster';
cfg.clusteralpha        = .05;
cfg.tail                = 0;
cfg.computecritval      = 'yes';
cfg.numrandomization    = 300;
cfg.neighbours = neighbours;
cfg.wrapTimeDim = 1;

cfg.clusterstatistic    = 'maxsum'; 
cfg.clustertail         = cfg.tail;
cfg.parameter           = 'avg';

nSub = size (subjects,2);
design = zeros(2,2*nSub);
for i = 1:nSub
    design(1,i) = i;
end
for i = 1:nSub
    design(1,nSub+i) = i;
end
design(2,1:nSub)        = 1;
design(2,nSub+1:2*nSub) = 2;

cfg.design  = design;
cfg.uvar    = 1;
cfg.ivar    = 2;

% run stats
[stats] = ft_timelockstatistics(cfg, all_pow{:}, all_pow_shuff{:});
length(find(stats.mask))

%% plot stats
             
black      = [0,0,0];
grey       = [0.6 0.6 0.6];
grey_shuff = [0.3 0.3 0.3];
red        = [1,0,0];

mask_chan  = any(stats.mask,2);

tmp_pow_sign       = squeeze(mean(pow_all_bin(:,mask_chan,:),2));
tmp_pow_shuff_sign = squeeze(mean(tmp_pow_shuff(:,mask_chan,:),2));

signMask = any(stats.mask);

plot.resp_modulation2_null(stats.time, tmp_pow_sign, tmp_pow_shuff_sign, signMask, mask_chan);


%% Plot topo

% topo of stats (summed tvalues)
stats_temp  = stats.stat.*stats.mask; % mask t-values
t_sums_topo = squeeze(nansum(stats_temp,2)); % sum of sig. t-values

% fake grand average
cfg         = [];

ga_topo = all_pow{1};
ga_topo.time = 1;
ga_topo.avg = t_sums_topo;

figure;
topo             = [];
topo.zlim        = [-20 20];
topo.layout      = lay; 
topo.parameter   = 'avg';
topo.gridscale   = 360;
topo.marker      = 'off';
topo.comment     = 'no';
topo.style       = 'both';
topo.contournum  = 2;
ft_topoplotTFR(topo, ga_topo);
set(gcf, 'Color', 'w');
h = colorbar;
ylabel(h, "summed t-values");
set(findall(gcf,'-property','FontSize'),'FontSize',16)
fig = gcf;