#!/usr/bin/env python
from __future__ import print_function
import os
import sys
import re
import tempfile
import logging

# adding all methods to the python path
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(script_dir )


from netmhcpan_4_1_executable import is_user_defined_allele
from allele_validator import Allele_Validator

from PercentilesCalculators import MHCIPercentilesCalculator
from util import InputError, UnexpectedInputError, PredictorError, PeptideSequenceInput, get_species, InputData, get_peptides, MethodSet, get_mhc_list

from functools import reduce

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

def get_percentile_for_score(score, allele, peptide, method, score_distributions=None) :
    ''' Returns the percentile scores for the raw scores passed. 
    '''
    if method == 'ann' or method == 'comblib_sidney2008':
        allele = allele.replace("*","")
    
    percentiles_calculator = MHCIPercentilesCalculator(score_distributions)

    binding_length = len(peptide)
    try:
        percentile = percentiles_calculator.get_percentile_scores(
                    [score,], method, allele, binding_length)[0]
    except ValueError:
        raise
    return percentile

def get_percentiles_for_scores(raw_scores, alleles, peptides, method_name, score_distributions=None):
    ''' Returns the percentile scores for the raw scores passed. 
    '''
    percentiles = []
    percentiles_calculator = MHCIPercentilesCalculator(score_distributions)
    for score, allele, peptide in zip(raw_scores, alleles, peptides):
        binding_length = len(peptide)
        try:
            percentile = percentiles_calculator.get_percentile_scores(
                        [score,], method_name, allele, binding_length)[0]
        except ValueError:
            raise
        percentiles.append(percentile)
    return percentiles

def read_peptides(fname):
    with open(fname, 'r') as r_file:
        peptides = r_file.readlines()
        peptides = [row.strip() for row in peptides if row.strip()]
        return peptides

def group_peptides_by_length(peptide_list):
    peptide_groups_by_length = []
    lengths = set(map(len, peptide_list))
    for length in lengths:
        peptide_groups_by_length.append([pep for pep in peptide_list if len(pep) == length])
    return peptide_groups_by_length

def update_allele_name_to_iedblable(distances, method):
    '''
    Example:
    {
        "HLA-A02:01": [
          "HLA-A02:01",
          "0.000"
        ]
    }
    '''
    validator = Allele_Validator()
    updated_distances = {}
    for k, v in distances.items():
        allele = validator.convert_synonym_to_iedblabel(k)
        # convert distance to numeric
        v = (v[0], float(v[1]))
        # don't convert allele name if can't find synonym
        # or allele == nan
        if not allele or allele!=allele:
            allele = k
        closest_allele = validator.convert_synonym_to_iedblabel(v[0])
        if not closest_allele or closest_allele!=closest_allele:
            updated_distances[allele] = v
        else:
            updated_distances[allele] = (closest_allele, v[1])

    return updated_distances


