import pandas as pd
import os
import math
from sklearn.metrics import matthews_corrcoef, accuracy_score, precision_score, recall_score
from random import seed, choice
from numpy import std, mean
from support_functions_splitting_predictor import bootstrap_metric

# === Load Foldseek results ===
df = pd.read_csv(
    "../../Data/derived/foldseek_baseline/foldseek_top1_per_query_foldseek_results_reasonable.tsv", 
    sep="\t"
)

# === Paths to toxin and control folders ===
toxin_folder = "../../Data/raw/exotoxins/rank_1"
control_folder = "../../Data/raw/secreted/rank_1"

# === Get protein IDs including .pdb extension ===
toxin_ids = {f for f in os.listdir(toxin_folder) if f.endswith(".pdb")}
control_ids = {f for f in os.listdir(control_folder) if f.endswith(".pdb")}

# === Check for overlap ===
assert toxin_ids.isdisjoint(control_ids), "Some PDBs appear in both toxin and control folders!"

# === Classify each pair ===
TP = TN = FP = FN = 0
missing = []
classifications = []

for _, row in df.iterrows():
    query = row['query']
    target = row['target']

    in_toxin_q = query in toxin_ids
    in_control_q = query in control_ids
    in_toxin_t = target in toxin_ids
    in_control_t = target in control_ids

    if not (in_toxin_q or in_control_q) or not (in_toxin_t or in_control_t):
        missing.append((query, target))
        classifications.append("missing")
        continue

    if in_toxin_q and in_toxin_t:
        TP += 1
        classifications.append("TP")
    elif in_control_q and in_control_t:
        TN += 1
        classifications.append("TN")
    elif in_control_q and in_toxin_t:
        FP += 1
        classifications.append("FP")
    elif in_toxin_q and in_control_t:
        FN += 1
        classifications.append("FN")
    else:
        classifications.append("undefined")

df["classification"] = classifications
df.to_csv("classified_hits.tsv", sep="\t", index=False)

# === Print confusion matrix ===
print("Confusion matrix summary:")
print(f"True Positives (TP): {TP}")
print(f"True Negatives (TN): {TN}")
print(f"False Positives (FP): {FP}")
print(f"False Negatives (FN): {FN}")
print(f"Missing proteins     : {len(missing)}")

# === Safe MCC function ===
def safe_mcc(tp, tn, fp, fn):
    numerator = (tp * tn) - (fp * fn)
    denominator = math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    if denominator == 0:
        return 0
    return numerator / denominator

mcc = safe_mcc(TP, TN, FP, FN)
print(f"Matthews Correlation Coefficient (manual MCC): {mcc:.4f}")

# === Bootstrapped metric evaluation ===

# Build label dictionary
label_dict = {f: 1 for f in toxin_ids}
label_dict.update({f: 0 for f in control_ids})

valid_rows = df[
    df["query"].isin(label_dict) &
    df["target"].isin(label_dict) &
    (df["classification"] != "missing")
]

y_true = valid_rows["query"].map(label_dict).values
y_pred = valid_rows["target"].map(label_dict).values

print("\nBootstrapped metrics (n=1000):")
print("MCC and SE:")
print(bootstrap_metric(metric=matthews_corrcoef, y_true=y_true, y_pred=y_pred, n_boot=10000))

print("Accuracy and SE:")
print(bootstrap_metric(metric=accuracy_score, y_true=y_true, y_pred=y_pred, n_boot=10000))

print("Precision and SE:")
print(bootstrap_metric(metric=precision_score, y_true=y_true, y_pred=y_pred, n_boot=10000))

print("Recall and SE:")
print(bootstrap_metric(metric=recall_score, y_true=y_true, y_pred=y_pred, n_boot=10000))
