


import tempfile
import os
import math
import operator
from collections import OrderedDict, namedtuple
from mhcipredictor import MHCIPredictor
MHCIPredictionParams = namedtuple('MHCIPredictionParams',
                                  ['method_name', 'sequence', 'allele', 'binding_length'])


class ProcessingPredictor(object):
    """Generates proteasome and TAP predictions Mamu predictions"""

    def __init__(self, data, allele_length_list, method_name=None):
        try:
            self.tap_precursor = int(data.get('tap_precursor'))
        except:
            raise InputError('Please enter an integer number in the tap precursor field')
        try:
            self.tap_alpha = float(data.get('tap_alpha'))
        except:
            raise InputError('Please enter a number in the tap alpha field')
        self.proteasome_type = data.get('proteasome')

        self.allele_length_list = allele_length_list
        self.method_name = method_name

        self.offset = None
        self.length = None
        self.mat = {}

        l = {}
        # this if statement is merely to get check if hla sequence is given instead of allele(s) from
        # drop-down option to calculate the TAP scores (just a quick fix)
        single_mhc_seq = data.get('hla_seq')
        if single_mhc_seq == '':
            single_mhc_seq = None  # if hla_seq is an empty string, assign value None

        if single_mhc_seq is not None:
            self.peplengths = data.getlist('hla_len')
        else:
            for allele, length in self.allele_length_list:
                l[length] = 1
            self.peplengths = sorted(l.keys())

    def calculate_processing_score(self, tap_score, proteasome_score):
        method_name = "processing"
        tap_proteasome_scores = zip(tap_score.values(), proteasome_score.values())

        # using tap_score dictionary as template to generate 'params' key for 'processing_score' value
        # TODO: more elegant solution?
        prediction_input = []
        for params in tap_score.keys():
            _prediction_input = MHCIPredictionParams(
                method_name, params.sequence, params.allele, params.binding_length
            )
            prediction_input.append(_prediction_input)
        return self.calculate_score(prediction_input, tap_proteasome_scores)

    def calculate_total_score(self, mhc_score, tap_score, proteasome_score):
        method_name = "total"
        mhc_tap_proteasome_scores = zip(mhc_score.values(), tap_score.values(), proteasome_score.values())
        # using mhc_score dictionary as template to generate 'params' key for 'total_score' value
        # TODO: more elegant solution?
        prediction_input = []
        for params in mhc_score.keys():
            _prediction_input = MHCIPredictionParams(
                method_name, params.sequence, params.allele, params.binding_length
            )
            prediction_input.append(_prediction_input)
        return self.calculate_score(prediction_input, mhc_tap_proteasome_scores)

    def calculate_score(self, _prediction_input, _input_scores):
        results = {}
        for i, _scores in enumerate(_input_scores):
            total_scores = ([sum(combo) for combo in zip(*_scores)])
            results.update({_prediction_input[i]: total_scores})
        results = OrderedDict(sorted(results.items(), key=operator.itemgetter(0)))
        return results

    def calculate_mhc_score(self, ic50_result):
        """ | *author*: Dorjee
            | *created*: 04-26-2017
            | *brief*: Method to perform ic50 score prediction.
        """
        method_name = "mhc"
        results = {}
        for params, scores in ic50_result.items():
            scores = (-math.log10(score) for score in scores)
            prediction_input = MHCIPredictionParams(
                method_name, params.sequence, params.allele, params.binding_length
            )
            results.update({prediction_input: tuple(scores)})
        results = OrderedDict(sorted(results.items(), key=operator.itemgetter(0)))
        return results

    def calculate_proteasome_score(self, sequence_list):
        """ | *author*: Dorjee
            | *created*: 04-25-2017
            | *brief*: Method to perform proteasome prediction.
        """
        method_name = "proteasome"
        predictor_dir_path = os.path.abspath(os.path.join(os.path.realpath(__file__), os.pardir))

        if self.proteasome_type == 'immuno':
            mat_name = os.path.join(predictor_dir_path, 'mat-prot-ec-i.txt')
        elif self.proteasome_type == 'constitutive':
            mat_name = os.path.join(predictor_dir_path, 'mat-prot-ec-c.txt')
        else:
            raise InputError('Invalid selection for proteasome type: ' + self.proteasome_type)
        matfile = open(mat_name, 'r')
        self.load_text_file(matfile)

        results = {}
        for allele_name, binding_length in self.allele_length_list:
            prot_offset = binding_length - 6
            for sequence in sequence_list:
                scores = []
                proteasome_scores = []
                for cleavage_pos in range(5, len(sequence)):
                    start = cleavage_pos - 5
                    end = min(cleavage_pos + 5, len(sequence))
                    score = self.offset
                    for pos in range(start, end):
                        score += self.mat[sequence[pos]][pos - start]
                    scores.append(score)

                for nterm in range(len(sequence) - (binding_length - 1)):
                    proteasome_score = scores[nterm + prot_offset]
                    proteasome_scores.append(proteasome_score)

                prediction_input = MHCIPredictionParams(
                    method_name, sequence, allele_name, binding_length
                )
                results.update({prediction_input: proteasome_scores})
        results = OrderedDict(sorted(results.items(), key=operator.itemgetter(0)))
        return results

    def calculate_tap_score(self, sequence_list):
        """ | *author*: Dorjee
            | *created*: 04-25-2017
            | *brief*: Method to perform TAP prediction.
        """
        method_name = "tap"
        mat_name = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mat-tap.txt")
        matfile = open(mat_name, 'r')
        self.load_text_file(matfile)

        results = {}
        for allele_name, binding_length in self.allele_length_list:
            for sequence in sequence_list:
                scores = self.score_tap(sequence, binding_length)
                prediction_input = MHCIPredictionParams(
                    method_name, sequence, allele_name, binding_length
                )
                results.update({prediction_input: scores})
        results = OrderedDict(sorted(results.items(), key=operator.itemgetter(0)))
        return results

    def score_tap(self, sequence, length):
        result = []
        for nterm in range(len(sequence) - (length - 1)):
            cterm = nterm + length - 1
            score = self.offset
            score += self.mat[sequence[cterm]][8]
            tap_n = 0.0
            for l in range(self.tap_precursor + 1):
                n = nterm - l
                if n < 0:
                    break
                tap_n += self.mat[sequence[n + 0]][0]
                tap_n += self.mat[sequence[n + 1]][1]
                tap_n += self.mat[sequence[n + 2]][2]
            tap_n /= l + 1.0
            score += tap_n * self.tap_alpha
            result.append(score)
        return result

    def load_text_file(self, infile):
        lines = infile.readlines()
        self.mat.clear()
        self.length = int(lines[0].split()[1])
        for line in lines[1:21]:
            entries = line.split()
            numbers = []
            for e in entries[1:]:
                numbers.append(float(e))
            if len(numbers) != self.length:
                raise InputError("Invalid number of columns in SMM matrix: %s, expected: %s." % (str(len(numbers)), str(self.length)))
            self.mat[line[0]] = tuple(numbers)
        self.offset = float(lines[21])