class NetMHCpanPredictor():
    validator = Allele_Validator()

    def is_valid(self, **kwargs):
        #from netmhcpan_4_1_executable import validation
        #validation(**kwargs)
        #return validation
        return None

    def predict(self, args):

        # TODO: consider removing unnecessary arguments (e.g., input_length, seq_file_type)
        method, input_allele, input_length, fname, seq_file_type = args
        allele_list = input_allele.split(",")

        if method == 'netmhcpan_el':
            from mhci_netmhcpan_4_1_el_percentile_data import score_distributions
            method_name = 'netmhcpan_el'
            score_unit = 'score'
            el = True
        else:
            from mhci_netmhcpan_4_1_ba_percentile_data import score_distributions
            method_name = 'netmhcpan_ba'
            score_unit = 'ic50'
            el = False

        from netmhcpan_4_1_executable import predict_many_peptides_file
        # TODO use allele-validator instead
        allele_list = self.validator.convert_iedblabel_to_methodlabel(iedb_labels=allele_list, method=method)
        #allele_list = [a.replace('H2', 'H-2') for a in allele_list]
        scores_by_peptides, distances = predict_many_peptides_file(fname, allele_list, el=el, with_distance_info=True)
        distances = update_allele_name_to_iedblable(distances, method)
        # transfer value fron list to dict. for example:  ["HLA-A02:01", "0.000"]  -->  {"HLA-A02:01": "0.000"}
        distances = {k:dict((v,)) for k,v in distances.items()}
        #print('scores_by_peptides: %s' % scores_by_peptides)
        alleles, peptides = zip(*scores_by_peptides.keys())
        # TODO: invesgate why there could be 6 columns
        scores, cores, icores = list(zip(*scores_by_peptides.values()))[:3]
        percentiles = get_percentiles_for_scores(scores, alleles, peptides, method_name, score_distributions)
        
        # TODO: Move this upstream into the sequence validation
        # DNA_sequence_input = False # Changed to that any of the input sequence is DNA means True
        #alleles = [a.replace('H-2', 'H2') for a in alleles]
        iedblabels = []
        iedblabel_map = {}
        for allele in alleles:
            iedblabel = iedblabel_map.get(allele, None)
            if iedblabel:
                iedblabels.append(iedblabel)
            else:
                iedblabel = self.validator.convert_methodlabel_to_iedblabel(allele, method)
                iedblabel_map[allele] = iedblabel
                iedblabels.append(iedblabel)
        combined_table_rows = list(zip(iedblabels,peptides,cores,icores,scores,percentiles))

        header = ('allele','peptide','core','icore',score_unit,'percentile')
        
        # TODO: Let's change the default sorting to use 'percentile' throughout
        combined_table_rows.sort(key=lambda tup: tup[5])

        #print('\t'.join(header))
        combined_table_rows.insert(0, header)

        return combined_table_rows, distances



class SMMPredictor():
    def predict(self, args):

        """ Make predictions given user provided list of sequences. The input sequence is in peptides format. """
        (method, input_allele, input_length, fname, seq_file_type) = args
        from smm_predictor import SMMMatrix, single_prediction_smm, single_prediction_smmpmbec
        from mhci_smm_percentile_data import score_distributions
        allele_list = input_allele.split(",")
        validator = Allele_Validator()
        if method == 'smm' or method == 'smmpmbec':
            score_unit = 'ic50'

        combined_table_rows = []
        peptide_list = read_peptides(fname)
        if not peptide_list:
            raise ValueError('smm prediction requires peptides input')
        # get actual peptide length from peptide_list
        length_peptide_dict = dict()
        for peptide in peptide_list:
            peptide_length = len(peptide)
            length_peptide_dict.setdefault(peptide_length, []).append(peptide)

        for allele in allele_list:
            if not validator.validate_alleles(allele, method, tools_group='mhci'):
                continue
            converted_allele = validator.convert_iedblabel_to_methodlabel(allele, method, tools_group='mhci')
            for peptide_length, peptide_list in length_peptide_dict.items():
                if method == 'smm':
                    scores = single_prediction_smm(converted_allele, peptide_length, peptide_list)
                elif method == 'smmpmbec':
                    scores = single_prediction_smmpmbec(converted_allele, peptide_length, peptide_list)
                for peptide, score in zip(peptide_list, scores):
                    score = float(score)
                    percentile = get_percentile_for_score(score, converted_allele, peptide, 'smm', score_distributions)
                    combined_table_rows.append((allele, peptide, score, percentile))

        header = ('allele','peptide', score_unit,'percentile')
        combined_table_rows.sort(key=lambda tup: tup[2])
        combined_table_rows.insert(0, header)

        return combined_table_rows

    def is_valid(self, **kwargs):
        return None


