#!/usr/bin/env python3
# coding: utf-8

# Author: Tanja Krueger
# Aim: this scrip contains functions for splitting data and visualizing the predictor performance

########################################################################################################################
# Import libraries needed.
import os
from datetime import datetime
import matplotlib.pyplot as plt
from Bio import SeqIO
import pandas as pd
from sklearn.model_selection import train_test_split
import sklearn
import numpy as np
import pandas as pd
from numpy import mean, std
from numpy.random import choice, seed
import matplotlib.pyplot as plt
import seaborn as sns
import statistics as stat
import re, argparse, csv, collections,random
from datetime import datetime
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import roc_curve, make_scorer, matthews_corrcoef, roc_auc_score, precision_recall_curve, classification_report, f1_score, confusion_matrix
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from typing import Dict
from sklearn import metrics
####################################################################################################################################
# Function 1: adds labels for splitting the data. 

def labeler (data,training_ratio,val_ratio):
    """ this function updates the binning label of a dataset based on the training to validation ratio"""
   ######################################################################################################################
    # assert the input
    try:
        len(data) >= 2
    except:
        print("the input file needs at least two entries otherwise it cant be split")
   #####################################################################################################################
    n_whole_bins=len(data)//(training_ratio+val_ratio)
    # List the labels.
    whole_bins_part=[[i]*(training_ratio+val_ratio) for i in np.arange(n_whole_bins)]
    # Flatten the list.
    flat_list = [item for sublist in whole_bins_part for item in sublist]
    # Get the length of the last bin.
    last_bin_len=len(data)%(training_ratio+val_ratio)
    # The last bin must NOT contain only one entry, because splitting would fail afterwards
    if last_bin_len == 1:
        # If last bin has just one entry, then include the element in the last-but-one bin.
        last_bin_part=[flat_list[-1]]
        print(last_bin_part)
        # if last bin is larger than one, then add one to the last label.
    elif last_bin_len>1:
        last_ele=flat_list[-1]
        last_bin_part=[last_ele+1]*last_bin_len
    else:
        last_bin_part=[]
    # Combine the two parts.
    label_list=flat_list+last_bin_part
    # Add new column to existing dataframe
    data["binning_label"]=label_list
    return data
##################################################################################################################################
# Function 2: splits data into test and training datasets.

def sorted_stratified_train_test_split2(x_data,train_val_ratio,test_ratio):
    """ this function splits a dataset into two sets. The size of the sets are independent parameters.
    An example is train_val_ratio of 9,test_ratio 1)"""
    # The input data needs to have at least two entries otherwise the splitting is not possible
    try:
        len(x_data) >= 2
    except:
        print("the input file needs at least two entries otherwise it cant be split")
    # Label the data.
    X=labeler(x_data,train_val_ratio,test_ratio)
    # Stratified splitting.
    X_train, X_test = sklearn.model_selection.train_test_split(X,
                                                               test_size=(test_ratio/train_val_ratio),
                                                               stratify=X["binning_label"],
                                                               random_state=7)


    return(X_train, X_test)

########################################################################################################################################
# Function 3Visualization helper for the confusion matrix:
def pretty_cnf_matrix(cfm,model_type,sst_level,cl,cv,cl_weight,where):
    """ plots a prettified version of the confusion matrix. Uses a confusion matrix in form of a nested list as input
    with [[TN,FP],[FN,TP]], and the model_type as a string
    sst_level: the level of mmseqs2 reduction of the input data
    cl: if code is run from the command line or the Code folder ../../ for the code folder
    cl_weights: weight for each class that represent the overall ratio of the classes
    where: where to store the results (can be baseline, sklearn, pytorch)"""

    fig, ax =plt.subplots(figsize=(4,4))
    sns.heatmap(pd.DataFrame(cfm),
            annot=True,
            cmap="Blues",
            xticklabels=["control","toxin"],
            yticklabels=["control","toxin"],
            fmt=".3g")# this prevents an overzealous scientific notation
    ax.xaxis.set_label_position("top")
    ax.xaxis.set_ticks_position("top")
    ax.set_ylabel("True label")
    ax.set_xlabel("Predicted label")
    plt.title(f'Confusion matrix for {model_type}')
    plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/confusionmatrix/SST{sst_level}_confusion_matrix_{model_type}_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")


