% Bullon Tarraso et al. 2025
% Code for figure 1b (and data from Supplementary table 1)

dirRoot = 'data';
ft_defaults;

subjects    = {'P02'  'P05'  'P07'  'P09'  'P11'  'P12' 'P13'  'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

%% Compute metrics per participant

for isubject=1:numel(subjects)
    
    % Load data
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));

    data_phase_obj = dataloader.load_segm_phase(dirRoot, dirSubject, 1, 0, 0);
    data_phase_sce = dataloader.load_segm_phase(dirRoot, dirSubject, 2, 0, 0);
    
    % Separate pre and post
    % First objects
    trlinfo  = cell2mat(data_phase_obj.trialinfo);
    pre_obj = trlinfo(1:90); % 180 trials per image category
    post_obj = trlinfo(91:end);
    pre_obj = pre_obj([pre_obj.Answer1Accuracy] <=1); % This is for the extremely rare annoying case of invalid trials (e.g., too long response time)
    post_obj = post_obj([post_obj.Answer1Accuracy] <=1);
    
    % The same for scenes
    trlinfo  = cell2mat(data_phase_sce.trialinfo);
    pre_sce = trlinfo(1:90);
    post_sce = trlinfo(91:end);
    pre_sce = pre_sce([pre_sce.Answer1Accuracy] <=1);
    post_sce = post_sce([post_sce.Answer1Accuracy] <=1);
    
    % Calculate accuracies
    % Objects of the pre-sleep task
    accuracy = [pre_obj.ExemplarAccuracy];   
    oldnew   = [pre_obj.OldNew];
    oldnew_acc = [pre_obj.Answer1Accuracy];
    
    pre_obj_dprime = math.dprime(mean(oldnew_acc(oldnew == 1)), ...
        mean(oldnew_acc(oldnew == 2)), sum(oldnew ==1), sum(oldnew ==2));   
    pre_obj_acc_outHits = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    pre_obj_acc = mean(accuracy(accuracy < 2))*100;
    pre_obj_hits = sum(oldnew_acc(oldnew==1) == 1)/sum(oldnew==1)*100;
    pre_obj_cr = sum(oldnew_acc(oldnew==2) == 1)/sum(oldnew==2)*100;
    
    % Objects of the post-sleep task
    accuracy = [post_obj.ExemplarAccuracy];   
    oldnew   = [post_obj.OldNew];
    oldnew_acc = [post_obj.Answer1Accuracy];
    
    post_obj_dprime = math.dprime(mean(oldnew_acc(oldnew == 1)), ...
        mean(oldnew_acc(oldnew == 2)), sum(oldnew ==1), sum(oldnew ==2));
    post_obj_acc_outHits = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    post_obj_acc = mean(accuracy(accuracy < 2))*100;
    post_obj_hits = sum(oldnew_acc(oldnew==1) == 1)/sum(oldnew==1)*100;
    post_obj_cr = sum(oldnew_acc(oldnew==2) == 1)/sum(oldnew==2)*100;
    
    % Scenes of the pre-sleep task
    accuracy = [pre_sce.ExemplarAccuracy];   
    oldnew   = [pre_sce.OldNew];
    oldnew_acc = [pre_sce.Answer1Accuracy];
    
    pre_sce_dprime = math.dprime(mean(oldnew_acc(oldnew == 1)), ...
        mean(oldnew_acc(oldnew == 2)), sum(oldnew ==1), sum(oldnew ==2));
    pre_sce_acc_outHits = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    pre_sce_acc = mean(accuracy(accuracy < 2))*100;
    pre_sce_hits = sum(oldnew_acc(oldnew==1) == 1)/sum(oldnew==1)*100;
    pre_sce_cr = sum(oldnew_acc(oldnew==2) == 1)/sum(oldnew==2)*100;
    
    % Scenes of the post-sleep task
    accuracy = [post_sce.ExemplarAccuracy];   
    oldnew   = [post_sce.OldNew];
    oldnew_acc = [post_sce.Answer1Accuracy];
    
    post_sce_dprime = math.dprime(mean(oldnew_acc(oldnew == 1)), ...
        mean(oldnew_acc(oldnew == 2)), sum(oldnew ==1), sum(oldnew ==2));
    post_sce_acc_outHits = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    post_sce_acc = mean(accuracy(accuracy < 2))*100;
    post_sce_hits = sum(oldnew_acc(oldnew==1) == 1)/sum(oldnew==1)*100;
    post_sce_cr = sum(oldnew_acc(oldnew==2) == 1)/sum(oldnew==2)*100;
    
    % Save for later
    subject_accs_outHits(isubject,:) = [pre_obj_acc_outHits, post_obj_acc_outHits, pre_sce_acc_outHits, post_sce_acc_outHits];
    subject_dprimes(isubject,:) = [pre_obj_dprime, post_obj_dprime, pre_sce_dprime, post_sce_dprime];
    subject_hits(isubject,:) = [pre_obj_hits, post_obj_hits, pre_sce_hits, post_sce_hits];
    subject_cr(isubject,:) = [pre_obj_cr, post_obj_cr, pre_sce_cr, post_sce_cr];
    subject_acc(isubject,:) = [pre_obj_acc, post_obj_acc, pre_sce_acc, post_sce_acc];