class InputError(Exception):
    """Exception raised for errors in the input."""
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return self.value

def get_peptide_list(sequence, length):
    peptide_list = []
    for i in range(len(sequence)-length+1):
        peptide = sequence[i:i+length]
        peptide_list.append(peptide)
    return peptide_list

def mhc_binding_predict(mhc_binding_method, sequence_list, allele_list, lengths_list):
    # percentile data could be retrieved here if needed
    to_delete = []
    seq_file_type = 'peptides'
    ic50_result = {}
    for allele in allele_list:
        for sequence in sequence_list:
            for length in lengths_list:
                peptide_list = get_peptide_list(sequence, length)
                with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_peptides_file:
                    fname = tmp_peptides_file.name
                    to_delete.append(fname)
                    tmp_peptides_file.write('\n'.join(peptide_list))

                mhci_predictor = MHCIPredictor(mhc_binding_method)
                result = mhci_predictor.predict(allele, "", fname, seq_file_type)
                if type(result) == tuple and len(result) == 2:
                    result, distances = result
                header = result[0]
                columns = list(zip(*result[1:]))
                for i in range(len(header)):
                    if header[i].lower() == 'ic50':
                        ic50_scores = columns[i]
                    if header[i] == 'peptide':
                        peptides = columns[i]
                peptide_ic50_dict = dict(zip(peptides, ic50_scores))
                # return to original order
                ic50_scores = [peptide_ic50_dict[peptide] for peptide in peptide_list]
                key = MHCIPredictionParams("processing", sequence, allele, length)
                ic50_result[key] = ic50_scores
    return ic50_result

def get_processing_socres(mhc_binding_method, sequence_list, allele_list, lengths_list, proteasome='immuno', tap_precursor=1, tap_alpha=0.2, hla_seq='', hla_len=''):

    ic50_result = mhc_binding_predict(mhc_binding_method, sequence_list, allele_list, lengths_list)
	# tap_precursor, tap_alpha

    input_options = {
        "tap_precursor": tap_precursor,
        "tap_alpha": tap_alpha,
        "proteasome": proteasome,
        "hla_seq": hla_seq,
        "hla_len": hla_len
    }
    input_sequences = sequence_list
    allele_lengths_list = []
    for allele in allele_list:
        for length in lengths_list:
            allele_lengths_list.append((allele,length))

    processing_predictor = ProcessingPredictor(input_options, allele_lengths_list)
    mhc_result = processing_predictor.calculate_mhc_score(ic50_result)
    tap_result = processing_predictor.calculate_tap_score(input_sequences)
    proteasome_result = processing_predictor.calculate_proteasome_score(input_sequences)
    total_result = processing_predictor.calculate_total_score(mhc_result, tap_result, proteasome_result)
    processing_result = processing_predictor.calculate_processing_score(tap_result, proteasome_result)

    output_dict = {
        'ic50_result': ic50_result,
        'mhc_result': mhc_result,
        'tap_result': tap_result,
        'proteasome_result': proteasome_result,
        'total_result': total_result,
        'processing_result': processing_result,
    }

    return output_dict


