import re
import numpy as np
import pandas as pd
from pathlib import Path
from netmhciipan_4_3_executable import single_prediction

def presumed_tool_label(allele):
    tool_label = allele.replace("/","-")  
    if tool_label.startswith('HLA-'):
        tool_label = tool_label.replace('HLA-', '')

    return tool_label

def clean_mhcii_alleles(allele):
    tool_label = allele.replace("/","-")  
    if tool_label.startswith('HLA-'):
        tool_label = tool_label.replace('HLA-', '')

    return re.sub(r'[^a-zA-Z0-9\s]', '', tool_label)


def add_allelelist():
    DATA_DIR = str(Path(__file__).parent.parent)
    TM_FILE = DATA_DIR + '/tools-mapping.tsv'
    MOL_FILE = DATA_DIR + '/mro_molecules.tsv'
    tm_df = pd.read_csv(TM_FILE, sep='\t', index_col=False)
    mol_df = pd.read_csv(MOL_FILE, sep='\t', index_col=False)
    mhcii_mol_df = mol_df[(mol_df['Parent']=='MHC class II protein complex')]
    mhcii_mol_label_list = mhcii_mol_df['IEDB Label'].values
    cleaned_mol_alleles = [clean_mhcii_alleles(_) for _ in mhcii_mol_label_list]
    mhcii_mol_df['cleaned_labels'] = cleaned_mol_alleles
    mhcii_mol_df = mhcii_mol_df.reset_index(drop=True)


    '''
    Filter to only get the netmhciipan alleles (all v4.1 and v4.2). From there,
    check if allelelist file alleles overlap with pre-existing ones from the tools-mapping.
    '''
    netmhciipan_tm_df = tm_df[(tm_df['Tool']=='netmhciipan')]
    netmhciipan_tm_list = netmhciipan_tm_df['Tool Label'].values
    print(f'Number of NetMHCIIpan alleles \nfrom Tools-Mapping file: {len(netmhciipan_tm_list)}')


    ALLELELIST_FILE = DATA_DIR + '/netmhciipan-4.3/allelelist.txt'
    with open(ALLELELIST_FILE, 'r') as f:
        alleles = [_.strip().split(' ')[0] for _ in f.readlines()]
        # print(alleles[:10])
        print(f'Number of alleles in allelelist: {len(alleles)}')
    
    print(f'The difference is {len(netmhciipan_tm_list) - len(alleles)} alleles.')

    cleaned_tm_alleles = [clean_mhcii_alleles(_) for _ in netmhciipan_tm_list]
    cleaned_allelelist_alleles = [clean_mhcii_alleles(_) for _ in alleles]
    allele_clean_dict = dict(zip(cleaned_allelelist_alleles, alleles))
    netmhciipan_tm_df['cleaned_labels'] = cleaned_tm_alleles

    # reindex the df
    netmhciipan_tm_df = netmhciipan_tm_df.reset_index(drop=True)
    print(netmhciipan_tm_df.head())


    clean_tm_alleles_arr = np.array(cleaned_tm_alleles)
    clean_tm_alleles_set = set(clean_tm_alleles_arr)
    clean_allelelist_alleles_arr = np.array(cleaned_allelelist_alleles)
    clean_allelelist_alleles_set = set(clean_allelelist_alleles_arr)

    clean_alleles_not_found = [al_allele for al_allele in clean_allelelist_alleles_arr if al_allele not in clean_tm_alleles_set]
    clean_alleles_found = [al_allele for al_allele in clean_allelelist_alleles_arr if al_allele in clean_tm_alleles_set]
    print(f'Number of alleles from allelelist file that is not found in tools-mapping: {len(clean_alleles_not_found)}')
    print(f'Number of alleles from allelelist file that is found in tools-mapping: {len(clean_alleles_found)}')
    print(f'Total alleles: {len(clean_alleles_found) + len(clean_alleles_not_found)}')

    # Get the original label from allelelist of these cleaned-stripped-down alleles
    alleles_not_found = [allele_clean_dict[_] for _ in clean_alleles_not_found]
    alleles_found = [allele_clean_dict[_] for _ in clean_alleles_found]
    
    # Iterate over all the known alleles and get their indices.
    # Use the indices to grab MROID. There are total of 5628 alleles.
    known_alleles_rows = []
    for allele in clean_alleles_found:
        # allele = clean_alleles_found[i]
        idx = np.where(netmhciipan_tm_df['cleaned_labels'].values == allele)[0][0]
        tlabel = netmhciipan_tm_df.loc[idx, 'Tool Label']
        mroid = netmhciipan_tm_df.loc[idx, 'MRO ID']
        known_alleles_rows.append({
            'Tool Group': 'mhcii',
            'Tool': 'netmhciipan',
            'Tool Version': '4.3',
            'Tool Label': tlabel,
            'MRO ID': mroid,
            'Lengths': '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30',
        })



    unknown_alleles_rows = []
    completely_unknown_alleles = []
    mol_counter = 0
    non_mol_counter = 0
    '''
    Running this code shows that 2071 of the unknown alleles from the allelelist file 
    can be found in the mro_molecules file.
    Rest of 3349 alleles are still unknown.
    '''
    for allele in clean_alleles_not_found:
        found_in_mol = False
        idx = np.where(netmhciipan_tm_df['cleaned_labels'].values == allele)[0]

        if 0 == idx.size :
            # Try searching in the molecules file
            idx = np.where(mhcii_mol_df['cleaned_labels'].values == allele)[0]
            found_in_mol = True

        if 0 == idx.size :
            completely_unknown_alleles.append(allele_clean_dict[allele])
            continue

        idx = idx[0]

        if found_in_mol:
            mol_counter += 1
            iedblabel = mhcii_mol_df.loc[idx, 'IEDB Label']
            tlabel = presumed_tool_label(iedblabel)
            # tlabel2 = allele_clean_dict[allele]
            mroid = mhcii_mol_df.loc[idx, 'MRO ID']
            unknown_alleles_rows.append({
                'Tool Group': 'mhcii',
                'Tool': 'netmhciipan',
                'Tool Version': '4.3',
                'Tool Label': tlabel,
                # 'Tool Label2': tlabel2,
                'MRO ID': mroid,
                'Lengths': '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30',
            })
        else:
            non_mol_counter += 1
            tlabel = netmhciipan_tm_df.loc[idx, 'Tool Label']
            mroid = netmhciipan_tm_df.loc[idx, 'MRO ID']
            unknown_alleles_rows.append({
                'Tool Group': 'mhcii',
                'Tool': 'netmhciipan',
                'Tool Version': '4.3',
                'Tool Label': tlabel,
                'MRO ID': mroid,
                'Lengths': '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30',
            })

    UNKNOWN_ALLELES_FILE = DATA_DIR + '/netmhciipan-4.3/allelelist.unknown.txt'
    with open(UNKNOWN_ALLELES_FILE, 'w') as f:
        for allele in completely_unknown_alleles:
            f.write(allele + '\n')

    # print(len(completely_unknown_alleles))
    # print(len(unknown_alleles_rows))

    for row in known_alleles_rows:
        tm_df = pd.concat([tm_df, pd.DataFrame([row], columns=tm_df.columns)], ignore_index=True)

    for row in unknown_alleles_rows:
        tm_df = pd.concat([tm_df, pd.DataFrame([row], columns=tm_df.columns)], ignore_index=True)

    tm_df.to_csv(f'{DATA_DIR}/tools-mapping.tsv', sep='\t', index=False)

    # print(len(unknown_alleles_rows))
    # print(len(known_alleles_rows))
    # print(len(known_alleles_rows) + len(unknown_alleles_rows))

