#!/usr/bin/env python3
# coding: utf-8


# ######################################################################################################################
# Author: Tanja Krüger Aim: This file visualizes the amino acid composition, the average length, the aromaticity and the
# logos in joined graphics Input: four fasta files with the separate bacterial and animal toxins, and four matrices that
# hold the surprise of the amino acids to occur in one of two sets
# Output: a series of files that show the length distribution, the pI, instability, aromaticity distribution and logos
# in combined graphs


########################################################################################################################
# downloaded
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import re, argparse
from datetime import datetime
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from Bio import SeqIO
import logomaker
from scipy.stats import gaussian_kde
import csv
# #################################################################################################
# Option depending where the user wants the run the code form, default running the code with make from the project folder.
cl=""
# If one wants to execute this file from the Code/python folder uncomment the next line.
#cl="../../"

########################################################################################################################
# Get the arguments from the command line.
parser = argparse.ArgumentParser(prog="data_analysis_3.py",
                                 description="shared visualizing protein length, PI and aromaticity")
parser.add_argument("at100",
                    type=str,
                    help="fasta file animal toxins 100% reduced")
parser.add_argument("ac100",
                    type=str,
                    help="fasta file animal control 100% reduced")
parser.add_argument("bt100",
                    type=str,
                    help="fasta file bacterial toxins 100% reduced")
parser.add_argument("bc100",
                    type=str,
                    help="fasta file bacterial control 100% reduced")

parser.add_argument("matrix_atac",
                    type=str,
                    help="surprise matrix animal toxins v animal control")
parser.add_argument("matrix_atbt",
                    type=str,
                    help="surprise matrix animal toxins v bacterial toxins")
parser.add_argument("matrix_acbc",
                    type=str,
                    help="surprise matrix animal control v bacterial control")
parser.add_argument("matrix_btbc",
                    type=str,
                    help="surprise matrix bacterial toxins v bacterail control")

args = parser.parse_args()


########################################################################################################################
# Step1: Check if the datasets all have a same mmseqs2 redundancy reduction, if yes if they have the same level
try:
    sst_level1 = re.search("SST(\d+)",args.at100).group(1) #mmseqs2 reduction level
    sst_level2 = re.search("SST(\d+)", args.ac100).group(1)  # mmseqs2 reduction level
    sst_level3 = re.search("SST(\d+)", args.bt100).group(1)
    sst_level4 = re.search("SST(\d+)", args.bc100).group(1)
    assert sst_level1==sst_level2==sst_level3==sst_level4, "all sets are redundancy reduce but they do not share the same level of reduction"
except: sst_level1="full unreduced"

# Step2: Log
# Step2.1: Open the predictor logfile and the general logfile.
out_file = open(f"{cl}Data/derived/log.log", "a")
explore_file=open(f"{cl}Exploratory/smd_{sst_level1}.log","a")

# Step2.2: Get the date and time
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")

# Step2.3: Write to the predictor log file and the general logfile
print(f"""########## \n
program {parser.prog} was executed at {dt_string} \n
program {parser.prog} was executed at {dt_string} \n
argments passed: the 100 reduced animal toxins {args.at100} \n
                the 100 reduced bacterial toxins {args.bt100} \n
                the 100 reduced animal controls {args.ac100}\n
                the 100 reudced bacterail controls {args.bc100}\n
                the four surprise matrices {args.matrix_atac}, {args.matrix_atbt},{args.matrix_acbc},and {args.matrix_btbc},\n
number of required arguments:14""",file=out_file)

########################################################################################################################
# Step 3: Open the data
# Step 3.1: Open fasta files and modify it to  dataframe.
with open(args.at100) as handle:
    at100 = pd.DataFrame({record.id: [record.description, str(record.seq)]  for record in SeqIO.parse(handle, "fasta")}).T
with open(args.ac100) as handle:
    ac100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
with open(args.bt100) as handle:
    bt100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
with open(args.bc100) as handle:
    bc100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
# Step 3.2: rename the column of the dataframes
at100.columns,ac100.columns, bt100.columns, bc100.columns=["info","seq"],["info","seq"],["info","seq"],["info","seq"]

# Step 3.3: Define the amino acids for the lables in the plot.
amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]

# Step 3.4: Open the matrices with the logo information
df_atac = pd.read_csv(args.matrix_atac)
df_atbt = pd.read_csv(args.matrix_atbt)
df_acbc = pd.read_csv(args.matrix_acbc)
df_btbc = pd.read_csv(args.matrix_btbc)

