#!/usr/bin/env python

'''
Created on 08.10.2015
@author: Dorjee Gyaltsen
'''

import os, sys
from optparse import OptionParser

immunogenicity_allele_dict = {"H-2-Db":"2,5,9","H-2-Dd":"2,3,5","H-2-Kb":"2,3,9","H-2-Kd":"2,5,9","H-2-Kk":"2,8,9","H-2-Ld":"2,5,9","HLA-A0101":"2,3,9","HLA-A0201":"1,2,9","HLA-A0202":"1,2,9","HLA-A0203":"1,2,9","HLA-A0206":"1,2,9","HLA-A0211":"1,2,9","HLA-A0301":"1,2,9","HLA-A1101":"1,2,9","HLA-A2301":"2,7,9","HLA-A2402":"2,7,9","HLA-A2601":"1,2,9","HLA-A2902":"2,7,9","HLA-A3001":"1,3,9","HLA-A3002":"2,7,9","HLA-A3101":"1,2,9","HLA-A3201":"1,2,9","HLA-A3301":"1,2,9","HLA-A6801":"1,2,9","HLA-A6802":"1,2,9","HLA-A6901":"1,2,9","HLA-B0702":"1,2,9","HLA-B0801":"2,5,9","HLA-B1501":"1,2,9","HLA-B1502":"1,2,9","HLA-B1801":"1,2,9","HLA-B2705":"2,3,9","HLA-B3501":"1,2,9","HLA-B3901":"1,2,9","HLA-B4001":"1,2,9","HLA-B4002":"1,2,9","HLA-B4402":"2,3,9","HLA-B4403":"2,3,9","HLA-B4501":"1,2,9","HLA-B4601":"1,2,9","HLA-B5101":"1,2,9","HLA-B5301":"1,2,9","HLA-B5401":"1,2,9","HLA-B5701":"1,2,9","HLA-B5801":"1,2,9"}

