% Bullon Tarraso et al. 2025
% Code for figure 1e

ft_defaults
dirRoot = 'data';

subjects    = {'P02' 'P05' 'P07' 'P09' 'P11' 'P12' 'P13' 'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

load('auxiliary/Bham-64CH-Neighbours.mat')
load('auxiliary\Bham-64CH-Lay.mat')


%% Prepare classification 
% classification

% For time axis
TOI_trial                 = [-.15 1.5];
TOI_loc                   = [-.15 1.5];

step = 0.01;
classi_timeaxis_train     = TOI_loc(1):0.01:TOI_loc(2);
classi_timeaxis_test      = TOI_trial(1):0.01:TOI_trial(2);


acc = cell(1,numel(subjects));

%--- prepare searchlight neighbours
cfg                 = [];
cfg.method          = 'distance';
cfg.layout          = lay;
cfg.neighbourdist   = .15;
cfg.channel         = {'eeg','-M*', '-E*','-Resp'};
SL_neighbours       = ft_prepare_neighbours(cfg);

%% Run classification
for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));
    
    % Load continuous data
    data_all_obj = dataloader.load_cont_data(dirRoot, dirSubject, 1, [], 0, 0, 'nan');
    data_all_sce = dataloader.load_cont_data(dirRoot, dirSubject, 2, [], 0, 0, 'nan');    
    
    %% Get  epochs
    % in case we go for cue locked
    data_all_obj = util.splitTrials(data_all_obj, 'S 41', 3, 0, dirSubject, 1, 0);
    data_all_sce = util.splitTrials(data_all_sce, 'S 41', 3, 0, dirSubject, 2, 0);

    % in case we go for cue locked
    data_loc_obj = dataloader.load_loc_data(dirRoot, dirSubject, 1, [], 0);
    data_loc_sce = dataloader.load_loc_data(dirRoot, dirSubject, 2, [], 0);
    
        
    % add conditions
    for i = 1:size (data_all_obj.trialinfo,1)
            data_all_obj.trialinfo{i, 1}.Stimtype = 1;
    end

    for i = 1:size (data_all_sce.trialinfo,1)
            data_all_sce.trialinfo{i, 1}.Stimtype = 2;
    end

    %% Some preprocessing
    
    % append object and scene data
    trialinfo_all          = [data_all_obj.trialinfo; data_all_sce.trialinfo];
    cfg                    = [];
    data_all               = ft_appenddata(cfg, data_all_obj, data_all_sce);
    data_all.fsample       = 200;
    data_all.trialinfo     = trialinfo_all; 

    trialinfo_all_l        = [data_loc_obj.trialinfo; data_loc_sce.trialinfo];
    cfg                    = [];
    data_all_loc           = ft_appenddata(cfg, data_loc_obj, data_loc_sce);
    data_all_loc.fsample   = 200;
    data_all_loc.trialinfo = trialinfo_all_l;      

    % Remove M1 and M2 from retrieval (localizer doesnt have this channels)
    cfg = [];
    cfg.channel = {'all', '-M1', '-M2'};
    data_all = ft_selectdata(cfg, data_all);
    data_all_loc = ft_selectdata(cfg, data_all_loc);
    
    % Remove respiration and save it for later
    cfg = [];
    cfg.channel = {'resp'};
    data_resp = ft_selectdata(cfg, data_all);
    
    cfg = [];
    cfg.channel = {'all', '-resp'};
    data_all = ft_selectdata(cfg, data_all);
    data_all_loc = ft_selectdata(cfg, data_all_loc);
    
    % Reref data to avg
    cfg            = [];
    cfg.reref      = 'yes';
    cfg.refmethod  = 'avg';
    cfg.refchannel = {'all'};
    cfg.removemean = 'no';
    data_all       = ft_preprocessing(cfg, data_all);
    data_all_loc   = ft_preprocessing(cfg, data_all_loc);

    cfg            = [];
    cfg.keeptrials = 'yes';
    cfg.removemean = 'yes';        
    data_all       = ft_timelockanalysis(cfg, data_all);
    data_all_loc   = ft_timelockanalysis(cfg, data_all_loc);

    % Reduce to TOI
    cfg = [];
    cfg.latency = TOI_trial;
    data_all = ft_preprocessing(cfg, data_all);
    cfg.latency = TOI_loc;
    data_all_loc = ft_preprocessing(cfg, data_all_loc);
    
    % Resample data to dvals fsample
    cfg = [];
    cfg.resamplefs = 100;
    cfg.method = 'downsample';
    data_all = ft_resampledata(cfg, data_all);
    data_all_loc = ft_resampledata(cfg, data_all_loc);
    data_resp = ft_resampledata(cfg, data_resp);
    
    % z-score
    dat_all     = data_all.trial;
    dat_all_loc = data_all_loc.trial;
    
    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, dat_all_loc]  = mv_preprocess_zscore(preprocess_param, dat_all_loc);        
    data_all_loc.trial = dat_all_loc;
    
    [~, dat_all ]      = mv_preprocess_zscore(preprocess_param, dat_all);        
    data_all.trial     = dat_all;

    cfg          = [];
    data_all_all = ft_appenddata(cfg,data_all, data_all_loc);

    % run pca on appended data and reduce ranks
    cfg = [];
    cfg.step = "calculate";   
    cfg.centered = 1;
    comp = math.my_PCA(cfg, data_all_loc);
    
    cfg = [];
    cfg.step = "transform";   
    cfg.centered = 1;
    cfg.eigVects = comp.eigVects;
    cfg.chosen = comp.eigValsCum < 95; % set desired variance to keep
    cfg.updateSens = 0;
    
    comp_loc = math.my_PCA(cfg, data_all_loc);
    comp_ret = math.my_PCA(cfg, data_all);
    
    % identify retrieval data
    cfg              = [];
    cfg.trials       = 1:numel(data_all.trialinfo);
    data_all_ret     = ft_selectdata(cfg,data_all_all);

    % identify localizer part
    cfg              = [];
    cfg.trials       = numel(data_all.trialinfo)+1:numel(data_all_all.trialinfo);
    data_all_loc    = ft_selectdata(cfg,data_all_all);
    
    data_all_ret.trial = comp_ret.trial;
    data_all_loc.trial = comp_loc.trial;
    
    if phaseAxis
        % Switch axis to phase
        data_all_ret = util.time2PhaseAxis(data_all_ret, data_resp, lockingRet, 'makima');
    end
    
    % temporal smoothing (running average of 200 ms)
    for itrial = 1:numel(data_all_ret.trial)
        data_all_ret.trial{itrial} = smoothdata(data_all_ret.trial{itrial},2,'movmean',0.2/(1/data_all_ret.fsample));
    end

    for itrial = 1:numel(data_all_loc.trial)
       data_all_loc.trial{itrial} = smoothdata(data_all_loc.trial{itrial},2,'movmean',0.2/(1/data_all_loc.fsample));
    end
     
    
    %% Bring data into better format
    % First localizaer data
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no'; 
    tmp   = ft_timelockanalysis(cfg,data_all_loc);
    dat_loc         = tmp.trial;

    % Prepare the mask for both classes
    trlinfo  = cell2mat(data_all_loc.trialinfo);   
    category = [trlinfo.Stimtype];           

    train_sel1 = ismember(category,1);
    train_sel2 = ismember(category,2);   

    % Use the mask to get data and Training Labels      
    train_dat        = cat(1,dat_loc(train_sel1,:,:),dat_loc(train_sel2,:,:));
    classcode_train  = cat(1,1*ones(sum(train_sel1),1),2*ones(sum(train_sel2),1));        

    % Now the same for retrieval 
    % bring data into better format
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no';
    tmp             = ft_timelockanalysis(cfg,data_all_ret);
    dat_ret         = tmp.trial;

    % class selection (2 classes)
    trlinfo  = cell2mat(data_all_ret.trialinfo);   
    category = [trlinfo.Stimtype];           
    accuracy = [trlinfo.ExemplarAccuracy];   % 1 [correct], 0 [wrong] 
    oldnew   = [trlinfo.OldNew];
    oldnew_acc = [trlinfo.Answer1Accuracy];
    
    % Save accuracy to correlate with reactivation (not used)
    accs_behav(isubject, 1) = mean(accuracy(accuracy<=1))*100;
    accs_behav(isubject, 2) = sum(accuracy == 1)/sum((oldnew == 1).*(oldnew_acc == 1))*100;
    
    % REMEMBERED TRIALS
    % Get masks for both categories
    test_sel1      = ismember(category,1)& accuracy==1 & oldnew == 1;
    test_sel2      = ismember(category,2)& accuracy==1 & oldnew == 1;

    % Test Labels    
    test_dat       = cat(1,dat_ret(test_sel1,:,:),dat_ret(test_sel2,:,:));
    classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));

    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, test_dat ]  = mv_preprocess_zscore(preprocess_param, test_dat);

    test_train_dat         = train_dat;%(:,:,timeaxis_samples_train);
    test_test_dat          = test_dat;%(:,:,timeaxis_samples_test);
    
    % Run classification with LDA as model and AUC as metric
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'auc';

    % In case we go for searchlight (topo insert)
   perfmat = zeros(numel(data_all_ret.label), length(classi_timeaxis_train), length(classi_timeaxis_test));
   parfor ichannel = 1:numel(data_all_ret.label)
       % Get neighbours of every channel
        these_channels  = cat(1,SL_neighbours(ichannel).label,SL_neighbours(ichannel).neighblabel(:));

        chansel         = ismember(data_all_ret.label,these_channels); % Get only the neighbours       
        % chansel = not(chansel); % This is from another searchlight approach (not used)

        % Classify only using the information of those neighbours
        [perf, result]  = mv_classify_timextime(cfg,test_train_dat(:,chansel,:),classcode_train,test_test_dat(:,chansel,:), classcode_test);
        perfmat(ichannel,:,:) = perf;

   end

    acc{isubject}   = perfmat;
    
    % NON REMEMBERED TRIALS (the same as for remembered ones)
    test_sel1      = ismember(category,1)& accuracy==0 & oldnew == 1;
    test_sel2      = ismember(category,2)& accuracy==0 & oldnew == 1;

    % Test Labels    
    test_dat       = cat(1,dat_ret(test_sel1,:,:),dat_ret(test_sel2,:,:));
    classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));

    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, test_dat ]  = mv_preprocess_zscore(preprocess_param, test_dat);

    test_train_dat         = train_dat;%(:,:,timeaxis_samples_train);
    test_test_dat          = test_dat;%(:,:,timeaxis_samples_test);
    
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'auc';
    
   perfmat = zeros(numel(data_all_ret.label), length(classi_timeaxis_train), length(classi_timeaxis_test));
   parfor ichannel = 1:numel(data_all_ret.label)
        these_channels  = cat(1,SL_neighbours(ichannel).label,SL_neighbours(ichannel).neighblabel(:));

        chansel         = ismember(data_all_ret.label,these_channels);
        % chansel = not(chansel);

        [perf, result]  = mv_classify_timextime(cfg,test_train_dat(:,chansel,:),classcode_train,test_test_dat(:,chansel,:), classcode_test);
        perfmat(ichannel,:,:) = perf;
   end
    
   acc_nrem{isubject}   = perfmat;
   