lens_at = pd.DataFrame([(len(i),"animal\ntoxins") for i in at100.seq],columns=["length","origin"])
lens_ac = pd.DataFrame([(len(i),"animal\ncontrols") for i in ac100.seq],columns=["length","origin"])
lens_bt = pd.DataFrame([(len(i),"bacterial\ntoxins") for i in bt100.seq],columns=["length","origin"])
lens_bc = pd.DataFrame([(len(i),"bacterial\ncontrols" )for i in bc100.seq],columns=["length","origin"])
df_all_lens = pd.concat([lens_ac, lens_at, lens_bc, lens_bt], ignore_index=True)


# Open a new csv file in write mode
def writer(df):
    with open(f"{cl}Data/derived/length_output.csv", "w") as f:
        # Create a csv writer object
        writer = csv.writer(f)
        for i in df.index:
            writer.writerow(df.iloc[i,:])

writer(df_all_lens)

# Step 4: Set the type of y_axis in the plots True for probability density, False for raw counts
case= "density"
if case == "density":
    density= True
elif case == "frequency":
    density= False
else:
    print("case supplied must either be density or frequency")


########################################################################################################################
#Step 5.1: Aromaticity visualization 2x2
def aro_viz(df_1,df_2,data_name1,data_name2,col1,col2,ax):
    ''' df_1 (DataFrame): The first dataframe containing sequences in the 'seq' column.
        df_2 (DataFrame): The second dataframe containing sequences in the 'seq' column.
        data_name1 (str): The name of the first dataset to be used as a label in the plot.
        data_name2 (str): The name of the second dataset to be used as a label in the plot.
        col1 (str): The color of the first histogram.
        col2 (str): The color of the second histogram.
        ax (Axes): The axes object to plot on.
    Returns:
        plt (Figure): The matplotlib figure object with the histograms of the aromaticity of the proteins.'''
    list_of_aro1=[ProteinAnalysis(sequence).aromaticity() for sequence in df_1.seq]
    list_of_aro2=[ProteinAnalysis(sequence).aromaticity() for sequence in  df_2.seq]
    bin_size = 0.01 # set bin size
    bins = np.arange(min(list_of_aro1 + list_of_aro2), max(list_of_aro1 +list_of_aro2) + bin_size, bin_size)
    ax.hist(list_of_aro1,bins=bins,label=f"{data_name1}",alpha=0.7,color=col1,density=density)
    ax.hist(list_of_aro2,bins=bins,label=f"{data_name2}",alpha=0.7,color=col2,density=density)
    ax.legend(loc='upper right')
    # print(f"{data_name1}:median {np.median(list_of_aro1)}")
    # print(f"{data_name1}:10percentile {np.percentile(list_of_aro1, 10)}")
    # print(f"{data_name1}:90percentile {np.percentile(list_of_aro1, 90)}")
    # print(f"{data_name2}:median {np.median(list_of_aro2)}")
    # print(f"{data_name2}:10percentile {np.percentile(list_of_aro2, 10)}")
    # print(f"{data_name2}:90percentile {np.percentile(list_of_aro2, 90)}")
    return fig
########################################################################################################################
# Step 5.2: Isoelectric Point visualization  2x2
def pI_viz(df_1,df_2,data_name1,data_name2,col1,col2,ax):
    ''' df_1 (DataFrame): The first dataframe containing sequences in the 'seq' column.
        df_2 (DataFrame): The second dataframe containing sequences in the 'seq' column.
        data_name1 (str): The name of the first dataset to be used as a label in the plot.
        data_name2 (str): The name of the second dataset to be used as a label in the plot.
        col1 (str): The color of the first histogram.
        col2 (str): The color of the second histogram.
        ax (Axes): The axes object to plot on.
    Returns:
        plt (Figure): The matplotlib figure object with the histograms of isoelectric point of the proteins.'''
    list_of_pi1=[ProteinAnalysis(sequence).isoelectric_point() for sequence in df_1.seq]
    list_of_pi2=[ProteinAnalysis(sequence).isoelectric_point() for sequence in df_2.seq ]
    bin_size = 0.15 # set bin size
    bins = np.arange(min(list_of_pi1 + list_of_pi2), max(list_of_pi1 +list_of_pi2) + bin_size, bin_size)
    ax.hist(list_of_pi1,bins=bins,label=f"{data_name1}",alpha=0.7,color=col1,density=density)
    ax.hist(list_of_pi2,bins=bins,label=f"{data_name2}",alpha=0.7,color=col2,density=density)
    ax.legend(loc='upper right')
    ax.set_ylim(0,0.6)
    return fig

