#! /usr/bin/env python

import sys, os, math

NetMHCcons = os.environ["NETMHCcons"]

from optparse import OptionParser

sys.path.append(NetMHCcons+'/bin/python/') ### add local lib to lib path

import sequences ### use local module sequences

# Build commandline parser
parser = OptionParser(usage="usage: %prog [options] FILE", version="1.0")

parser.add_option("-m", "--matfile", type="string", dest="matfilename", metavar="STR",
                  help="substitution matrix: /home/projects/projects/vaccine/data/matrices/BLOSUM62")

parser.add_option("-a", "--alleles", type="string", dest="alleles", metavar="STR",
                  help="',' -separated list of alleles to check if FILE is not used")

parser.add_option("-r", "--reffile", type="string", dest="reffilename", metavar="STR",
                  help="Hit search file: /home/projects/projects/vaccine/data/pseudoseqs/HLA_AB_refNN.pseudoseqs")

parser.add_option("-p", "--psseqfile", type="string", dest="psseqfilename", metavar="STR",
                  help="Pseudo sequence lookup file: /home/projects/projects/vaccine/data/pseudoseqs/HLA_AB.pseudoseqs")

parser.add_option("-n", "--nearest", type="int", dest="nnearest", metavar="NUM",
                  help="number of nearest alleles to return: 1")

parser.add_option("-s", "--symdist", action="store_true", dest="sdist",
                  help="symmetric distance on: False")

parser.add_option("-S", "--self", action="store_true", dest="ego",
                  help="returning self if exist: False")


parser.set_defaults(matfilename='/home/projects/projects/vaccine/data/matrices/BLOSUM62', reffilename='/home/projects/projects/vaccine/data/pseudoseqs/HLA_AB_refNN.pseudoseqs', psseqfilename='/home/projects/projects/vaccine/data/pseudoseqs/HLA_AB.pseudoseqs', nnearest=1, sdist=False, ego=False, alleles="")

(opts, args) = parser.parse_args()

if len(args)==0:
    infile = sys.stdin
else:
    infile = open(args[0])

def readsubmat(blfilename):
    try:
        blfile=open(blfilename, 'r')
    except:
        sys.stderr.write('Could not open file '+blfilename+'\n')
        sys.exit()
    
    blmat={}
    for line in blfile:
        if line[0:1] != "#":
            fields=line.split()
            if not blmat.has_key('A'):
                order=fields
                for aa in order:
                    blmat[aa]={}
            else:
                aa1=fields[0]
                for pos in range(1,len(fields)):
                    aa2=order[pos-1]
                    blmat[aa1][aa2]=int(fields[pos])
    blfile.close()
    return(blmat)



def getdist(seq1, seq2, mat, sym):
    #print "# ", seq1, "# # ", seq2
    dst=0
    slfdist1=0
    slfdist2=0
    for pos in range(len(seq1)):
        aa1=seq1[pos:pos+1]
        aa2=seq2[pos:pos+1]
        dst=dst+mat[aa1][aa2]
        slfdist1=slfdist1+mat[aa1][aa1]
        slfdist2=slfdist2+mat[aa2][aa2]
        #        sys.stderr.write("\t%d:\t%s\t%s\t%d\t%d" % (pos, aa1, aa2, mat[aa1][aa2], mat[aa1][aa1]))
        #    sys.stderr.write("\n###\t%d\t%d\n" % (dst, slfdist))
    if sym:
        distance = 1.0 - dst/math.sqrt(slfdist1*slfdist2)
    else:
        distance =  1.0-(float(dst)/float(slfdist1))

    return(distance)


def readrefs(file, alldict):
    dict={}
    for line in file:
        fields=line.split()
        if not fields[0][0] == "#":
            dict[fields[0]]=alldict[fields[0]]
    file.close()
        
    return(dict)

def findbest(dict, self, seq, mat, symmetric, nhits, useself):
    mindist=2.0
    besthit=''
    hits = []
    for allele in dict:
        if (not allele==self) or (useself):
            dist=getdist(seq, dict[allele], mat, symmetric)
            #        sys.stderr.write("\t%s %f\n " % (allele, dist))
            if dist<mindist:
                if len(hits)<nhits:
                    hits.append((dist, allele))
                    hits.sort()
                    mindist = hits[-1][0]
                else:
                    hits.append((dist, allele))
                    hits.sort()
                    hits = hits[:nhits]
                    mindist = hits[-1][0]
                    
                    
                    
    return(hits)

def readpsseqs(file):
    dict={}
    for line in file:
        fields=line.split()
        if not fields[0][0] == "#":
            dict[fields[0]]=fields[1]
    file.close()
    return(dict)





reffilename=opts.reffilename

print "#\t"+reffilename
                                
matfilename = opts.matfilename
allpsseqs   = opts.psseqfilename
submat      = readsubmat(matfilename)




psfile=open(allpsseqs, 'r')
psdict=readpsseqs(psfile)


    
reffile=open(reffilename, 'r')
refdict=readrefs(reffile, psdict)

if opts.alleles=="":
    for line in infile:
        fields=line.split()
        testhla=fields[0]
    #    sys.stderr.write("# %s\n" % testhla)
        try:
            testseq=psdict[testhla]
        except:
            #sys.stderr.write('# The allele %s does not exist in the file %s\n# Using input %s as a pseudo sequence\n\n' % (testhla, allpsseqs, testhla))
            testseq=testhla
        distlist=findbest(refdict, testhla, testseq, submat, opts.sdist, opts.nnearest, opts.ego)
        for hits in distlist:
            sys.stdout.write("%s\t%s\t%.3f\n" % (testhla, hits[1], hits[0]))
else:
    hlas = opts.alleles.split(',')
    for testhla in hlas:
        try:
            testseq=psdict[testhla]
        except:
            sys.stderr.write('# The allele %s does not exist in the file %s\n# Using input %s as a pseudo sequence\n\n' % (testhla, allpsseqs, testhla))
            testseq=testhla
        distlist=findbest(refdict, testhla, testseq, submat, opts.sdist, opts.nnearest, opts.ego)
        for hits in distlist:
            sys.stdout.write("%s\t%s\t%.3f\n" % (testhla, hits[1], hits[0]))









