#!/usr/bin/env python3
# coding: utf-8


# ######################################################################################################################
# Author: Tanja Krüger Aim: This file visualizes the amino acid composition, the average length, the aromaticity and the
# logos in joined graphics - it's an adaptation of the already established 2x2 plots
# Input: four fasta files with the separate bacterial and animal toxins, and four matrices that
# hold the surprise of the amino acids to occur in one of two sets
# Output: a series of files that show the length distribution, the pI, instability, aromaticity distribution and logos
# in combined graphs


########################################################################################################################
# downloaded
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import re, argparse
from datetime import datetime
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from Bio import SeqIO
from scipy.stats import gaussian_kde
import csv
# #################################################################################################
# Option depending where the user wants the run the code form, default running the code with make from the project folder.
cl=""
# If one wants to execute this file from the Code/python folder uncomment the next line.
#cl="../../"

########################################################################################################################
# Get the arguments from the command line.
parser = argparse.ArgumentParser(prog="data_analysis_6.py",
                                 description="updated shared visualizing protein length, PI and aromaticity")
parser.add_argument("at100",
                    type=str,
                    help="fasta file animal toxins 100% reduced")
parser.add_argument("ac100",
                    type=str,
                    help="fasta file animal control 100% reduced")
parser.add_argument("bt100",
                    type=str,
                    help="fasta file bacterial toxins 100% reduced")
parser.add_argument("bc100",
                    type=str,
                    help="fasta file bacterial control 100% reduced")

parser.add_argument("matrix_atac",
                    type=str,
                    help="surprise matrix animal toxins v animal control")
parser.add_argument("matrix_atbt",
                    type=str,
                    help="surprise matrix animal toxins v bacterial toxins")
parser.add_argument("matrix_acbc",
                    type=str,
                    help="surprise matrix animal control v bacterial control")
parser.add_argument("matrix_btbc",
                    type=str,
                    help="surprise matrix bacterial toxins v bacterail control")

args = parser.parse_args()


########################################################################################################################
# Step1: Check if the datasets all have a same mmseqs2 redundancy reduction, if yes if they have the same level
try:
    sst_level1 = re.search("SST(\d+)",args.at100).group(1) #mmseqs2 reduction level
    sst_level2 = re.search("SST(\d+)", args.ac100).group(1)  # mmseqs2 reduction level
    sst_level3 = re.search("SST(\d+)", args.bt100).group(1)
    sst_level4 = re.search("SST(\d+)", args.bc100).group(1)
    assert sst_level1==sst_level2==sst_level3==sst_level4, "all sets are redundancy reduce but they do not share the same level of reduction"
except: sst_level1="full unreduced"

# Step2: Log
# Step2.1: Open the predictor logfile and the general logfile.
out_file = open(f"{cl}Data/derived/log.log", "a")
explore_file=open(f"{cl}Exploratory/smd_{sst_level1}.log","a")

# Step2.2: Get the date and time
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")

# Step2.3: Write to the predictor log file and the general logfile
print(f"""########## \n
program {parser.prog} was executed at {dt_string} \n
program {parser.prog} was executed at {dt_string} \n
argments passed: the 100 reduced animal toxins {args.at100} \n
                the 100 reduced bacterial toxins {args.bt100} \n
                the 100 reduced animal controls {args.ac100}\n
                the 100 reudced bacterail controls {args.bc100}\n
                the four surprise matrices {args.matrix_atac}, {args.matrix_atbt},{args.matrix_acbc},and {args.matrix_btbc},\n
number of required arguments:14""",file=out_file)

########################################################################################################################
# Step 3: Open the data
# Step 3.1: Open fasta files and modify it to dataframe.
with open(args.at100) as handle:
    at100 = pd.DataFrame({record.id: [record.description, str(record.seq)]  for record in SeqIO.parse(handle, "fasta")}).T
with open(args.ac100) as handle:
    ac100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
with open(args.bt100) as handle:
    bt100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
with open(args.bc100) as handle:
    bc100 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
# Step 3.2: rename the column of the dataframes
at100.columns,ac100.columns, bt100.columns, bc100.columns=["info","seq"],["info","seq"],["info","seq"],["info","seq"]

# Step 3.3: Define the amino acids for the labels in the plot.
amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]

