#!/usr/bin/env python3
# coding: utf-8


########################################################################################################################
# Author: Tanja Krüger
# Aim: This file builds a predictor based on embeddings. The APAAC are already split into test and training.
# Input: X_train: APAAC training split
# Input: X_test: APAAC test split
# Input: y_train:
# Input: y_test:
# Output: saved model and visualizations of training and test results and logging of performance on train and test set


########################################################################################################################
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re, argparse, csv, collections,random
from datetime import datetime
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.model_selection import GridSearchCV,train_test_split
from sklearn.metrics import roc_curve, make_scorer, matthews_corrcoef, roc_auc_score, precision_recall_curve, classification_report, f1_score, confusion_matrix, accuracy_score, precision_score, recall_score, roc_auc_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.decomposition import PCA
from support_functions_splitting_predictor import  CV_vis_embeddings, pretty_cnf_matrix, bootstrap_metric, vis_features_imp_embs
import pickle
from sklearn.inspection import permutation_importance
from sklearn.model_selection import learning_curve
import xgboost as xgb
from joblib import dump

# #################################################################################################
# Option depending where the user wants the run the code from, default running the code with make from the project folder.
#cl=""
# If one wants to execute this file from the Code/python folder uncomment the next line.
cl="../../"

########################################################################################################################
# Get the arguments from the command line.
parser = argparse.ArgumentParser(prog="APAAC_predictor.py",
                                 description="using APAAC to build of five sklearn predictors+ visual results")
parser.add_argument("X_train",
                    type=str,
                    help=" X_train data, APAAC")
parser.add_argument("X_test",
                    type=str,
                    help=" X_test data,  APAAC")
parser.add_argument("y_train",
                    type=str,
                    help=" y_training labels")
parser.add_argument("y_test",
                    type=str,
                    help="y_testing labels")


args = parser.parse_args()

# Set parameters for running the code
cv=10 # cross validation
cl_weight='balanced' # class_weight can be balanced or None
where="sklearn" #under which folder the results are stored
sst_level=re.search("SST(\d+)",args.y_train).group(1) #MMseqs2 reduction level
learn_curve=False # set if you want to visualize the learning curve [True, False]
random_seed=7
random_state=7
########################################################################################################################
# Open the predictor logfile and the general logfile.
pred_log_file = open(f"{cl}Data/derived/pseudoAAC/{where}_{sst_level}_{cl_weight}_pred_log.log", "a")
out_file = open(f"{cl}Data/derived/log.log", "a")

# Get the date and time
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")

# Write to the predictor log file and the general logfile
print(f"""################################################################## \n
predictor was executed at {dt_string}""",file=pred_log_file)
print(f"""##########\n
program {parser.prog} was executed at {dt_string}\n
program {parser.prog} was executed at {dt_string}\n
argument passed:{args.X_train}, {args.X_test}, {args.y_train} and {args.y_test}\n",
number of required arguments:4 """,file=out_file)

#######################################################################################################################
########################################################################################################################

# Step 1: Open the data and reset the index for each file using the first unnamed column.
# Step 1.1: Open the y-train data.
y_train=pd.read_csv(args.y_train)
y_train.set_index("Unnamed: 0",drop=True, inplace=True)
# Step 1.2: Open the y-test data.
y_test=pd.read_csv(args.y_test)
y_test.set_index("Unnamed: 0",drop=True, inplace=True)
# Step 1.3: Open the X_train APAAC data
X_train = pd.read_csv(args.X_train, header=None)
# Step 1.4: Open the X_test APAACdata.
X_test = pd.read_csv(args.X_test, header=None)

# Step 1.5: Set an index for the X_training and testing data (needed downstream)
X_train.index = y_train.index
X_test.index = y_test.index


