import re
import logging

class ConsensusPredictor():
    pass
class ANNPredictor():
    pass
class ANN34Predictor():
    pass
class SMMPredictor():
    pass
class SMMPMBECPredictor():
    pass
class ComblibSidney2008Predictor():
    pass
class NetMHCPanPredictor():
    pass
class NetMHCPan28Predictor():
    pass
class NetMHCPan4Predictor():
    pass
class PickPocketPredictor():
    pass
class NetMHCConsPredictor():
    pass
class NetMHCStabPanPredictor():
    pass

class ProcessingANNPredictor():
    pass
class ProcessingSMMPredictor():
    pass
class ProcessingComblibSidney2008Predictor():
    pass
class ProcessingNetMHCPanPredictor():
    pass
class ProcessingPickPocketPredictor():
    pass
class ProcessingNetMHCConsPredictor():
    pass
class ProcessingNetMHCStabPanPredictor():
    pass
class NetMHCPan4elPredictor():
    pass
class NetMHCPan41Predictor():
    pass
class NetMHCPan41elPredictor():
    pass

try:
    from mhci.predictors.ConsensusPredictor import ConsensusPredictor as MHCIConsensusPredictor
    from mhci.predictors.ANNPredictor import ANNPredictor as MHCIANNPredictor
    from mhci.predictors.SMMPredictor import SMMPredictor as MHCISMMPredictor
    from mhci.predictors.ComblibSidney2008Predictor import ComblibSidney2008Predictor as MHCIComblibSidney2008Predictor
    #from mhci.predictors.NetMHCPanPredictor import NetMHCPanPredictor as MHCINetMHCPanPredictor
    from mhci.predictors.PickPocketPredictor import PickPocketPredictor as MHCIPickPocketPredictor
    from mhci.predictors.NetMHCConsPredictor import NetMHCConsPredictor as MHCINetMHCConsPredictor
    from mhci.predictors.NetMHCStabPanPredictor import NetMHCStabPanPredictor as MHCINetMHCStabPanPredictor
    #from mhci.predictors.ANN34Predictor import ANN34Predictor as MHCIANN34Predictor
    from mhci.predictors.NetMHCPan4Predictor import NetMHCPan4Predictor
    from mhci.predictors.NetMHCPan4elPredictor import NetMHCPan4elPredictor
    from mhci.predictors.NetMHCPan41Predictor import NetMHCPan41Predictor
    from mhci.predictors.NetMHCPan41elPredictor import NetMHCPan41elPredictor
    #from mhci.predictors.NetMHCPan28Predictor import NetMHCPan28Predictor

    #from processing.predictors.ANNPredictor import ANNPredictor as ProcessingANNPredictor
    #from processing.predictors.SMMPredictor import SMMPredictor as ProcessingSMMPredictor
    #from mhci.predictors.SMMPMBECPredictor import SMMPMBECPredictor
    #from processing.predictors.ComblibSidney2008Predictor import ComblibSidney2008Predictor as ProcessingComblibSidney2008Predictor
    #from processing.predictors.NetMHCPanPredictor import NetMHCPanPredictor as ProcessingNetMHCPanPredictor
    #from processing.predictors.PickPocketPredictor import PickPocketPredictor as ProcessingPickPocketPredictor
    #from processing.predictors.NetMHCConsPredictor import NetMHCConsPredictor as ProcessingNetMHCConsPredictor
    #from processing.predictors.NetMHCStabPanPredictor import NetMHCStabPanPredictor as ProcessingNetMHCStabPanPredictor

    from mhci.predictors.ANNPredictor import ANNPredictor
    from mhci.predictors.ConsensusPredictor import ConsensusPredictor
    from mhci.predictors.SMMPredictor import SMMPredictor
    from mhci.predictors.SMMPMBECPredictor import SMMPMBECPredictor
    from mhci.predictors.ComblibSidney2008Predictor import ComblibSidney2008Predictor
    #from mhci.predictors.NetMHCPanPredictor import NetMHCPanPredictor
    from mhci.predictors.PickPocketPredictor import PickPocketPredictor
    from mhci.predictors.NetMHCConsPredictor import NetMHCConsPredictor
    from mhci.predictors.NetMHCStabPanPredictor import NetMHCStabPanPredictor

    #from mhci.predictors.ANN34Predictor import ANN34Predictor
    from mhci.predictors.NetMHCPan4Predictor import NetMHCPan4Predictor
    #from mhci.predictors.NetMHCPan28Predictor import NetMHCPan28Predictor
except:
    pass



