% Bullon Tarraso et al. 2025
% Code for figure 2a

dirRoot = 'data';
ft_defaults;

% Removing P01 and P18 from the list due to low hit number
subjects    = {'P02'  'P05'  'P07'  'P09'  'P11'  'P12' 'P13'  'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

% Time step
tStep = 0.25; % In seconds
TOI = [-0.5, 3];
hilbert = false;
removeArtifacts = false;

%% Prepare subjectsTable
subjectsTable = table();
nTrialsTable = table();
hists = cell(length(subjects), 1);

for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));

    data_obj = dataloader.load_segm_phase(dirRoot, dirSubject, 1, hilbert, removeArtifacts);
    data_sce = dataloader.load_segm_phase(dirRoot, dirSubject, 2, hilbert, removeArtifacts);

    % append object and scene data
    trialinfo_all          = [data_obj.trialinfo; data_sce.trialinfo];
    cfg                    = [];
    data_all               = ft_appenddata(cfg, data_obj, data_sce);
    data_all.fsample       = 200;
    data_all.trialinfo     = trialinfo_all;

    % Reduce to TOI
    cfg = [];
    cfg.latency = TOI;
    data_all = ft_preprocessing(cfg, data_all);
    
    % Prepare trial info
    trlinfo  = cell2mat(data_all.trialinfo);   
        
    accuracy = [trlinfo.ExemplarAccuracy];   
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy];
    rt1 = [trlinfo.RT1];
    subject_acc = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    
    % We extract the results for each period of significance separately
    numEpochs = floor(size(data_all.time{1,1},2)/(data_all.fsample*tStep));
    for tWindow = 1:numEpochs

        % Get inspiration/expiration retrieval mask       
        ins = cellfun(@(x) abs(circ_mean(x((tWindow-1)*tStep*data_all.fsample+1:tWindow*tStep*data_all.fsample), [], 2))<=pi/2, data_all.trial);
        
        % First the inspiration row
        % Extract accuracy results
        oldnew_acc_ins = mean(oldnew_acc(ins & (oldnew_acc <= 1)));
        % image_acc_ins = mean(accuracy(ins & (accuracy <= 1)));
        image_acc_ins = sum((accuracy == 1).*ins)/sum((oldnew == 1).*(oldnew_acc == 1).*ins)*100 -subject_acc;
        
        % Extract response time results
        rt_ins = rt1(ins);
        rt_ins = rmoutliers(rt_ins, "median");
        mean_RT_ins = sqrt(mean(rt_ins.^2));
    
        % Stats
        ins_trials_prop = mean(ins);
    
        % Add row to table
        newRow = [trlinfo(1).SubjectNumber, true, tWindow, ins_trials_prop,...
            oldnew_acc_ins, image_acc_ins, mean_RT_ins];
    
        subjectsTable = [subjectsTable; table(newRow)];
    
        % And second the expiration row
        oldnew_acc_exp = mean(oldnew_acc(not(ins) & (oldnew_acc <= 1)));
        % image_acc_exp = mean(accuracy(not(ins) & (accuracy <= 1)));
        image_acc_exp = sum((accuracy == 1).*not(ins))/sum((oldnew == 1).*(oldnew_acc == 1).*not(ins))*100 -subject_acc;
    
        rt_exp = rt1(not(ins));
        rt_exp = rmoutliers(rt_exp, "median");
        mean_RT_exp = sqrt(mean(rt_exp.^2));
    
        % Stats
        exp_trials_prop = mean(not(ins));
    
        % Add row to table
        newRow = [trlinfo(1).SubjectNumber, false, tWindow, exp_trials_prop,...
            oldnew_acc_exp, image_acc_exp, mean_RT_exp];

        subjectsTable = [subjectsTable; table(newRow)];
    end
end
subjectsTable = splitvars(subjectsTable);

subjectsTable.Properties.VariableNames = ["SubjectNr", "Inspiration", "TOI", "Proportion",...
    "OldNew_acc", "Image_acc", "Mean_RT"];
retPhase = {'Expiration', 'Inspiration'};
    retPhase = retPhase(subjectsTable.Inspiration+1);
subjectsTable.RetPhase = retPhase';

%% Get significant tpoints

% extract relevant data from table

