#!/usr/bin/env python3
# coding: utf-8


# ######################################################################################################################
# Author: Tanja Krüger 
Aim: This file visualizes the amino acid composition, the isoelectric points, the aromaticity using fungi toxins as input.


########################################################################################################################
# downloaded
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import re, argparse
from datetime import datetime
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from Bio import SeqIO
from scipy.stats import gaussian_kde
import csv
import logomaker

# #################################################################################################
# Option depending where the user wants the run the code form, default running the code with make from the project folder.
cl=""
# If one wants to execute this file from the Code/python folder uncomment the next line.
#cl="../../"

########################################################################################################################
# Get the arguments from the command line.
parser = argparse.ArgumentParser(prog="data_analysis_6.py",
                                 description="updated shared visualizing protein length, PI and aromaticity")
parser.add_argument("at100",
                    type=str,
                    help="fasta file animal toxins 100% reduced")
parser.add_argument("ft100",
                    type=str,
                    help="fasta file fungal toxins 100% reduced")
parser.add_argument("bt100",
                    type=str,
                    help="fasta file bacterial toxins 100% reduced")
parser.add_argument("bc100",
                    type=str,
                    help="fasta file bacterial controls 100% reduced")


parser.add_argument("matrix_ftat",
                    type=str,
                    help="surprise matrix fungal toxins v bacterial toxins")
parser.add_argument("matrix_ftbt",
                    type=str,
                    help="surprise matrix fungal toxins v bacterial toxins")


args = parser.parse_args()


########################################################################################################################
# Step0: Check if the datasets all have a same mmseqs2 redundancy reduction, if yes if they have the same level
try:
    sst_level1 = re.search("SST(\d+)",args.at100).group(1) #mmseqs2 reduction level
    sst_level2 = re.search("SST(\d+)", args.ac100).group(1)  # mmseqs2 reduction level
    sst_level3 = re.search("SST(\d+)", args.bt100).group(1)
    sst_level4 = re.search("SST(\d+)", args.bc100).group(1)
    assert sst_level1==sst_level2==sst_level3==sst_level4, "all sets are redundancy reduce but they do not share the same level of reduction"
except: sst_level1="full unreduced"



########################################################################################################################
# Step 1: Open the data
# Step 1.1: Open fasta files and modify it to a useable dataframe.
with open(args.at100) as handle:
    at100 = pd.DataFrame({record.id: [record.description, str(record.seq)]  for record in SeqIO.parse(handle, "fasta")}).T
with open(args.ft100) as handle:
    ft100 = pd.DataFrame({record.id: [record.description, str(record.seq)]  for record in SeqIO.parse(handle, "fasta")}).T    
with open(args.bt100) as handle:
    bt100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
with open(args.bc100) as handle:
    bc100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
# Step 1.2: rename the column of the dataframes
at100.columns,ft100.columns, bt100.columns, bc100.columns=["info","seq"],["info","seq"],["info","seq"],["info","seq"]

# Step 1.3: Define the amino acids for the lables in the plot.
amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]

# Step 1.4: Open the matrices with the logo information
df_ftat = pd.read_csv(args.matrix_ftat)
df_ftbt = pd.read_csv(args.matrix_ftbt)

# Step 2: update the parametrs used in the visualizations 
plt.rcParams.update({
    'axes.labelsize': 16,   # Font size for x and y axis labels
    'xtick.labelsize': 14,  # Font size for x-axis tick labels
    'ytick.labelsize': 14,  # Font size for y-axis tick labels
    'legend.fontsize': 14,  # Font size for legend
    'axes.titlesize': 18,   # Font size for subplot titles
    'font.family': 'serif', # Font family
})

## Step 3: Set the type of y_axis in the plots True for probability density, False for raw counts
case= "density"
if case == "density":
    density= True
elif case == "frequency":
    density= False
else:
    print("case supplied must either be density or frequency")




########################################################################################################################
# Step 4: Functions
def viz2(attribute,bw_method,df_1,df_2,data_name1,data_name2,col1,col2,ax):
    ''' attribute: The attribute that is plotted can be isoelectric_point or aromaticity
        bw_method: the smoothness of the curve, larger number result in smoother curves.
        df_1 (DataFrame): The first dataframe containing sequences in the 'seq' column.
        df_2 (DataFrame): The second dataframe containing sequences in the 'seq' column.
        data_name1 (str): The name of the first dataset to be used as a label in the plot.
        data_name2 (str): The name of the second dataset to be used as a label in the plot.
        col1 (str): The color of the first histogram.
        col2 (str): The color of the second histogram.
        ax (Axes): The axes object to plot on.
    Returns:
        plt (Figure): The matplotlib figure object with the histograms of isoelectric point of the proteins.'''
    values1 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_1.seq]
    values2 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_2.seq]
    x_range1 = np.linspace(min(values1+values2), max(values1+values2), 1000)
    x_range2 = np.linspace(min(values1 + values2), max(values1 + values2), 1000)
    if attribute == 'aromaticity':
        x_range1 = np.linspace(0, 0.3, 1000)
        x_range2 = np.linspace(0, 0.3, 1000)
    else:
        ax.set_ylim(0, 0.5)
    kde1= gaussian_kde(values1, bw_method=bw_method)
    kde2 = gaussian_kde(values2, bw_method=bw_method)
    ax.plot(x_range1, kde1(x_range1), label=data_name1, color=col1, linewidth=2, linestyle="-")
    ax.plot(x_range2, kde2(x_range2), label=data_name2, color=col2, linewidth=2, linestyle="--")
    ax.legend(loc='upper right',fontsize=16)
    return fig