mhci_predictor_class_by_method_name = {
    'consensus': ConsensusPredictor,
    'ann': ANNPredictor,
    'ann 3.4': ANN34Predictor,
    'smm': SMMPredictor,
    'smmpmbec': SMMPMBECPredictor,
    'comblib_sidney2008': ComblibSidney2008Predictor,

    'netmhcpan 2.8': NetMHCPan28Predictor,

    'netmhcpan 4.0': NetMHCPan4Predictor,

    'netmhcpan_ba 4.0': NetMHCPan4Predictor,

    'netmhcpan_el 4.0': NetMHCPan4elPredictor,

    # default version
    'netmhcpan': NetMHCPan41Predictor,
    'netmhcpan_ba': NetMHCPan41Predictor,
    'netmhcpan_el': NetMHCPan41elPredictor,      
    'netmhcpan 4.1': NetMHCPan41Predictor,
    'netmhcpan_ba 4.1': NetMHCPan41Predictor,
    'netmhcpan_el 4.1': NetMHCPan41elPredictor,

    'pickpocket': PickPocketPredictor,
    'netmhccons': NetMHCConsPredictor,
    'netmhcstabpan': NetMHCStabPanPredictor,
    'recommended 2020.04': NetMHCPan4elPredictor,
    'recommended 2020.09': NetMHCPan41elPredictor,
    'recommended 2.18':      'predictor for recommended 2.18 not exist',
    'recommended 2.19':      'predictor for recommended 2.19 not exist',
    'recommended 2.22':      'predictor for recommended 2.22 not exist',

}



processing_predictor_by_method_name = {
    'ann': ProcessingANNPredictor,
    'smm': ProcessingSMMPredictor,
    'comblib_sidney2008': ProcessingComblibSidney2008Predictor,
    'netmhcpan': ProcessingNetMHCPanPredictor,
    'pickpocket': ProcessingPickPocketPredictor,
    'netmhccons': ProcessingNetMHCConsPredictor,
    'netmhcstabpan': ProcessingNetMHCStabPanPredictor,
}

# for all these 3, the first version is the default one.
MHCI_METHOD_VERSIONS = {
'recommended':      ['2020.09', '2020.04', ],
'consensus':        ['2.18',],
'ann':             ['4.0',],
'netmhcpan':       ['4.1', '4.0'],
'netmhcpan_el':       ['4.1', '4.0'],
'netmhcpan_ba':       ['4.1', '4.0'],
'pickpocket':      ['1.1'],
'netmhccons' :     ['1.1'],
'netmhcstabpan':   ['1.0'],
'comblib_sidney2008':   ['1.0'],
'smm':   ['1.0'],
'smmpmbec':   ['1.0'],
#'arb':   ['1.0'],
}

MHCII_METHOD_VERSIONS = {
'recommended':      ['2.22','2.18'],
'consensus':        ['2.22','2.18'],
'netmhciipan':      ['4.1', '4.3', '4.2', '4.0', '3.2','3.1'],
'netmhciipan_el':       ['4.1', '4.3', '4.2', '4.0',],
'netmhciipan_ba':       ['4.1', '4.3', '4.2', '4.0',],
'smm_align':        ['1.1'],
'nn_align':         ['2.3','2.2'],
'comblib':          ['1.0'],
'tepitope':         ['1.0'],
}

PROCESSING_METHOD_VERSIONS = {
'recommended':      ['2.18'],
'consensus':        ['2.18'],
'ann':             ['4.0','3.4'],
'netmhcpan':       ['3.0','2.8','4.0'],
'pickpocket':      ['1.1'],
'netmhccons' :     ['1.1'],
'netmhcstabpan':   ['1.0'],
}

def get_method_version(mhc_class, method,version=None):
    '''return method and version'''
    if mhc_class == 'mhci':
        method_version_dict = MHCI_METHOD_VERSIONS
    elif mhc_class == 'mhcii':
        method_version_dict = MHCII_METHOD_VERSIONS
    elif mhc_class == 'processing':
        method_version_dict = PROCESSING_METHOD_VERSIONS
    else:
        raise ValueError('wrong mhc class "%s"' % mhc_class)
    if version:
        method_version_combo = (method.split()[0], version) 
    else: 
        method_version_combo = filter(None,re.compile('[\-\s]+').split(method))
    
    method_version_combo = list(method_version_combo)
    if len(method_version_combo) == 2:
        method_name,method_version = method_version_combo
        if method_version in method_version_dict.get(method_name,[]):
            return (method_name, method_version)
    elif len(method_version_combo) == 1:
        method_name = method_version_combo[0]
        if method_name in method_version_dict:
            default_version = method_version_dict[method_name][0]
            return (method_name, default_version)
        else:
            logging.warning('method "%s" is not in method version dict' % method_name)