pp = [2 5 7 9 11 12 13 14 15 16 19 21 22 25 26 29 30 31];

new_tab = sortrows(subjectsTable,8);

data_inh_all = [];
data_exp_all = [];
for isub = 1:size(pp,2)
    p       = [];
    p1      = new_tab(new_tab.SubjectNr ==pp(isub),:);
    p1_exp  = p1(p1.Inspiration ==0,:);       
    p1_inh  = p1(p1.Inspiration ==1,:);

    dat_inh = [p1_inh.Image_acc]'; 
    dat_exp = [p1_exp.Image_acc]';

    data_inh_all(isub,:) = movmean(dat_inh,3);
    data_exp_all(isub,:) = movmean(dat_exp,3);
end

% bring into FT structure

temp.time   = [TOI(1):tStep:TOI(end)-tStep];
temp.dimord = 'chan_time';
temp.label  = {'resp'};

for isub = 1:size(pp,2)
    temp.avg = data_inh_all(isub,:);    
    data_inh{isub} = temp;

    temp.avg = data_exp_all(isub,:);    
    data_exp{isub} = temp; 
end

% prepare stats
cfg                     = [];
cfg.latency             = [-0.5, 3];
cfg.spmversion          = 'spm12';
cfg.minnbchan           = [];
cfg.channel             = 'all';
cfg.statistic           = 'depsamplesT';
cfg.method              = 'montecarlo'; 
cfg.correctm            = 'cluster';
cfg.clusteralpha        = .05;
cfg.tail                = 0;
cfg.computecritval      = 'yes';
cfg.numrandomization    = 1000;

cfg.clusterstatistic    = 'maxsize'; 
cfg.clustertail         = cfg.tail;
cfg.parameter           = 'avg';

nSub = size (pp,2);
design = zeros(2,2*nSub);
for i = 1:nSub
    design(1,i) = i;
end
for i = 1:nSub
    design(1,nSub+i) = i;
end
design(2,1:nSub)        = 1;
design(2,nSub+1:2*nSub) = 2;

cfg.design  = design;
cfg.uvar    = 1;
cfg.ivar    = 2;

% run stats
[Fieldtripstats] = ft_timelockstatistics(cfg, data_inh{:}, data_exp{:});
length(find(Fieldtripstats.mask))

disp("Best pos cluster:")
Fieldtripstats.posclusters(1).prob

disp("Best neg cluster:")
Fieldtripstats.negclusters(1).prob

%% plot

tVector = data_exp{1, 1}.time + tStep/2;

red   = [255, 17, 0]/255;
blue  = [0 0.4470 0.7410];

mean_inh = mean(data_inh_all);  
SEM_inh  = squeeze(std (data_inh_all)/sqrt(size(data_inh_all,1)));

mean_exp = mean(data_exp_all);  
SEM_exp  = squeeze(std (data_exp_all)/sqrt(size(data_exp_all,1)));  

figure;
% plot ERP
[hl,hp] = boundedline(tVector, mean_inh, SEM_inh, 'cmap',blue, 'alpha','transparency', 0.15);% 'cmap',[0.85 0 0], 'alpha','transparency', 0.5 // 'cmap',[0 0 0], 'alpha','transparency', 0.25
set(hl, 'linewidth', 3);

hold on
[hl,hp] = boundedline(tVector, mean_exp, SEM_exp, 'cmap','r', 'alpha','transparency', 0.15);% 'cmap',[0.85 0 0], 'alpha','transparency', 0.5 // 'cmap',[0 0 0], 'alpha','transparency', 0.25
set(hl, 'linewidth', 3);  

%stats_time = nearest(cstat{1}.time,cfg.latency(1)):nearest(cstat{1}.time,cfg.latency(2));

extendMask = movmax(Fieldtripstats.mask, [1,0]);
sigline    = nan(1,numel(extendMask));
sigline(extendMask==1) = 0;

% ylim([60, 74]);
hold on
plot(Fieldtripstats.time + tStep/2,sigline,'color','k','linewidth',3);

%ylim([50, 100])
xline(0, 'k--');
xlabel("Time retrieval (s)");
ylabel("Memory accuracy difference(%)");
legend("INHP", "EXHT", "p < 0.05");