def viz3(attribute,bw_method,df_1,df_2,df_3,data_name1,data_name2,data_name3,col1,col2,col3,ax):
    ''' attribute: The attribute that is plotted can be isoelectric_point or aromaticity
        bw_method: the smoothness of the curve, larger number result in smoother curves.
        df_1 (DataFrame): The first dataframe containing sequences in the 'seq' column.
        df_2 (DataFrame): The second dataframe containing sequences in the 'seq' column.
        df_2 (DataFrame): The third dataframe containing sequences in the 'seq' column.
        data_name1 (str): The name of the first dataset to be used as a label in the plot.
        data_name2 (str): The name of the second dataset to be used as a label in the plot.
        data_name3 (str): The name of the third dataset to be used as a label in the plot.
        col1 (str): The color of the first histogram.
        col2 (str): The color of the second histogram.
        col3 (str): The color of the third histogram.
        ax (Axes): The axes object to plot on.
    Returns:
        plt (Figure): The matplotlib figure object with the histograms of isoelectric point of the proteins.'''
    values1 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_1.seq]
    values2 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_2.seq]
    values3 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_3.seq]
    x_range1 = np.linspace(min(values1+values2+values3), max(values1+values2+values3), 1000)
    x_range2 = np.linspace(min(values1 + values2+values3), max(values1 + values2+values3), 1000)
    x_range3 = np.linspace(min(values1 + values2+values3), max(values1 + values2+values3), 1000)
    if attribute == 'aromaticity':
        x_range1 = np.linspace(0, 0.3, 1000)
        x_range2 = np.linspace(0, 0.3, 1000)
        x_range3 = np.linspace(0, 0.3, 1000)
    else:
        ax.set_ylim(0, 0.5)
    kde1= gaussian_kde(values1, bw_method=bw_method)
    kde2 = gaussian_kde(values2, bw_method=bw_method)
    kde3 = gaussian_kde(values3, bw_method=bw_method)
    ax.plot(x_range1, kde1(x_range1), label=data_name1, color=col1, linewidth=2, linestyle="-")
    ax.plot(x_range2, kde2(x_range2), label=data_name2, color=col2, linewidth=2, linestyle="--")
    ax.plot(x_range3, kde3(x_range3), label=data_name3, color=col3, linewidth=2, linestyle="--")
    ax.legend(loc='upper right',fontsize=16)
    return fig

def calculate_kde(sequences, bw_method, attribute):
    """ Generalized KDE functions:
    arguments:
    This function calculates the kernel density estimation (KDE) on an attribute that was calculated for a list of sequences.
    KDE is a smooth curve that represents the distribution an attibute.
    attribute: can be isoelectric_point or aromaticty (as provided in the ProteinAnalysis package)
    bw_method: is the level of detail in which the kwd line is drawn
    returns:
    a kernl densitiy estimation either on aromaticity or pI """
    values = [getattr(ProteinAnalysis(seq), attribute)() for seq in sequences]
    return gaussian_kde(values, bw_method=bw_method)

def plot_data(ax, kde, x_range, label, color, linestyle):
    """ This function draws a line to visualize a kernel density estimation
    Parameters:
    - ax: Where on the axes to draw the line.
    - kde: The kernel density estimation.
    - x_range: The range of values for the chosen attribute (pi or aromaticty).
    - label: The name for this group of data (toxins, controls..)
    - color: The color of the line.
    - linestyle: The style of the line (solid, dashed, etc.)."""
    ax.plot(x_range, kde(x_range), label=label, color=color, linewidth=2, linestyle=linestyle)

def save_plot(fig, filename, axis_name):
    plt.legend(loc='upper right',fontsize=16)
    plt.xlabel(axis_name,fontsize=16)
    plt.ylabel("density",fontsize=16)
    plt.rcParams['savefig.dpi'] = 300
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.savefig(f"{cl}Figures/{filename}", bbox_inches="tight")