class MHCIMethod(object):
    """
    This class can accept 3 kind of input for init. id like "434", method_name like "ann 3.4", or "method_name='ann',version='3.4'".
    It need to be noticed that the id can only equal 1~11 at last. (JY)
    >>> mm=MHCIMethod('netmhcpan 4.0')
    >>> print mm.name()
    netmhcpan
    >>> print mm.id()
    3
    >>> print mm.version
    4.0
    >>> print mm.predictor
    __main__.NetMHCPan4Predictor
    >>> mm=MHCIMethod('netmhcpan-4.0')
    >>> print mm.name()
    netmhcpan
    >>> print mm.id()
    3
    >>> print mm.version
    4.0
    >>> print mm.predictor
    __main__.NetMHCPan4Predictor
    >>> mm=MHCIMethod('434')
    >>> print mm.name()
    ann
    >>> print mm.id()
    4
    >>> print mm.version
    3.4
    >>> print mm.predictor
    __main__.ANN34Predictor
    >>> mm=MHCIMethod('ann')
    >>> print mm.name()
    ann
    >>> print mm.id()
    4
    >>> print mm.version
    4.0
    >>> print mm.predictor
    __main__.ANNPredictor
    >>> mm=MHCIMethod('ann', '3.4')
    >>> print mm.name()
    ann
    >>> print mm.id()
    4
    >>> print mm.version
    3.4
    >>> print mm.predictor
    __main__.ANN34Predictor
    """
    # TODO: This should be pulled from a canonical source.
    #    (Something like a Database or data file in a shared dependency application.
    method_dict = {1:'recommended', 2:'consensus', 3:'netmhcpan', 4:'ann', 5:'smmpmbec',
                   6:'smm', 7:'comblib_sidney2008', 8:'arb', 9:'pickpocket',
                   10:'netmhccons', 11:'netmhcstabpan',
                   340:'netmhcpan 4.0', 34:'netmhcpan 4.1', 
                   341:'netmhcpan_el', 342:'netmhcpan_ba',
                    3401:'netmhcpan_el 4.0',3402:'netmhcpan_ba 4.0',
                    3411:'netmhcpan_el 4.1',3412:'netmhcpan_ba 4.1',
                    328:'netmhcpan 2.8',330:'netmhcpan 3.0',434:'ann 3.4', 440:'ann 4.0', 
                    1218:'recommended 2.18',1219:'recommended 2.19', 2218:'consensus 2.18', 
                    1222:'recommended 2.22', }
    _method_id = None
    _method_name = None
    _version = None
    _predictor = None
    def __init__(self, method_name_or_id, version=None):
        super(MHCIMethod, self).__init__()
        logging.info('method_name_or_id: %s' % method_name_or_id)
        method_name_or_id = method_name_or_id.replace('-',' ')
        logging.info('version: %s' % str(version))        
        try:
            method_name_or_id = int(method_name_or_id)
        except:
            pass

        if type(method_name_or_id) is int:
            if method_name_or_id not in self.method_dict.keys():
                raise ValueError('{} is not a valid MHC-I method id'.format(method_name_or_id))
            self._method_id = method_name_or_id
            self._method_name = self.method_dict[self._method_id]
        else:    
            if method_name_or_id.lower() in self.method_dict.values():
                self._method_name = method_name_or_id.lower()
            else:
                method,version = method_name_or_id.lower().split()
                method_version = method_name_or_id.lower().split()
                method = method_version[0]
                if len(method_version) == 2:
                    version = method_version[1]
                    if version in MHCI_METHOD_VERSIONS[method]:
                        self._method_name = method  
                        self._version = version            
                    else:
                        raise ValueError('{} is not a valid MHC-I method name'.format(method_name_or_id))
                elif method not in list(self.method_dict.values()):
                    raise ValueError('{} is not a valid MHC-I method name'.format(method_name_or_id))
            method_idx = list(self.method_dict.values()).index(self._method_name)
            self._method_id = list(self.method_dict.keys())[method_idx]
        logging.debug('self._method_name: %s' % str(self._method_name))
        logging.debug('version: %s' % str(version))
        method_version_combo = get_method_version('mhci', self._method_name, version)
        if not method_version_combo:
            raise ValueError('not valid method name "%s" or version "%s"' % (self._method_name, version))
        self._method_name, self._version = method_version_combo

        # get id again (e.g. use "4" instead of "434)
        method_idx = list(self.method_dict.values()).index(self._method_name)
        self._method_id = {v:k for k,v in self.method_dict.items()}[self._method_name]

    def id(self):
        return self._method_id

    def name(self):
        return self._method_name

    @property
    def version(self):
        return self._version

    @property
    def predictor(self):
        if not self._predictor:
            if self._version and self._method_name+' '+self._version in mhci_predictor_class_by_method_name:
                self._predictor = mhci_predictor_class_by_method_name[self._method_name+' '+self._version]
            elif self._method_name in mhci_predictor_class_by_method_name:
                self._predictor = mhci_predictor_class_by_method_name[self._method_name]
            else:
                raise ValueError('could not find predictor for method name "%s" and version "%s"' % (self._method_name, self._version))
        return self._predictor