########################################################################################################################################
# Function 4 Visualization helper for Hyperparamter optimization grid search.
def CV_vis(data ,train_or_crossval_score,sst_level,cl,cv,cl_weight,where):
    """
    Aim: Plots the gridsearch results as a heatmap
    Parameters: data: is a dataframe with mean train score, mean test (test is the cross validation)score and parameters in columns.
                train_or_crossval_score: A string that declares the kind of data to be plotted.
                sst: the level of mmseqs2 reduction of the input data
                cl: if code is run from the command line or the Code folder ../../ for the code folder
                where: where to store the result (can be baseline, sklearn, pytorch)"""

    # Switch between the two cases of possible data.
    if train_or_crossval_score=="train":
        col=0
        title_add_on="trainingscore"
    if train_or_crossval_score=="crossval":
        col=1
        title_add_on="crossvalidationscore"
    # Positions of th parameters are in the third and fourth column
    n_param1 = len(set(data.iloc[:, 2]))
    n_param2 = len(set(data.iloc[:, 3]))
    # Plot the scores in form of a heatmap.
    fig, ax = plt.subplots(figsize=(4,4))
    sns.heatmap(data=np.reshape(list(data.iloc[:, col]), [n_param1, n_param2]),
                cmap="Blues",
                annot=True,
                xticklabels=np.unique(data.iloc[:, 3]),# do not use set here, as set doesnt keep order intact
                yticklabels=np.unique(data.iloc[:, 2]),
                cbar=True)  # Check if it works generally
    # Extract model classifier name.
    classifier_name = re.search("(.*)\_", str(data.columns[2])).group(1)
    plt.title(f"{classifier_name} using {title_add_on}")
    plt.xlabel(f"{str(data.columns[3])}")
    plt.ylabel(f"{str(data.columns[2])}")
    plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/gridsearch/SST{sst_level}_gridsearch_{classifier_name}_{title_add_on}_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

def CV_vis_embeddings(data ,train_or_crossval_score,sst_level,cl,cv,cl_weight,where):
    """
    Aim: Plots the gridsearch results as a heatmap
    Parameters: data: is a dataframe with mean train score, mean test (test is the cross validation)score and parameters in columns.
                train_or_crossval_score: A string that declares the kind of data to be plotted.
                sst_level: the level of mmseqs2 reduction of the input data
                cl: if code is run from the command line or the Code folder ../../ for the code folder
                where: where to store the results (can be baseline, sklearn, pytorch)"""

    # Switch between the two cases of possible data.
    if train_or_crossval_score=="train":
        col=0
        title_add_on="trainingscore"
    if train_or_crossval_score=="crossval":
        col=1
        title_add_on="crossvalidationscore"
    # Positions of th parameters are in the third and fourht column
    n_param1 = len(set(data.iloc[:, 2]))
    n_param2 = len(set(data.iloc[:, 3]))
    # Plot the scores in form of a heatmap.
    fig, ax = plt.subplots(figsize=(4,4))
    sns.heatmap(data=np.reshape(list(data.iloc[:, col]), [n_param1, n_param2]),
                cmap="Blues",
                annot=True,
                xticklabels=np.unique(data.iloc[:, 3]),# do not use set here, as set doesnt keep order intact
                yticklabels=np.unique(data.iloc[:, 2]),
                cbar=True)  # Check if it works generally
    # Extract model classifier name.
    classifier_name = re.search("(.*)\_", str(data.columns[2])).group(1)
    plt.title(f"{classifier_name} using {title_add_on}")
    plt.xlabel(f"{str(data.columns[3])}")
    plt.ylabel(f"{str(data.columns[2])}")
    plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/gridsearch/SST{sst_level}_gridsearch_{classifier_name}_{title_add_on}_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

