#!/usr/bin/env python3
# coding: utf-8


########################################################################################################################
#Author: Tanja Krüger
#Aim: This file shows the average length, the PI and the aromaticity of two protein files in fasta format
#Input: fasta file: animal toxins
#Input: fasta file: control toxins
#Input: fasta file: bacterial toxins
#Input: fasta file: bacterial control
#Outpout: a serires of files that show the length distribution, the pI, instability, aromaticity distribution


########################################################################################################################
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re, argparse, csv, collections,random
from datetime import datetime
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from Bio import SeqIO

# #################################################################################################
# Option depending where the user wants the run the code form, default running the code with make from the project folder.
cl=""
# If one wants to execute this file from the Code/python folder uncomment the next line.
cl="../../"

########################################################################################################################
# Get the arguments from the command line.
parser = argparse.ArgumentParser(prog="data_analysis_7.py",
                                 description="visualizing protein length, PI and aromaticity between two protein sets in shared plots")
parser.add_argument("f1",
                    type=str,
                    help="fasta file animal toxins")
parser.add_argument("f2",
                    type=str,
                    help="fasta file animal controls")
# parser.add_argument("f3",
#                     type=str,
#                     help="fasta file bacterial toxins")
# parser.add_argument("f4",
#                     type=str,
#                     help="fasta file bacterial controls")
args = parser.parse_args()


########################################################################################################################
# Exract if data is reduced
try:
    sst_level1 = re.search("SST(\d+)",args.f1).group(1) #mmseqs2 reduction level
    sst_level2 = re.search("SST(\d+)", args.f2).group(1)  # mmseqs2 reduction level
except: sst_level1="full unreduced"
data_name1=""
# Extract data_name
# Series of if statments needed as no real name pattern in the input files and set the color used
def name_extractor(set):
    if "toxin" in set:
        if "animal" in set:
            data_name = "animal toxins"
            col = "#B1041B"
        else:
            data_name = "bacterial toxins"
            col = "#2156B5"
    else:
        if "animal" in set:
            data_name = "animal control"
            col = "#EABA49"
        else:
            data_name = "bacterial control"
            col = "#61BDD2"
    return data_name, col

def length_comp(set1,set2):
    data_name1, col1= name_extractor(set1)
    data_name2, col2=name_extractor(set2)


# Open the predictor logfile and the general logfile.
out_file = open(f"{cl}Data/derived/log.log", "a")
explore_file=open(f"{cl}Exploratory/smd_{sst_level1}.log","a")

# Get the date and time
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")

# Write to the predictor log file and the general logfile
print(f"""########## \n
program {parser.prog} was executed at {dt_string} \n
program {parser.prog} was executed at {dt_string} \n
argument passed: {args.f1} and {args.f2}",\n
number of required arguments:2 """,file=out_file)

########################################################################################################################
# Open the fasta file and modify it to a useable dataframe.
with open(args.f1) as handle:
    df_1 = pd.DataFrame({record.id: [record.description, str(record.seq)]  for record in SeqIO.parse(handle, "fasta")}).T