class MHCIIMethod(object):
    """
    This class can accept 3 kind of input for init. id like "332", method_name like "netmhciipan 3.1", or "method_name='netmhciipan',version='3.2'".
    It need to be noticed that the id can only equal 1~7 at last. (JY)
    >>> mm=MHCIIMethod('NetMHCIIpan 3.2')
    >>> print mm.name()
    netmhciipan
    >>> print mm.id()
    3
    >>> print mm.version
    3.2
    >>> print mm.predictor
    None
    >>> mm=MHCIIMethod('NetMHCIIpan-3.1')
    >>> print mm.name()
    netmhciipan
    >>> print mm.id()
    3
    >>> print mm.version
    3.1
    >>> print mm.predictor
    None
    >>> mm=MHCIIMethod('331')
    >>> print mm.name()
    netmhciipan
    >>> print mm.id()
    3
    >>> print mm.version
    3.1
    >>> print mm.predictor
    None
    >>> mm=MHCIIMethod('nn_align')
    >>> print mm.name()
    nn_align
    >>> print mm.id()
    4
    >>> print mm.version
    2.2
    >>> print mm.predictor
    None
    >>> mm=MHCIIMethod('nn_align', '2.2')
    >>> print mm.name()
    nn_align
    >>> print mm.id()
    4
    >>> print mm.version
    2.2
    >>> print mm.predictor
    None
    """
    # TODO: This should be pulled from a canonical source.
    #    (Something like a Database or data file in a shared dependency application.
    method_dict = { 
        1:'recommended', 2:'consensus', 3:'netmhciipan', 4:'nn_align', 5:'smm_align', 6:'comblib', 7:'tepitope',
        331:'netmhciipan 3.1',332:'netmhciipan 3.2', 3401:'netmhciipan_el', 3402:'netmhciipan_ba', 
        422:'nn_align 2.2', 423:'nn_align 2.3',
        1222:'recommended 2.22', 1218:'recommended 2.18', 2222:'consensus 2.22', 2218:'consensus 2.18',
     }
    _method_id = None
    _method_name = None
    _version = None
    _predictor = None
    def __init__(self, method_name_or_id, version=None):
        super(MHCIIMethod, self).__init__()        
        try:
            method_name_or_id = int(method_name_or_id)
        except:
            pass

        if type(method_name_or_id) is int:
            if method_name_or_id not in self.method_dict.keys():
                raise ValueError('{} is not a valid MHC-II method id'.format(method_name_or_id))
            self._method_id = method_name_or_id
            self._method_name = self.method_dict[self._method_id]
        else: 
            method_name_or_id = method_name_or_id.replace('-',' ')
            # make netmhciipan-4.1 default method "netmhciipan_el-4.1"
            method_name_or_id = method_name_or_id.replace("netmhciipan-4.1", "netmhciipan_el-4.1")
            if method_name_or_id.lower() in self.method_dict.values():
                self._method_name = method_name_or_id.lower()
            elif len(method_name_or_id.lower().split()) > 1:
                method,version = method_name_or_id.lower().split()
                if version in MHCII_METHOD_VERSIONS[method]:
                    self._method_name = method              
                else:
                    raise ValueError('{} is not a valid MHC-II method name'.format(method_name_or_id))
            else:
                raise ValueError('Selected prediction method "%s" does not exist.' % method_name_or_id)
            method_idx = list(self.method_dict.values()).index(self._method_name)
            self._method_id = list(self.method_dict.keys())[method_idx]
        method_version_combo = get_method_version('mhcii', self._method_name, version)
        if not method_version_combo:
            raise ValueError('not valid method name "%s" or version "%s"' % (self._method_name, version))
        self._method_name, self._version = method_version_combo
        # make netmhciipan-4.1 default method "netmhciipan_el-4.1"
        if self._method_name == 'netmhciipan' and self._version == '4.1':
        	self._method_name == 'netmhciipan_el' 

        # get id again (e.g. use "4" instead of "434)
        method_idx = list(self.method_dict.values()).index(self._method_name)
        self._method_id = {v:k for k,v in self.method_dict.items()}[self._method_name]

    def id(self):
        return self._method_id

    def name(self):
        return self._method_name

    @property
    def version(self):
        return self._version

    @property
    def predictor(self):
        return self._predictor


if __name__ == '__main__':
    import doctest
    doctest.testmod()
