import os
from pathlib import Path
from shutil import copy2
from Bio import SeqIO

def load_fasta_ids(fasta_path):
    """
    Extracts sequence IDs from a FASTA file.
    The first word of each header (after '>') is used.
    """
    ids = set()
    for record in SeqIO.parse(fasta_path, "fasta"):
        ids.add(record.id)
    return ids

def copy_matching_pdbs(pdb_folder, target_ids, output_folder, log_missing=True):
    """
    Copies PDB files from pdb_folder to output_folder if their basename (without .pdb)
    matches an ID in target_ids.
    """
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    found = 0
    missing = []

    for pdb_file in Path(pdb_folder).glob("*.pdb"):
        pdb_id = pdb_file.stem  # filename without .pdb
        if pdb_id in target_ids:
            copy2(pdb_file, output_folder / pdb_file.name)
            found += 1
        else:
            missing.append(pdb_id)

    print(f"Copied {found} matching PDB files to: {output_folder}")
    if log_missing:
        print(f"{len(target_ids - set(p.stem for p in output_folder.glob('*.pdb')))} PDBs missing from input folder")

def main(pdb_dir, fasta_a, fasta_b, out_dir_a, out_dir_b):
    print("Loading identifiers from FASTA files...")
    ids_a = load_fasta_ids(fasta_a)
    ids_b = load_fasta_ids(fasta_b)

    print("Copying group A PDBs...")
    copy_matching_pdbs(pdb_dir, ids_a, out_dir_a)

    print("Copying group B PDBs...")
    copy_matching_pdbs(pdb_dir, ids_b, out_dir_b)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Split a folder of PDB files based on FASTA groupings.")
    parser.add_argument("--pdb_dir", required=True, help="Directory containing .pdb files")
    parser.add_argument("--fasta_a", required=True, help="FASTA file for group A")
    parser.add_argument("--fasta_b", required=True, help="FASTA file for group B")
    parser.add_argument("--out_dir_a", default="groupA_pdbs", help="Output directory for group A PDBs")
    parser.add_argument("--out_dir_b", default="groupB_pdbs", help="Output directory for group B PDBs")

    args = parser.parse_args()

    main(args.pdb_dir, args.fasta_a, args.fasta_b, args.out_dir_a, args.out_dir_b)