with open(args.f2) as handle:
    df_2 = pd.DataFrame(
        {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
df_1.columns,df_2.columns=["info","seq"],["info","seq"]


#  Set the type of y_axis in the plots True for probability density, False for raw counts
case= "density"
if case == "density":
    density= True
elif case == "frequency":
    density= False
else:
    print("case supplied must either be density or frequency")

#Define a function that calculates the standardized mean difference
def smd_calc(pop1,pop2):
    """ This function return the standardized median difference for two different populations"""
    mean_pop1=np.median(pop1)
    mean_pop2=np.median(pop2)
    abs_mean_dif=abs(mean_pop2-mean_pop1)
    var_pop1=np.var(pop1)
    var_pop2=np.var(pop2)
    sqrt_var=np.sqrt(var_pop1+var_pop2)
    smd=abs_mean_dif/sqrt_var
    return(smd)

def length_comp(set1,set2):
    data_name1, col1= name_extractor(set1)
    data_name2, col2=name_extractor(set2)
    # Open the fasta file and modify it to a useable dataframe.
    with open(set1) as handle:
        df_1 = pd.DataFrame(
            {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
    with open(set2) as handle:
        df_2 = pd.DataFrame(
            {record.id: [record.description, str(record.seq)] for record in SeqIO.parse(handle, "fasta")}).T
    df_1.columns, df_2.columns = ["info", "seq"], ["info", "seq"]
    #  Visualize the length of the data
    lens_1=[len(i) for i in df_1.seq]
    lens_2=[len(i)for i in df_2.seq]
    smd_len=smd_calc(lens_1,lens_2)
    bin_size = 50 # set bin size
    bins = np.arange(min(lens_1 + lens_2), max(lens_1 +lens_2) + bin_size, bin_size)
    plt.style.use("seaborn")
    fig, ax = plt.subplots()
    ax.hist(lens_1,bins=bins,label=f"{data_name1}",alpha=0.7,color=col1,density=density)
    ax.hist(lens_2,bins=bins,label=f"{data_name2}",alpha=0.7,color=col2,density=density)
    ax.set_xlabel(" length of protein",fontsize=15)
    ax.set_ylabel(case,fontsize=15)
    fig.suptitle(f"protein length",fontsize=20)
    plt.title(f"standardized median difference: {round(smd_len,2)}",fontsize=15)
    plt.legend(fontsize=15)
    plt.rcParams['savefig.dpi'] = 300
    plt.tight_layout()
    plt.savefig(f"{cl}Figures/SST{sst_level1}_length_distribution_{case}_{data_name1}_{data_name2}.png",bbox_inches="tight")
    return fig

plt.show()

# #  Amino acid ratios
# colnames=['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q','R', 'S', 'T', 'V', 'W', 'Y']
# list_of_ratios1=[[sequence.count(i)/len(sequence) for i in colnames] for sequence in df_1.seq]
# list_of_ratios2=[[sequence.count(i)/len(sequence) for i in colnames] for sequence in df_2.seq]
# #  Transform the nested list into a dataframe.
# # attention: some sequences can contain other letters then the one in colnames (often X for unknown amino acid) then the
# # calculated ration over all 20 amino acid does not add up to 1.
# df_ratios_per_aa1=pd.DataFrame(list_of_ratios1, columns=colnames)
# df_ratios_per_aa2=pd.DataFrame(list_of_ratios2, columns=colnames)
# # Visualize
# bin_size = 0.01 # set bin size
# bins = np.arange(0,0.2 + bin_size, bin_size)
# fig, ax = plt.subplots(4,5,figsize=(17,17))
# for num,i in enumerate(colnames):
#     yi=num%5 # location vertically
#     xi=num//5 # location horizontally
#     ax[xi, yi].hist(df_ratios_per_aa1.loc[:,i],bins=bins,alpha=0.7,color=col1,density=density) # histogram with 40 bins
#     ax[xi,yi].hist(df_ratios_per_aa2.loc[:,i],bins=bins,alpha=0.7,color=col2,density=density)
#     ax[xi,yi].set_title(i,fontsize=15) # title
#     ax[xi,yi].set_ylabel(case,fontsize=15)
# plt.rcParams['savefig.dpi'] = 300
# fig.suptitle(f'relative frequency of each amino acid for SST{sst_level1} data\n y_axis different scales',fontsize=20) # title over all plots
# plt.savefig(f"{cl}Figures/SST{sst_level1}_amino_acid_dist_{case}_{data_name1}_{data_name2}.png",bbox_inches="tight")
#
#
#
# # Step 4: aromaticity .
# list_of_aro1=[ProteinAnalysis(sequence).aromaticity() for sequence in df_1.seq]
# list_of_aro2=[ProteinAnalysis(sequence).aromaticity() for sequence in  df_2.seq]
# smd_aro=smd_calc(list_of_aro1,list_of_aro2)
# bin_size = 0.01 # set bin size
# bins = np.arange(min(list_of_aro1 + list_of_aro2), max(list_of_aro1 +list_of_aro2) + bin_size, bin_size)
# fig, ax = plt.subplots()
# ax.hist(list_of_aro1,bins=bins,label=f"{data_name1}",alpha=0.7,color=col1,density=density)
# ax.hist(list_of_aro2,bins=bins,label=f"{data_name2}",alpha=0.7,color=col2,density=density)
# ax.set_xlabel(" aromaticity of protein",fontsize=15)
# ax.set_ylabel(case,fontsize=15)
# fig.suptitle(f"protein aromaticity ",fontsize=20)
# plt.title(f"standardized median difference: {round(smd_aro,2)}",fontsize=15)
# plt.legend(fontsize=15)
# plt.rcParams['savefig.dpi'] = 300
# plt.tight_layout()
# plt.savefig(f"{cl}Figures/SST{sst_level1}_aromaticity_dist_{case}_{data_name1}_{data_name2}.png",bbox_inches="tight")
#
# #Step 5: Isoelectric point PI
# list_of_pi1=[ProteinAnalysis(sequence).isoelectric_point() for sequence in df_1.seq]
# list_of_pi2=[ProteinAnalysis(sequence).isoelectric_point() for sequence in df_2.seq ]
# smd_pI=smd_calc(list_of_pi1,list_of_pi2)
# bin_size = 0.15 # set bin size
# bins = np.arange(min(list_of_pi1 + list_of_pi2), max(list_of_pi1 +list_of_pi2) + bin_size, bin_size)
# fig, ax = plt.subplots()
# ax.hist(list_of_pi1,bins=bins,label=f"{data_name1}",alpha=0.7,color=col1,density=density)
# ax.hist(list_of_pi2,bins=bins,label=f"{data_name2}",alpha=0.7,color=col2,density=density)
# ax.set_xlabel(" PI of protein")
# ax.set_ylabel(case)
# fig.suptitle(f"protein isoelectric point",fontsize=20)
# plt.title(f"standardized median difference: {round(smd_pI,2)}",fontsize=15)
# plt.legend(fontsize=15)
# plt.rcParams['savefig.dpi'] = 300
# plt.tight_layout()
# plt.savefig(f"{cl}Figures/SST{sst_level1}_pi_dist_{case}_{data_name1}_{data_name2}.png",bbox_inches="tight")



# #log the results
# print(f"#####################################",file=explore_file)
# print(f"data SST{sst_level1}")
# print(f"smd_pI: {smd_pI}",file=explore_file)
# print(f"smd_aro: {smd_aro}",file=explore_file)
# print(f"smd_len: {smd_len}",file=explore_file)

out_file.close()
explore_file.close()
plt.close("all")