import unittest
import subprocess
import random
import pandas as pd
import os
SKIP_GITLAB_CI = eval(os.getenv('SKIP_GITLAB_CI', "False"))


@unittest.skipIf(SKIP_GITLAB_CI, "MHCI Standalone package is not included as part of this package.")
class TestMHCIMethods(unittest.TestCase):
    '''
    Tests 50 random alleles from each method and checks if the results are properly returned.
    It uses MHCI standalone.
    '''
    NUM_SAMPLES = 50
    TOOLS_MAPPING_DF = pd.read_excel('data/Tools_MRO_Mapping_VFYD.xlsx', engine='openpyxl')
    MHC_ALLELES_DF = pd.read_table('data/mhc_alleles.tsv')
    MHCI_PATH='../mhc_i'
    PREDICTOR_PATH = '%s/src/predict_binding.py' %(MHCI_PATH)



    def test_ann_example(self):
        allele = 'HLA-A*02:01'
        length = '9'
        sequence_file = '%s/examples/input_sequence.fasta' %(self.MHCI_PATH)
        cmd_args = [self.PREDICTOR_PATH, 'ann', allele, length, sequence_file]

        process = subprocess.Popen(cmd_args, shell=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
        process.wait()
        result, err = process.communicate()
        
        # Check if the result table exists
        self.assertTrue(result)
    
    def run_prediction_with_random_alleles(self, method):
        print('\n===================================== %s =====================================' %(method))
        # Grab all netmhcpan-4.1 tool labels as list
        df = self.TOOLS_MAPPING_DF[(self.TOOLS_MAPPING_DF['Tool']==method)]
        alleles = df['Tool Label'].tolist()
        mro_id = df['MRO ID'].tolist()
        alleles_info_dict = dict(zip(alleles, mro_id))
        random_alleles = alleles

        if method == 'consensus':
            # Have to distinguish MHCI from MHCII
            consensus_alleles = []
            for allele_name, allele_id in alleles_info_dict.items():
                # print('allele id:', allele_id)
                filtered_df = self.MHC_ALLELES_DF.loc[
                    (self.MHC_ALLELES_DF['MRO ID']==allele_id) & 
                    (self.MHC_ALLELES_DF['Predictor Availability']==1) &
                    (self.MHC_ALLELES_DF['Tool Group']=='mhci')
                ]
                
                if len(filtered_df) != 0: consensus_alleles.append(allele_name)
            
            random_alleles = random.sample(consensus_alleles, self.NUM_SAMPLES)
        else:
            # Select 50 random alleles
            if self.NUM_SAMPLES < len(alleles):
                random_alleles = random.sample(alleles, self.NUM_SAMPLES)
        
        # Test to see if they all output results
        if method == 'netmhcpan-4.1': method = 'netmhcpan_el'
        counter = 1
        length = '9'
        sequence_file = '%s/examples/input_sequence.fasta' %(self.MHCI_PATH)
        
        for allele in random_alleles:
            cmd_args = [self.PREDICTOR_PATH, method, allele, length, sequence_file]
            # process = subprocess.Popen(cmd_args)
            process = subprocess.Popen(cmd_args, shell=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
            process.wait()
            result, err = process.communicate()

            print('%s. %s -- fetching results...' %(counter, allele))
            # print(result)
            self.assertTrue(result)
            counter = counter + 1

    def test_random_alleles(self):
        methods = [
            'ann',
            'comblib_sidney2008',
            'consensus',
            # 'netmhcpan-4.1',
            'pickpocket',
            'smm',
            'smmpmbec'
        ]

        for method in methods :
            self.run_prediction_with_random_alleles(method)
            




        


if __name__=='__main__':
    unittest.main()