def check_tm():
    # Total of 7699 allele rows should have been added to tools-mapping.
    DATA_DIR = str(Path(__file__).parent.parent)
    TM_FILE = DATA_DIR + '/tools-mapping.tsv'
    tm_df = pd.read_csv(TM_FILE, sep='\t', index_col=False)
    netmhciipan_tm_df = tm_df[(tm_df['Tool']=='netmhciipan') & (tm_df['Tool Version']==4.3)]
    print(len(netmhciipan_tm_df)==7699)


def add_allele_dot_list():
    DATA_DIR = str(Path(__file__).parent.parent)
    TM_FILE = DATA_DIR + '/tools-mapping.tsv'
    MOL_FILE = DATA_DIR + '/mro_molecules.tsv'
    tm_df = pd.read_csv(TM_FILE, sep='\t', index_col=False)
    mol_df = pd.read_csv(MOL_FILE, sep='\t', index_col=False)
    mhcii_mol_df = mol_df[(mol_df['Parent']=='MHC class II protein complex')]
    mhcii_mol_label_list = mhcii_mol_df['IEDB Label'].values
    cleaned_mol_alleles = [clean_mhcii_alleles(_) for _ in mhcii_mol_label_list]
    mhcii_mol_df['cleaned_labels'] = cleaned_mol_alleles
    mhcii_mol_df = mhcii_mol_df.reset_index(drop=True)

    ALLELE_DOT_LIST_FILE = DATA_DIR + '/netmhciipan-4.3/allele.list'
    with open(ALLELE_DOT_LIST_FILE, 'r') as f:
        alleles = [_.strip() for _ in f.readlines()]
    print(f'Number of alleles in allele.list: {len(alleles)}')

    netmhciipan_tm_df = tm_df[(tm_df['Tool']=='netmhciipan')]
    netmhciipan_tm_list = netmhciipan_tm_df['Tool Label'].values
    print(f'Number of NetMHCIIpan alleles \nfrom Tools-Mapping file: {len(netmhciipan_tm_list)}')

    cleaned_tm_alleles = [clean_mhcii_alleles(_) for _ in netmhciipan_tm_list]
    cleaned_allelelist_alleles = [clean_mhcii_alleles(_) for _ in alleles]
    allele_clean_dict = dict(zip(cleaned_allelelist_alleles, alleles))
    netmhciipan_tm_df['cleaned_labels'] = cleaned_tm_alleles

    # reindex the df
    netmhciipan_tm_df = netmhciipan_tm_df.reset_index(drop=True)
    print(netmhciipan_tm_df.head())

    clean_tm_alleles_arr = np.array(cleaned_tm_alleles)
    clean_tm_alleles_set = set(clean_tm_alleles_arr)
    clean_allelelist_alleles_arr = np.array(cleaned_allelelist_alleles)
    clean_allelelist_alleles_set = set(clean_allelelist_alleles_arr)

    clean_alleles_not_found = [al_allele for al_allele in clean_allelelist_alleles_arr if al_allele not in clean_tm_alleles_set]
    clean_alleles_found = [al_allele for al_allele in clean_allelelist_alleles_arr if al_allele in clean_tm_alleles_set]
    print(f'Number of alleles from allelelist file that is not found in tools-mapping: {len(clean_alleles_not_found)}') #3349
    print(f'Number of alleles from allelelist file that is found in tools-mapping: {len(clean_alleles_found)}') #7699
    print(f'Total alleles: {len(clean_alleles_found) + len(clean_alleles_not_found)}')

    # Get the original label from allelelist of these cleaned-stripped-down alleles
    alleles_not_found = [allele_clean_dict[_] for _ in clean_alleles_not_found]
    alleles_found = [allele_clean_dict[_] for _ in clean_alleles_found]
    
    # Iterate over all the known alleles and get their indices.
    # Use the indices to grab MROID. There are total of 5628 alleles.
    known_alleles_rows = []
    for allele in clean_alleles_found:
        # allele = clean_alleles_found[i]
        idx = np.where(netmhciipan_tm_df['cleaned_labels'].values == allele)[0][0]
        tlabel = netmhciipan_tm_df.loc[idx, 'Tool Label']
        mroid = netmhciipan_tm_df.loc[idx, 'MRO ID']
        known_alleles_rows.append({
            'Tool Group': 'mhcii',
            'Tool': 'netmhciipan',
            'Tool Version': '4.3',
            'Tool Label': tlabel,
            'MRO ID': mroid,
            'Lengths': '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30',
        })

    print(len(known_alleles_rows))

    unknown_alleles_rows = []
    completely_unknown_alleles = []
    mol_counter = 0
    non_mol_counter = 0
    '''
    Running this code shows that 2071 of the unknown alleles from the allelelist file 
    can be found in the mro_molecules file.
    Rest of 3349 alleles are still unknown.
    '''
    for allele in clean_alleles_not_found:
        found_in_mol = False
        idx = np.where(netmhciipan_tm_df['cleaned_labels'].values == allele)[0]

        if 0 == idx.size :
            # Try searching in the molecules file
            idx = np.where(mhcii_mol_df['cleaned_labels'].values == allele)[0]
            found_in_mol = True

        if 0 == idx.size :
            completely_unknown_alleles.append(allele_clean_dict[allele])
            continue

        idx = idx[0]

        if found_in_mol:
            mol_counter += 1
            iedblabel = mhcii_mol_df.loc[idx, 'IEDB Label']
            tlabel = presumed_tool_label(iedblabel)
            # tlabel2 = allele_clean_dict[allele]
            mroid = mhcii_mol_df.loc[idx, 'MRO ID']
            unknown_alleles_rows.append({
                'Tool Group': 'mhcii',
                'Tool': 'netmhciipan',
                'Tool Version': '4.3',
                'Tool Label': tlabel,
                # 'Tool Label2': tlabel2,
                'MRO ID': mroid,
                'Lengths': '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30',
            })
        else:
            non_mol_counter += 1
            tlabel = netmhciipan_tm_df.loc[idx, 'Tool Label']
            mroid = netmhciipan_tm_df.loc[idx, 'MRO ID']
            unknown_alleles_rows.append({
                'Tool Group': 'mhcii',
                'Tool': 'netmhciipan',
                'Tool Version': '4.3',
                'Tool Label': tlabel,
                'MRO ID': mroid,
                'Lengths': '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30',
            })

    print(len(unknown_alleles_rows))
    print(len(completely_unknown_alleles)) 

    UNKNOWN_ALLELES_FILE = DATA_DIR + '/netmhciipan-4.3/allele.list.unknown.txt'
    with open(UNKNOWN_ALLELES_FILE, 'w') as f:
        for allele in completely_unknown_alleles:
            f.write(allele + '\n')

    for row in known_alleles_rows:
        tm_df = pd.concat([tm_df, pd.DataFrame([row], columns=tm_df.columns)], ignore_index=True)

    for row in unknown_alleles_rows:
        tm_df = pd.concat([tm_df, pd.DataFrame([row], columns=tm_df.columns)], ignore_index=True)

    # dropping duplicates is not working here.
    tm_df = tm_df.drop_duplicates()
    tm_df.to_csv(f'{DATA_DIR}/tools-mapping.tsv', sep='\t', index=False)