class ComblibPredictor():

    def predict(self, args):
        """ Make predictions given user provided list of sequences. The input sequence is in peptides format. """
        (method, input_allele, input_length, fname, seq_file_type) = args
        from mhci_comblib_predictor import single_prediction_comblib_peptides
        from mhci_comblib_sidney2008_percentile_data import score_distributions
        allele_list = input_allele.split(",")
        validator = Allele_Validator()
        score_unit = 'score'
        method = 'comblib_sidney2008'
        peptide_list = read_peptides(fname)
        # only predict for 9mer
        peptide_list = [p for p in peptide_list if len(p) == 9]
        combined_table_rows = []
        for allele in allele_list:
            if not validator.validate_alleles(allele, method, tools_group='mhci'):
                continue
            converted_allele = validator.convert_iedblabel_to_methodlabel(allele, method, tools_group='mhci')
            
            scores = single_prediction_comblib_peptides(converted_allele, 9, peptide_list)
            for peptide, score in zip(peptide_list, scores):
                score = float(score)
                percentile = get_percentile_for_score(score, converted_allele, peptide, method, score_distributions)
                combined_table_rows.append((allele, peptide, score, percentile))

        header = ('allele','peptide', score_unit,'percentile')
        combined_table_rows.sort(key=lambda tup: tup[2])            
        combined_table_rows.insert(0, header)

        return combined_table_rows

    def is_valid(self, **kwargs):
        return None

class MHCFlurryPredictor():

    def is_valid(self, **kwargs):
        return True

    def predict(self, args):
        """ Make predictions given user provided list of sequences. The input sequence is in peptides format. """
        (method, input_allele, input_length, fname, seq_file_type) = args
        score_unit = 'ic50'

        peptide_list = read_peptides(fname)
        for peptide in peptide_list:
            for amino_acid in peptide:
                if not amino_acid.upper() in "ACDEFGHIKLMNPQRSTVWY":
                    sys.stderr.write("Sequence: '%s' contains an invalid character: '%c' at position %d.\n" % (sequences, amino_acid, sequences.find(amino_acid)))
                    exit(1)

        allele_list = input_allele.split(",")
        allele_list = [allele.strip().replace('*','').replace(':','') for allele in allele_list]

        from mhcflurry_predictor import mhcflurry_predict

        result = mhcflurry_predict(peptide_list, allele_list)

        # print_result_in_tsv
        if not result:
            return

        return result


class ANNPredictor():


    def is_valid(self, **kwargs):
        return None

    def predict(self, args):
        
        validator = Allele_Validator()
        method, input_allele, input_length, fname, seq_file_type = args

        from mhci_ann_predictor_percentile_data import score_distributions
        from netmhc_4_0_executable import pep_score_predict_from_peptide_file

        allele_list = input_allele.split(",")

        if not method == 'ann':
            raise ValueError('Method: %s is not ANN' % method)

        score_unit = 'ic50'

        combined_table_rows = []
        for allele in allele_list:
            if not validator.validate_alleles(allele, method, tools_group='mhci'):
                continue
            converted_allele = validator.convert_iedblabel_to_methodlabel(allele, method, tools_group='mhci')
            peptide_score_tuple = pep_score_predict_from_peptide_file(converted_allele, 9, fname)
            for peptide, score in peptide_score_tuple:
                score = float(score)
                percentile = get_percentile_for_score(score, converted_allele, peptide, method, score_distributions)
                combined_table_rows.append((allele, peptide, score, percentile))


        header = ('allele','peptide', score_unit,'percentile')
        combined_table_rows.sort(key=lambda tup: tup[2])
        combined_table_rows.insert(0, header)

        return combined_table_rows

