import sys
import unittest
from pathlib import Path
PROJECT_DIR = str(Path(__file__).resolve().parents[1])
sys.path.insert(1, PROJECT_DIR)
from allele_validator import AlleleValidator

class TestValidLengths(unittest.TestCase):
    validator = AlleleValidator()

    def test_positive_cases(self):
        # Test for valid allele
        iedb_label = "BoLA-1*009:01"
        lengths = ['8', '9', '10', '11', '12', '13', '14']
        expected_valid_dict = {"BoLA-1*009:01": ['8', '9', '10', '11', '12', '13', '14']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Test for valid allele
        iedb_label = ["BoLA-1*009:01"]
        lengths = ['8', '9', '10', '11', '12', '13', '14']
        expected_valid_dict = {"BoLA-1*009:01": ['8', '9', '10', '11', '12', '13', '14']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Valid label + valid subset of lengths
        iedb_label = ["BoLA-1*009:01"]
        lengths = ['9', '10', '13', '14']
        expected_valid_dict = {"BoLA-1*009:01": ['9', '10', '13', '14']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Test when both allele and length is a single value
        iedb_label = "HLA-DRB1*01:01"
        lengths = 15
        expected_valid_dict = {'HLA-DRB1*01:01': ['15']}
        expected_invalid_dict = {}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhciipan_el", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        # Length can be a string and still work.
        lengths = '15'
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhciipan_el", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

    def test_negative_cases(self):
        # Wrong label + valid lengths
        iedb_label = ["BoLA-1:00901"]
        lengths = ['8', '9', '10', '11', '12', '13', '14']
        self.assertRaises(ValueError, self.validator.validate_allele_lengths, iedb_label=iedb_label, method="metmhcpan", lengths=lengths)
        
    def test_mixed_cases(self):
        # Valid label + valid/invalid subset of lengths
        iedb_label = ["BoLA-1*009:01"]
        lengths = ['9', '10', '13', '14', '15', '16']
        expected_valid_dict = {"BoLA-1*009:01": ['9', '10', '13', '14']}
        expected_invalid_dict = {"BoLA-1*009:01": ['15', '16']}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        iedb_label = ["BoLA-1*009:01"]
        lengths = ['9', '10', '13', '14', '15', '16']
        expected_valid_dict = {"BoLA-1*009:01": ['9', '10', '13', '14']}
        expected_invalid_dict = {"BoLA-1*009:01": ['15', '16']}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="netmhcpan_el", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        iedb_label = ['HLA-A*02:01']
        lengths = ['9','10','150']
        expected_valid_dict = {'HLA-A*02:01': ['9']}
        expected_invalid_dict = {'HLA-A*02:01': ['10', '150']}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="comblib_sidney2008", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)


        iedb_label = ['HLA-A*02:01', 'HLA-A*01:01']
        lengths = ['9','10','150']
        expected_valid_dict = {'HLA-A*02:01': ['9']}
        expected_invalid_dict = {'HLA-A*02:01': ['10', '150']}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="comblib_sidney2008", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)

        iedb_label = ['HLA-A*02:01', 'HLA-B*27:05']
        lengths = ['9','10','150']
        expected_valid_dict = {'HLA-A*02:01': ['9'], 'HLA-B*27:05': ['9']}
        expected_invalid_dict = {'HLA-A*02:01': ['10', '150'], 'HLA-B*27:05': ['10', '150']}
        valid_dict, invalid_dict = self.validator.validate_allele_lengths(iedb_label=iedb_label, method="comblib_sidney2008", lengths=lengths)
        self.assertDictEqual(expected_valid_dict, valid_dict)
        self.assertDictEqual(expected_invalid_dict, invalid_dict)


if __name__ == "__main__":
    unittest.main()