% Bullon Tarraso et al. 2025
% Code for figure 3b (+inset)

ft_defaults
dirRoot = 'data';

subjects    = {'P02' 'P05' 'P07' 'P09' 'P11' 'P12' 'P13' 'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

locking = "cue";

load(fullfile("auxiliary", "classi_dvals_rem_" + locking));
load(fullfile("auxiliary", "classi_dvals_nrem_" + locking));

% Config
remove_artifacts = false;
use_hilbert = false;
nbin = 60;
norm_dvals = false;

for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));
          
    %% Load data
    
    % Load segmentated phase retrieval data
    data_phase_obj = dataloader.load_segm_phase(dirRoot, dirSubject, 1, use_hilbert, remove_artifacts);
    data_phase_sce = dataloader.load_segm_phase(dirRoot, dirSubject, 2, use_hilbert, remove_artifacts);
    
    
    % Concatenate phase data
    trialinfo_all          = [data_phase_obj.trialinfo; data_phase_sce.trialinfo];
    cfg                    = [];
    data_all               = ft_appenddata(cfg, data_phase_obj, data_phase_sce);
    data_all.fsample       = 200;
    data_all.trialinfo     = trialinfo_all; 
    
    % Restrict to reactivation timepoints
    if strcmp(locking, "cue")
        cfg = [];
        cfg.latency = [0, 1.5];
        data_all = ft_selectdata(cfg, data_all);
    end
    
    % Downsample to dvals fsample (half of the original) 
    cfg = [];
    cfg.resamplefs = 100;
    cfg.method = 'downsample';
    data_all = ft_resampledata(cfg, data_all);
    
    % class selection (2 classes)
    trlinfo  = cell2mat(data_all.trialinfo);   
    accuracy = [trlinfo.ExemplarAccuracy];   % 1 [correct], 0 [wrong] 
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy];
    
    cfg = [];
    cfg.trials = accuracy==1 & oldnew == 1; 
    rem_phases      = ft_selectdata(cfg, data_all);  
    cfg.trials = accuracy==0 & oldnew == 1; 
    nrem_phases      = ft_selectdata(cfg, data_all);
    
    % Save memory accuracy to correlate with rhos
    accs_behav(isubject) = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    
    %% Zscore dvals
    if norm_dvals
        tmp_obj = zscore(acc_dvals_rem(isubject).dvals_object_avg,[], 'all');
        tmp_sce = zscore(acc_dvals_rem(isubject).dvals_scene_avg,[], 'all');

        acc_dvals_rem(isubject).all_dvals = cat(1, tmp_obj, tmp_sce);
        
        tmp_obj = zscore(acc_dvals_nrem(isubject).dvals_object_avg,[], 'all');
        tmp_sce = zscore(acc_dvals_nrem(isubject).dvals_scene_avg,[], 'all');

        acc_dvals_nrem(isubject).all_dvals = cat(1, tmp_obj, tmp_sce);
    end
    
    %% Remove period before cue
    if strcmp(locking, "cue")
        acc_dvals_rem(isubject).all_dvals = acc_dvals_rem(isubject).all_dvals(:,21:end);
        acc_dvals_nrem(isubject).all_dvals = acc_dvals_nrem(isubject).all_dvals(:,21:end);
    end
    
    %% Get index of maximum reactivation
    [~, maxIdxRem] = max(acc_dvals_rem(isubject).all_dvals, [], 2);
    [~, maxIdxNRem] = max(acc_dvals_nrem(isubject).all_dvals, [], 2);
     
    %% Get mean dvalue per participant
    meanDvalRem(isubject) = mean(acc_dvals_rem(isubject).all_dvals, "all");
    meanDvalNRem(isubject) = mean(acc_dvals_nrem(isubject).all_dvals, "all");
    meanDvalAll(isubject) = mean([acc_dvals_rem(isubject).all_dvals; acc_dvals_nrem(isubject).all_dvals], "all");
    
    %% Get phases at such index
    for itrial = 1:length(rem_phases.trial)
        maxPhasesRem(itrial) = rem_phases.trial{itrial}(maxIdxRem(itrial));
    end
    for itrial = 1:length(nrem_phases.trial)
        maxPhasesNRem(itrial) = nrem_phases.trial{itrial}(maxIdxNRem(itrial));
    end
    
    maxPhasesAll = [maxPhasesRem, maxPhasesNRem];
    
    %% Extract mean and vector length
    meanRem(isubject) = circ_mean(maxPhasesRem');
    vLengthRem(isubject) = circ_r(maxPhasesRem');
    
    meanNRem(isubject) = circ_mean(maxPhasesNRem');
    vLengthNRem(isubject) = circ_r(maxPhasesNRem');
    
    meanAll(isubject) = circ_mean(maxPhasesAll');
    vLengthAll(isubject) = circ_r(maxPhasesAll');
end

%% Get correlation with behavior and vector length
% First for remembered
[c, p] = corr(vLengthRem', accs_behav', 'Rows', 'complete', 'Type', 'Spearman');

figure;
scatter(vLengthRem, accs_behav);
xlabel("Respiration dvals coupling [vector length]");
ylabel("Memory accuracy (%)");
title(strcat("Remembered Corr = ", string(c), ", p = ", string(p)));

% Not remembered
[c, p] = corr(vLengthNRem', accs_behav', 'Rows', 'complete', 'Type', 'Spearman');

figure;
scatter(vLengthNRem, accs_behav);
xlabel("Respiration dvals coupling [vector length]");
ylabel("Memory accuracy (%)");
title(strcat("Not remembered Corr = ", string(c), ", p = ", string(p)));

% All
[c, p] = corr(vLengthAll', accs_behav', 'Rows', 'complete', 'Type', 'Spearman');

figure;
scatter(vLengthAll, accs_behav);
xlabel("Respiration dvals coupling [vector length]");
ylabel("Memory accuracy (%)");
title(strcat("All trials Corr = ", string(c), ", p = ", string(p)));

%% Create fancy plot
fitfun = @(b,x) (100*x+ b(1))./(x);
beta0 = 1;
mdl = fitnlm(vLengthRem, accs_behav, fitfun, beta0);

vInterpol = 0:0.05:0.45;
accsInterpol = fitfun(mdl.Coefficients{:,1}, vInterpol);

figure;
scatter(vLengthRem, accs_behav, 150, [0.4660 0.6740 0.1880], "filled", "MarkerEdgeColor", "k", "LineWidth", 1);
hold on;
plot(vInterpol, accsInterpol, "k--", 'LineWidth', 6);
xlabel("Respiration-dvals coupling (vector length)");
ylabel("Memory accuracy (%)");

figure;
polarhistogram(meanRem, 12, 'DisplayStyle', 'bar', 'FaceColor', [0.4660 0.6740 0.1880], 'EdgeColor', [0 0 0], 'LineWidth', 3);

%% Create fancy plot (Supplementary)
% For Not remembered
fitfun = @(b,x) (100*x+ b(1))./(x);
beta0 = 1;
mdl = fitnlm(vLengthNRem, accs_behav, fitfun, beta0);

vInterpol = 0:0.05:0.4;
accsInterpol = fitfun(mdl.Coefficients{:,1}, vInterpol);

figure;
scatter(vLengthNRem, accs_behav, 150, [1 0.2 0.2], "filled", "MarkerEdgeColor", "k", "LineWidth", 1);
xlabel("Respiration-dvals coupling (vector length)");
ylabel("Memory accuracy (%)");
title("Not remembered");

% For all trials
mdl = fitnlm(vLengthAll, accs_behav, fitfun, beta0);

vInterpol = 0:0.05:0.4;
accsInterpol = fitfun(mdl.Coefficients{:,1}, vInterpol);

figure;
scatter(vLengthAll, accs_behav, 150, [0.3 0.3 0.3], "filled", "MarkerEdgeColor", "k", "LineWidth", 1);
xlabel("Respiration-dvals coupling (vector length)");
ylabel("Memory accuracy (%)");
title("All trials")