# Step 2: Check if input length match
assert len(y_train.index)==len(X_train.index) , "the training y and X dont have the same length"
assert len(y_test.index)==len(X_test.index) , "the test y and X dont have the same length"
# Step 2.2: Check if the files were provided in the right order.
assert "y_test" in args.y_test, "the wrong file was provided for y_test"
assert "y_train" in args.y_train, "the wrong file was provided for y_train"
assert "X_test" in args.X_test, "the wrong file was provided for X_test"
assert "X_train" in args.X_train, "the wrong file was provided for X_train"

# Step 3: Miscellaneous
# Step 3.1: Reshape the column vectors to np.arrays
y_test_array=np.ravel(y_test)
y_train_array=np.ravel(y_train)
# Step 3.2: Add a metric not standardly used by sklearn.
mcc=make_scorer(matthews_corrcoef)
# Step 3.3: Get empty data frame and list for the gridsearch results.
results_df = pd.DataFrame(columns=["test_roc_auc_mean","test_roc_auc_SE","valiation_mcc","test_mcc_mean","test_mcc_SE","test_fpr","test_tpr", "test_prec","test_rec"])


# Step 3.4:Fixing the random state
random_state=7



# Step 4: Set up the pipeline for each model
pipe=[Pipeline([("scaler",StandardScaler()),("pca", PCA(n_components=20)),("knn20",KNeighborsClassifier())]),
      Pipeline([("scaler",StandardScaler()),("svc_pca", PCA(n_components=20)),("svcPC20",svm.SVC(probability=True,class_weight=cl_weight))]),
      Pipeline([("scaler",StandardScaler()),("pca", PCA(n_components=20)),("logreg20",LogisticRegression(class_weight=cl_weight,penalty="l2"))]),
      Pipeline([("scaler",None),("pca", PCA(n_components=20)),("random_forest20",RandomForestClassifier(random_state=7,n_estimators=1000,class_weight=cl_weight))]),
      Pipeline([("scaler", None),("pca", PCA(n_components=20)), ("xgb_model20", xgb.XGBClassifier(objective="binary:logistic",random_state=7,n_estimators=10))]),
       ]
# Step 5: Parameter gird used during Grid Search.
param_grid=[{"knn20__n_neighbors":[1,3,5,7,9],"knn20__p":[1,2,3] },
        { "svcPC20__C":[1,3,10,30,100],"svcPC20__gamma":[0.0001,0.0003,0.001,0.003,0.01] },
              {"logreg20__C":[0.0001,0.001,0.01,0.1,1,10],"logreg20__solver":["liblinear", "sag", "saga", "newton-cg"]},
              {"random_forest20__max_features":[4,5,6,7],"random_forest20__max_leaf_nodes":[25,50,75,100]},
              {"xgb_model20__min_child_weight":[0.001,0.01,0.1,1], "xgb_model20__reg_lambda":[0.1,1,10],
               "xgb_model20__max_depth": [7,9,11],                "xgb_model20__subsample": [1],
               "xgb_model20__gamma":[0.001,0.01,1,10],           "xgb_model20__learning_rate": [0.2,0.4,1]},
              ]