def logo_viz(df,data_name1,data_name2,ax):
    """ This function creates a logo plot for the given dataframe using logomaker.
        The logo plot shows the relative frequency of amino acids at each position in a sequence alignment.
        The color scheme depends on the data names provided as arguments.
        Parameters:
        df (pandas.DataFrame): A dataframe containing the surprise matrix for the logo plot.
        data_name1 (str): The name of the first data set (used
        data_name2 (str): The name of the second data set
        ax (matplotlib.axes.Axes): The axes object to plot the logo on.
        Returns:
        matplotlib.axes.Axes: The axes object with the logo plot.
        """
    COLORS = {}
    for num,i in enumerate(amino_acids):
        if data_name1=="fungal_toxins" and data_name2=="animal_toxins":
            COLORS[i] = '#00A86B' if df.iloc[num, num] >= 0 else '#B1041B'
        elif data_name1 == "fungal_toxins" and data_name2 == "bacterial_toxins":
            COLORS[i] = '#00A86B' if df.iloc[num, num] >= 0 else '#2156B5'
        else:
            print("the wrong kind of file was provided, probably matrix in the wrong order")
    # Create Logo object.
    lg=logomaker.Logo(df,flip_below=False,color_scheme=COLORS,figsize=[10, 5],ax=ax,zorder=1)
    lg.style_spines(visible=False)
    lg.style_spines(spines=['left', 'bottom'], visible=True)
    lg.style_xticks(fmt='%d', anchor=0)
    return ax


########################################################################################################################
#Step 5: ACTUALLY CREATING THE SHARED PLOTS
# Step 5.1: isoelectric point visualization in 1by2 plot
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A"],[ "B"]],layout='constrained')
viz3('isoelectric_point',0.075,ft100,at100,bt100,"fungal toxins","animal toxins","bacterial toxins","#00A86B","#B1041B","#2156B5",axs['A'])
axs['A'].set_ylabel(case,fontsize=16)
axs['A'].tick_params(axis='both', which='major', labelsize=14)
viz2('isoelectric_point',0.075,ft100,bc100,"fungal toxins","bacterial controls","#00A86B","#61BDD2",axs["B"])
axs['B'].set_ylabel(case,fontsize=16)
axs['B'].set_xlabel("isoelectric point",fontsize=16)
axs['B'].tick_params(axis='both', which='major', labelsize=14)
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize=18, verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_pIs_1by2_fungi.png", bbox_inches="tight")




# Step 5.2: save the listed_sequences, the colors and the labels in dictionaries for better accession.
seq_lists = {'ft100': ft100.seq,'bt100': bt100.seq,'at100': at100.seq,}
colors = {'ft100':"#00A86B",'bt100': "#2156B5", 'at100': "#B1041B",}
labels = {'ft100': 'Fungal toxins','bt100': 'Bacterial toxins','at100': 'Animal toxins'}
line_style={'ft100': '-','bt100': '--','at100': '--'}

# Step 5.3: Define attributes that are calculated on the protein sequences (pI or aromaticty) and the fineness of the resulting kde line
attributes_bw = [('isoelectric_point', 0.05), ('isoelectric_point', 0.1), ('aromaticity', 0.1)]

# Step 5.4: Loop over attribute and bw_method similtaneusly
for attribute, bw_method in attributes_bw:
    # Dictionary comprehension to get the kde for each of the four data sets in seq_list
    kdes = {key: calculate_kde(value, bw_method, attribute) for key, value in seq_lists.items()}
    # Make a list of all possible values of an attribute
    all_values = [getattr(ProteinAnalysis(seq), attribute)() for sequences in seq_lists.values() for seq in sequences]
    # Calculate the range and then divide in 1000 steps
    x_range = np.linspace(min(all_values), max(all_values), 1000)
    if attribute== 'aromaticity':
        x_range=np.linspace(0, 0.3, 1000)
    # Plot each data set, with its colors, label entries
    fig, ax = plt.subplots()
    for key in seq_lists:
        plot_data(ax, kdes[key], x_range, labels[key], colors[key],line_style[key])
    # The Name under which the plot is saved
    filename = f"shared_{attribute}_{bw_method}.png"
    # the axis name recycles the attirbute name but replaces the dash with a space
    axis_name = attribute.replace('_', ' ')
    # the plot is saved
    save_plot(fig, filename, axis_name)


# Step 5.5: amino acid visualization per surprise  Logo 2x2
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A"],[ "B"]],layout='constrained')
logo_viz(df_ftat,"fungal_toxins","animal_toxins",axs['A'])
axs['A'].set_ylabel("Surprise",fontsize=16)
axs['A'].grid(True,zorder=0)
axs['A'].set_xticklabels([])
axs['A'].tick_params(axis='both', which='major', labelsize=14)
logo_viz(df_ftbt,"fungal_toxins","bacterial_toxins",axs["B"]) #an an
axs['B'].grid(True,zorder=0)
axs['B'].set_xticklabels([])
axs['B'].set_xlabel(" Amino acids",fontsize=16)
axs['B'].tick_params(axis='both', which='major', labelsize=14)
axs['B'].set_ylabel("Surprise",fontsize=16)
fig.suptitle(f"amino acid usage", fontsize=20)
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize=18, verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_logos_fungi.png", bbox_inches="tight")