def check_tm_1():
    # Total of 7699 allele rows should have been added to tools-mapping.
    DATA_DIR = str(Path(__file__).parent.parent)
    TM_FILE = DATA_DIR + '/tools-mapping.tsv'
    tm_df = pd.read_csv(TM_FILE, sep='\t', index_col=False)
    tm_df = tm_df.drop_duplicates()
    netmhciipan_tm_df = tm_df[(tm_df['Tool']=='netmhciipan') & (tm_df['Tool Version']==4.3)]
    print(len(netmhciipan_tm_df)==7699)
    
    '''
    tm_df.drop_duplicates works here, so resave the file.
    '''
    tm_df.to_csv(f'{DATA_DIR}/tools-mapping.tsv', sep='\t', index=False)


def test_exec():
    DATA_DIR = str(Path(__file__).parent.parent)
    TM_FILE = DATA_DIR + '/tools-mapping.tsv'
    tm_df = pd.read_csv(TM_FILE, sep='\t', index_col=False)
    netmhciipan_tm_df = tm_df[(tm_df['Tool']=='netmhciipan') & (tm_df['Tool Version']==4.3)]
    print(len(netmhciipan_tm_df)==7699)

    netmhciipan_tm_list = netmhciipan_tm_df['Tool Label'].values

    counter = 0
    for tlabel in netmhciipan_tm_list:
        tlabel = tlabel.replace("/","-")  
        result = single_prediction(['ASSASSSSAAAAAAAT'], [(tlabel, 16)])

        if not result:
            raise ValueError(f"Tool Label ({tlabel}) is not compatible with netmhciipan-4.3 method.")

        counter += 1
        print(counter)



if __name__=='__main__':
    # add_allelelist()
    # check_tm()

    # add_allele_dot_list()
    # check_tm_1()

    # allele.list.unknown.txt and allelelist.unknown.txt is exactly the same.

    # Test all the newly added method labeles on the executable
    test_exec()