# Loop over each pipe in combination with the defined parameter grids. The next step are carried out for each pair.
for i, j in zip(pipe,param_grid):
    print(str(list(j.keys())[0]))
    # Step 6: Extract the classifier name for labeling graphs further down.
    classifier_name = (re.search("(.*)\__", str(list(j.keys())[0])).group(1))
    print(classifier_name)

    # Step 8: Apply the pipe to the grid search.
    opti=GridSearchCV(i, param_grid=j,cv=cv,scoring=mcc, return_train_score=True,n_jobs=-1)
    # Step 9: Train the model on the training dataset.
    opti.fit(X_train,y_train_array)

    # Step 10: Save the trained predictor.
    pickle.dump(opti, open(f"{cl}Predictor/{where}_{classifier_name}_SST{sst_level}_CV{cv}", "wb"))
    dump(opti, f"{cl}Predictor/{where}_{classifier_name}_SST{sst_level}_CV{cv}.joblib")

    # Step 11: Explore the cross validation results using the TRAINING set.
    # Step 11.1:1: Log the best performing hyperparameter combination.
    print(f"###########################################################\n"
          f"###########################################################\n"
          f"###########################################################\n"
          f"Model type: {classifier_name}\n"
          f"GRID-SEARCH\n"
          f"the best parameters using the grid-search are:{opti.best_params_}\n"
          f"the best parameters had the following score (mcc) during grid-search:{opti.best_score_}\n",
         file=pred_log_file)
    # Step 11.2: Visualize the whole grid-search with a modified heatmap.
    # Step 11.2.1: Extract the training and "testing" score from the grid search results
    # (testing refers here to the cross validation hold out results)
    data_grid_search = pd.DataFrame(opti.cv_results_)[['mean_train_score', 'mean_test_score', 'params']]
    # Step 10.2.2: Flatten the results to get rid of the nested dictionary.
    data_grid_search = data_grid_search.drop('params', 1).assign(
        **pd.DataFrame(data_grid_search.params.values.tolist()))
    # Step 11.2.3: Visualize the training and the cross-validation score.
    # Step 10.2.3: Log the results
    print(f"###########################################################\n"
          f"Model type: {classifier_name}\n"
          f" GRID-SEARCH INTERMEDIATE RESULTS\n"
          f"during the gridsearch the following results were produced:\n"
          f"{data_grid_search.to_string()}\n",
          file=pred_log_file)
    print(data_grid_search.to_string())

    # Step 11: Explore the performance of the different models (using the optimized parameters) on the TEST-SET.
    # Step 11.1: confusion matrix
    cnf = confusion_matrix(y_test_array, opti.predict(X_test))
    pretty_cnf_matrix(cnf, classifier_name, sst_level=sst_level, cl=cl, cv=cv, cl_weight=cl_weight,
                      where=where)  # custom function for visualization
    # Step 11.2: Calculate the most important metrics and store them in the results dataframe for shared visualization
    roc_auc_score_mean, roc_auc_score_error = bootstrap_metric(metric=roc_auc_score, y_true=y_test_array,
                                                               y_pred=opti.predict_proba(X_test)[:, 1], n_boot=1000)
    mcc_bootstrap_mean, mcc_bootstrap_error = bootstrap_metric(metric=matthews_corrcoef, y_true=y_test_array,
                                                               y_pred=opti.predict(X_test), n_boot=1000)
    fpr, tpr, thresholds = roc_curve(y_test_array, opti.predict_proba(X_test)[:, 1])
    prec, rec, threshold = precision_recall_curve(y_test_array, opti.predict_proba(X_test)[:, 1])
    # Step 11.3: Save the results in a results dataframe for shared visualization
    results_df = results_df.append(pd.DataFrame([{"test_roc_auc_mean": roc_auc_score_mean,
                                                  "test_roc_auc_SE": roc_auc_score_error,
                                                  "valiation_mcc": opti.best_score_,
                                                  "test_mcc_mean": mcc_bootstrap_mean,
                                                  "test_mcc_SE": mcc_bootstrap_error,
                                                  "test_fpr": fpr,
                                                  "test_tpr": tpr,
                                                  "test_prec": prec,
                                                  "test_rec": rec}], index=[classifier_name]))
    # Step 11.4: Log the performance on the test set into the predictor logfile.
    print(f"###########################################################\n"
          f"Model type: {classifier_name}\n"
          f"\nTESTING:\nthe following metrics were calculated on the test set, the SE calculated using bootstrapping:\n"
          f"mcc and mcc_SE:\n "
          f"{bootstrap_metric(metric=matthews_corrcoef, y_true=y_test_array, y_pred=opti.predict(X_test), n_boot=1000)}\n"
          f"accuracy and accuracy_SE:\n "
          f"{bootstrap_metric(metric=accuracy_score, y_true=y_test_array, y_pred=opti.predict(X_test), n_boot=1000)}\n"
          f"precision and precision_SE:\n "
          f"{bootstrap_metric(metric=precision_score, y_true=y_test_array, y_pred=opti.predict(X_test), n_boot=1000)}\n"
          f"recall and recall_SE:\n "
          f"{bootstrap_metric(metric=recall_score, y_true=y_test_array, y_pred=opti.predict(X_test), n_boot=1000)}\n"
          f"ROC AUC:\n"
          f"{bootstrap_metric(metric=roc_auc_score, y_true=y_test_array, y_pred=opti.predict_proba(X_test)[:, 1], n_boot=1000)}\n"
          f"###########################################################\n",
          file=pred_log_file)

    # Step 13: Feature importance, for the different models the features importance or the closest equivalent is
    # stored under different names
    # Step 13.1: Check if coefs exist:
    if hasattr(opti.best_estimator_.named_steps[classifier_name], "coef_"):
        print("has coef")
        coefs = opti.best_estimator_.named_steps[classifier_name].coef_[0]
        vis_features_imp_embs(coefs, sst_level, cl, cv, classifier_name, X_test.columns, cl_weight, where=where)
    elif hasattr(opti.best_estimator_.named_steps[classifier_name], "feature_importances_"):
        print("has feature_importance")
        coefs = opti.best_estimator_.named_steps[classifier_name].feature_importances_
        vis_features_imp_embs(coefs, sst_level, cl, cv, classifier_name, X_test.columns, cl_weight, where=where)
    elif classifier_name == "svc":
        coefs = permutation_importance(opti, X_test, y_test_array).importances_mean
        print(coefs)
        vis_features_imp_embs(coefs, sst_level, cl, cv, classifier_name, X_test.columns, cl_weight, where=where)
    else:
        print(f"{classifier_name} doesnt have the attribute coef_ nor feature importance")
    # Step 14: Visualization of the learning curve for the optimized hyperparameter. This is a check to see
    # overfitting and if more data is usefull.
    plt.style.use("seaborn")
    if learn_curve == True:
        # Step 14.1: Calculate the training and testing scores for different partitioning sizes of the dataset.
        # important to remember what scoring method is used here, default is accuracy, but mcc is used here:
        train_sizes, train_scores, test_scores=learning_curve(estimator=opti , X=X_train,y=y_train_array, train_sizes=[0.1,0.25,0.5,0.75,1],cv=5,n_jobs=-1)
        # Step 14.2: Get the averages and the standard deviations from the cross validation scores
        train_mean=np.mean(train_scores, axis=1)
        train_std= np.std(train_scores,axis=1)
        test_mean=np.mean(test_scores,axis=1)
        test_std=np.std(test_scores, axis=1)
        # Step 14.3: Save the results as a figure. Means are plotted as seperated lines, the standard diviation is
        # plotted as a intervall around the line.
        fig, ax = plt.subplots(figsize=(4, 4))
        plt.plot(train_sizes, train_mean, label= "Training mcc",color="blue",linestyle='dashed')
        plt.fill_between(train_sizes, train_mean+train_std,train_mean-train_std, alpha=0.15,color="blue")
        plt.plot(train_sizes,test_mean,label="validation mcc",color="green",linestyle='dashed')
        plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.15,color="green")
        plt.xlabel(" Number of training examples")
        plt.ylabel("mcc")
        plt.legend()
        plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/learning_curves/SST{sst_level}_learning_curves_{classifier_name}_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

