import unittest
import random
import pandas as pd
import os
import sys
from pathlib import Path
from netmhcpan_4_1_executable import predict_many
PROJECT_DIR = str(Path(__file__).resolve().parents[1])
sys.path.insert(1, PROJECT_DIR)

SKIP_GITLAB_CI = eval(os.getenv('SKIP_GITLAB_CI', "False"))

@unittest.skipIf(SKIP_GITLAB_CI, "CI Can't seem to find location of netmhcpan_4_1_executable")
class NetMHCpanExecutableTest(unittest.TestCase):
    NUM_SAMPLES = 50
    TOOLS_MAPPING_DF = pd.read_excel('%s/data/Tools_MRO_Mapping_VFYD.xlsx' %(PROJECT_DIR), engine='openpyxl')

    '''test cases for netmhcpan 4.0 and netmhcpan_el 4.0'''
    def test_base_case(self):
        allele_name = 'HLA-A*02:01'
        binding_length = '9'
        sequence = "SLYNTVATLYCVHQRIDV"
        sequence_list = [sequence,]
        allele_length_2tuple_list = [(allele_name, binding_length),]

        results = predict_many(sequence_list, allele_length_2tuple_list, el=True)
        scores = list(results.values())
        expected_scores = [[0.828183, 'SLYNTVATL', 'SLYNTVATL', 0.000583, 'LYNTVATLY', 'LYNTVATLY', 6e-05, 'YNTVATLYC', 'YNTVATLYC', 0.039677, 'NTVATLYCV', 'NTVATLYCV', 5.5e-05, 'TVATLYCVH', 'TVATLYCVH', 3.1e-05, 'VATLYCVHQ', 'VATLYCVHQ', 0.000402, 'ATLYCVHQR', 'ATLYCVHQR', 0.091622, 'TLYCVHQRI', 'TLYCVHQRI', 0.0, 'LYCVHQRID', 'LYCVHQRID', 0.00061, 'YCVHQRIDV', 'YCVHQRIDV']]
        # print(results)
        self.assertEqual(expected_scores, scores)

    def test_bola_allele(self):
        allele_name = 'BoLA-N:00101'
        binding_length = '9'
        sequence = "SLYNTVATLYCVHQRIDV"
        sequence_list = [sequence,]
        allele_length_2tuple_list = [(allele_name, binding_length),]

        results = predict_many(sequence_list, allele_length_2tuple_list, el=True)
        scores = list(results.values())
        
        self.assertTrue(scores)
    
    def test_sla_allele(self):
        allele_name = 'SLA-1*0801'
        binding_length = '9'
        sequence = "SLYNTVATLYCVHQRIDV"
        sequence_list = [sequence,]
        allele_length_2tuple_list = [(allele_name, binding_length),]

        results = predict_many(sequence_list, allele_length_2tuple_list, el=True)
        scores = list(results.values())

        self.assertTrue(scores)

    def test_h2_allele(self):
        allele_name = 'H-2-Qa1'
        binding_length = '9'
        sequence = "SLYNTVATLYCVHQRIDV"
        sequence_list = [sequence,]
        allele_length_2tuple_list = [(allele_name, binding_length),]

        results = predict_many(sequence_list, allele_length_2tuple_list, el=True)
        scores = list(results.values())

        self.assertTrue(scores)

    '''Randomly select 50 alleles and check if results exist'''
    def test_random_alleles(self):
        print('\n')
        # Grab all netmhcpan-4.1 tool labels as list
        netmhcpan_4_1_df = self.TOOLS_MAPPING_DF[(self.TOOLS_MAPPING_DF['Tool']=='netmhcpan-4.1')]
        netmhcpan_4_1_alleles = netmhcpan_4_1_df['Tool Label'].tolist()

        # Select 50 random alleles
        random_alleles = random.sample(netmhcpan_4_1_alleles, self.NUM_SAMPLES)
        
        # Test to see if they all output results
        binding_length = '9'
        sequence = "SLYNTVATLYCVHQRIDV"
        counter = 1
        for allele in random_alleles:
            sequence_list = [sequence,]
            allele_length_2tuple_list = [(allele, binding_length),]

            results = predict_many(sequence_list, allele_length_2tuple_list, el=True)
            scores = list(results.values())

            print('%s. %s -- fetching results...' %(counter, allele))

            self.assertTrue(scores)
            counter = counter + 1 

if __name__=='__main__':
    unittest.main()