% Bullon Tarraso et al. 2025
% Code for figure 1c

ft_defaults
dirRoot = 'data';

subjects    = {'P02' 'P05' 'P07' 'P09' 'P11' 'P12' 'P13' 'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

% Parameters of the script
getnull = 1; % Whether create a randomized null distribution (if 0, then test is against manual chance)
numPerms = 250; % Number of permutations for the random distribution

%% Prepare classification 
% classification
% time of interest
TOI_loc                   = [-.2 3]; % With respect cue

classi_timeaxis_train     = TOI_loc(1):0.01:TOI_loc(2); % Time vector with 100 Hz fsample
acc = cell(1,numel(subjects));

%% Run

for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));  

    % We want cue locked data, so we can go for the already split
    % version
    data_loc_obj = dataloader.load_loc_data(dirRoot, dirSubject, 1, [], 0);
    data_loc_sce = dataloader.load_loc_data(dirRoot, dirSubject, 2, [], 0);
 
    %% Some preprocessing   

    % Append data from both sessions
    trialinfo_all_l        = [data_loc_obj.trialinfo; data_loc_sce.trialinfo];
    cfg                    = [];
    data_all_loc           = ft_appenddata(cfg, data_loc_obj, data_loc_sce);
    data_all_loc.fsample   = 200;
    data_all_loc.trialinfo = trialinfo_all_l;      
    
    % Respiration is not needed here
    cfg = [];
    cfg.channel = {'all', '-resp'};
    data_all_loc = ft_selectdata(cfg, data_all_loc);
    
    % Reref data to avg
    cfg            = [];
    cfg.reref      = 'yes';
    cfg.refmethod  = 'avg';
    cfg.refchannel = {'all'};
    cfg.removemean = 'no';
    data_all_loc   = ft_preprocessing(cfg, data_all_loc);

    cfg            = [];
    cfg.keeptrials = 'yes';
    cfg.removemean = 'no';        
    data_all_loc   = ft_timelockanalysis(cfg, data_all_loc);

    % Reduce to TOI
    cfg = [];
    cfg.latency = TOI_loc;
    data_all_loc = ft_preprocessing(cfg, data_all_loc);
    
    % Resample data to dvals fsample (100 Hz)
    cfg = [];
    cfg.resamplefs = 100;
    cfg.method = 'downsample';
    data_all_loc = ft_resampledata(cfg, data_all_loc);
    
    % z-score
    dat_all_loc = data_all_loc.trial;
    
    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, dat_all_loc ]  = mv_preprocess_zscore(preprocess_param, dat_all_loc);        
    data_all_loc.trial = dat_all_loc;

    data_all_loc = ft_appenddata(cfg, data_all_loc);

    % run pca on appended data and reduce ranks
    cfg = [];
    cfg.step = "calculate";   
    cfg.centered = 1;
    comp = math.my_PCA(cfg, data_all_loc);
    
    cfg = [];
    cfg.step = "transform";   
    cfg.centered = 1;
    cfg.eigVects = comp.eigVects;
    cfg.chosen = comp.eigValsCum < 95; % set desired variance to keep
    cfg.updateSens = 0;
    
    comp_loc = math.my_PCA(cfg, data_all_loc);
    data_all_loc.trial = comp_loc.trial;
  
    % Smooth the data by a 200 ms moving window
    for itrial = 1:numel(data_all_loc.trial)
        data_all_loc.trial{itrial} = smoothdata(data_all_loc.trial{itrial},2,'movmean',0.2/(1/data_all_loc.fsample));
    end   
    
    %% Bring data into better format and perform classification
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no'; 
    tmp   = ft_timelockanalysis(cfg,data_all_loc);
    dat_loc         = tmp.trial;

    % Prepare the mask for both classes
    trlinfo  = cell2mat(data_all_loc.trialinfo);   
    category = [trlinfo.Stimtype];           

    train_sel1 = ismember(category,1);
    train_sel2 = ismember(category,2);   

    % Use the mask to get data and Training Labels    
    train_dat        = cat(1,dat_loc(train_sel1,:,:),dat_loc(train_sel2,:,:));
    classcode_train  = cat(1,1*ones(sum(train_sel1),1),2*ones(sum(train_sel2),1));   

    test_train_dat         = train_dat;%(:,:,timeaxis_samples_train);
    
    % Here we perform the classification. We use five fold classification
    % and repeaat it 5 times. Metric is AUC and model is LDA
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'auc';
    cfg.repeat = 5;
    
    [perfmat, result, testlabels]  = mv_classify_across_time(cfg, test_train_dat, classcode_train);

    shuff_aucs = ones(length(perfmat), numPerms)*0.5; % Chance in case getNull = false
    if getnull
        parfor i=1:numPerms
            classcode_shuff = Shuffle(classcode_train); % Shuffle labels for null distribution
            [perfmat_shuff, result]  = mv_classify_across_time(cfg, test_train_dat, classcode_shuff);
            shuff_aucs(:,i) = perfmat_shuff;
        end
    end

    acc{isubject}   = perfmat;
    acc_shuff{isubject} = shuff_aucs;
    
