import unittest
import sys
from pathlib import Path
PROJECT_DIR = str(Path(__file__).resolve().parents[1])
sys.path.insert(1, PROJECT_DIR)
from allele_validator import Allele_Validator


class TestConversion(unittest.TestCase) :
    def test_convert_iedblabel_to_methodlabel(self) :
        validator = Allele_Validator()

        ''' Single input '''
        # Valid case - Single input (str)
        iedb_labels = "BoLA-1*023:01"
        expected_result = 'BoLA-1:02301'
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan")
        self.assertEqual(expected_result, result)

        iedb_labels = "SLA-1*01:01"
        expected_result = 'SLA-1*0101'
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="pickpocket")
        self.assertEqual(expected_result, result)

        # Testing netCTLpan method
        iedb_labels = "HLA-A*01:01"
        expected_result = 'HLA-A01:01'
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netctlpan")
        self.assertEqual(expected_result, result)


        # Valid case - Single input (list)
        iedb_labels = ["BoLA-1*023:01"]
        expected_result = ['BoLA-1:02301']
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        # Invalid case
        iedb_labels = "Wrong Label"
        expected_result = None
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan")
        self.assertEqual(expected_result, result)

        iedb_labels = ["Wrong Label"]
        expected_result = [None]
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan")
        self.assertListEqual(expected_result, result)


        ''' Multiple inputs'''
        # Valid case - Multiple inputs
        iedb_labels = ["BoLA-1*023:01", "BoLA-2*016:01", "BoLA-2*026:02"]
        expected_result = ['BoLA-1:02301', 'BoLA-2:01601', 'BoLA-2:02602']
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        iedb_labels = ["BoLA-1*023:01", "Wrong Label", "BoLA-2*026:02"]
        expected_result = ['BoLA-1:02301', None, 'BoLA-2:02602']
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        # Wrong tools group
        iedb_labels = ["BoLA-1*023:01", "BoLA-2*016:01", "BoLA-2*026:02"]
        expected_result = [None, None, None]
        result = validator.convert_iedblabel_to_methodlabel(iedb_labels=iedb_labels, method="netmhcpan", tools_group="mhcii")
        self.assertListEqual(expected_result, result)

        # Invalid case
        self.assertRaises(ValueError, validator.convert_iedblabel_to_methodlabel, iedb_labels=iedb_labels, method="netmhcpan", tools_group="mmhci")

    def test_convert_methodlabel_to_iedblabel(self) :
        validator = Allele_Validator()
        
        ''' Single input '''
        # Valid case - Single input (str)
        method_labels = 'BoLA-1:02301'
        expected_result = "BoLA-1*023:01"
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan")
        self.assertEqual(expected_result, result)

        method_labels = 'SLA-1:0101'
        expected_result = 'SLA-1*01:01'
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan")
        self.assertEqual(expected_result, result)

        # Valid case - Single input (list)
        method_labels = ['BoLA-1:02301']
        expected_result = ["BoLA-1*023:01"]
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        # Invalid case
        method_labels = "Wrong Label"
        expected_result = None
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan")
        self.assertEqual(expected_result, result)

        method_labels = ["Wrong Label"]
        expected_result = [None]
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan")
        self.assertListEqual(expected_result, result)


        ''' Multiple inputs'''
        # Valid case - Multiple inputs
        method_labels = ['BoLA-1:02301', 'BoLA-2:01601', 'BoLA-2:02602']
        expected_result = ["BoLA-1*023:01", "BoLA-2*016:01", "BoLA-2*026:02"]
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        
        method_labels = ['BoLA-1:02301', "Wrong Label", 'BoLA-2:02602']
        expected_result = ["BoLA-1*023:01", None, "BoLA-2*026:02"]
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        # Wrong tools group
        method_labels = ['BoLA-1:02301', 'BoLA-2:01601', 'BoLA-2:02602']
        expected_result = [None, None, None]
        result = validator.convert_methodlabel_to_iedblabel(method_labels=method_labels, method="netmhcpan", tools_group="mhcii")
        self.assertListEqual(expected_result, result)

        # Invalid case
        self.assertRaises(ValueError, validator.convert_methodlabel_to_iedblabel, method_labels=method_labels, method="netmhcpan", tools_group="mmhci")



    def test_convert_mro_to_methodlabel(self) :
        validator = Allele_Validator()
        validator.data

        ''' Single input '''
        # Valid case
        mro_ids = "MRO:0036770"
        expected_result = 'BoLA-1:00901'
        result = validator.convert_mroid_to_methodlabel(mro_ids=mro_ids, method="netmhcpan", tools_group="mhci")
        self.assertEqual(expected_result, result)

        mro_ids = ["MRO:0036770"]
        expected_result = ['BoLA-1:00901']
        result = validator.convert_mroid_to_methodlabel(mro_ids=mro_ids, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        # Invalid case
        mro_ids = "MRO:000000"
        expected_result = None
        result = validator.convert_mroid_to_methodlabel(mro_ids=mro_ids, method="netmhcpan")
        self.assertEqual(expected_result, result)


        ''' Multiple inputs '''
        # Valid case
        mro_ids = ["MRO:0036770", "MRO:0036638", "MRO:0036921"]
        expected_result = ['BoLA-1:00901', 'BoLA-1:00902', 'BoLA-1:02001']
        result = validator.convert_mroid_to_methodlabel(mro_ids=mro_ids, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

        mro_ids = ["MRO:0036770", "MRO:000000", "MRO:0036921"]
        expected_result = ['BoLA-1:00901', None, 'BoLA-1:02001']
        result = validator.convert_mroid_to_methodlabel(mro_ids=mro_ids, method="netmhcpan", tools_group="mhci")
        self.assertListEqual(expected_result, result)

    def test_convert_methodlabel_to_mroid(self) :
        validator = Allele_Validator()
        validator.data

        ''' Single input '''
        # Valid case
        tools_label = "BoLA-1:00901"
        expected_result = "MRO:0036770"
        result = validator.convert_methodlabel_to_mroid(tools_label=tools_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        tools_label = "SLA-1:0101"
        expected_result = "MRO:0040794"
        result = validator.convert_methodlabel_to_mroid(tools_label=tools_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        tools_label = ["BoLA-1:00901"]
        expected_result = ["MRO:0036770"]
        result = validator.convert_methodlabel_to_mroid(tools_label=tools_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        # Invalid case
        tools_label = "BoLA-1"
        expected_result = None
        result = validator.convert_methodlabel_to_mroid(tools_label=tools_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        ''' Multiple inputs '''
        # Valid case
        tools_label = ["BoLA-1:00901", "BoLA-2:04402", "HLA-A29:01"]
        expected_result = ['MRO:0036770', 'MRO:0036418', 'MRO:0001041']
        result = validator.convert_methodlabel_to_mroid(tools_label=tools_label, method="netmhcpan")
        self.assertListEqual(expected_result, result)

        tools_label = ["BoLA-1:00901", "BoLA:000000", "HLA-A29:01"]
        expected_result = ['MRO:0036770', None, 'MRO:0001041']
        result = validator.convert_methodlabel_to_mroid(tools_label=tools_label, method="netmhcpan")
        self.assertListEqual(expected_result, result)

    def test_convert_iedblabel_to_mroid(self) :
        validator = Allele_Validator()
        validator.data
        
        ''' Single input '''
        # Valid case
        iedb_label = "BoLA-1*009:01"
        expected_result = "MRO:0036770"
        result = validator.convert_iedblabel_to_mroid(iedb_label=iedb_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        iedb_label = ["BoLA-1*009:01"]
        expected_result = ["MRO:0036770"]
        result = validator.convert_iedblabel_to_mroid(iedb_label=iedb_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        # Invalid case
        iedb_label = "BoLA-1"
        expected_result = None
        result = validator.convert_iedblabel_to_mroid(iedb_label=iedb_label, method="netmhcpan")
        self.assertEqual(expected_result, result)

        ''' Multiple inputs '''
        # Valid case
        iedb_label = ["BoLA-1*009:01", "BoLA-2*044:02", "HLA-A*29:01"]
        expected_result = ['MRO:0036770', 'MRO:0036418', 'MRO:0001041']
        result = validator.convert_iedblabel_to_mroid(iedb_label=iedb_label, method="netmhcpan")
        self.assertListEqual(expected_result, result)

        iedb_label = ["BoLA-1*009:01", "BoLA*0000:00", "HLA-A*29:01"]
        expected_result = ['MRO:0036770', None, 'MRO:0001041']
        result = validator.convert_iedblabel_to_mroid(iedb_label=iedb_label, method="netmhcpan")
        self.assertListEqual(expected_result, result)

    def test_convert_synonym_to_iedblabel(self) :
        validator = Allele_Validator()
        validator.data

        ''' Single input '''
        synonym = "HLA-A02:01"
        expected_result = "HLA-A*02:01"
        result = validator.convert_synonym_to_iedblabel(synonym=synonym)
        self.assertEqual(expected_result, result)

        synonym = "N*01801"
        expected_result = "BoLA-2*018:01"
        result = validator.convert_synonym_to_iedblabel(synonym=synonym)
        self.assertEqual(expected_result, result)

        synonym = ["N*01801"]
        expected_result = ["BoLA-2*018:01"]
        result = validator.convert_synonym_to_iedblabel(synonym=synonym)
        self.assertEqual(expected_result, result)

        ''' Multiple inputs '''
        synonym = ["N*01801", "invalid", "BoLA-DQA*10011"]
        expected_result = ["BoLA-2*018:01", None, "BoLA-DQA*010:01:01"]
        result = validator.convert_synonym_to_iedblabel(synonym=synonym)
        self.assertEqual(expected_result, result)

        ''' Inputs that have synonym and IEDB Label as same'''
        synonym = "HLA-C*05:238"
        expected_result = "HLA-C*05:238"
        result = validator.convert_synonym_to_iedblabel(synonym=synonym)
        self.assertEqual(expected_result, result)

        synonym = ["HLA-C*05:238", "HLA-Cw*0602", "invalid", "HLA-C*06:06"]
        expected_result = ["HLA-C*05:238", "HLA-C*06:02", None, "HLA-C*06:06"]
        result = validator.convert_synonym_to_iedblabel(synonym=synonym)
        self.assertEqual(expected_result, result)


if __name__ == "__main__" :
    unittest.main()