# Step 3.4: Open the matrices with the logo information
df_atac = pd.read_csv(args.matrix_atac)
df_atbt = pd.read_csv(args.matrix_atbt)
df_acbc = pd.read_csv(args.matrix_acbc)
df_btbc = pd.read_csv(args.matrix_btbc)

lens_at = pd.DataFrame([(len(i),"animal\ntoxins") for i in at100.seq],columns=["length","origin"])
lens_ac = pd.DataFrame([(len(i),"animal\ncontrols") for i in ac100.seq],columns=["length","origin"])
lens_bt = pd.DataFrame([(len(i),"bacterial\ntoxins") for i in bt100.seq],columns=["length","origin"])
lens_bc = pd.DataFrame([(len(i),"bacterial\ncontrols" )for i in bc100.seq],columns=["length","origin"])
df_all_lens = pd.concat([lens_ac, lens_at, lens_bc, lens_bt], ignore_index=True)


# Open a new csv file in write mode
def writer(df):
    with open(f"{cl}Data/derived/length_output.csv", "w") as f:
        # Create a csv writer object
        writer = csv.writer(f)
        for i in df.index:
            writer.writerow(df.iloc[i,:])

writer(df_all_lens)

# Step 4: Set the type of y_axis in the plots True for probability density, False for raw counts
case= "density"
if case == "density":
    density= True
elif case == "frequency":
    density= False
else:
    print("case supplied must either be density or frequency")




########################################################################################################################
# St 5: Functions
def viz2(attribute,bw_method,df_1,df_2,data_name1,data_name2,col1,col2,ax):
    ''' attribute: The attribute that is plotted can be isoelectric_point or aromaticity
        bw_method: the smoothness of the curve, larger number result in smoother curves.
        df_1 (DataFrame): The first dataframe containing sequences in the 'seq' column.
        df_2 (DataFrame): The second dataframe containing sequences in the 'seq' column.
        data_name1 (str): The name of the first dataset to be used as a label in the plot.
        data_name2 (str): The name of the second dataset to be used as a label in the plot.
        col1 (str): The color of the first histogram.
        col2 (str): The color of the second histogram.
        ax (Axes): The axes object to plot on.
    Returns:
        plt (Figure): The matplotlib figure object with the histograms of isoelectric point of the proteins.'''
    values1 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_1.seq]
    values2 = [getattr(ProteinAnalysis(seq), attribute)() for seq in df_2.seq]
    x_range1 = np.linspace(min(values1+values2), max(values1+values2), 1000)
    x_range2 = np.linspace(min(values1 + values2), max(values1 + values2), 1000)
    if attribute == 'aromaticity':
        x_range1 = np.linspace(0, 0.3, 1000)
        x_range2 = np.linspace(0, 0.3, 1000)
    else:
        ax.set_ylim(0, 0.5)
    kde1= gaussian_kde(values1, bw_method=bw_method)
    kde2 = gaussian_kde(values2, bw_method=bw_method)
    ax.plot(x_range1, kde1(x_range1), label=data_name1, color=col1, linewidth=2, linestyle="-")
    ax.plot(x_range2, kde2(x_range2), label=data_name2, color=col2, linewidth=2, linestyle="--")
    ax.legend(loc='upper right')
    return fig

def calculate_kde(sequences, bw_method, attribute):
    """ Generalized KDE functions:
    arguments:
    This function calculates the kernel density estimation (KDE) on an attribute that was calculated for a list of sequences.
    KDE is a smooth curve that represents the distribution an attribute.
    attribute: can be isoelectric_point or aromaticity (as provided in the ProteinAnalysis package)
    bw_method: is the level of detail in which the kwd line is drawn
    returns:
    a kernel density estimation either on aromaticity or pI """
    values = [getattr(ProteinAnalysis(seq), attribute)() for seq in sequences]
    return gaussian_kde(values, bw_method=bw_method)

def plot_data(ax, kde, x_range, label, color, linestyle):
    """ This function draws a line to visualize a kernel density estimation
    Parameters:
    - ax: Where on the axes to draw the line.
    - kde: The kernel density estimation.
    - x_range: The range of values for the chosen attribute (pi or aromaticity).
    - label: The name for this group of data (toxins, controls..)
    - color: The color of the line.
    - linestyle: The style of the line (solid, dashed, etc.)."""
    ax.plot(x_range, kde(x_range), label=label, color=color, linewidth=2, linestyle=linestyle)