end

%% Prepare data for stats
accmat = [];     
for isubject=1:numel(acc)
    accmat(isubject,:) = acc{isubject};           
end

accmat_shuff = [];    
for isubject=1:numel(acc)
    accmat_shuff(isubject,:) = squeeze(mean(acc_shuff{isubject},2)); % Mean over repetitions
end

cstat = {};
    
cstat{1}.label               = {'Channels'};
cstat{1}.time                = classi_timeaxis_train;

cstat{1}.individual(:,1,:,:) = accmat;
cstat{1}.dimord              = 'subj_chan_time';

cstat{2}.label               = {'Channels'};
cstat{2}.time                = classi_timeaxis_train;

cstat{2}.individual(:,1,:,:) = accmat_shuff;
cstat{2}.dimord              = 'subj_chan_time';

%% Run stats
cfg                     = [];
cfg.spmversion          = 'spm12';
cfg.statistic           = 'depsamplesT';
cfg.method              = 'montecarlo';
cfg.correctm            = 'cluster';
cfg.alpha               = .05;
cfg.clusteralpha        = .05;
cfg.tail                = 0;
cfg.neighbours          = [];
cfg.minnbchan           = 0;
cfg.computecritval      = 'yes';
cfg.numrandomization    = 1000;

cfg.clusterstatistic    = 'maxsum';
cfg.clustertail         = cfg.tail;
cfg.parameter           = 'individual';

design = zeros(2,2*numel(subjects)); % prepare the stats design 
design(1,:)= [1:numel(subjects), 1:numel(subjects)];
design(2,:)= [ones(1,numel(subjects)), 2*ones(1,numel(subjects))];
cfg.design  = design;
cfg.uvar    = 1;
cfg.ivar    = 2;

% run stats
[stat] = ft_timelockstatistics(cfg, cstat{:}); 
length(find(stat.mask))

%% Plot figure 1c
grey       = [0.6 0.6 0.6];

accmat_SEM = std(accmat)/sqrt(size(accmat,1));
shuff_SEM = std(accmat_shuff)/sqrt(size(accmat_shuff,1));

figure()
[hl,hp] = boundedline(stat.time,mean(accmat), accmat_SEM, 'alpha','transparency', 0.15, 'cmap','k') %'alpha','transparency', 0.5 // 'cmap',[0 0 0], 'alpha','transparency', 0.25
set(hl, 'linewidth', 3);

hold on 

[hl,hp] = boundedline(stat.time,mean(accmat_shuff), shuff_SEM, '--', 'cmap',grey, 'alpha','transparency', 0.15);% 'cmap',[0.85 0 0], 'alpha','transparency', 0.5 // 'cmap',[0 0 0], 'alpha','transparency', 0.25
set(hl, 'linewidth', 3);

sigMask = nan(length(stat.mask),1);
sigMask(stat.mask) = 0.49;
hold on
plot(stat.time, sigMask, '-k', 'LineWidth', 3)

ylabel("Classification (AUC)");
xlabel("Time localizer (s)");
xline(0, 'k--', 'linewidth',2);

xlim([-0.2, 3]);
ylim([0.48, 0.7]);
set(findall(gcf,'-property','FontSize'),'FontSize',16)
