clear all;
ft_defaults
dirRoot = 'D:\BreathRet\data';
dirSave = "D:\BreathRet\saves";

subjects    = {'P02' 'P05' 'P07' 'P09' 'P11' 'P12' 'P13' 'P14' 'P15' 'P16' 'P19' 'P21' 'P22' 'P25' 'P26' 'P29' 'P30' 'P31'};

load('D:\BreathRet\auxiliary/Bham-64CH-Neighbours.mat')
load('D:\BreathRet\auxiliary\Bham-64CH-Lay.mat')

lockingRet = "exhaleTroughs";
removeArtifacts = false;

%% Prepare classification
% classification
% time of interest
if strcmp(lockingRet, "cue")
    TOI_trial                 = [-.2 1.5];
else
    TOI_trial                 = [-2 2];
    step = 0.01;
end
TOI_loc                   = [-.2 1.5];

classi_timeaxis_train     = TOI_loc(1):0.01:TOI_loc(2);
classi_timeaxis_test      = TOI_trial(1):0.01:TOI_trial(2);

acc = cell(1,numel(subjects));

%--- prepare searchlight neighbours
cfg                 = [];
cfg.method          = 'distance';
cfg.layout          = lay;
cfg.neighbourdist   = .15;
cfg.channel         = {'eeg','-M*', '-E*','-Resp'};
SL_neighbours       = ft_prepare_neighbours(cfg);

%%
window_radii = 5;
if strcmp(lockingRet, "inhalePeaks")  
    timeRange = [-1, 2]; 
else
    timeRange = [0, 3]; 
end

for isubject=1:numel(subjects)
    
    dirSubject = dir(fullfile(dirRoot, strcat(subjects(isubject), "*")));
    %% Load data
    
    % Load continuous retrieval data
    data_all_obj = dataloader.load_cont_data(dirRoot, dirSubject, 1, [], 0, 0, 'nan');
    data_all_sce = dataloader.load_cont_data(dirRoot, dirSubject, 2, [], 0, 0, 'nan');
    
    if ~strcmp(lockingRet, "cue")      
        [data_all_obj, timings] = util.splitTrials_breathFeat(data_all_obj, lockingRet, timeRange, window_radii, 1, 0, dirSubject, 1);
        [data_all_sce, timings] = util.splitTrials_breathFeat(data_all_sce, lockingRet, timeRange, window_radii, 1, 0, dirSubject, 2);   
    else   
        data_all_obj = util.splitTrials(data_all_obj, 'S 41', 3, 0, dirSubject, 1, 0);
        data_all_sce = util.splitTrials(data_all_sce, 'S 41', 3, 0, dirSubject, 2, 0);
    end
    
    data_loc_obj = dataloader.load_loc_data(dirRoot, dirSubject, 1, [], 0);
    data_loc_sce = dataloader.load_loc_data(dirRoot, dirSubject, 2, [], 0);
    
    % add conditions
    for i = 1:size (data_all_obj.trialinfo,1)
            data_all_obj.trialinfo{i, 1}.Stimtype = 1;
    end

    for i = 1:size (data_all_sce.trialinfo,1)
            data_all_sce.trialinfo{i, 1}.Stimtype = 2;
    end

    %% Some preprocessing
    
    % append object and scene data
    trialinfo_all          = [data_all_obj.trialinfo; data_all_sce.trialinfo];
    cfg                    = [];
    data_all               = ft_appenddata(cfg, data_all_obj, data_all_sce);
    data_all.fsample       = 200;
    data_all.trialinfo     = trialinfo_all; 

    trialinfo_all_l        = [data_loc_obj.trialinfo; data_loc_sce.trialinfo];
    cfg                    = [];
    data_all_loc           = ft_appenddata(cfg, data_loc_obj, data_loc_sce);
    data_all_loc.fsample   = 200;
    data_all_loc.trialinfo = trialinfo_all_l;      

    % Remove M1 and M2 from retrieval (localizer doesnt have this channels)
    cfg = [];
    cfg.channel = {'all', '-M1', '-M2'};
    data_all = ft_selectdata(cfg, data_all);
    data_all_loc = ft_selectdata(cfg, data_all_loc);
    
    % Remove respiration and save it for later
    cfg = [];
    cfg.channel = {'resp'};
    data_resp = ft_selectdata(cfg, data_all);
    
    cfg = [];
    cfg.channel = {'all', '-resp'};
    data_all = ft_selectdata(cfg, data_all);
    data_all_loc = ft_selectdata(cfg, data_all_loc);
    
    % Reref data to avg
    cfg            = [];
    cfg.reref      = 'yes';
    cfg.refmethod  = 'avg';
    cfg.refchannel = {'all'};
    cfg.removemean = 'no';
    data_all       = ft_preprocessing(cfg, data_all);
    data_all_loc   = ft_preprocessing(cfg, data_all_loc);

    cfg            = [];
    cfg.keeptrials = 'yes';
    cfg.removemean = 'no';        
    data_all       = ft_timelockanalysis(cfg, data_all);
    data_all_loc   = ft_timelockanalysis(cfg, data_all_loc);

    % Reduce to TOI
    cfg = [];
    cfg.latency = TOI_trial;
    data_all = ft_preprocessing(cfg, data_all);
    cfg.latency = TOI_loc;
    data_all_loc = ft_preprocessing(cfg, data_all_loc);
    
    % Resample data to dvals fsample
    cfg = [];
    cfg.resamplefs = 100;
    cfg.method = 'downsample';
    data_all = ft_resampledata(cfg, data_all);
    data_all_loc = ft_resampledata(cfg, data_all_loc);

    % z-score
    dat_all     = data_all.trial;
    dat_all_loc = data_all_loc.trial;
    
    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, dat_all_loc ]  = mv_preprocess_zscore(preprocess_param, dat_all_loc);        
    data_all_loc.trial = dat_all_loc;
    
    [~, dat_all ]      = mv_preprocess_zscore(preprocess_param, dat_all);        
    data_all.trial     = dat_all;

    cfg          = [];
    data_all_all = ft_appenddata(cfg,data_all, data_all_loc);

    cfg = [];
    cfg.step = "calculate";   
    cfg.centered = 1;
    comp = math.my_PCA(cfg, data_all_loc);
    
    cfg = [];
    cfg.step = "transform";   
    cfg.centered = 1;
    cfg.eigVects = comp.eigVects;
    cfg.chosen = comp.eigValsCum < 95; % set desired variance to keep
    cfg.updateSens = 0;
    
    comp_loc = math.my_PCA(cfg, data_all_loc);
    comp_ret = math.my_PCA(cfg, data_all);
    
    % run pca on appended data and reduce ranks
