import sys
import re
import pandas as pd
import numpy as np
from pathlib import Path
PROJECT_DIR = str(Path(__file__).resolve().parents[1])
sys.path.insert(1, PROJECT_DIR)

# Get allele data
PARENT_DIR = Path(__file__).parent
DATA_DIR = PARENT_DIR.parent / "data"
NETMHCIIPAN_43 = DATA_DIR / "netmhciipan-4.3"
ORIG_TM_FILE = DATA_DIR / "Tools_MRO_mapping.xlsx"
TOOLS_MAPPING_FILE = DATA_DIR / "tools-mapping.tsv"
MRO_MOLECULES_FILE = DATA_DIR / "mro_molecules.tsv"

# Function to strip away all special characters. Also, removes 'HLA-' prefix.
def clean_mhcii_alleles(allele):
    if allele.startswith('HLA-'):
        allele = allele.replace('HLA-', '')

    return re.sub(r'[^a-zA-Z0-9\s]', '', allele)


def main():
    ALLELELIST_FILE = DATA_DIR / "netmhciipan-4.3" / "allelelist.txt"
    with open(ALLELELIST_FILE, 'r') as f:
        allelelist_list = []

        # Double check that the left and right columns are exactly the same.
        for _ in f.readlines():
            allele_left, allele_right = _.strip().split(' ')
            if allele_left != allele_right:
                print(f'Mismatch was found: {allele_left}, {allele_right}')
                break
            
            allelelist_list.append(allele_left)
    
    print(f'Total alleles from \'allelelist.txt\' file: {len(allelelist_list)}')
    # Read molecule file as a dataframe
    mol_df = pd.read_csv(MRO_MOLECULES_FILE, sep='\t', index_col=False)
    print(mol_df.head())
    parent = "MHC class II protein complex"
    mhcii_mol_df = mol_df[mol_df['Parent']==parent].copy()
    print(mhcii_mol_df.head())
    mhcii_mol_label_list = mhcii_mol_df['IEDB Label'].values
    cleaned_mol_alleles = [clean_mhcii_alleles(_) for _ in mhcii_mol_label_list]
    # print(cleaned_mol_alleles[:20])
    print(f'Total alleles from molecules file: {len(cleaned_mol_alleles)}')
    mhcii_mol_df.loc[:, 'cleaned_labels'] = cleaned_mol_alleles
    mhcii_mol_df = mhcii_mol_df.reset_index(drop=True)
    print(mhcii_mol_df.head())
    print(len(mhcii_mol_df))

    cleaned_allelelist_alleles = [clean_mhcii_alleles(_) for _ in allelelist_list]
    allele_clean_dict = dict(zip(cleaned_allelelist_alleles, allelelist_list))
    print(len(allele_clean_dict) == 11048)

    unknown_alleles = []
    mapped_alleles = []
    for stripped_allele_name, tools_allele_name in allele_clean_dict.items():
        
        # Finds all indices where there's a matching name
        idx = np.where(mhcii_mol_df['cleaned_labels'].values == stripped_allele_name)[0]
        if 0 == idx.size :
            # Alleles that were not found in the molecules file
            unknown_alleles.append(tools_allele_name)
            continue

        idx = idx[0]
        mroid = mhcii_mol_df.loc[idx, 'MRO ID']
        mapped_alleles.append((tools_allele_name, mroid))

    print(f'Number of alleles that that were mapped in \'mro_molecules.tsv\' file: {len(mapped_alleles)}')
    print(f'Number of alleles that are not found in \'mro_molecules.tsv\' file: {len(unknown_alleles)}')

    # Read the Tools_MRO_mapping.xlsx file, and filter it so that it only includes 'netmhciipan' alleles.
    original_tm_df = pd.read_excel(ORIG_TM_FILE, engine='openpyxl', index_col=False)

    # Performing copy() to prevent SettingWithCopyWarning from pandas
    netmhciipan_orig_tm_df = original_tm_df[original_tm_df['tool']=='netmhciipan'].copy()
    print(netmhciipan_orig_tm_df.head())
    print(len(netmhciipan_orig_tm_df)) #5622
    netmhciipan_orig_tm_list = netmhciipan_orig_tm_df['term'].values
    cleaned_orig_tm_alleles = [clean_mhcii_alleles(_) for _ in netmhciipan_orig_tm_list]
    # print(cleaned_orig_tm_alleles)
    print(f'Total alleles from molecules file: {len(cleaned_orig_tm_alleles)}')
    netmhciipan_orig_tm_df.loc[:, 'cleaned_labels'] = cleaned_orig_tm_alleles
    netmhciipan_orig_tm_df = netmhciipan_orig_tm_df.reset_index(drop=True)
    print(netmhciipan_orig_tm_df.head())
    print(len(netmhciipan_orig_tm_df))

    # Creating a dict where key is the original tool name, and value is the stripped allele names.
    allele_clean_dict_reverse = dict(zip(allelelist_list, cleaned_allelelist_alleles))

    still_unknown_alleles = []
    tm_mapped_alleles = []
    for allele in unknown_alleles:
        stripped_allele_name = allele_clean_dict_reverse[allele]
        
        # Finds all indices where there's a matching name
        idx = np.where(netmhciipan_orig_tm_df['cleaned_labels'].values == stripped_allele_name)[0]
        if 0 == idx.size :
            # Alleles that were not found in the molecules file
            still_unknown_alleles.append(allele)
            continue

        idx = idx[0]
        mroid = netmhciipan_orig_tm_df.loc[idx, 'MRO ID']
        tm_mapped_alleles.append((tools_allele_name, mroid))

    print(f'Number of alleles that that were mapped in \'Tools_MRO_mapping.xlsx\' file: {len(tm_mapped_alleles)}')
    print(f'Number of alleles that are not found in \'Tools_MRO_mapping.xlsx\' file: {len(still_unknown_alleles)}')

    complete_mapped_alleles = mapped_alleles + tm_mapped_alleles
    print(len(complete_mapped_alleles))

    tools_mapping_df = pd.read_csv(TOOLS_MAPPING_FILE, skipinitialspace=True, sep='\t')
    netmhciipan_43_df = tools_mapping_df[(tools_mapping_df['Tool']=='netmhciipan') & (tools_mapping_df['Tool Version']== 4.3)].copy().reset_index(drop=True)
    

    counter = 0
    for entry in mapped_alleles:
        allele = entry[0]
        mroid = entry[1]
        length = '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30'

        if allele == 'HLA-DPA10303-DPB18201':
            allele = 'DPA1*03:03/DPB1*82:01'

        row = ['mhcii', 'netmhciipan', '4.3', allele, mroid, length]
        # Add logic to make sure no duplicates are added.
        # Check tools mapping to see if there are an entry where given 'netmhciipan' and 'version',
        # that there is no same MRO ID.
        mroid_arr = netmhciipan_43_df['MRO ID'].values
        if mroid not in mroid_arr:
            counter = counter + 1
            tools_mapping_df = pd.concat([tools_mapping_df, pd.DataFrame([row], columns=tools_mapping_df.columns)], ignore_index=True)
        # else:
        #     print(mroid)
    
    tools_mapping_df.to_csv(f'{DATA_DIR}/tools-mapping.tsv', sep='\t', index=False)