########################################################################################################################################
# Function 5: boostrapping for Standard Error of metrics
def bootstrap_mcc(y_true,y_pred,n_boot):
    # Input are arrays, but Series are needed for the code. therefore convert them into Series.
    y_true = pd.Series(y_true) #Convert the true values into a Series.
    y_pred = pd.Series(y_pred) #Convert the predicted values into a Series.
    # Get an empty list.
    mcc=[]
    # Set random seed.
    seed(7)
    # Loop over the amount of times I want to carry out the bootstrapping..
    for i in range(n_boot):
        bootstrap_sample_index=choice(y_true.index, size=len(y_true), replace=True)
        # Slice the initial Series and retrieve the sampled index from the line above.
        bootstrap_sample_ytrue=y_true[bootstrap_sample_index] # Slice the true values.
        bootstrap_sample_ypred=y_pred[bootstrap_sample_index] # Slice the predicted values.
        # Calculate the matthews correlation coefficient on each of the sample pairs.
        boot_metric=matthews_corrcoef(bootstrap_sample_ytrue,bootstrap_sample_ypred)
        # Add the mcc to the end of a list.
        mcc.append(boot_metric)
    # Calculate the average of mcc.
    average_metric=round(float(mean(mcc)), 4)
    # Calculate 1.96 times the standard deviation to get the 1,96*Standarderror (that what is displayed in the graphs)
    metric_se=round(float(std(mcc, ddof=1)), 4)*1.96
    # Return the results.
    return average_metric ,metric_se


########################################################################################################################################
# Function 5: boostrapping for Standard Error of metrics
def bootstrap_metric(metric,y_true,y_pred,n_boot):
    # Input are arrays, but Series are needed for the code. therefore convert them into Series.
    y_true = pd.Series(y_true) #Convert the true values into a Series.
    y_pred = pd.Series(y_pred) #Convert the predicted values into a Series.
    # Get an empty list.
    metric_accumulated=[]
    # Set random seed.
    seed(7)
    # Loop over the amount of times I want to carry out the bootstrapping..
    for i in range(n_boot):
        bootstrap_sample_index=choice(y_true.index, size=len(y_true), replace=True)
        # Slice the initial Series and retrieve the sampled index from the line above.
        bootstrap_sample_ytrue=y_true[bootstrap_sample_index] # Slice the ture values.
        bootstrap_sample_ypred=y_pred[bootstrap_sample_index] # Slice the predicted values.
        # Calculate the matthews correlation coefficient on each of the sample pairs.
        boot_metric=metric(bootstrap_sample_ytrue,bootstrap_sample_ypred)
        # Add the mcc to the end of a list.
        metric_accumulated.append(boot_metric)
    # Calculate the average of mcc.
    average_metric=round(float(mean(metric_accumulated)), 4)
    print(average_metric)
    # Calculate 1.96 times the standard deviation to get the 1,96*Standarderror (that what is displayed in the graphs)
    metric_se=round(float(std(metric_accumulated, ddof=1)), 4)*1.96
    print(metric_se)
    # Return the results.
    return average_metric ,metric_se



######################################################################################################################
# Function 6: Visualization of feature importance.

def vis_features_imp(coefs,sst_level,cl,cv,model_type,colnames,cl_weight,where):
    """ This function helps visualize the feature importance of a predictor.
    coefs:The coef of each optimized predictor is passed
    cv: the cross validation fold
    cl: if run from the command line or not
    sst_level: level of mmseq2 reduction
    model_type: type of predictor
    cl_weights: weight for each class that represent the overall ratio of the classes
    where: where to store the results (can be baseline, sklearn, pytorch)"""
    plt.style.use("seaborn")
    fig, ax = plt.subplots()
    ax.bar(x=np.arange(len(colnames)), height=coefs,tick_label=colnames)
    plt.title(f"feature importance of {model_type} using SST{sst_level}")
    plt.xlabel(f"features")
    plt.ylabel(f"feature importance")
    plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/feature_importance/SST{sst_level}_featureimportance_{model_type}_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