# Step 14: Shared visualizations of all baseline models:
line_list = ["solid", "dotted", "dashed", "dashdot", "dotted", "dashed", "dashdot", "solid", "dotted", "dashed",
             "dashdot", ]
color_list = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']



# Step 14.1:  Shared visualization of the ROC curve on test data.
plt.style.use("seaborn")
fig, ax = plt.subplots(figsize=(4, 4))
for num, i in enumerate(results_df.index):
    print(f"num: {num}")
    print(f"i: {i}")
    l_style = line_list[num]
    col = color_list[num]
    fpr = results_df.iloc[num,5]
    tpr = results_df.iloc[num,6]
    plt.plot(fpr, tpr, label=f"{str(i)}: {round(results_df.iloc[num,0],2)} (+/-{round(results_df.iloc[num,1],21)}) ROC AUC", linestyle=l_style, color=col)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"ROC for SST{sst_level} data")
plt.legend()
plt.rcParams['savefig.dpi'] = 300
plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/SST{sst_level}_rocauc_testset_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

# Step 14.2: Shared visualization of the precision recall curve on test data.
fig, ax = plt.subplots(figsize=(4, 4))
for num, i in enumerate(results_df.index):
    l_style = line_list[num]
    col = color_list[num]
    precision = results_df.iloc[num, 7]
    recall = results_df.iloc[num, 8]
    plt.plot(precision, recall,
             label=f"{str(i)}",
             linestyle=l_style,
             color=col),
