import logging
from method_version_info import MHCIMethod,MHCIIMethod


class Alleles(object):
    """
    Created  on 2016-08-25.  Yan
    @brief: Alleles Class for MHC II which turns input data to alleles. 

    >>> a=Alleles(allele='DPA1*01/DPB1*04:01    ')
    >>> a.allele_list
    ['DPA1*01-DPB1*04:01']

    >>> a=Alleles(allele=['  HLA-DRB1*01:01', ' H2-IAb  '])
    >>> a.allele_list
    ['DRB1*01:01', 'H2-IAb']

    >>> a=Alleles()
    >>> a.allele_list
    []

    >>> a=Alleles(allele='')
    >>> a.allele_list
    []
    """

    def __init__(self, **kwargs):
        """
        Function which chould accept 0-3 parameters.
        """
        allele = kwargs.get('allele', [])
        allele_a = kwargs.get('allele_a', [])
        allele_b = kwargs.get('allele_b', [])
        if type(allele) == list and len(allele) == 1:
            allele = allele[0]
        if type(allele_a) == list and len(allele_a) == 1:
            allele_a = allele_a[0]
        if type(allele_b) == list and len(allele_b) == 1:
            allele_b = allele_b[0]
            
        if not type(allele) in [list, str,]:
            raise TypeError ('alleles input must be str or list. But it is %s, and its type is %s.' % (allele,type(allele)))
            
        allele = allele if ( type(allele) is list) else  allele.split(',')

        allele = [item.strip() for item in allele]
        allele_a = [item.strip() for item in allele_a]
        allele_b = [item.strip() for item in allele_b]

        allele_ab = ['-'.join(ab) for ab in zip(allele_a, allele_b)]
        allele = allele + allele_ab
        
        # remove any 'HLA-' prefix in allele name, and replace '/' with '-' in between alpha-beta chains
        allele = [a.replace("HLA-", "").replace("/", "-") for a in allele]

        positives = []
        negatives = []
        duplicates = [] 
  
        duplicates = [d for d in allele if allele.count(d) > 1]
        duplicates = list(set(duplicates))        
        
        self.duplicates = duplicates
        self.negatives = negatives
        self.allele_list = filter(None, allele)
        
        

class Proteins(object):
    """
    Contains a list of protein sequences and names and several conversion functions.    
    >>> p=Proteins(">TestProtein\\nFNCLGMSNRDFLEGVSG")
    >>> p.sequences
    ['FNCLGMSNRDFLEGVSG']
    >>> p.names
    ['TestProtein']
    >>> p=Proteins(">TestProtein1\\nFNCLGMSNRDFLEGVSG\\n>TestProtein2\\nFNCLGMSNRDFLEGVSG")
    >>> p.transfer_to_fasta_list()
    ['>TestProtein1\\nFNCLGMSNRDFLEGVSG', '>TestProtein2\\nFNCLGMSNRDFLEGVSG']

    """
    def __init__(self, input_sequences, sequence_format='auto'):
        """
        Function accept 2 parameters, and the format could be one the 3: fasta, one_sequence, space_separated. 
        If only sequences was give, the function will try to recognize the format of it.
        """
        self.sequences = []
        self.names = []
        if not input_sequences == None:
            input_sequences = input_sequences.strip()
            self.extractForm(input_sequences, sequence_format)

    def add_protein(self, sequence, name=""):
        """
        adding one protein sequence to the instance with validation.
        """
        sequence = sequence.strip().upper()
        for amino_acid in sequence:
            if not amino_acid in "ACDEFGHIKLMNPQRSTVWY":
                raise ValueError("Sequence: '%s' contains an invalid character: '%c' at position %d." % (sequence, amino_acid, sequence.find(amino_acid)))
        self.sequences.append(sequence) 
        if not name:
            name = "sequence %d" % (len(self.sequences))
        name = str(name)
        self.names.append(name)

    def extractFasta(self, fasta):
        """
        To extact sequences from the input string with fasta format.
        """
        input_sequences = fasta.split(">")
        if len(input_sequences) < 2:
            raise ValueError("Invalid FASTA format: No '>' found.")
        for i in input_sequences[1:]:
            if(len(i) > 0):
                end_of_name = i.find("\n")
                if end_of_name == -1:
                    raise ValueError("Invalid FASTA format: No Protein sequence found between two names.")
                name = i[:end_of_name]
                seq = i[end_of_name:].split()
                self.add_protein("".join(seq), name)

    def convert_first_seq(self, seq):
        """
        To convert the first sequence into FASTA by appending a temporary header
        """
        seq = '\n'.join([x for x in seq.split("\n") if x.strip() != ''])
        seq_list = seq.split('\r\n')
        if seq_list[0].startswith('>') == False:
            seq_list.insert(0, ">sequence 1")
        seq_list = '\n'.join([str(x) for x in seq_list])
        return seq_list

    def transfer_to_fasta_list(self):
        """
        TO transter the instance to a list of fasta format sequences.
        """
        return [">"+"\n".join([name, seq]) for (name, seq) in zip(self.names, self.sequences)]

    def extractForm(self, input_sequences, sequence_format):
        """
        To extract data from the input parameters. 
        This will be called be __init__.py
        """
        if sequence_format == "auto":
            if ">" in input_sequences:
                # convert the first sequence into FASTA by appending a temporary header
                input_sequences = self.convert_first_seq(input_sequences)
                sequence_format = "fasta"
            else:
                seq = input_sequences.split()
                sequence_format = "space_separated"

        if sequence_format == "fasta":
            self.extractFasta(input_sequences)
        elif sequence_format == "one_sequence":
            seq = input_sequences.split()
            self.add_protein("".join(seq))
        elif sequence_format == "space_separated":
            seqs = input_sequences.split()
            for seq in seqs:
                self.add_protein(seq)

    def all_seq_len_less_than(self, length):        
        """
        If all_seq_len_less_than a certain length.
        """
        for seq in self.sequences:
        # if there is even one sequence longer then minimum, avoid the error message
            if len(seq) >= length:
                return False
        return True 

    def get_protein_seq(self, seq_nums=None):
        """
        return a list of protein sequences with (peptide_index, name, sequence)'
        """
        result = []
        for i, (name, sequence) in enumerate(zip(self.names, self.sequences)):
            peptide_index = seq_nums[i] if seq_nums else (i + 1)
            blocklen = 50
            if len(sequence) > blocklen:
                seqblock = ""
                for block in range(0, len(sequence) - blocklen, blocklen):
                    seqblock += """%s\n""" % sequence[block:block + blocklen]
                seqblock += sequence[block + blocklen:]
                result.append(tuple([peptide_index, name, seqblock]))
            else:
                result.append(tuple([peptide_index, name, sequence]))
        return result
        
if __name__ == "__main__":
    import doctest
    doctest.testmod()