def vis_features_imp(coefs,sst_level,cl,cv,model_type,colnames,cl_weight,where):
    """ This function helps visualize the twenty most important feature importance of a predictor.
    coefs:The coef of each optimized predictor is passed
    cv: the cross validation fold
    cl: if run from the command line or not
    sst_level: level of mmseqs2 reduction
    model_type: type of predictor
    cl_weights: weight for each class that represent the overall ratio of the classes
    where: where to store the results (can be baseline, sklearn, pytorch)"""
    for pos, coef in enumerate(coefs):
        coefs[pos] = abs(coef)
    tuple_list = [(feat, imp) for feat, imp in zip(colnames, coefs)]
    sorted_tuple_list = sorted(tuple_list, key=lambda x: x[1], reverse=True)

    # Steps get the absoluts of the coefs in case there are negative ones
    # make tuples out of tickl label and coef
    # Then sort the tuples over the
    plt.style.use("seaborn")
    fig, ax = plt.subplots(figsize=(4,3))
    features = [i[0] for i in sorted_tuple_list[:20]]
    print(features)
    importances = [i[1] for i in sorted_tuple_list[:20]]
    print(importances)
    x_pos = np.arange(len(features))
    print(x_pos)
    #ax.bar(x=np.arange(len(colnames)), height=coefs,tick_label=colnames)
    plt.bar(x=x_pos,height= importances,tick_label=features)
    plt.xticks(rotation=90)  # Rotates X-Axis Ticks by 45-degrees
    plt.title(f"feature importance of {model_type} using SST{sst_level}")
    plt.xlabel(f"features")
    plt.ylabel(f"feature importance")
    plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/feature_importance/SST{sst_level}_featureimportnace_{model_type}_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

def vis_features_imp_embs(coefs,sst_level,cl,cv,model_type,colnames,cl_weight,where):
    """ This function helps visualize the twente most important feature importance of a predictor.
    coefs:The coef of each optimized predictor is passed
    cv: the cross validation fold
    cl: if run from the command line or not
    sst_level: level of mmseqs2 reduction
    model_type: type of predictor
    cl_weights: weight for each class that represent the overall ratio of the classes
    where: where to store the results (can be baseline, sklearn, pytorch)"""
    for pos, coef in enumerate(coefs):
        coefs[pos] = abs(coef)
    tuple_list = [(feat, imp) for feat, imp in zip(colnames, coefs)]
    sorted_tuple_list = sorted(tuple_list, key=lambda x: x[1], reverse=True)

    # Steps get the absoluts of the coefs in case there are negative ones
    # make tuples out of tickl label and coef
    # Then sort the tuples over the
    plt.style.use("seaborn")
    fig, ax = plt.subplots(figsize=(4,3))
    features = [i[0] for i in sorted_tuple_list[:20]]
    print(features)
    importances = [i[1] for i in sorted_tuple_list[:20]]
    print(importances)
    x_pos = np.arange(len(features))
    print(x_pos)
    #ax.bar(x=np.arange(len(colnames)), height=coefs,tick_label=colnames)
    plt.bar(x=x_pos,height= importances,tick_label=features)
    plt.xticks(rotation=90)  # Rotates X-Axis Ticks by 45-degrees
    plt.title(f"feature importance of {model_type} using SST{sst_level}")
    plt.xlabel(f"features")
    plt.ylabel(f"feature importance")
    plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/feature_importance/SST{sst_level}_featureimportnace_{model_type}_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")



classifier_name="test"
cv=4
cl="../../"
sst_level=25
model_type="logreg"
colnames=['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q','R', 'S', 'T', 'V', 'W', 'Y']
coefs=[-0.24388894,  0.24811639, -0.44344147,  0.22529812, -0.04722234,  0.02296377,
  0.5050371,   0.17826091,  0.17199711,  0.1349233,  -0.09380333,  0.09620798,
 -0.24017034,  0.08985451,  0.29003978,  0.180154,   -0.3021584,  -0.20250117,
 -0.38279059, -0.26075926]
cl_weight="this is a test anyway"
where="sklearn"
#vis_features_imp(coefs,sst_level,cl,cv,model_type,colnames,cl_weight,where)
plt.plot(1,2)
#plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/learning_curves/SST{sst_level}_learning_curves_{classifier_name}_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")
