#!/usr/bin/env python3
# coding: utf-8


######################################################################################################################
#Aim this file visualizes the results of a fold cluster ananysis
#Input: the file of the original IDs that belong to protein group1
#Input: the file of the original IDs that belong to protein group2
#Input: the file of the clustered proteins as tsv at 100FST
#Input: the file of the clustered proteins as tsv at 50FST
#Input: the file of the clusterd proteins as tsv at25FST
#Output:  the protein cluster numbers of different threshholds ofFST
#Output: the protein number per cluster
#Output: the puritiy per cluster
########################################################################################################################
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re, argparse, csv, collections,random
from datetime import datetime
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from Bio import SeqIO
import operator

########################################################################################################################
parser = argparse.ArgumentParser(prog="fold_cluster_analysis",
                                 description="visualized the fold clustering results")
parser.add_argument("IDs1",
                    type=str,
                    help=" file containing the IDs of group1 ")
parser.add_argument("IDs2",
                    type=str,
                    help=" file containing the IDs of group2 ")
parser.add_argument("r100",
                    type=str,
                    help="  100% fold similarity  ")
parser.add_argument("r50",
                    type=str,
                    help="  50% fold similarity  ")
parser.add_argument("r25",
                    type=str,
                    help="  25% fold similarity")

args = parser.parse_args()
cl = ""
cl = "../../"

########################################################################################################################
# Open and write to the log file
out_file = open(f"{cl}Data/derived/log.log","a")
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
# Write to the log file
print("##########",
      file=out_file)
print(f"program {parser.prog} was executed at {dt_string}",
      file=out_file)
print(f"program {parser.prog} was executed at {dt_string}")
print(f"number of required arguments: 5",
      file=out_file)
########################################################################################################################
# Check which datasets are passed: a unique color scale is chosen
if "toxin" in args.IDs1:
    if "animal" in args.IDs1:
        data_name1="animal_toxins"
        col1="#B1041B"
    else:
        data_name1="bacterial_toxins"
        col1 = "#2156B5"
else:
    if "animal" in args.IDs1:
        data_name1 = "animal_control"
        col1 = "#EABA49"
    else:
        data_name1 = "bacterial_control"
        col1 = "#61BDD2"

if "toxin" in args.IDs2:
    if "animal" in args.IDs2:
        data_name2="animal_toxins"
        col2 = "#B1041B"
    else:
        data_name2="bacterial_toxins"
        col2 = "#2156B5"
else:
    if "animal" in args.IDs2:
        data_name2 = "animal_control"
        col2 = "#EABA49"
    else:
        data_name2 = "bacterial_control"
        col2 = "#61BDD2"
########################################################################################################################

# Open the files of different fold similary thresholds 100,50 and 25
r100=pd.read_csv(args.r100,sep="\t",header=None)
r50=pd.read_csv(args.r50,sep="\t",header=None)
r25=pd.read_csv(args.r25,sep="\t",header=None)
IDs=pd.read_csv(args.IDs1,header=None)

# Check that the provided files have the fold reduction level in the right order




# Plot the differnt numbers of clusters that the different FST result in.
count_unreduced=len(set(r100.iloc[:,1]))
count100=len(set(r100.iloc[:,0]))
count50=len(set(r50.iloc[:,0]))
count25=len(set(r25.iloc[:,0]))
plt.style.use("seaborn")
fig, ax=plt.subplots()
plt.bar(["unreduced","100%\nfoldsimilarity reduced","50%\nfoldsimilarity reduced","25%\nfoldsimilarity reduced"],[count_unreduced,count100,count50,count25])
plt.title(f"{data_name1} and {data_name2}  fold clustering")
plt.ylabel("number of clusters")
plt.xlabel("fold similarity reduction ")
plt.show()

# Loop over the unique entries in the clusters representatives
cluster_sizes_100=[operator.countOf(r100.iloc[:,0],i) for i in set(r100.iloc[:,0])]
cluster_sizes_50=[operator.countOf(r50.iloc[:,0],i) for i in set(r50.iloc[:,0])]
cluster_sizes_25=[operator.countOf(r25.iloc[:,0],i) for i in set(r25.iloc[:,0])]
dict_100=[[i, operator.countOf(cluster_sizes_100,i)] for i in set(cluster_sizes_100)]
dict_50=[[i, operator.countOf(cluster_sizes_50,i)] for i in set(cluster_sizes_50)]
dict_25=[[i, operator.countOf(cluster_sizes_25,i)] for i in set(cluster_sizes_25)]


n_equal_bins = 100
bin_edges = np.linspace(start=1, stop=max(cluster_sizes_25),num=n_equal_bins + 1, endpoint=True)
plt.style.use("seaborn")
fig, ax=plt.subplots()
plt.hist(cluster_sizes_25,bin_edges,alpha=0.7,label="fold similarity cluster 25%")
plt.hist(cluster_sizes_50,bin_edges,alpha=0.7,label="fold similarity cluster 50%")
plt.hist(cluster_sizes_100,bin_edges,alpha=0.7,label="fold similarity cluster 100%")
plt.title(f"{data_name1} and {data_name2}  fold clustering")
plt.ylabel("log(number)")
plt.yscale('log')
plt.legend()
plt.xlabel("number of proteins per clusters ")
plt.show()

purity_100=[]
purity_50=[]
purity_25=[]

# Calculate the purity of each cluster
for i in r100[0].unique():
    subset = r100[r100[0] == i]
    c = sum(subset[1].str.extract("(.*).pdb")[0].isin(IDs[0]))
    purity = c / len(subset)
    purity_100.append(purity)

for i in r50[0].unique():
    subset = r50[r50[0] == i]
    c = sum(subset[1].str.extract("(.*).pdb")[0].isin(IDs[0]))
    purity = c / len(subset)
    purity_50.append(purity)

for i in r25[0].unique():
    subset = r25[r25[0] == i]
    c = sum(subset[1].str.extract("(.*).pdb")[0].isin(IDs[0]))
    purity = c / len(subset)
    purity_25.append(purity)



n_equal_bins = 25
bin_edges = np.linspace(start=0, stop=1,num=n_equal_bins + 1, endpoint=True)
plt.style.use("seaborn")
fig, ax=plt.subplots(fig_size=())
plt.hist(purity_100,bin_edges,alpha=0.7,label="fold similarity cluster 100%")
plt.hist(purity_50,bin_edges,alpha=0.7,label="fold similarity cluster 50%")
plt.hist(purity_25,bin_edges,alpha=0.7,label="fold similarity cluster 25%")
plt.title(f"{data_name1} and {data_name2}  fold clustering")
plt.ylabel("log(number)")
plt.xlabel("purity per clusters ")
plt.legend()
plt.show()

# Repeat the visualization with the singles


plt.close("all")
out_file.close()

#naming extraction what was provided

