import pandas as pd
import argparse

def process_foldseek_results(input_path: str, output_path: str):
    # Define expected column names in Foldseek output
    cols = [
        "query", "target", "lddt", "lddtfull", "alntmscore", "qtmscore", "ttmscore",
        "alnlen", "qcov", "tcov", "prob", "evalue"
    ]

    # Load the TSV file
    df = pd.read_csv(input_path, sep="\t", names=cols)
    print(f" Total alignments loaded: {len(df)}")
    print(f" Unique queries in results: {df['query'].nunique()}")

    # Sort by alignment score and keep only the top hit per query
    top_hits = (
        df.sort_values(by="prob", ascending=False)
          .drop_duplicates(subset="query", keep="first")
          .reset_index(drop=True)
    )

    print(f"Top hits retained: {len(top_hits)}")

    # Save result
    top_hits.to_csv(output_path, sep="\t", index=False)
    print(f" Top hits saved to: {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract top hits per query from Foldseek result file.")
    parser.add_argument("--input", required=True, help="Path to Foldseek .tsv result file")
    parser.add_argument("--output", required=True, help="Output path to save top hits TSV")

    args = parser.parse_args()
    process_foldseek_results(args.input, args.output)