def save_plot(fig, filename, axis_name):
    plt.legend(loc='upper right')
    plt.xlabel(axis_name)
    plt.ylabel("density")
    plt.rcParams['savefig.dpi'] = 300
    plt.savefig(f"{cl}Figures/{filename}", bbox_inches="tight")



########################################################################################################################
#Step 6: ACTUALLY CREATING THE SHARED PLOTS
# Step 6.1: isoelectric point visualization in 1by2 plot
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A"],[ "B"]],layout='constrained')
viz2('isoelectric_point',0.075,bt100,bc100,"bacterial_toxins","bacterial_controls","#2156B5","#61BDD2",axs['A'])
axs['A'].set_ylabel(case)
viz2('isoelectric_point',0.075,at100,ac100,"animal_toxins","animal_controls","#B1041B","#EABA49",axs["B"])
axs['B'].set_ylabel(case)
axs['B'].set_xlabel("isoelectric point")
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='medium', verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_pIs_1by2.png", bbox_inches="tight")

# Step 6.2: isoelectric point visualization in 1by2 plot
plt.style.use("seaborn")
fig,  axs= plt.subplot_mosaic([["A"],[ "B"]],layout='constrained')
viz2('aromaticity',0.1,bt100,bc100,"bacterial_toxins","bacterial_controls","#2156B5","#61BDD2",axs['A'])
axs['A'].set_ylabel(case)
viz2('aromaticity',0.1,at100,ac100,"animal_toxins","animal_controls","#B1041B","#EABA49",axs["B"])
axs['B'].set_ylabel(case)
axs['B'].set_xlabel("aromaticity")
plt.rcParams['savefig.dpi'] = 300
for label, ax in axs.items():
    trans = mtransforms.ScaledTranslation(10/72, -5/72, fig.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='medium', verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
plt.savefig(f"{cl}Figures/shared_aro_1by2.png", bbox_inches="tight")


# save the listed_sequences, the colors and the labels in dictionaries for better accession.
seq_lists = {'bt100': bt100.seq,'bc100': bc100.seq,'at100': at100.seq,'ac100': ac100.seq}
colors = {'bt100': "#2156B5",'bc100': "#61BDD2",'at100': "#B1041B",'ac100': "#EABA49"}
labels = {'bt100': 'Bacterial toxins','bc100': 'Bacterial control','at100': 'Animal toxins','ac100': 'Animal control'}
line_style={'bt100': '-','bc100': '-','at100': '--','ac100': '--'}

# Define attributes that are calculated on the protein sequences (pI or aromaticty) and the fineness of the resulting kde line
attributes_bw = [('isoelectric_point', 0.05), ('isoelectric_point', 0.1), ('aromaticity', 0.1)]

# Loop over attribute and bw_method simultaneously
for attribute, bw_method in attributes_bw:
    # Dictionary comprehension to get the kde for each of the four data sets in seq_list
    kdes = {key: calculate_kde(value, bw_method, attribute) for key, value in seq_lists.items()}
    # Make a list of all possible values of an attribute
    all_values = [getattr(ProteinAnalysis(seq), attribute)() for sequences in seq_lists.values() for seq in sequences]
    # Calculate the range and then divide in 1000 steps
    x_range = np.linspace(min(all_values), max(all_values), 1000)
    if attribute== 'aromaticity':
        x_range=np.linspace(0, 0.3, 1000)
    # Plot each data set, with its colors, label entries
    fig, ax = plt.subplots()
    for key in seq_lists:
        plot_data(ax, kdes[key], x_range, labels[key], colors[key],line_style[key])
    # The Name under which the plot is saved
    filename = f"shared_{attribute}_{bw_method}.png"
    # the axis name recycles the attribute name but replaces the dash with a space
    axis_name = attribute.replace('_', ' ')
    # the plot is saved
    save_plot(fig, filename, axis_name)




#log the results
print(f"#####################################",file=explore_file)
print(f"data SST{sst_level1}")
# CLose all files
out_file.close()
explore_file.close()
plt.close("all")



#