import os
import pickle
import math
import re
package_name = 'mhci_comblib_predictor'
data_dir = 'data'
data_dir_path = os.path.join(os.path.dirname(__file__), data_dir)


class CombinatorialLibrary:
    """Can load and save SMM matrices in multiple formats and use them to score sequences """
    def __init__(self, method_name='smm', path_data=data_dir_path):
        self.offset = None
        self.length = None
        self.mat={}
        self.path_data = os.path.join(path_data, method_name)

    def __init__(self, path_data=data_dir_path, lib_source='comblib_sidney2008'):
        self.dic_pssm = None
        self.offset = None
        self.length = None
        self.mat={}
        self.path_data = path_data
        self.lib_source = lib_source

    def initialize(self, mhc, length):
        self.dic_pssm = self.read_pssm_comblib(self.lib_source)
        self.mhc = mhc
        self.length = length
        if re.search('H-2.*', self.mhc):
            i = re.search(r'(?<=\d)', self.mhc).start()
            key = (self.mhc[:i]+self.mhc[i:].replace("-","_"), self.length)
        else:
            key = (self.mhc.replace('-','_').replace('*','-').replace(':',''), self.length)
        
        w = self.dic_pssm[key]
        (self.mat, self.offset) = self.get_dic_mat(w)

    def get_score_unit(self):
        """The unit of prediction scores"""
        return 'Score'

    def predict_sequence(self,sequence,pred):   
        """Given one protein sequence, break it up into peptides, return their predicted binding scores."""
        peptide_list = get_peptides(sequence, self.length)
        scores = self.predict_peptide_list(peptide_list)
        
        #get percentile scores
        args = ('comblib_sidney2008', self.mhc.replace("*",""), self.length)
        ps = PercentileScore(self.path_data, 'consensus', args)
        percentile = ps.get_percentile_score(scores)
        return list(zip(scores, percentile))

    def predict_peptide_list(self, peptide_list):
        scores = []
        for peptide in peptide_list:
            score = self.offset
            for pos in range(self.length):
                amino_acid = peptide[pos]
                try:
                    score += self.mat[amino_acid][pos]
                except:
                    raise PredictorError("""Invalid character '%c' in sequence '%s'.""" % (amino_acid, peptide))
            score = math.pow(10,score)
            scores.append(score)
        return (tuple(scores))

    def read_data_cpickle(self,fname):
        f = open(fname,'rb')
        data = pickle.load(f)
        f.close()
        return data

    def read_pssm_comblib(self, lib_source):
        'Reads in all available pssms derived from combinatorial libraries.'
        factor = 1.0 # This will be multipled to all matrix elements.
        fname_sidney2008 = os.path.join(self.path_data,'comblib_sidney2008','dic_pssm_sidney2008.cPickle')
#         fname_udaka2000 = os.path.join(self.path_data,'comblib_udaka2000','dic_pssm_udaka2000.cPickle')
        dic_pssm_sidney2008 = self.read_data_cpickle(fname_sidney2008)
#         dic_pssm_udaka2000 = self.read_data_cpickle(fname_udaka2000)
        dic_pssm = None
        if (lib_source == 'comblib_sidney2008'):
            factor = -1.0
            dic_pssm = dic_pssm_sidney2008
#         elif (lib_source == 'comblib_udaka2000'):
#             factor = 1.0
#             dic_pssm = dic_pssm_udaka2000

        key_list = dic_pssm.keys()
        for key in key_list:
            w = dic_pssm[key]
            w = [factor*val for val in w]
            dic_pssm[key] = w
        return dic_pssm

    def get_dic_mat(self, w):
        'Converts 1-dimensional vector into a dictionary of lists key = [aa]'
        offset = w[0]
        dic_mat = {}
        aa_list = "ACDEFGHIKLMNPQRSTVWY"
        for aa_index in range(len(aa_list)):
            aa = aa_list[aa_index]
            row = []
            for pos_index in range(self.length):
                index = 1 + 20*pos_index + aa_index
                value = w[index]
                row.append(value)
            dic_mat[aa] = row
        return (dic_mat, offset)

    def pickle_dump(self, file_name):
        fout = open(file_name,"wb")
        pickle.dump(self.length, fout)
        pickle.dump(self.mat,fout)
        pickle.dump(self.offset,fout)
        fout.close()

    def pickle_load(self, file_name):
        fin = open(file_name,"rb")
        self.length = pickle.load(fin)
        self.mat = pickle.load(fin)
        self.offset = pickle.load(fin)
        fin.close()

def single_prediction_comblib_peptides(allele, length, peptide_list):
    predictor = CombinatorialLibrary()
    predictor.initialize(allele, int(length))
    result = predictor.predict_peptide_list(peptide_list)
    return result
