import unittest
import sys
from pathlib import Path
PROJECT_DIR = str(Path(__file__).resolve().parents[1])
sys.path.insert(1, PROJECT_DIR)
from allele_validator import Allele_Validator


class TestValidity(unittest.TestCase):
    def test_validate_alleles(self):
        validator = Allele_Validator()

        iedb_label = 'HLA-A*02:01'
        expected_result = True
        result = validator.validate_alleles(iedb_label, 'netmhcpan')
        self.assertEqual(expected_result, result)

        iedb_label = 'HLA-A*02:01'
        expected_result = True
        result = validator.validate_alleles(iedb_label, 'netmhcpan_el')
        self.assertEqual(expected_result, result)

        iedb_label = 'HLA-A*02:01'
        expected_result = True
        result = validator.validate_alleles(iedb_label, 'netmhcpan_ba')
        self.assertEqual(expected_result, result)

        iedb_label = "BoLA-1*009:01"
        expected_result = True
        result = validator.validate_alleles(iedb_label)
        self.assertEqual(expected_result, result)

        iedb_label = ["BoLA-1*009:01"]
        expected_result = [True]
        result = validator.validate_alleles(iedb_label)
        self.assertCountEqual(expected_result, result)

        iedb_label = "BoLA-1:00901"
        expected_result = False
        result = validator.validate_alleles(iedb_label)
        self.assertEqual(expected_result, result)

        iedb_label = ["BoLA-1*009:01", "BoLA-1*021:01", "BoLA-1:00901"]
        expected_result = [True, True, False]
        result = validator.validate_alleles(iedb_label)
        self.assertCountEqual(expected_result, result)

    
    def test_validate_allele_lengths(self):
        validator = Allele_Validator()
        
        # Test for valid allele
        iedb_label = "BoLA-1*009:01"
        lengths = ['8', '9', '10', '11', '12', '13', '14']
        expected_valid_dict = {"BoLA-1*009:01": ['8', '9', '10', '11', '12', '13', '14']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Test for valid allele
        iedb_label = ["BoLA-1*009:01"]
        lengths = ['8', '9', '10', '11', '12', '13', '14']
        expected_valid_dict = {"BoLA-1*009:01": ['8', '9', '10', '11', '12', '13', '14']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Wrong label + valid lengths
        iedb_label = ["BoLA-1:00901"]
        lengths = ['8', '9', '10', '11', '12', '13', '14']
        self.assertRaises(ValueError, validator.validate_allele_lengths, iedb_label=iedb_label, method="metmhcpan", lengths=lengths)

        # Valid label + valid subset of lengths
        iedb_label = ["BoLA-1*009:01"]
        lengths = ['9', '10', '13', '14']
        expected_valid_dict = {"BoLA-1*009:01": ['9', '10', '13', '14']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Valid label + valid/invalid subset of lengths
        iedb_label = ["BoLA-1*009:01"]
        lengths = ['9', '10', '13', '14', '15', '16']
        expected_valid_dict = {"BoLA-1*009:01": ['9', '10', '13', '14']}
        expected_invalid_dict = {"BoLA-1*009:01": ['15', '16']}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        iedb_label = ["BoLA-1*009:01"]
        lengths = ['9', '10', '13', '14', '15', '16']
        expected_valid_dict = {"BoLA-1*009:01": ['9', '10', '13', '14']}
        expected_invalid_dict = {"BoLA-1*009:01": ['15', '16']}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan_el", lengths=lengths)
        
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        iedb_label = ['HLA-A*02:01']
        lengths = ['9','10','150']
        expected_valid_dict = {'HLA-A*02:01': ['9']}
        expected_invalid_dict = {'HLA-A*02:01': ['10', '150']}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="comblib_sidney2008", lengths=lengths)

        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)


        iedb_label = ['HLA-A*02:01', 'HLA-A*01:01']
        lengths = ['9','10','150']
        expected_valid_dict = {'HLA-A*02:01': ['9']}
        expected_invalid_dict = {'HLA-A*02:01': ['10', '150'], 'HLA-A*01:01': ['9', '10', '150']}
        valid_dict, invalid_dict = validator.validate_allele_lengths(iedb_label=iedb_label, method="comblib_sidney2008", lengths=lengths)

        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)


    def test_get_alleles(self):
        validator = Allele_Validator()

        ''' Provide only method name '''
        # Test for valid method name
        expected_result = [
            'BoLA-3*001:01', 'BoLA-1*023:01', 'BoLA-6*013:01', 'BoLA-3*002:01', 'BoLA-2*012:01', 'BoLA-6*041:01', 
            'H2-Db', 'H2-Dd', 'H2-Kb', 'H2-Kd', 'H2-Kk', 'H2-Ld', 'HLA-A*01:01', 'HLA-A*02:01', 'HLA-A*02:02', 
            'HLA-A*02:03', 'HLA-A*02:06', 'HLA-A*02:11', 'HLA-A*02:12', 'HLA-A*02:16', 'HLA-A*02:17', 'HLA-A*02:19', 
            'HLA-A*02:50', 'HLA-A*03:01', 'HLA-A*11:01', 'HLA-A*23:01', 'HLA-A*24:02', 'HLA-A*24:03', 'HLA-A*25:01', 
            'HLA-A*26:01', 'HLA-A*26:02', 'HLA-A*26:03', 'HLA-A*29:02', 'HLA-A*30:01', 'HLA-A*30:02', 'HLA-A*31:01', 
            'HLA-A*32:01', 'HLA-A*32:07', 'HLA-A*32:15', 'HLA-A*33:01', 'HLA-A*66:01', 'HLA-A*68:01', 'HLA-A*68:02', 
            'HLA-A*68:23', 'HLA-A*69:01', 'HLA-A*80:01', 'HLA-B*07:02', 'HLA-B*08:01', 'HLA-B*08:02', 'HLA-B*08:03', 
            'HLA-B*14:02', 'HLA-B*15:01', 'HLA-B*15:02', 'HLA-B*15:03', 'HLA-B*15:09', 'HLA-B*15:17', 'HLA-B*15:42', 
            'HLA-B*18:01', 'HLA-B*27:05', 'HLA-B*27:20', 'HLA-B*35:01', 'HLA-B*35:03', 'HLA-B*38:01', 'HLA-B*39:01', 
            'HLA-B*40:01', 'HLA-B*40:02', 'HLA-B*40:13', 'HLA-B*42:01', 'HLA-B*44:02', 'HLA-B*44:03', 'HLA-B*45:01', 
            'HLA-B*46:01', 'HLA-B*48:01', 'HLA-B*51:01', 'HLA-B*53:01', 'HLA-B*54:01', 'HLA-B*57:01', 'HLA-B*58:01', 
            'HLA-B*58:02', 'HLA-B*73:01', 'HLA-B*83:01', 'HLA-C*03:03', 'HLA-C*04:01', 'HLA-C*05:01', 'HLA-C*06:02', 
            'HLA-C*07:01', 'HLA-C*07:02', 'HLA-C*08:02', 'HLA-C*12:03', 'HLA-C*14:02', 'HLA-C*15:02', 'HLA-E*01:01', 
            'HLA-E*01:03', 'Mamu-A1*001:01', 'Mamu-A1*002:01', 'Mamu-A1*007:01', 'Mamu-A1*011:01', 'Mamu-A1*022:01', 
            'Mamu-A1*026:01', 'Mamu-A2*01:02', 'Mamu-A7*01:03', 'Mamu-B*001:01', 'Mamu-B*003:01', 'Mamu-B*008:01', 
            'Mamu-B*010:01', 'Mamu-B*017:01', 'Mamu-B*039:01', 'Mamu-B*052:01', 'Mamu-B*066:01', 'Mamu-B*083:01', 
            'Mamu-B*087:01', 'Patr-A*01:01', 'Patr-A*03:01', 'Patr-A*04:01', 'Patr-A*07:01', 'Patr-A*09:01', 
            'Patr-B*01:01', 'Patr-B*13:01', 'Patr-B*24:01', 'SLA-1*04:01', 'SLA-2*04:01', 'SLA-3*04:01'
            ]

        self.assertCountEqual(expected_result, validator.get_alleles(method="smm"))
        
        expected_result = [
            'BoLA-3*001:01','BoLA-1*023:01','BoLA-6*013:01','BoLA-3*002:01','BoLA-2*012:01','BoLA-6*041:01',
            'H2-Db','H2-Dd','H2-Kb','H2-Kd','H2-Kk','H2-Ld','HLA-A*01:01','HLA-A*02:01','HLA-A*02:02','HLA-A*02:03',
            'HLA-A*02:05','HLA-A*02:06','HLA-A*02:07','HLA-A*02:11','HLA-A*02:12','HLA-A*02:16','HLA-A*02:17','HLA-A*02:19',
            'HLA-A*02:50','HLA-A*03:01','HLA-A*03:02','HLA-A*03:19','HLA-A*11:01','HLA-A*23:01','HLA-A*24:02','HLA-A*24:03',
            'HLA-A*25:01','HLA-A*26:01','HLA-A*26:02','HLA-A*26:03','HLA-A*29:02','HLA-A*30:01','HLA-A*30:02','HLA-A*31:01','HLA-A*32:01',
            'HLA-A*32:07','HLA-A*32:15','HLA-A*33:01','HLA-A*66:01','HLA-A*68:01','HLA-A*68:02','HLA-A*68:23','HLA-A*69:01',
            'HLA-A*80:01','HLA-B*07:02','HLA-B*08:01','HLA-B*08:02','HLA-B*08:03','HLA-B*14:01','HLA-B*14:02','HLA-B*15:01',
            'HLA-B*15:02','HLA-B*15:03','HLA-B*15:09','HLA-B*15:17','HLA-B*18:01','HLA-B*27:05','HLA-B*27:20','HLA-B*35:01',
            'HLA-B*35:03','HLA-B*37:01','HLA-B*38:01','HLA-B*39:01','HLA-B*40:01','HLA-B*40:02','HLA-B*40:13','HLA-B*42:01',
            'HLA-B*44:02','HLA-B*44:03','HLA-B*45:01','HLA-B*45:06','HLA-B*46:01','HLA-B*48:01','HLA-B*51:01','HLA-B*53:01','HLA-B*54:01',
            'HLA-B*57:01','HLA-B*57:03','HLA-B*58:01','HLA-B*58:02','HLA-B*73:01','HLA-B*81:01','HLA-B*83:01','HLA-C*03:03',
            'HLA-C*04:01','HLA-C*05:01','HLA-C*06:02','HLA-C*07:01','HLA-C*07:02','HLA-C*08:02','HLA-C*12:03','HLA-C*14:02',
            'HLA-C*15:02','HLA-E*01:01','HLA-E*01:03','Mamu-A1*001:01','Mamu-A1*002:01','Mamu-A1*007:01','Mamu-A1*011:01',
            'Mamu-A1*022:01','Mamu-A1*026:01','Mamu-A2*01:02','Mamu-A7*01:03','Mamu-B*001:01','Mamu-B*003:01','Mamu-B*008:01',
            'Mamu-B*010:01','Mamu-B*017:01','Mamu-B*039:01','Mamu-B*052:01','Mamu-B*066:01','Mamu-B*083:01','Mamu-B*087:01',
            'Patr-A*01:01','Patr-A*03:01','Patr-A*04:01','Patr-A*07:01','Patr-A*09:01','Patr-B*01:01','Patr-B*13:01',
            'Patr-B*24:01','SLA-1*04:01','SLA-1*07:01','SLA-2*04:01','SLA-3*04:01'
        ]

        self.assertCountEqual(expected_result, validator.get_alleles(method="ann"))

        # Test for invalid method name
        self.assertRaises(ValueError, validator.get_alleles, method="aaa")

        ''' Provide both method name and version '''
        # Test valid method name and valid version
        
        expected_result = [
            'BoLA-3*001:01','BoLA-1*023:01','BoLA-6*013:01','BoLA-3*002:01','BoLA-2*012:01','BoLA-6*041:01',
            'H2-Db','H2-Dd','H2-Kb','H2-Kd','H2-Kk','H2-Ld','HLA-A*01:01','HLA-A*02:01','HLA-A*02:02','HLA-A*02:03',
            'HLA-A*02:05','HLA-A*02:06','HLA-A*02:07','HLA-A*02:11','HLA-A*02:12','HLA-A*02:16','HLA-A*02:17','HLA-A*02:19',
            'HLA-A*02:50','HLA-A*03:01','HLA-A*03:02','HLA-A*03:19','HLA-A*11:01','HLA-A*23:01','HLA-A*24:02','HLA-A*24:03',
            'HLA-A*25:01','HLA-A*26:01','HLA-A*26:02','HLA-A*26:03','HLA-A*29:02','HLA-A*30:01','HLA-A*30:02','HLA-A*31:01','HLA-A*32:01',
            'HLA-A*32:07','HLA-A*32:15','HLA-A*33:01','HLA-A*66:01','HLA-A*68:01','HLA-A*68:02','HLA-A*68:23','HLA-A*69:01',
            'HLA-A*80:01','HLA-B*07:02','HLA-B*08:01','HLA-B*08:02','HLA-B*08:03','HLA-B*14:01','HLA-B*14:02','HLA-B*15:01',
            'HLA-B*15:02','HLA-B*15:03','HLA-B*15:09','HLA-B*15:17','HLA-B*18:01','HLA-B*27:05','HLA-B*27:20','HLA-B*35:01',
            'HLA-B*35:03','HLA-B*37:01','HLA-B*38:01','HLA-B*39:01','HLA-B*40:01','HLA-B*40:02','HLA-B*40:13','HLA-B*42:01',
            'HLA-B*44:02','HLA-B*44:03','HLA-B*45:01','HLA-B*45:06','HLA-B*46:01','HLA-B*48:01','HLA-B*51:01','HLA-B*53:01','HLA-B*54:01',
            'HLA-B*57:01','HLA-B*57:03','HLA-B*58:01','HLA-B*58:02','HLA-B*73:01','HLA-B*81:01','HLA-B*83:01','HLA-C*03:03',
            'HLA-C*04:01','HLA-C*05:01','HLA-C*06:02','HLA-C*07:01','HLA-C*07:02','HLA-C*08:02','HLA-C*12:03','HLA-C*14:02',
            'HLA-C*15:02','HLA-E*01:01','HLA-E*01:03','Mamu-A1*001:01','Mamu-A1*002:01','Mamu-A1*007:01','Mamu-A1*011:01',
            'Mamu-A1*022:01','Mamu-A1*026:01','Mamu-A2*01:02','Mamu-A7*01:03','Mamu-B*001:01','Mamu-B*003:01','Mamu-B*008:01',
            'Mamu-B*010:01','Mamu-B*017:01','Mamu-B*039:01','Mamu-B*052:01','Mamu-B*066:01','Mamu-B*083:01','Mamu-B*087:01',
            'Patr-A*01:01','Patr-A*03:01','Patr-A*04:01','Patr-A*07:01','Patr-A*09:01','Patr-B*01:01','Patr-B*13:01',
            'Patr-B*24:01','SLA-1*04:01','SLA-1*07:01','SLA-2*04:01','SLA-3*04:01'
        ]
        
        self.assertCountEqual(expected_result, validator.get_alleles(method="ann", version="4.0"))

        # Test valid method name and invalid version
        self.assertRaises(ValueError, validator.get_alleles, method="ann", version="4.2")

        # Test for invalid method name and invalid version
        self.assertRaises(ValueError, validator.get_alleles, method="aaa", version="4.2")

        # Test invalid tools_group name
        self.assertRaises(ValueError, validator.get_alleles, tools_group="mhhhh")

    def test_get_available_lengths(self):
        validator = Allele_Validator()

        ''' single allele example '''
        expected_result = ['8','9','10','11','12','13','14','15']
        result = validator.get_available_lengths(iedb_label="BoLA-1*009:01", method="pickpocket")
        self.assertListEqual(expected_result, result)

        expected_result = ['9']
        result = validator.get_available_lengths(iedb_label="Mamu-A2*01:02", method="smm")
        self.assertListEqual(expected_result, result)

        result = validator.get_available_lengths(iedb_label="DQA1*06:02/DQB1*06:11", method="smm")
        self.assertIsNone(result)

        expected_result = [['9']]
        result = validator.get_available_lengths(iedb_label=["Mamu-A2*01:02"], method="smm")
        self.assertListEqual(expected_result, result)

        expected_result = [None]
        result = validator.get_available_lengths(iedb_label=["DQA1*06:02/DQB1*06:11"], method="smm")
        self.assertListEqual(expected_result, result)

        ''' Multiple alleles example - Valid '''
        expected_result = [['8','9','10','11','12','13','14','15'], ['8','9','10','11','12','13','14','15']]
        result = validator.get_available_lengths(iedb_label=["BoLA-1*009:01", "Mamu-A2*01:02"], method="pickpocket")
        self.assertListEqual(expected_result, result)

        expected_result = [['9','10','11'], ['8', '9']]
        result = validator.get_available_lengths(iedb_label=["Patr-A*09:01", "Mamu-B*039:01"], method="smm")
        self.assertListEqual(expected_result, result)

        ''' Multiple example - Valid & Invalid '''
        expected_result = [['8','9','10','11','12','13','14','15'], None]
        result = validator.get_available_lengths(iedb_label=["BoLA-1*009:01", "DQA1*06:02/DQB1*06:11"], method="pickpocket")
        self.assertListEqual(expected_result, result)

        ''' Wrong method name '''
        self.assertRaises(ValueError, validator.get_available_lengths, iedb_label="BoLA-1*009:01", method="mhhhh")


if __name__ == "__main__":
    unittest.main()