end

%% Run RM-ANOVA for the behavioral analyses
% First for associative memory (out of hits)
% Get accuracies table
accs = array2table(subject_accs_outHits, 'VariableNames', ["acc1", "acc2", "acc3", "acc4"]);

% Get within design matrix
session = ["Object"; "Object"; "Scene"; "Scene"];
retrieval = ["Pre"; "Post"; "Pre"; "Post"];
w2design = table(session, retrieval);

% Fit rmanova model
rm = fitrm(accs, "acc1-acc4~1", 'WithinDesign', w2design);
ranovatbl = ranova(rm, "WithinModel", "session*retrieval");
disp(ranovatbl)

% Second for old/new recognition (dprimes)
% Get accuracies table
dprime_table = array2table(subject_dprimes, 'VariableNames', ["dprime1", "dprime2", "dprime3", "dprime4"]);

% Get within design matrix
session = ["Object"; "Object"; "Scene"; "Scene"];
retrieval = ["Pre"; "Post"; "Pre"; "Post"];
w2design = table(session, retrieval);

% Fit rmanova model
rm = fitrm(dprime_table, "dprime1-dprime4~1", 'WithinDesign', w2design);
ranovatbl_oldnew = ranova(rm, "WithinModel", "session*retrieval");
disp(ranovatbl_oldnew)

mean_accs = mean(subject_accs_outHits);
mean_accs = reshape(mean_accs, 2,2)';

%% Plot result in bars (this creates figure 1b in the paper) (NEW)
clear mem;
mem(:,1) = (subject_accs_outHits(:,1) + subject_accs_outHits(:,2))/2;
mem(:,2) = (subject_accs_outHits(:,3) + subject_accs_outHits(:,4))/2;

% set colors to grey, blue, red
colors(1,:) = [0.9 0.9 0.9]; 
colors(2,:) = [0.4 0.4 0.4];    

fig = figure;
hold on
% if we want each bar to have a different color, loop
for b = 1:size(mem, 2),
    bb = bar(b, mean(mem(:,b)), 'FaceColor',  colors(b, : ), 'EdgeColor', 'k', 'BarWidth', 0.7,'LineWidth',2);
    
end

% show standard deviation on top
h = ploterr(1:2, mean(mem), [], std(mem)/sqrt(length( mem)), 'k.');
set(h(1), 'marker', 'none'); % remove marker
set(h,'Linewidth',1.5)
set(gca, 'Ylim', [0 100],'XTickLabel',[],'xtick',[])
set(gcf, 'Position', [360,278,400,420])

yticks([0 20 40 60 80 100 120])
yticklabels({'0' '20' '40' '60'  '80'  '100' '120'})
ax = gca;
ax.FontSize = 14;

xticks([1.5])
xticklabels({' Objects    Scenes'})
b = get(gca,'XTickLabel');  
set(gca,'XTickLabel',b,'fontsize',18)
xlabel("Category");

ax = gca;
ax.XRuler.Axle.LineWidth = 2;
ax.YRuler.Axle.LineWidth = 2;
ax.XAxis.MajorTickChild.LineWidth = 2;
ax.YAxis.MajorTickChild.LineWidth = 2;

y = ylabel('Memory performance (%)', 'FontSize', 20)
set(y, 'Units', 'Normalized', 'Position', [-0.18, 0.5, 0]);

%%%%%%%%%%
subjects = 1:18;
color   = cbrewer2('seq', 'Greys',50);
linspecer(length(subjects),'sequential'); % for single subject dots
jitter = linspace(-0.1,0.1,48); % jitter for x axis (single subject dots)

hold on

for iparticipant = 1:length(subjects)
    b = randsample(jitter,1);
    plot([1+b 2+b],mem(iparticipant,[1,2]), 'ok', 'MarkerSize', 6)
end