########################################################################################################################
# Step 5.3 :Logo visualization 2x2
def logo_viz(df,data_name1,data_name2,ax):
    """ This function creates a logo plot for the given dataframe using logomaker.
        The logo plot shows the relative frequency of amino acids at each position in a sequence alignment.
        The color scheme depends on the data names provided as arguments.
        Parameters:
        df (pandas.DataFrame): A dataframe containing the surprise matrix for the logo plot.
        data_name1 (str): The name of the first data set (used
        data_name2 (str): The name of the second data set
        ax (matplotlib.axes.Axes): The axes object to plot the logo on.
        Returns:
        matplotlib.axes.Axes: The axes object with the logo plot.
        """
    COLORS = {}
    for num,i in enumerate(amino_acids):
        if data_name1=="animal_toxins" and data_name2=="animal_controls":
            COLORS[i] = '#B1041B' if df.iloc[num, num] >= 0 else '#EABA49'
        elif data_name1 == "bacterial_toxins" and data_name2 == "bacterial_controls":
            COLORS[i] = '#2156B5' if df.iloc[num, num] >= 0 else '#61BDD2'
        elif data_name1 == "animal_toxins" and data_name2 == "bacterial_toxins":
            COLORS[i] = '#B1041B' if df.iloc[num, num] >= 0 else '#2156B5'
        elif data_name1 == "animal_controls" and data_name2 == "bacterial_controls":
            COLORS[i] = '#EABA49' if df.iloc[num, num] >= 0 else '#61BDD2'
        else:
            print("the wrong kind of file was provided, probably matrix in the wrong order")
    # Create Logo object.
    lg=logomaker.Logo(df,flip_below=False,color_scheme=COLORS,figsize=[10, 5],ax=ax,zorder=1)
    lg.style_spines(visible=False)
    lg.style_spines(spines=['left', 'bottom'], visible=True)
    lg.style_xticks(fmt='%d', anchor=0)
    return ax


########################################################################################################################
#Step 6: ACTUALLY CREATING THE SHARED PLOTS
# Step 6.1: aromaticity visualization
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A", "B"], ["C", "D"]],layout='constrained',sharex=True, sharey=True)
aro_viz(bt100,bc100,"bacterial_toxins","bacterial_controls","#2156B5","#61BDD2",axs['A'])
axs['A'].set_ylabel(case)
aro_viz(at100,ac100,"animal_toxins","animal_controls","#B1041B","#EABA49",axs["B"])
aro_viz(at100,bt100,"animal_toxins","bacterial_toxins","#B1041B","#2156B5",axs["C"])
axs['C'].set_xlabel(" aromaticity")
axs['C'].set_ylabel(case)
aro_viz(ac100,bc100,"animal_controls","bacterial_controls","#EABA49","#61BDD2",axs["D"])
axs['D'].set_xlabel(" aromaticity")
fig.suptitle(f"protein aromaticity", fontsize=20)
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='medium', verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_aromaticity.png", bbox_inches="tight")

# Step 6.2: pI visualization
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A", "B"], ["C", "D"]],layout='constrained',sharex=True, sharey=True)
pI_viz(bt100,bc100,"bacterial_toxins","bacterial_controls","#2156B5","#61BDD2",axs['A'])
axs['A'].set_ylabel(case)
pI_viz(at100,ac100,"animal_toxins","animal_controls","#B1041B","#EABA49",axs["B"])
pI_viz(at100,bt100,"animal_toxins","bacterial_toxins","#B1041B","#2156B5",axs["C"])
axs['C'].set_xlabel(" isoelectric point")
axs['C'].set_ylabel(case)
pI_viz(ac100,bc100,"animal_controls","bacterial_controls","#EABA49","#61BDD2",axs["D"])
axs['D'].set_xlabel(" isoelectric point")
fig.suptitle(f"isoelectric point", fontsize=20)
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='medium', verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_isoelectric_point.png", bbox_inches="tight")

# Step 6.3: amino acid visualization per surprise  Logo 2x2
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A", "B"], ["C", "D"]],layout='constrained')
logo_viz(df_btbc,"bacterial_toxins","bacterial_controls",axs['A'])
axs['A'].set_ylabel("Surprise")
axs['A'].grid(True,zorder=0)
axs['A'].set_xticklabels([])
logo_viz(df_atac,"animal_toxins","animal_controls",axs["B"]) #an an
axs['B'].grid(True,zorder=0)
axs['B'].set_xticklabels([])
logo_viz(df_atbt,"animal_toxins","bacterial_toxins",axs["C"]) # tox tox
axs['C'].set_xlabel(" Amino acids")
axs['C'].set_ylabel("Surprise")
axs['C'].set_xticklabels([])
axs['C'].grid(zorder=0)
logo_viz(df_acbc,"animal_controls","bacterial_controls",axs["D"]) #con con
axs['D'].set_xlabel(" Amino acids")
axs['D'].set_xticklabels([])
axs['D'].grid(zorder=0)
fig.suptitle(f"amino acid usage", fontsize=20)
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='medium', verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_logos.png", bbox_inches="tight")
#