def fasta_to_sequence_list(fasta_file):
    sequence_list = []
    with open(fasta_file, "r") as infile:
        file_content = infile.read()
        if '>' not in file_content:
            raise ValueError('Expected ">" not found. Input file must be in fasta format.')

        sequences = file_content.split('>')
        for s_raw in sequences:
            s = s_raw.strip()
            if len(s) == 0: continue
            end_of_name = s.find('\n')                
            sequence_name = s[:end_of_name].strip()
            if not sequence_name:
                raise ValueError("No sequence_name. Please check your fasta file.")
            seq_blocks = s[end_of_name:].split()
            sequence = ''.join(seq_blocks)
            input_sequence = ">{}\n{}".format(sequence_name, sequence)
            sequence_list.append(sequence)
    return sequence_list

def transfer_result_to_table(sequence_list, result):   

    header = ["sequence_number", "start", "peptide", "allele", 'proteasome_score', 'tap_score', 'mhc_score', 'processing_score', 'total_score',]
    output_table = []
    output_table.append(header)
    sequence_number_dict = dict(map(reversed, enumerate(sequence_list)))

    ic50_score_dict = result.get('ic50_result')
    mhc_score_dict = result.get('mhc_result')
    tap_score_dict = result.get('tap_result')
    proteasome_score_dict = result.get('proteasome_result')
    total_score_dict = result.get('total_result')
    processing_score_dict = result.get('processing_result')

    for key,value in ic50_score_dict.items():
        sequence = key.sequence
        sequence_number = sequence_number_dict[sequence]+1
        allele = key.allele
        length = key.binding_length
        # TODO: check if all scores are in original order (input peptides order)
        for i in range(len(value)):
            start = i+1
            peptide = sequence[i:i+length]
            ic50 = value[i]
            key = key._replace(method_name = 'proteasome')
            proteasome_score = proteasome_score_dict[key][i]
            key = key._replace(method_name = 'tap')
            tap_score = tap_score_dict[key][i]
            key = key._replace(method_name = 'mhc')
            mhc_score = mhc_score_dict[key][i]
            key = key._replace(method_name = 'processing')
            processing_score = processing_score_dict[key][i]
            key = key._replace(method_name = 'total')
            total_score = total_score_dict[key][i]
            output_table.append([sequence_number, start, peptide, allele, proteasome_score, tap_score, mhc_score, processing_score, total_score])

    return output_table


def processing_predict(predictor, fasta_file, alleles, peptide_length_range):
    if predictor.get('type') != 'processing':
        return InputError('predictor type must be "processing" not %s' % predictor.get('type'))

    proteasome = predictor.get('proteasome')
    tap_precursor = predictor.get('tap_precursor')
    tap_alpha = predictor.get('tap_alpha')
    mhc_binding_method = predictor.get('mhc_binding_method')
    sequence_list = fasta_to_sequence_list(fasta_file)
    lengths_list = list(range(peptide_length_range[0], peptide_length_range[1]+1))
    allele_list = alleles.split(',')

    header = ["sequence_number", "start", "peptide", "allele", 'proteasome_score', 'tap_score', 'mhc_score', 'processing_score', 'total_score',]
    result = get_processing_socres(mhc_binding_method, sequence_list, allele_list, lengths_list, proteasome, tap_precursor, tap_alpha)

    output_table = transfer_result_to_table(sequence_list, result)
    return output_table
         


if __name__ == '__main__':
    # for test purpose
    print('main')
    ic50_result = {MHCIPredictionParams("netmhcpan", "TMDKSELVQK", "HLA-A*01:01", 10): [1111]   }
    input_sequences = ["TMDKSELVQK", ]
    allele_lengths_list = [("HLA-A*01:01", 10),]
    result = get_processing_socres(ic50_result, input_sequences, allele_lengths_list)
    print(result)

    predictor =     {
        "type": "processing",
        "proteasome": "immuno",
        "tap_precursor": 1,
        "tap_alpha": 0.2
      }
    input_ic50 = {
      "warnings": [],
      "additional_outputs": {
        "binding.netmhcpan_ba": {
          "allele_distances": {
            "HLA-A*02:01": {
              "HLA-A*02:01": "0.000"
            },
            "HLA-A*01:01": {
              "HLA-A*01:01": "0.000"
            }
          }
        }
      },
      "header": [
        "allele",
        "peptide",
        "core",
        "icore",
        "ic50",
        "rank"
      ],
      "output_table": [
        [
          "HLA-A*01:01",
          "TMDKSELVQK",
          "TMDKSELQK",
          "TMDKSELVQK",
          11772.61,
          3.2
        ],
        [
          "HLA-A*01:01",
          "KMKGDYFRYF",
          "KMKGDYFYF",
          "KMKGDYFRYF",
          23695.75,
          13
        ],
        [
          "HLA-A*01:01",
          "EILNSPEKAC",
          "EILNSPEKC",
          "EILNSPEKAC",
          42395.54,
          87
        ]
      ]
    }

    result = processing_predict(predictor, input_ic50)
    print(result)