class MHCIPredictor():

    def __init__(self, method):
        '''instantiated based on method'''
        self.method = method
        self.predictor = self.get_predictor(method)
        self.version = '20130222'
        self.row_data = []

    def get_predictor(self, method):
        '''get predictor based on method name '''
        method = method.replace('netmhcpan_ba', 'netmhcpan').replace('IEDB_recommended', 'netmhcpan_el')

        if method == 'smm' or method == 'smmpmbec':
            return SMMPredictor()
        elif method == 'mhcflurry':
            return MHCFlurryPredictor()
        elif method in ['netmhcpan_ba', 'netmhcpan', 'netmhcpan_el']: 
            return NetMHCpanPredictor()
        elif method == 'ann': 
            return ANNPredictor()
        elif method == 'comblib_sidney2008' or method == 'comblib': 
            return ComblibPredictor()
        return None

    def is_valid(self, input_allele, input_length, fname, seq_file_type='peptide'):
        '''run validation'''
        error = self.is_valid_file(fname, seq_file_type)
        if error: 
            return error
        else:
            return self.predictor.is_valid(input_allele=input_allele, input_length=input_length, fname=fname)


    def is_valid_file(self, fname, seq_file_type='peptide'):
        """ Check if arg is a valid file that already exists on the file system. """

        file_path = os.path.abspath(fname)
        if not os.path.exists(file_path):
            return "The file {} does not exist!\n".format(fname)
        errors = []
        if seq_file_type == 'peptide':
            peptides = read_peptides(fname)
            for peptide in peptides:
                for amino_acid in peptide:
                    if amino_acid.upper() not in "ACDEFGHIKLMNPQRSTVWY":
                        errors.append("""Invalid character '%c' in sequence '%s'.""" % (amino_acid, peptide))
        elif seq_file_type == 'fasta':
            pass

        if errors:
            return ', '.join(errors)

        return None


    def predict(self, input_allele, input_length, fname, seq_file_type):
        '''run prediction'''
        input_args = (self.method, input_allele, input_length, fname, seq_file_type)
        #print(self.predictor)
        if self.predictor:
            return self.predictor.predict(input_args)
        else:
            return self.predict_others(input_args)


    @staticmethod
    def read_protein(fname):
        file_contents = open(fname, 'r').read()
        protein = PeptideSequenceInput(file_contents)
        return protein

    @staticmethod
    def insert_dash(method, actual_methods_used, score_list):
        scores = []
        consensus_methods = ['ann', 'smm', 'comblib_sidney2008']
        m = map(lambda v: v in actual_methods_used, consensus_methods)
        m = list(m)
        dashes = ['-', '-']
        for score in map(list, score_list):
            if not m[0]:
                score.extend(dashes)
            if not m[1]:
                score.extend(dashes)
            if not m[2]:
                score.extend(dashes)
            if method == 'IEDB_recommended':
                score.extend(dashes)
            scores.append(tuple(score))
        return scores


    def predict_others(self, args):
        """ Make predictions given user provided list of sequences. The input sequence is in fasta format. """
        (method, input_allele, input_length, fname) = args

        alleles = input_allele.split(",")
        lengths = input_length.split(",")
        DNA_sequence_input = False # Changed to that any of the input sequence is DNA means True

        # check if number of alleles and lengths are same
        if len(alleles) != len(lengths):
            sys.stderr.write("ERROR: Number of alleles and corresponding lengths are not equal!\n")
            exit(1)

        self.is_valid_file(fname)

        species = [get_species(allele) for allele in alleles]
        negative_inputs = self.check_for_negative_inputs(method, alleles, lengths, species) if not method == "IEDB_recommended" else []
        if negative_inputs:
            for negative_input in negative_inputs:
                allele, length, species = negative_input
                sys.stderr.write("ERROR: length '{}' for allele '{}' doesn't exist!\n".format(length, allele))
                exit(1)

        combined_table_rows = []
        peptide_list = read_peptides(fname)
        for peptide in peptide_list:
            for amino_acid in peptide:
                if not amino_acid.upper() in "ACDEFGHIKLMNPQRSTVWY":
                    sys.stderr.write("Sequence: '%s' contains an invalid character: '%c' at position %d.\n" % (sequences, amino_acid, sequences.find(amino_acid)))
                    exit(1)
            # Check if string is DNA sequence
            if DNA_sequence_input or re.match('^[ACGT]+$', peptide.upper()):
                DNA_sequence_input = True
            else:
                DNA_sequence_input = False

        for peptides in group_peptides_by_length(peptide_list):
            length = str(len(peptides[0]))
            if method == 'comblib_sidney2008' and length != '9':
                continue
            hla_seq = ''
            proteins = PeptideSequenceInput('\n'.join(peptides))           

            for allele in alleles:
                use_cutoff = cutoff_value = None
                input = InputData(self.version, method, allele, hla_seq, length, proteins, species)
                mhc_predictor = MHCBindingPredictions(input)
                mhc_scores = mhc_predictor.predict(input.input_protein.as_amino_acid_text())
                logging.debug('mhc_scores:%s' % str(mhc_scores))
                table_rows = self.format_binding(input, mhc_scores, method, use_cutoff, cutoff_value)
                logging.debug('table_rows:%s' % str(table_rows))
                method_used = ','.join(mhc_predictor.get_method_set_selected(method))
                table_rows.sort(key=lambda tup: tup[6])
                table_rows = self.add_method_used(table_rows, method)
                logging.debug('table_rows:%s' % str(table_rows))
        combined_table_rows = table_rows
        combined_table_rows.sort(key=lambda tup: tup[2])
        logging.debug('combined_table_rows:%s' % str(combined_table_rows))
        # headers for different methods
        if method == 'IEDB_recommended':
            header = ('allele','seq_num','start','end','length','peptide','method',mhc_predictor.get_score_unit(),'ann_ic50','ann_percentile','smm_ic50','smm_percentile','comblib_sidney2008_score','comblib_sidney2008_percentile','netmhcpan_el_score','netmhcpan_percentile')
            combined_table_rows.sort(key=lambda tup: tup[7])
        elif method == 'consensus':
            header = ('allele','seq_num','start','end','length','peptide',mhc_predictor.get_score_unit(),'ann_ic50','ann_percentile','smm_ic50','smm_percentile','comblib_sidney2008_score','comblib_sidney2008_percentile')
            combined_table_rows.sort(key=lambda tup: tup[6])
            header = header[:1]+header[5:]
            combined_table_rows = [row[:1]+row[5:] for row in combined_table_rows]
        elif method in ['netmhcpan','netmhcpan_el', 'netmhcpan_ba']:
            header = ('allele','seq_num','start','end','length','peptide','core','icore',mhc_predictor.get_score_unit(),'percentile')
            combined_table_rows.sort(key=lambda tup: tup[8])            
        else:
            header = ('allele','seq_num','start','end','length','peptide',mhc_predictor.get_score_unit(),'percentile')
            combined_table_rows.sort(key=lambda tup: tup[6])
            header = header[:1]+header[5:]
            combined_table_rows = [row[:1]+row[5:] for row in combined_table_rows]
        print('\t'.join(header))

        if method in [ 'netmhcpan_el' ]:
            combined_table_rows.reverse()
        return combined_table_rows


    def modify(self, lst):
        return[tuple(self.flatten(x)) for x in lst]

    @staticmethod
    def flatten(tup):
        from itertools import chain
        return list(chain(*(i if isinstance(i, tuple) else (i,) for i in tup)))
 
    def format_binding(self, proteins, results, method, cutoff, value):
        for length, allele, score, methods in results:
            actual_methods_used = methods.split(",")
            if method == 'consensus' or method == 'IEDB_recommended':
                if any(m in actual_methods_used for m in ['ann', 'smm', 'comblib_sidney2008']):
                    logging.debug("any(m in actual_methods_used for m in ['ann', 'smm', 'comblib_sidney2008'])")
                    score_list = []
                    for s in score:

                        ranks_scores = reduce(lambda x, y: x + y, s[1])
                        logging.debug('ranks_scores: %s' % str(ranks_scores))
                        scores = list(zip(s[0], list(zip(*ranks_scores))))
                        logging.debug('s[0]:%s' % str(s[0]))
                        logging.debug('scores:%s' % scores)
                        scores = self.insert_dash(method, actual_methods_used, self.modify(scores))
                        score_list.append(scores)
                elif all(m in actual_methods_used for m in ['netmhcpan']):
                    score_list = []
                    for results in score:
                        # results is a tuple of the form:
                        #  (<score>, <percentile>) where <percentile tuple> and
                        # ic50 scores and their percentiles for a single sequence/allele-length
                        # prediction.
                        print(repr(results))
                        score_row = [
                            (p, '-', '-', '-', '-', '-', '-', s, p) for core, icore, s, p in results
                        ]
                        score_list.append(score_row)
                else:
                    logging.warning('actual_methods_used:%s' % actual_methods_used)
                self.add_rows_binding(allele, length, proteins, score_list, actual_methods_used, cutoff, value)
            else:
                self.add_rows_binding(allele, length, proteins, score, actual_methods_used, cutoff, value)
        return self.modify(self.row_data)

    def add_rows_binding(self, allele, length, proteins, score_list, actual_methods_used, cutoff, value):
        for (i,(sequence, predictions)) in enumerate(zip(proteins.input_protein.as_amino_acid_text(), score_list)):
            for (k, prediction) in enumerate(predictions):
                sequence_source = "%s" %(i+1)
                sequence_start = "%s" %(k + 1)
                sequence_end = "%s" %(k + int(length))
                scanned_sequence = sequence[k : k + length]
                self.row_data.append((allele, sequence_source, sequence_start, sequence_end, length, scanned_sequence, prediction, '-'.join(actual_methods_used)))

    @staticmethod
    def cons_netmhcpan(scores):
        score_list = []
        for score in scores:
            lis = list(score)
            del lis[-1]
            item2 = list(score[1])
            item2.append(score[0])
            lis.append(tuple(item2))
            score_list.append(tuple(lis))
        return score_list

    @staticmethod
    def add_method_used(table_rows, method):
        formated_data = []
        for row in table_rows:
            lis = list(row)
            if method == 'IEDB_recommended':
                if '-' not in lis[-1]:
                    lis.insert(6, lis[-1])
                else:
                    lis.insert(6, "Consensus ("+lis[-1].replace("-","/")+")")
                del lis[-1]
                formated_data.append(tuple(lis))
            else: 
                del lis[-1]
                formated_data.append(tuple(lis))
        return formated_data
    
    @staticmethod
    def check_fasta(fasta):
        seq_list = filter(None, [x.strip() for x in fasta.split('>')])
        seq_list = list(seq_list)
        if len(seq_list) > 1:
            print( "File must contain a single MHC sequence in fasta format.")
            sys.exit(1)
        
        for seq in seq_list:
            sequence = seq.split("\n")[1]
            for amino_acid in sequence.strip():
                if not amino_acid.upper() in "ACDEFGHIKLMNPQRSTVWY":
                    print( "Sequence: '%s' contains an invalid character: '%c' at position %d." %(sequence, amino_acid, sequence.find(amino_acid)))
                    sys.exit(1)

    def check_for_negative_inputs(self, _method_name, allele_list, length_list, species_list):
        negatives = []

        length_list = map(int, length_list)

        # user-defined alleles don't have length-species info like allele names
        if not any([is_user_defined_allele(a) for a in allele_list]):
            miad = MHCIAlleleData()
            for allele, length, species in zip(allele_list, length_list, species_list):
                length_list = miad.get_allowed_peptide_lengths(method_name=_method_name.replace('netmhcpan_el','netmhcpan'), allele_name=allele)
                if length not in length_list:
                    negatives.append((allele, length, species))
        return negatives

    @staticmethod
    def query_yes_no(question, default="yes"):
        """Ask a yes/no question via raw_input() and return their answer.
    
        "question" is a string that is presented to the user.
        "default" is the presumed answer if the user just hits <Enter>.
            It must be "yes" (the default), "no" or None (meaning
            an answer is required of the user).
    
        The "answer" return value is one of "yes" or "no".
        """
        valid = {"yes": True, "y": True,  "ye": True,
                 "no": False, "n": False}
        if default == None:
            prompt = " [y/n] "
        elif default == "yes":
            prompt = " [Y/n] "
        elif default == "no":
            prompt = " [y/N] "
        else:
            raise ValueError("invalid default answer: '%s'" % default)


