% Bullon Tarraso et al. 2025
% Code for figure 2b

dirRoot = 'data';
ft_defaults;

% Removing P01 and P18 from the list due to low hit number
subjects    = {'P02'  'P05'  'P07'  'P09'  'P11'  'P12' 'P13'  'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

TOI_trial = [-0.5, 3];
remove_artifacts = false;
use_hilbert = false;
%% Run

for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));

    data_phase_obj = dataloader.load_segm_phase(dirRoot, dirSubject, 1, use_hilbert, remove_artifacts);
    data_phase_sce = dataloader.load_segm_phase(dirRoot, dirSubject, 2, use_hilbert, remove_artifacts);

    % append object and scene data
    trialinfo_all          = [data_phase_obj.trialinfo; data_phase_sce.trialinfo];
    cfg                    = [];
    data_all               = ft_appenddata(cfg, data_phase_obj, data_phase_sce);
    data_all.fsample       = 200;
    data_all.trialinfo     = trialinfo_all;
    
    % Reduce to TOI
    cfg = [];
    cfg.latency = TOI_trial;
    data_all = ft_preprocessing(cfg, data_all);
    
    % Prepare trial info
    trlinfo  = cell2mat(data_all.trialinfo);   
        
    accuracy = [trlinfo.ExemplarAccuracy];   
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy];

    % Get inhp/exht retrieval mask
    % First exht (derivative bigger than -pi at some point)
    exp = cellfun(@(x) any(diff(x) <= -pi), data_all.trial);
 
    % Extract accuracy results
    image_acc_exht(isubject) = sum((accuracy == 1).*exp)/sum((oldnew == 1).*(oldnew_acc == 1).*exp)*100;
    image_acc_not_exht(isubject) = sum((accuracy == 1).*not(exp))/sum((oldnew == 1).*(oldnew_acc == 1).*not(exp))*100;
    
    % Now for inhp (phase around 0 at some point)
    ins = cellfun(@(x) any(abs(x) <= 0.05), data_all.trial);
 
    % Extract accuracy results
    image_acc_inhp(isubject) = sum((accuracy == 1).*ins)/sum((oldnew == 1).*(oldnew_acc == 1).*ins)*100;
    image_acc_not_inhp(isubject) = sum((accuracy == 1).*not(ins))/sum((oldnew == 1).*(oldnew_acc == 1).*not(ins))*100;
    
    %% Do the inhp>exht and exht>inhp analysis
    % First get index (tp = timepoint) of the exht and inhp
    exp_tp = cellfun(@(x) find(diff(x) <= -pi), data_all.trial, 'UniformOutput', false);
    ins_tp = cellfun(@(x) find(abs(x)==min(abs(x))), data_all.trial, 'UniformOutput', false);
    
    % Remove trials without event
    exp_tp(not(exp)) = {nan};
    ins_tp(not(ins)) = {nan};
    
    % Remove trials with two events
    n_exp = cellfun(@(x) length(x), exp_tp);
    n_ins = cellfun(@(x) length(x), ins_tp);
    
    exp_tp(n_exp > 1) = {nan};
    ins_tp(n_ins > 1) = {nan};
    
    % Convert to array
    exp_tp = cell2mat(exp_tp);
    ins_tp = cell2mat(ins_tp);
    
    % Substract to find the categories
    diff_tp = exp_tp - ins_tp;
    ins2exp = diff_tp > 0;
    exp2ins = diff_tp < 0;
    
    % Compute accs
    image_acc_inhp2exht(isubject) = sum((accuracy == 1).*ins2exp)/sum((oldnew == 1).*(oldnew_acc == 1).*ins2exp)*100;
    image_acc_not_inhp2exht(isubject) = sum((accuracy == 1).*not(ins2exp))/sum((oldnew == 1).*(oldnew_acc == 1).*not(ins2exp))*100;
    image_acc_exht2inhp(isubject) = sum((accuracy == 1).*exp2ins)/sum((oldnew == 1).*(oldnew_acc == 1).*exp2ins)*100;
    image_acc_not_exht2inhp(isubject) = sum((accuracy == 1).*not(exp2ins))/sum((oldnew == 1).*(oldnew_acc == 1).*not(exp2ins))*100;
end

%% Plot and ttest (Fig2b, INHP2EXHT and EXHT2INHP sequences)
red   = [255, 17, 0]/255;
blue  = [0 0.4470 0.7410];
purple = red/2 + blue;

figure;
plt = daviolinplot([image_acc_inhp2exht', image_acc_exht2inhp'], 'proximity', 0.7, 'colors', purple);
ylabel("Accuracy (%)");
xticklabels(["INHP -> EXHT", "EXHT -> INHP"]);
[h, p, ci,stats] = ttest(image_acc_inhp2exht, image_acc_exht2inhp)
