import os
import pickle
import math
package_name = 'smm_predictor'
data_dir = 'data'
data_dir_path = os.path.join(os.path.dirname(__file__), data_dir)



class SMMMatrix:
    """Can load and save SMM matrices in multiple formats and use them to score sequences """
    def __init__(self, method_name='smm', path_data=data_dir_path):
        self.offset = None
        self.length = None
        self.mat={}
        self.path_data = os.path.join(path_data, method_name)

    def initialize(self, mhc, length):
        self.mhc = mhc
        self.length = length
        self.path_model = self.get_path_model(self.path_data, mhc, length)
        self.pickle_load(self.path_model)

    def get_score_unit(self):
        """The unit of prediction scores"""
        return 'IC50 (nM)'

    def get_path_model(self, path_data, mhc, length):
        '''Used by ARB, SMM to read appropriate files containing trained models.'''
        model_name = mhc.replace('*','-').replace(' ','-').replace(':','') + '-' + str(length)
        path_model = os.path.join(path_data, model_name) + '.cpickle'  # HLA-A-0201-9
        return path_model


    def predict_sequence(self,sequence,pred):
        """Given one protein sequence, break it up into peptides, return their predicted binding scores."""
        peptide_list = get_peptides(sequence, self.length)
        scores = self.predict_peptide_list(peptide_list)
        
        #get percentile scores
        args = ('smm', self.mhc.replace("*",""), self.length)
        ps = PercentileScore(os.path.dirname(self.path_data), 'consensus', args)
        percentile = ps.get_percentile_score(scores)
        return zip(scores, percentile)

    def predict_peptides_file(self,fname,method):
        """Given one protein sequence, break it up into peptides, return their predicted binding scores."""
        peptide_list = self.read_peptides(fname)
        scores = self.predict_peptide_list(peptide_list)
        
        #get percentile scores
        args = ('smm', self.mhc.replace("*",""), self.length)
        ps = PercentileScore(os.path.dirname(self.path_data), 'consensus', args)
        percentile = ps.get_percentile_score(scores)
        return zip(peptide_list, scores, percentile)

    def predict_peptide_list(self, peptide_list):
        scores=[]
        for peptide in peptide_list:
            score=self.offset
            for pos in range(self.length):
                amino_acid=peptide[pos]
                try:
                    score+=self.mat[amino_acid][pos]
                except:
                    raise PredictorError("""Invalid character '%c' in sequence '%s'.""" % (amino_acid, peptide))
            score=math.pow(10,score)
            scores.append(score)
        return (tuple(scores))

    def load_text_file(self, infile):
        lines=infile.readlines()
        self.mat.clear()
        self.length=int(lines[0].split()[1])
        for line in lines[1:21]:
            entries = line.split()
            numbers = []
            for e in entries[1:]:
                numbers.append(float(e))
            if len(numbers)!=self.length:
                raise PredictorError("Invalid number of columns in SMM matrix: " + str(len(numbers)) + " expected: " + str(self.length) + ".")
            self.mat[line[0]]=tuple(numbers)
        self.offset = float(lines[21])

    def read_peptides(self, fname):
        with open(fname, 'r') as r_file:
            peptides = r_file.readlines()
            peptides = [row.strip() for row in peptides if row.strip()]
            return peptides

    def pickle_dump(self, file_name):
        fout = open(file_name,"wb")
        pickle.dump(self.length, fout)
        pickle.dump(self.mat,fout)
        pickle.dump(self.offset,fout)
        fout.close()

    def pickle_load(self, file_name):
        fin = open(file_name,"rb")
        self.length = pickle.load(fin)
        self.mat = pickle.load(fin)
        self.offset = pickle.load(fin)
        fin.close()

def single_prediction_smm(allele, length, peptide_list):
    method_name='smm'
    predictor = SMMMatrix(method_name=method_name)
    predictor.initialize(allele, length)
    result = predictor.predict_peptide_list(peptide_list)
    return result

def single_prediction_smmpmbec(allele, length, peptide_list):
    method_name='smmpmbec'
    predictor = SMMMatrix(method_name=method_name)
    predictor.initialize(allele, length)
    result = predictor.predict_peptide_list(peptide_list)
    return result