end

%% Prepare for ft stats

accmat = [];     
for isubject=1:numel(acc)
    accmat(isubject,:,:,:) = acc{isubject};           
end
accmat_nrem = [];    
for isubject=1:numel(acc)
    accmat_nrem(isubject,:,:,:) = acc_nrem{isubject};           
end
        
% If not, we compare rem vs nrem (Figure 1e)
cstat = {};

cstat{1}.label               = {'Channels'};
cstat{1}.time                = classi_timeaxis_test;
cstat{1}.freq                = classi_timeaxis_train;

cstat{1}.individual(:,1,:,:) = accmat;
cstat{1}.dimord              = 'subj_chan_freq_time';

cstat{2}.label               = {'Channels'};
cstat{2}.time                = classi_timeaxis_test;
cstat{2}.freq                = classi_timeaxis_train;

cstat{2}.individual(:,1,:,:) = accmat_nrem;
cstat{2}.dimord              = 'subj_chan_freq_time';

%% FT stats

cfg                     = [];
cfg.spmversion          = 'spm12';
cfg.minnbchan           = 0;
cfg.channel             = 'all';
% cfg.channel = signChanns;
cfg.statistic           = 'depsamplesT';
cfg.method              = 'montecarlo'; 
cfg.correctm            = 'cluster'; 
cfg.alpha               = .05;
cfg.clusteralpha        = .05;
cfg.tail                = 1; 
cfg.computecritval      = 'yes';