class ImmunogenicityPredictor():

    def isint(self, x):
        try:
            a = float(x)
            b = int(a)
        except ValueError: return False
        else: return a == b

    def get_position_to_mask(self, allele, mask_choice, position_to_mask=None):

        if mask_choice == 'default' or not mask_choice:
            position_to_mask = None
        elif mask_choice == 'by_allele':

            allele_dict = {"H-2-Db":"2,5,9","H-2-Dd":"2,3,5","H-2-Kb":"2,3,9","H-2-Kd":"2,5,9","H-2-Kk":"2,8,9","H-2-Ld":"2,5,9","HLA-A0101":"2,3,9","HLA-A0201":"1,2,9","HLA-A0202":"1,2,9","HLA-A0203":"1,2,9","HLA-A0206":"1,2,9","HLA-A0211":"1,2,9","HLA-A0301":"1,2,9","HLA-A1101":"1,2,9","HLA-A2301":"2,7,9","HLA-A2402":"2,7,9","HLA-A2601":"1,2,9","HLA-A2902":"2,7,9","HLA-A3001":"1,3,9","HLA-A3002":"2,7,9","HLA-A3101":"1,2,9","HLA-A3201":"1,2,9","HLA-A3301":"1,2,9","HLA-A6801":"1,2,9","HLA-A6802":"1,2,9","HLA-A6901":"1,2,9","HLA-B0702":"1,2,9","HLA-B0801":"2,5,9","HLA-B1501":"1,2,9","HLA-B1502":"1,2,9","HLA-B1801":"1,2,9","HLA-B2705":"2,3,9","HLA-B3501":"1,2,9","HLA-B3901":"1,2,9","HLA-B4001":"1,2,9","HLA-B4002":"1,2,9","HLA-B4402":"2,3,9","HLA-B4403":"2,3,9","HLA-B4501":"1,2,9","HLA-B4601":"1,2,9","HLA-B5101":"1,2,9","HLA-B5301":"1,2,9","HLA-B5401":"1,2,9","HLA-B5701":"1,2,9","HLA-B5801":"1,2,9"}

            allele = allele.replace("*","").replace(":","").replace('H2','H-2') if allele else None

            # Check if allele is included in the available alleles
            if allele in allele_dict:
                position_to_mask = allele_dict[allele]

            # Check if allele option is used and is in the available alleles
            else:
                # TODO: add warnings for this
                print("Allele {} is not available.".format(allele))
                position_to_mask = None
                
        return position_to_mask

    # works for multiple alleles input
    def predict(self, **kwargs):
        alleles = kwargs.get('input_allele')
        if not alleles:
            return self.predict_single(**kwargs)
        result_list = []
        for allele in [a.strip() for a in alleles.split(',')]:
            kwargs['input_allele'] = allele
            result = self.predict_single(**kwargs)
            if result:
                #exclude header
                result_list.extend(result[1:])
        # Sort by the last column value (score)
        result_list.sort(key=lambda tup: tup[-1], reverse=True)
        # add header
        header_list= ['peptide', 'allele', 'score']
        result_list.insert(0, header_list)
        return result_list

    # only for 1 allele input
    def predict_single(self, **kwargs):
        '''Returns the prediction result.'''
        method = kwargs.get('method')
        allele = kwargs.get('input_allele')
        fname = kwargs.get('fname')
        position_to_mask = kwargs.get('position_to_mask', '')
        # mask_choice would be in ['default', 'by_allele', 'custom']
        mask_choice = kwargs.get('mask_choice', 'default')

        sequence_text = open(fname, "r").read().split()
        for peptide in sequence_text:
            for amino_acid in peptide.strip():
                if not amino_acid.upper() in "ACDEFGHIKLMNPQRSTVWY":
                    print("Sequence: '%s' contains an invalid character: '%c' at position %d." %(peptide, amino_acid, peptide.find(amino_acid)))
                    sys.exit(1)

        # TODO: if position_to_mask == None, should the function use default or skip this prediciton (empty result)
        position_to_mask = self.get_position_to_mask(allele, mask_choice, position_to_mask)


        immunoscale = {"A":0.127, "C":-0.175, "D":0.072, "E":0.325, "F":0.380, "G":0.110, "H":0.105, "I":0.432, "K":-0.700, "L":-0.036, "M":-0.570, "N":-0.021, "P":-0.036, "Q":-0.376, "R":0.168, "S":-0.537, "T":0.126, "V":0.134, "W":0.719, "Y":-0.012}
        immunoweight = [0.00, 0.00, 0.10, 0.31, 0.30, 0.29, 0.26, 0.18, 0.00]

        result_list = []

        for pep in sequence_text:
            peptide = pep.upper()
            peplen = len(peptide)
            
            cterm = peplen - 1
            score = 0
            count = 0

            if position_to_mask == 'default' or not position_to_mask:
                mask_num  = [0, 1, cterm]
                mask_out = [1, 2, "cterm"]
            elif position_to_mask:
                try: 
                    mask_str = position_to_mask.split(",")
                    mask_num = list(map(int, mask_str))
                    mask_num = list(map(lambda x: x - 1, mask_num))
                    mask_out = list(map(lambda x: x + 1, mask_num))
                except IOError as e:
                    print ("I/O error({0}): {1}".format(e.errno, e.strerror))
            else:
                self.mask_num = []
                self.mask_out = [1,2, "cterm"]

            if peplen > 9:
                pepweight = immunoweight[:5] + ((peplen - 9) * [0.30]) + immunoweight[5:]
            else:
                pepweight = immunoweight
                
            try:
                for pos in peptide:
                    if pos not in immunoscale.keys():
                        raise KeyError()
                    elif count not in mask_num:
                        score += pepweight[count] * immunoscale[pos]
                        count += 1
                    else:
                        count += 1
                result_list.append([peptide, allele, round(score, 5)])

            except IOError as e:
                print("I/O error({0}): {1}".format(e.errno, e.strerror))
#                     shutil.rmtree(atemp_dir)  
#                     raise ("Error: Please make sure you are entering in correct amino acids.")
            except:
                print("Unexpected error:", sys.exc_info()[0])
                raise

        # Sort by the last column value (score)
        result_list.sort(key=lambda tup: tup[-1], reverse=True)

        # Column headers for the result
        header_list= ['peptide', 'allele', 'score']
        result_list.insert(0, header_list)

        return result_list
        
    def create_csv(self, mask_choice, mask_out, data):
        import csv 
        import tempfile

        tmpdir = './output'

        # Create a temporary file inside the tmp/ directory
        tmpfile = tempfile.NamedTemporaryFile(prefix="immunogenicity_", suffix=".csv", dir=tmpdir, delete=False)

        with open(tmpfile.name, 'wb') as result:
            writer = csv.writer(result, delimiter=',')
            data.insert(0, ['masking: ', '{0}'.format(mask_choice)])
            data.insert(1, ['masked variables: ', '{0}'.format(mask_out)])
            for score in data:
                writer.writerow(score)
        tmpfile.close()
        return tmpfile.name

    def is_valid(self, **kwargs):
        return True