plt.xlabel("Precision")
plt.ylabel("Recall=True Positive Rate")
plt.title(f"Precision Recall Curve for SST{sst_level} data ")
plt.legend()
plt.rcParams['savefig.dpi'] = 300
plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/SST{sst_level}_precisionrecall_testset_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")



# Step 14.4: Shared visualization of the train mcc no error available.
fig, ax = plt.subplots(figsize=(4, 4))
plt.stem(results_df.index, results_df["valiation_mcc"])
plt.xlabel("model")
plt.xticks(rotation=90)
plt.ylabel("mcc ")
plt.title(f"mcc during cross validation based on SST{sst_level}")
plt.legend()
plt.ylim(0, 1)
plt.rcParams['savefig.dpi'] = 300
plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/SST{sst_level}_mcc_trainset_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")


# Step 14.5: Shared visualization of the test mcc.
fig, ax = plt.subplots(figsize=(4, 4))
plt.stem(results_df.index, results_df["test_mcc_mean"])  # Plot the lollipop plot.
plt.errorbar(results_df.index, results_df["test_mcc_mean"],
             yerr=results_df["test_mcc_SE"],
             fmt="none",
             ecolor="grey",
             capsize=10,
             markeredgewidth=1)  # Add a second layer only containing the error bars.
plt.xlabel("model")
plt.xticks(rotation=90)
plt.ylabel("mcc ")
plt.title(f"mcc based on hold out test set SST{sst_level}")
plt.legend()
plt.rcParams['savefig.dpi'] = 300
plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/SST{sst_level}_mcc_testset_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")

#
def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = round(rect.get_height(),2)
        ax.annotate('{}'.format(height),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0,-12),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom',color="white")
#
# Step 14.6: Shared visualization of the test mcc in bar format.
fig, ax = plt.subplots(figsize=(4, 4))
bar1=plt.bar(results_df.index, results_df["test_mcc_mean"])  # Plot the lollipop plot.
plt.errorbar(results_df.index, results_df["test_mcc_mean"],
             yerr=results_df["test_mcc_SE"],
             fmt="none",
             ecolor="grey",
             capsize=10,
             markeredgewidth=1)  # Add a second layer only containing the error bars.
plt.xlabel("model")
plt.xticks(rotation=90)
plt.ylabel("mcc ")
plt.title(f"mcc based on hold out test set SST{sst_level}")
plt.legend()
plt.rcParams['savefig.dpi'] = 300
plt.ylim(0.6, 1)
autolabel(bar1)

plt.savefig(f"{cl}Figures/{where}_predictor_hyperparameter/SSTlevel{sst_level}/bar_SST{sst_level}_mcc_testset_mcc_CV{cv}_classweight{cl_weight}.png",bbox_inches="tight")


out_file.close()
pred_log_file.close()
plt.close('all')