cfg.numrandomization    = 1000;

cfg.clusterstatistic    = 'maxsum'; 
cfg.clustertail         = cfg.tail;
cfg.parameter           = 'individual';

nSub = size(accmat,1);
% set up design matrix
design = zeros(2,2*nSub);
for i = 1:nSub
    design(1,i) = i; 
end
for i = 1:nSub
    design(1,nSub+i) = i;
end
design(2,1:nSub)        = 1;
design(2,nSub+1:2*nSub) = 2;

cfg.design  = design;
cfg.uvar    = 1;
cfg.ivar    = 2;

% run stats
[Fieldtripstats] = ft_freqstatistics(cfg, cstat{:});
length(find(Fieldtripstats.mask))

%% Plot figure 1e or 3a
figure;
imagesc(classi_timeaxis_test,classi_timeaxis_train,squeeze(Fieldtripstats.stat));

colorbar
axis xy
xlabel('Time retrieval (s)');
ylabel('Time localizer (s)')
set(gca,'tickdir','out')
caxis([-2.5 2.5])


h = colorbar();
ylabel(h, 't-values');
set(gca,'TickDir','in');
set(gca,'linewidth',1.5)

xline(0,'--', 'linewidth',2, 'color','k');
yline(0,'--', 'linewidth',2, 'color','k');
set(gcf, 'Position', [100, 400, 570, 300]);

% plot contours for significant clusters
stats_time =cstat{1}.time;
stats_freq = cstat{1}.freq;

sigmap = double(squeeze(Fieldtripstats.mask));

plot_contour(classi_timeaxis_test,classi_timeaxis_train,sigmap,'k-',1)
set(findall(gcf,'-property','FontSize'),'FontSize',20)

