% Bullon Tarraso et al. 2025
% Code for figure 3c (+inset)

dirRoot = 'data';

subjects    = {'P02' 'P05' 'P07' 'P09' 'P11' 'P12' 'P13' 'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

load('auxiliary/Bham-64CH-Neighbours.mat')
load('auxiliary\Bham-64CH-Lay.mat')

locking = "exhaleTroughs";
timeRange = [0,3];
window_radii = 2;
avg_window = 5;

load(fullfile("auxiliary", "classi_dvals_rem_" + locking));
load(fullfile("auxiliary", "classi_dvals_nrem_" + locking));

if strcmp(locking, "exhaleTroughs") && ~exist("TFRs_exht_all", 'var')
    load(fullfile("auxiliary", "TFRs_exht_all"));
end

% Config
remove_artifacts = false;
numPerms = 200;
freqRange = [8, 20]; % From TFR significance (slobscene_retrieval_TFR.m)
norm_dvals = false;

for isubject=1:numel(subjects)

    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));

    TFR_exht_obj = TFRs_exht_all{isubject, 1}.TFR;
    TFR_exht_sce = TFRs_exht_all{isubject, 2}.TFR;

    % Split rem vs nrem
    trlinfo  = cell2mat(TFR_exht_obj.trialinfo);   
    accuracy = [trlinfo.ExemplarAccuracy];
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy]; 
    
    % Split TFR data 
    cfg = [];
    cfg.trials = find(accuracy == 1 & oldnew == 1);
    TFR_rem_obj = ft_selectdata(cfg, TFR_exht_obj);
    cfg.trials = find(accuracy == 0 & oldnew == 1);
    TFR_nrem_obj = ft_selectdata(cfg, TFR_exht_obj);

    % Split rem vs nrem
    trlinfo  = cell2mat(TFR_exht_sce.trialinfo);   
    accuracy = [trlinfo.ExemplarAccuracy];
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy];
    
    cfg = [];
    cfg.trials = find(accuracy == 1 & oldnew == 1);
    TFR_rem_sce = ft_selectdata(cfg, TFR_exht_sce);
    cfg.trials = find(accuracy == 0 & oldnew == 1);
    TFR_nrem_sce = ft_selectdata(cfg, TFR_exht_sce);

    cfg = [];
    TFR_rem = ft_appendfreq(cfg, TFR_rem_obj, TFR_rem_sce);
    TFR_nrem = ft_appendfreq(cfg, TFR_nrem_obj, TFR_nrem_sce);

    % Get behavioral results for later correlation
    cfg = [];
    TFR_exht = ft_appendfreq(cfg, TFR_exht_obj, TFR_exht_sce);
    
    % Split rem vs nrem
    trlinfo  = cell2mat(TFR_exht.trialinfo);   
    accuracy = [trlinfo.ExemplarAccuracy];
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy]; 
    
    % Save memory accuracy to correlate with rhos
    accs_behav(isubject, 1) = mean(accuracy(accuracy<=1))*100;
    accs_behav(isubject, 2) = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    
    %% Prepare power data
    % Load TFR cue locked for normalization
    TFR_obj = load(fullfile(dirSubject(1).folder, dirSubject(1).name, "TFR_all")).TFR_all;
    TFR_sce = load(fullfile(dirSubject(2).folder, dirSubject(2).name, "TFR_all")).TFR_all;
    
    cfg = [];
    TFR_all = ft_appendfreq(cfg, TFR_obj, TFR_sce);
    
    % Normalize all TFRs
    cfg             = [];
    cfg.latency     = [-1 0];

    TFR_norm        = ft_selectdata(cfg,TFR_all);
    pow_norm = TFR_norm.powspctrm;  

    pow  = TFR_all.powspctrm;
    powz = nan(size(pow));

    for ichan = 1:size(pow,2)
        for ifreq = 1:size(pow,3)
            %for itime =1:size(pow,4)
                d = squeeze(pow_norm(:,ichan,ifreq,:));
                m = nanmean(d(:));
                s = nanstd(d(:));

                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
           %end
        end
    end

    TFR_all.powspctrm = powz;

    % Normalize remembered trials
    TFR_norm        = ft_selectdata(cfg,TFR_rem);
    pow_norm = TFR_norm.powspctrm;  

    pow  = TFR_rem.powspctrm;
    powz = nan(size(pow));

    for ichan = 1:size(pow,2)
        for ifreq = 1:size(pow,3)
            %for itime =1:size(pow,4)
                d = squeeze(pow_norm(:,ichan,ifreq,:));
                m = nanmean(d(:));
                s = nanstd(d(:));

                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
           %end
        end
    end

    TFR_rem.powspctrm = powz;

    % Normalize not remembered trials
    TFR_norm        = ft_selectdata(cfg,TFR_nrem);
    pow_norm = TFR_norm.powspctrm;  

    pow  = TFR_nrem.powspctrm;
    powz = nan(size(pow));

    for ichan = 1:size(pow,2)
        for ifreq = 1:size(pow,3)
            %for itime =1:size(pow,4)
                d = squeeze(pow_norm(:,ichan,ifreq,:));
                m = nanmean(d(:));
                s = nanstd(d(:));

                powz(:,ichan,ifreq,:) = (pow(:,ichan,ifreq,:) - m)./s;
            %end
        end
    end

    TFR_nrem.powspctrm = powz;

    %% Average over channels and frequencies
    cfg = [];
    cfg.frequency   = freqRange;
    cfg.avgoverchan = 'no';
    cfg.avgoverfreq = 'yes';
    cfg.nanmean = 'yes';
    pow_subject_rem = ft_selectdata(cfg, TFR_rem);
    pow_subject_nrem = ft_selectdata(cfg, TFR_nrem);

    pow_subject_rem.avg = squeeze(pow_subject_rem.powspctrm);
    pow_subject_rem.dimord = 'rpt_chan_time'; 
    pow_subject_rem = rmfield(pow_subject_rem, ["freq", "powspctrm"]);
    pow_subject_nrem.avg = squeeze(pow_subject_nrem.powspctrm);
    pow_subject_nrem.dimord = 'rpt_chan_time';
    pow_subject_nrem = rmfield(pow_subject_nrem, ["freq", "powspctrm"]);


    powi_r(isubject,:)  = squeeze(nanmean(nanmean(pow_subject_rem.avg)));
    powi_nr(isubject,:) = squeeze(nanmean(nanmean(pow_subject_nrem.avg)));

    %% Zscore dvals
    if norm_dvals
        tmp_obj = zscore(acc_dvals_rem(isubject).dvals_object_avg,[], 'all');
        tmp_sce = zscore(acc_dvals_rem(isubject).dvals_scene_avg,[], 'all');

        acc_dvals_rem(isubject).all_dvals = cat(1, tmp_obj, tmp_sce);

        tmp_obj = zscore(acc_dvals_nrem(isubject).dvals_object_avg,[], 'all');
        tmp_sce = zscore(acc_dvals_nrem(isubject).dvals_scene_avg,[], 'all');

        acc_dvals_nrem(isubject).all_dvals = cat(1, tmp_obj, tmp_sce);
    end

    cfg = [];
    cfg.latency     = [-2 2];
    pow_subject_rem = ft_selectdata(cfg, pow_subject_rem);
    pow_subject_nrem = ft_selectdata(cfg, pow_subject_nrem);

    % Downsample dvals to fit with TFR sampling rate
    dvals_rem = downsample(acc_dvals_rem(isubject).all_dvals', (size(acc_dvals_rem(isubject).all_dvals,2)-1)/(length(pow_subject_rem.time)-1))';
    dvals_nrem = downsample(acc_dvals_nrem(isubject).all_dvals', (size(acc_dvals_nrem(isubject).all_dvals,2)-1)/(length(pow_subject_nrem.time)-1))';    
    
    %% Get correlations   
    shiftVector = ((1:size(dvals_rem,2)) - ceil(size(dvals_rem,2)/2))*-1; 
    
    pow_all_rem{isubject} = squeeze(mean(pow_subject_rem.avg));
    dvals_all_rem{isubject} = mean(dvals_rem);
    
    for ii = 1:54       
        pow_rem = squeeze(pow_subject_rem.avg(:,ii,:));
        tmp = corr(pow_rem, dvals_rem, 'Rows', 'complete', 'Type', 'Spearman');
        corr_all_rem{isubject}(ii, :) = math.getDiagonal(tmp);

        parfor p =1:numPerms
            dvals_shuff = Shuffle(dvals_rem);
            tmp = corr(pow_rem, dvals_shuff, 'Rows', 'complete', 'Type', 'Spearman');
            corr_rem_null(ii, :, p) = math.getDiagonal(tmp);
        end
    end
    
    corr_all_rem_null{isubject} = corr_rem_null;

    pow_all_nrem{isubject} = squeeze(mean(pow_subject_nrem.avg));
    dvals_all_nrem{isubject} = mean(dvals_nrem);
    for ii = 1:54       
        pow_nrem = squeeze(pow_subject_nrem.avg(:,ii,:));
        tmp = corr(pow_nrem, dvals_nrem, 'Rows', 'complete', 'Type', 'Spearman');
        corr_all_nrem{isubject}(ii, :) = math.getDiagonal(tmp);

        parfor p =1:numPerms
            dvals_shuff = Shuffle(dvals_nrem);
            tmp = corr(pow_nrem, dvals_shuff, 'Rows', 'complete', 'Type', 'Spearman');
            corr_nrem_null(ii, :, p) = math.getDiagonal(tmp);
        end
    end
    corr_all_nrem_null{isubject} = corr_nrem_null;

end

%% Prepare data for statistics

temp = [];
temp.time  = pow_subject_nrem.time;  %position
temp.label = pow_subject_rem.label ;

for i = 1:18
    mean_i       = mean(corr_all_rem{1,1},2);
    rep_mean     = repmat(mean_i,1,size(corr_all_rem{1, 1},2));
    all_rep_mean{i} =  rep_mean;
end

zz = zeros(54, size(corr_all_rem{1, 1},2));

for i = 1:18
    
    tempus_r{1,i} = temp;
    tempus_r{1,i}.avg = movmean(corr_all_rem{1,i}, avg_window, 2);
       
    tempus_n{1,i} = temp;
    tempus_n{1,i}.avg = movmean(corr_all_nrem{1,i}, avg_window, 2);    
        
    tempus_zero{1,i} = temp;
    tempus_zero{1,i}.avg = zz;       
    
    tempus_null{1,i} = temp;
    tempus_null{1,i}.avg = movmean(squeeze(mean(corr_all_rem_null{1,i},3)), avg_window, 2);
end
    
%% Run stats
cfg         = [];
cfg.latency = [-2 2]; 
cfg.method           = 'montecarlo';
cfg.statistic        = 'depsamplesT';
cfg.correctm         = 'cluster';
cfg.clusteralpha     = 0.05;
cfg.clusterstatistic = 'maxsum';
cfg.minnbchan        = 3;  
cfg.tail             = 0;
%cfg.clustertail      = 0;
cfg.alpha            = 0.05;
cfg.numrandomization = 1000;
cfg.neighbours       = neighbours; 

Nsubj  = 18;
design = zeros(2, Nsubj*2);
design(1,:) = [1:Nsubj 1:Nsubj];
design(2,:) = [ones(1,Nsubj) ones(1,Nsubj)*2];

cfg.design = design;
cfg.uvar   = 1;
cfg.ivar   = 2;

[stat] = ft_timelockstatistics(cfg, tempus_r{:}, tempus_n{:});
stat.negclusters(1).prob

%% Plot results
red   = [255, 17, 0]/255;
grey       = [0.6 0.6 0.6];
green   = [35, 175, 35]/255;

maski = find(any(stat.mask,2));
clear all_rem all_nrem;
for i = 1:18
    cfg  = [];
    cfg.latency = [stat.time(1) stat.time(end)];
    cfg.channel = maski;
    subject_ft = ft_selectdata(cfg, tempus_r{i});
    all_rem(i,:,:) = subject_ft.avg;
    
    cfg  = [];
    cfg.latency = [stat.time(1) stat.time(end)];
    cfg.channel = maski;
    subject_ft = ft_selectdata(cfg, tempus_n{i});
    all_nrem(i,:,:) = subject_ft.avg;
end

% Plot results
r_mean = squeeze(mean(all_rem, [1,2]));
null_mean = squeeze(mean(all_nrem, [1,2]));

r_SEM = squeeze(std(all_rem, 0, [1,2])/sqrt(size(all_rem,1)));
null_SEM = squeeze(std(all_nrem, 0, [1,2])/sqrt(size(all_rem,1)));

figure()
[hl,hp] = boundedline(stat.time,r_mean, r_SEM, 'cmap', green, 'alpha','transparency', 0.15);% 'cmap',[0.85 0 0], 'alpha','transparency', 0.5 // 'cmap',[0 0 0], 'alpha','transparency', 0.25
set(hl, 'linewidth', 3);

hold on 

[hl,hp] = boundedline(stat.time,null_mean, null_SEM, 'cmap', red, 'alpha','transparency', 0.15);% 'cmap',[0.85 0 0], 'alpha','transparency', 0.5 // 'cmap',[0 0 0], 'alpha','transparency', 0.25
set(hl, 'linewidth', 3);

if length(find(stat.mask)) > 0
    sigMask = nan(size(stat.mask,2),1);
    sigMask(any(stat.mask)) = 0;
    hold on
    plot(stat.time, sigMask, '-k', 'LineWidth', 3)
end

xline(0, 'k--');

xlabel("Time exhalation trough (s)");
ylabel("Spearman Correlation");
legend("Remembered", "Not remembered", "p < 0.05");


%% topo
stat.stat = stat.stat.*stat.mask; % mask t-values
t_sums_topo = squeeze(nansum(nansum(stat.stat,2),3)); % sum of sig. t-values

% fake grand average
cfg         = [];
cfg.avgovertime = 'yes';
ga_topo     = ft_selectdata(cfg,tempus_r{:});

ga_topo.powspctrm = t_sums_topo;

figure;
topo             = [];
topo.zlim        = [-15 15];
topo.layout      = lay;
topo.parameter   = 'powspctrm';
topo.gridscale   = 360;
topo.marker      = 'off';
topo.comment     = 'no';
topo.style       = 'both';
topo.contournum  = 2;
ft_topoplotTFR(topo, ga_topo);
set(gcf, 'Color', 'w');
hcb = colorbar;
ylabel(hcb, 'summed t-values');
set(findall(gcf,'-property','FontSize'),'FontSize',22)