%     cfg              = [];
%     cfg.method       = 'pca';
%     cfg.updatesens   = 'no';
%     cfg.numcomponent = 30;
%     cfg.demean       = 'no';       
%     comp             = ft_componentanalysis(cfg, data_all_all);
%    
%     cfg              = [];
%     cfg.updatesens   = 'no';
%     cfg.component    = comp.label(31:end);
%     data_all_all     = ft_rejectcomponent(cfg, comp);
%     data_all_all.fsample = round(data_all_all.fsample);
    
    % identify retrieval data
    cfg              = [];
    cfg.trials       = 1:numel(data_all.trialinfo);
    data_all_ret     = ft_selectdata(cfg,data_all_all);

    % identify localizer part
    cfg              = [];
    cfg.trials       = numel(data_all.trialinfo)+1:numel(data_all_all.trialinfo);
    data_all_locs    = ft_selectdata(cfg,data_all_all);
    
    data_all_ret.trial = comp_ret.trial;
    data_all_loc.trial = comp_loc.trial;  
    
    % temporal smoothing (running average)
    for itrial = 1:numel(data_all_ret.trial)
        data_all_ret.trial{itrial} = smoothdata(data_all_ret.trial{itrial},2,'movmean',0.2/(1/data_all_locs.fsample));
    end

    for itrial = 1:numel(data_all_locs.trial)
        data_all_locs.trial{itrial} = smoothdata(data_all_locs.trial{itrial},2,'movmean',0.2/(1/data_all_locs.fsample));
    end
     
    
    %% Bring data into better format
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no'; 
    tmp   = ft_timelockanalysis(cfg,data_all_locs);
    dat_loc         = tmp.trial;

    % class selection (2 classes)
    trlinfo  = cell2mat(data_all_locs.trialinfo);   
    category = [trlinfo.Stimtype];           

    train_sel1 = ismember(category,1);
    train_sel2 = ismember(category,2);   

    % Training Labels    
    train_dat        = cat(1,dat_loc(train_sel1,:,:),dat_loc(train_sel2,:,:));
    classcode_train  = cat(1,1*ones(sum(train_sel1),1),2*ones(sum(train_sel2),1));

    % now the same for retrieval 
    % bring data into better format
    cfg             = [];
    cfg.keeptrials  = 'yes';
    cfg.removemean  = 'no';
    tmp             = ft_timelockanalysis(cfg,data_all_ret);
    dat_ret         = tmp.trial;

    % class selection (2 classes)
    trlinfo  = cell2mat(data_all_ret.trialinfo);   
    category = [trlinfo.Stimtype];           
    accuracy = [trlinfo.ExemplarAccuracy];   % 1 [correct], 0 [wrong] 
    oldnew   = [trlinfo.OldNew];

    % REMEMBERED TRIALS
    test_sel1      = ismember(category,1)& accuracy==1 & oldnew == 1;
    test_sel2      = ismember(category,2)& accuracy==1 & oldnew == 1;

    % Test Labels    
    test_dat       = cat(1,dat_ret(test_sel1,:,:),dat_ret(test_sel2,:,:));
    classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));

    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, test_dat ]  = mv_preprocess_zscore(preprocess_param, test_dat);
    
    % time-info  
    timeaxis_samples_train = nearest(data_all_locs.time{1},TOI_loc(1)):round(0.01*data_all_locs.fsample):nearest(data_all_locs.time{1},TOI_loc(2));
    timeaxis_samples_test  = nearest(data_all_ret.time{1},TOI_trial(1)):round(0.01*data_all_locs.fsample):nearest(data_all_ret.time{1},TOI_trial(2));

    test_train_dat         = train_dat(:,:,timeaxis_samples_train);
    test_test_dat          = test_dat(:,:,timeaxis_samples_test);
    
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'none';
    cfg.output_type = 'dval';

    [perf, result]  = mv_classify_timextime(cfg, test_train_dat, classcode_train, test_test_dat, classcode_test);    
    acc{isubject}      = perf;
    
    % procede with dvals
    acc_dvals_rem(isubject).acc = cell2mat(acc{1,isubject});
    % acc_dvals_rem(isubject).acc = zscore(acc_dvals_rem(isubject).acc,[], 'all');
    acc_dvals_rem(isubject).acc = permute(acc_dvals_rem(isubject).acc,[1 3 2]); % change to trials x time localizer x time retrieval

    % break up into conditions
    obj = zeros(size(find(classcode_test==1)',2),1)';
    sce = zeros(size(find(classcode_test==2)',2),1)';

    obj = find(classcode_test==1)';
    sce = find(classcode_test==2)';

    acc_dvals_rem(isubject).dvals_object  = acc_dvals_rem(isubject).acc(obj,:,:);
    acc_dvals_rem(isubject).dvals_scene   = acc_dvals_rem(isubject).acc(sce,:,:).*-1;

    acc_dvals_rem(isubject).dvals_object_avg = squeeze(mean(acc_dvals_rem(isubject).dvals_object(:,36:end,:),2)); % average across sign. localizer time
    acc_dvals_rem(isubject).dvals_scene_avg  = squeeze(mean(acc_dvals_rem(isubject).dvals_scene(:,36:end,:),2));  % average across sign. localizer time

    acc_dvals_rem(isubject).all_dvals = [acc_dvals_rem(isubject).dvals_object_avg;acc_dvals_rem(isubject).dvals_scene_avg]; % append conditions
    
    % NON REMEMBERED TRIALS
    test_sel1      = ismember(category,1)& accuracy==0 & oldnew == 1;
    test_sel2      = ismember(category,2)& accuracy==0 & oldnew == 1;

    % Test Labels    
    test_dat       = cat(1,dat_ret(test_sel1,:,:),dat_ret(test_sel2,:,:));
    classcode_test = cat(1,1*ones(sum(test_sel1),1),2*ones(sum(test_sel2),1));

    preprocess_param   = mv_get_preprocess_param('zscore');
    [pparam, test_dat ]  = mv_preprocess_zscore(preprocess_param, test_dat);
    
    % time-info  
    timeaxis_samples_train = nearest(data_all_locs.time{1},TOI_loc(1)):round(0.01*data_all_locs.fsample):nearest(data_all_locs.time{1},TOI_loc(2));
    timeaxis_samples_test  = nearest(data_all_ret.time{1},TOI_trial(1)):round(0.01*data_all_locs.fsample):nearest(data_all_ret.time{1},TOI_trial(2));

    test_train_dat         = train_dat(:,:,timeaxis_samples_train);
    test_test_dat          = test_dat(:,:,timeaxis_samples_test);
    
    cfg             = [];
    cfg.classifier  = 'lda';
    cfg.metric      = 'none';
    cfg.output_type = 'dval';

    [perf, result]  = mv_classify_timextime(cfg, test_train_dat, classcode_train, test_test_dat, classcode_test);   
    acc{isubject}      = perf;
    
   % procede with dvals
    acc_dvals_nrem(isubject).acc = cell2mat(acc{1,isubject});
    % acc_dvals_nrem(isubject).acc = zscore(acc_dvals_nrem(isubject).acc,[], 'all');
    acc_dvals_nrem(isubject).acc = permute(acc_dvals_nrem(isubject).acc,[1 3 2]); % change to trials x time localizer x time retrieval

    % break up into conditions
    obj = zeros(size(find(classcode_test==1)',2),1)';
    sce = zeros(size(find(classcode_test==2)',2),1)';

    obj = find(classcode_test==1)';
    sce = find(classcode_test==2)';

    acc_dvals_nrem(isubject).dvals_object  = acc_dvals_nrem(isubject).acc(obj,:,:);
    acc_dvals_nrem(isubject).dvals_scene   = acc_dvals_nrem(isubject).acc(sce,:,:).*-1;

    acc_dvals_nrem(isubject).dvals_object_avg = squeeze(mean(acc_dvals_nrem(isubject).dvals_object(:,36:end,:),2)); % average across sign. localizer time
    acc_dvals_nrem(isubject).dvals_scene_avg  = squeeze(mean(acc_dvals_nrem(isubject).dvals_scene(:,36:end,:),2));  % average across sign. localizer time

    acc_dvals_nrem(isubject).all_dvals = [acc_dvals_nrem(isubject).dvals_object_avg;acc_dvals_nrem(isubject).dvals_scene_avg]; % append conditions
    
end

save(fullfile(dirSave, strcat("classi_dvals_rem_", lockingRet)), "acc_dvals_rem");
save(fullfile(dirSave, strcat("classi_dvals_nrem_", lockingRet)), "acc_dvals_nrem");