# #  Amino acid ratios
# colnames=['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q','R', 'S', 'T', 'V', 'W', 'Y']
# list_of_ratios1=[[sequence.count(i)/len(sequence) for i in colnames] for sequence in at100.seq]
# list_of_ratios2=[[sequence.count(i)/len(sequence) for i in colnames] for sequence in ac100.seq]
# list_of_ratios3 = [[sequence.count(i) / len(sequence) for i in colnames] for sequence in bt100.seq]
# list_of_ratios4 = [[sequence.count(i) / len(sequence) for i in colnames] for sequence in bc100.seq]
# df_ratios_per_aa1=pd.DataFrame(list_of_ratios1, columns=colnames)
# df_ratios_per_aa2=pd.DataFrame(list_of_ratios2, columns=colnames)
# df_ratios_per_aa3 = pd.DataFrame(list_of_ratios3, columns=colnames)
# df_ratios_per_aa4 = pd.DataFrame(list_of_ratios4, columns=colnames)
# bin_size = 0.005 # set bin size
# bins = np.arange(0,0.2 + bin_size, bin_size)
# plt.style.use("seaborn")
# fig, ax = plt.subplots(4,5,figsize=(17,17),sharex=True,)
# for num,i in enumerate(colnames):
#     yi=num%5 # location vertically
#     xi=num//5 # location horizontally
#     ax[xi, yi].hist(df_ratios_per_aa1.loc[:, i], bins=bins, alpha=0.5, color="#B1041B", density=density)
#     ax[xi, yi].hist(df_ratios_per_aa2.loc[:, i], bins=bins, alpha=0.5, color="#EABA49", density=density)
#     ax[xi, yi].hist(df_ratios_per_aa3.loc[:, i], bins=bins, alpha=0.5, color="#2156B5", density=density)
#     ax[xi, yi].hist(df_ratios_per_aa4.loc[:, i], bins=bins, alpha=0.5, color="#61BDD2", density=density)
#     ax[xi,yi].set_title(i,fontsize=15) # title
#     plt.rcParams['savefig.dpi'] = 300
#     fig.suptitle(f'relative frequency of each amino acid for  data\n y_axis different scales',fontsize=20) # title over all plots

#plt.show()

# # Calculate the upper 90 and 10 % quantile
# list_of_len_at=[len(i) for i in at100.seq]
# list_of_len_ac=[len(i) for i in ac100.seq]
# list_of_len_bt=[len(i)for i in bt100.seq]
# list_of_len_bc=[len(i) for i in bc100.seq]
#
#
# print(f"at:median {np.median(list_of_len_at)}")
# print(f"at:25percentile {np.percentile(list_of_len_at, 25)}")
# print(f"at:75percentile {np.percentile(list_of_len_at, 75)}")
# print(f"at:max {np.max(list_of_len_at)}")
# print("XXXXXXXXXXXXXXXXXXXXXXX")
# print(f"ac:median {np.median(list_of_len_ac)}")
# print(f"ac:25percentile {np.percentile(list_of_len_ac, 25)}")
# print(f"ac:75percentile {np.percentile(list_of_len_ac, 75)}")
# print(f"ac:max {np.max(list_of_len_ac)}")
# print("XXXXXXXXXXXXXXXXXXXXXXX")
# print(f"bt:median {np.median(list_of_len_bt)}")
# print(f"bt:25percentile {np.percentile(list_of_len_bt, 25)}")
# print(f"bt:75percentile {np.percentile(list_of_len_bt, 75)}")
# print(f"bt:max {np.max(list_of_len_bt)}")
# print("XXXXXXXXXXXXXXXXXXXXXXX")
# print(f"bc:median {np.median(list_of_len_bc)}")
# print(f"bc:25percentile {np.percentile(list_of_len_bc, 25)}")
# print(f"bc:75percentile {np.percentile(list_of_len_bc, 75)}")
# print(f"bc:max {np.max(list_of_len_bc)}")
# print("XXXXXXXXXXXXXXXXXXXXXXX")




#log the results
print(f"#####################################",file=explore_file)
print(f"data SST{sst_level1}")

out_file.close()
explore_file.close()
plt.close("all")



#