def add_4_3_alleles_name_list_data():

    unmapped_dtu_web_alleles = ['DPA1*03:03', 'DPB1*84:01', 'DQB1*04:34', 'DPB1*102:01', 'DQB1*02:71', 'DQA1*01:10', 'DPB1*35:01', 'DPB1*68:01', 'DQB1*03:139', 'DQB1*03:05', 'DQB1*02:07', 'DQB1*03:52', 'DPB1*91:01', 'DQB1*03:06', 'DPB1*100:01', 'DQB1*02:05', 'DQB1*05:28', 'DPB1*39:01', 'DPB1*98:01', 'DQB1*03:232', 'DPB1*63:01', 'DPB1*125:01', 'DQA1*01:07', 'DPB1*10:01', 'DQB1*03:154', 'DPA1*01:10', 'DQB1*05:02', 'DQB1*03:108', 'DPB1*17:01', 'DQB1*04:01', 'DQB1*03:17', 'DPB1*18:01', 'DQB1*03:68', 'DPB1*109:01', 'DPB1*28:01', 'DPB1*26:01', 'DPB1*60:01', 'DPB1*131:01', 'DPA1*02:02', 'DQB1*03:251', 'DQB1*02:06', 'DPB1*81:01', 'DQB1*03:75', 'DQB1*03:231', 'DQB1*03:184', 'DQB1*06:09', 'DPB1*51:01', 'DPB1*40:01', 'DPB1*110:01', 'DPB1*112:01', 'DQB1*05:04', 'DPA1*02:01', 'DQB1*03:242', 'DQB1*03:13', 'DPB1*87:01', 'DPB1*127:01', 'DPB1*11:01', 'DPB1*02:01', 'DQB1*06:48', 'DPB1*89:01', 'DPB1*30:01', 'DQB1*06:28', 'DPB1*02:02', 'DQB1*05:23', 'DQA1*02:01', 'DQB1*03:196', 'DPB1*25:01', 'DQB1*03:158', 'DPB1*09:01', 'DPB1*06:01', 'DQB1*03:186', 'DPB1*119:01', 'DQB1*06:126', 'DQB1*03:114', 'DQB1*03:109', 'DQB1*06:16', 'DPB1*36:01', 'DPA1*01:08', 'DPB1*76:01', 'DQB1*06:14', 'DQB1*03:12', 'DPB1*108:01', 'DPB1*14:01', 'DQB1*03:30', 'DPB1*133:01', 'DQB1*06:01', 'DQB1*06:15', 'DPB1*107:01', 'DQB1*05:42', 'DQB1*03:166', 'DPB1*27:01', 'DPB1*78:01', 'DQB1*05:24', 'DQB1*03:138', 'DPB1*56:01', 'DPB1*86:01', 'DPB1*99:01', 'DQB1*06:153', 'DPA1*01:09', 'DPB1*58:01', 'DPB1*134:01', 'DPA1*01:04', 'DPB1*111:01', 'DQB1*03:135', 'DPB1*75:01', 'DQB1*03:34', 'DPB1*65:01', 'DPB1*106:01', 'DQB1*03:129', 'DPB1*117:01', 'DQA1*01:02', 'DQB1*06:03', 'DPB1*93:01', 'DQB1*06:123', 'DQB1*04:04', 'DQB1*03:01', 'DPA1*01:06', 'DQB1*06:31', 'DQA1*06:02', 'DQB1*06:02', 'DQB1*03:131', 'DQA1*05:07', 'DQB1*03:147', 'DPB1*54:01', 'DPB1*05:01', 'DQA1*03:03', 'DPB1*115:01', 'DQB1*05:103', 'DQB1*03:265', 'DPB1*92:01', 'DQB1*06:20', 'DPB1*20:01', 'DPB1*130:01', 'DPB1*97:01', 'DPB1*08:01', 'DQB1*06:17', 'DPA1*01:07', 'DQA1*05:11', 'DPB1*55:01', 'DQB1*02:12', 'DPB1*122:01', 'DPB1*77:01', 'DQB1*06:19', 'DQA1*01:03', 'DQB1*03:27', 'DQB1*03:72', 'DQB1*03:241', 'DQA1*01:08', 'DPB1*48:01', 'DQB1*06:22', 'DPB1*38:01', 'DPB1*79:01', 'DPA1*01:03', 'DPB1*52:01', 'DQB1*03:21', 'DQB1*05:52', 'DQB1*03:10', 'DPA1*02:03', 'DPB1*114:01', 'DPB1*49:01', 'DPA1*03:02', 'DQA1*06:01', 'DPA1*04:01', 'DPB1*118:01', 'DQB1*03:82', 'DQB1*03:31', 'DPB1*82:01', 'DQB1*03:185', 'DQB1*05:37', 'DQB1*03:119', 'DQA1*04:01', 'DQB1*06:07', 'DPB1*01:01', 'DQB1*06:10', 'DPB1*29:01', 'DQA1*05:10', 'DQA1*01:09', 'DQB1*03:195', 'DQA1*05:01', 'DQB1*04:02', 'DQB1*03:167', 'DPB1*62:01', 'DPB1*132:01', 'DQA1*03:02', 'DPB1*72:01', 'DPB1*66:01', 'DPA1*01:05', 'DPB1*23:01', 'DPB1*70:01', 'DQB1*05:01', 'DQB1*03:69', 'DQA1*05:06', 'DQB1*04:08', 'DQA1*05:04', 'DQB1*03:04', 'DPB1*103:01', 'DPB1*34:01', 'DPB1*128:01', 'DPB1*03:01', 'DQB1*03:172', 'DQB1*03:142', 'DQB1*03:20', 'DQB1*06:140', 'DQB1*03:28', 'DQB1*06:05', 'DQB1*03:08', 'DQB1*03:180', 'DPB1*15:01', 'DQA1*05:08', 'DQB1*02:85', 'DPB1*50:01', 'DQB1*02:25', 'DQB1*06:13', 'DQB1*03:02', 'DPB1*24:01', 'DPB1*90:01', 'DPB1*104:01', 'DQB1*03:134', 'DPB1*46:01', 'DPB1*74:01', 'DQB1*03:14', 'DQB1*06:108', 'DQB1*02:02', 'DPB1*41:01', 'DQB1*03:104', 'DQB1*03:22', 'DPB1*95:01', 'DQB1*03:96', 'DQB1*05:05', 'DQB1*03:11', 'DPA1*03:01', 'DPB1*37:01', 'DQB1*03:113', 'DQA1*03:01', 'DPB1*47:01', 'DQB1*03:121', 'DQA1*01:01', 'DPB1*73:01', 'DQB1*02:01', 'DPB1*04:01', 'DQB1*02:04', 'DPB1*123:01', 'DPB1*116:01', 'DPB1*22:01', 'DQB1*03:235', 'DQA1*01:06', 'DQA1*05:09', 'DQB1*03:120', 'DQB1*03:133', 'DPB1*88:01', 'DQB1*03:194', 'DQB1*05:83', 'DPB1*45:01', 'DQA1*04:02', 'DQB1*05:10', 'DQB1*03:163', 'DQB1*03:254', 'DQB1*06:110', 'DQB1*03:157', 'DQB1*02:03', 'DQB1*03:38', 'DPB1*31:01', 'DQB1*06:04', 'DPB1*124:01', 'DQB1*03:208', 'DPB1*59:01', 'DQB1*03:105', 'DQB1*05:46', 'DPB1*67:01', 'DPB1*85:01', 'DQB1*03:264', 'DQB1*03:188', 'DQB1*05:08', 'DPB1*105:01', 'DPB1*101:01', 'DQB1*03:25', 'DQB1*03:87', 'DQB1*03:191', 'DQB1*05:35', 'DQB1*03:57', 'DQB1*02:48', 'DQB1*04:23', 'DQB1*03:207', 'DQB1*06:214', 'DQB1*04:03', 'DQB1*03:47', 'DQB1*03:16', 'DPB1*21:01', 'DQA1*05:05', 'DQB1*03:132', 'DPB1*94:01', 'DQA1*05:02', 'DQB1*03:19', 'DQB1*05:11', 'DPB1*129:01', 'DQB1*06:11', 'DQB1*02:30', 'DQB1*05:03', 'DQB1*06:52', 'DPB1*126:01', 'DQB1*03:03', 'DQA1*05:03', 'DPB1*33:01', 'DQB1*06:12', 'DQB1*03:62', 'DPB1*71:01', 'DPB1*69:01', 'DPB1*96:01', 'DQB1*05:117', 'DPB1*80:01', 'DQB1*06:39', 'DPB1*13:01', 'DQB1*03:29', 'DPA1*02:04', 'DPB1*04:02', 'DQB1*06:08', 'DPB1*16:01', 'DPB1*113:01', 'DQB1*06:41', 'DQB1*05:18', 'DPB1*83:01', 'DPB1*121:01', 'DQA1*01:04', 'DQA1*01:05', 'DPB1*19:01', 'DPB1*44:01', 'DPB1*53:01', 'DPB1*32:01', 'DQB1*04:22', 'DQB1*03:201']
    print(len(unmapped_dtu_web_alleles))

    mol_df = pd.read_csv(MRO_MOLECULES_FILE, sep='\t', index_col=False)
    parent = "MHC class II protein complex"
    mhcii_mol_df = mol_df[mol_df['Parent']==parent].copy()
    print(mhcii_mol_df.head())
    mhcii_mol_label_list = mhcii_mol_df['IEDB Label'].values
    cleaned_mol_alleles = [clean_mhcii_alleles(_) for _ in mhcii_mol_label_list]
    # print(cleaned_mol_alleles[:20])
    print(f'Total alleles from molecules file: {len(cleaned_mol_alleles)}')
    mhcii_mol_df.loc[:, 'cleaned_labels'] = cleaned_mol_alleles
    mhcii_mol_df = mhcii_mol_df.reset_index(drop=True)
    print(mhcii_mol_df.head())
    print(len(mhcii_mol_df))
    
    tools_mapping_df = pd.read_csv(TOOLS_MAPPING_FILE, skipinitialspace=True, sep='\t')
    netmhciipan_43_df = tools_mapping_df[(tools_mapping_df['Tool']=='netmhciipan') & (tools_mapping_df['Tool Version']== 4.3)].copy().reset_index(drop=True)
    
    unmapped_dtu_alleles_list_set = [(clean_mhcii_alleles(_), _) for _ in unmapped_dtu_web_alleles]

    unknown_alleles = []
    mapped_alleles = []
    for stripped_allele_name, tools_allele_name in unmapped_dtu_alleles_list_set:
        # Finds all indices where there's a matching name
        idx = np.where(mhcii_mol_df['cleaned_labels'].values == stripped_allele_name)[0]
        if 0 == idx.size :
            # Alleles that were not found in the molecules file
            unknown_alleles.append(tools_allele_name)
            continue

        idx = idx[0]
        mroid = mhcii_mol_df.loc[idx, 'MRO ID']
        mapped_alleles.append((tools_allele_name, mroid))

    print(f'Number of alleles that that were mapped in \'mro_molecules.tsv\' file: {len(mapped_alleles)}')
    print(f'Number of alleles that are not found in \'mro_molecules.tsv\' file: {len(unknown_alleles)}')

    counter = 0
    '''Out of 7343 alleles that had MRO ID, everything already existed in
    the tools mapping file except for 1 allele: DPA1*03:03/DPB1*82:01'''
    for entry in mapped_alleles:
        allele = entry[0]
        mroid = entry[1]
        length = '11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30'

        row = ['mhcii', 'netmhciipan', '4.3', allele, mroid, length]
        # Add logic to make sure no duplicates are added.
        # Check tools mapping to see if there are an entry where given 'netmhciipan' and 'version',
        # that there is no same MRO ID.
        mroid_arr = netmhciipan_43_df['MRO ID'].values
        if mroid not in mroid_arr:
            counter = counter + 1
            tools_mapping_df = pd.concat([tools_mapping_df, pd.DataFrame([row], columns=tools_mapping_df.columns)], ignore_index=True)
        # else:
        #     print(mroid)

    # print(counter)
    tools_mapping_df.to_csv(f'{DATA_DIR}/tools-mapping.tsv', sep='\t', index=False)


if __name__ == '__main__':
    # main()

    add_4_3_alleles_name_list_data()