import argparse
import pandas as pd

def ArgumentParser(args=None):
    parser = argparse.ArgumentParser(description="Compute PHBR score for each peptide")

    # required argumenets
    parser.add_argument('--mhc-predictions',
                        required=True,
                        help='Path to the MHC predictions (generated from mhc2phbr.py)')
    parser.add_argument('--output-file',
                        required=True,
                        help='Path to the TSV output file')
    
    # optional arguments
    parser.add_argument('--homozygous-loci',
                        required=False,
                        default='',
                        help='A comma-separate list of loci that are homozygous (e.g. "A,C" for MHC-I or "DRB,DPA" for MHC-II) (default: %(default)s)')
    parser.add_argument('--mhci',
                        action='store_true',
                        help='MHC-I PHBR score')
    parser.add_argument('--mhcii',
                        action='store_true',
                        help='MHC-II PHBR score')
    
    return parser


# get best rank for each peptide-allele pair (including homozygous loci)
def GetBestRank(mhc_pred_df, mhc='i', homozygous_list=[]):
    best_rank_df = mhc_pred_df.groupby(['peptide', 'allele'])['rank'].min().reset_index()   # best rank for each peptide-allele pair

    # get the loci
    if mhc == 'i':
        best_rank_df['loci'] = best_rank_df['allele'].apply(lambda x: x.split('*')[0][4])   # MHC-I (A, B, C)
    else:
        best_rank_df['loci'] = best_rank_df['allele'].apply(lambda x: x.split('*')[0][4:6]) # MHC-II (DP, DQ, DR)
    
    # identify and duplicate homozygous rows
    if mhc == 'i':
        dup_rows = best_rank_df[best_rank_df['loci'].isin(homozygous_list)]
        best_rank_df = pd.concat([best_rank_df, dup_rows], axis=0, ignore_index=True)
    else:
        for chain in ['A', 'B']:                                                            # A chain: DPA, DQA; B chain: DPB, DQB, DRB
            homozygous_chain_list = [s[:-1] for s in homozygous_list if s.endswith(chain)]
            dup_rows = best_rank_df[best_rank_df['loci'].isin(homozygous_chain_list)]
            best_rank_df = pd.concat([best_rank_df, dup_rows], axis=0, ignore_index=True)

    return best_rank_df


# calculate PHBR for each peptide
# fill rank for missing allele with if required
def GetPHBR(best_rank_df, gene_list=['A','B','C'], expected_allele_counts=[2,2,2], fill_missing=False, missing_rank=50):
    loci_counts = best_rank_df.groupby(['peptide', 'loci']).size().to_dict() # get allele counts of loci
    peptides = best_rank_df['peptide'].unique().tolist() # unique peptides

    # compute PHBR
    outputs = list()
    for peptide in peptides: # for each peptide
        d = {'peptide': peptide}
        ranks = best_rank_df[best_rank_df['peptide']==peptide]['rank'].tolist() # rank list
        
        for i,gene in enumerate(gene_list): # for each gene
            allele_count = loci_counts.get((peptide, gene), 0) # allele count for the gene
            d[f'#{gene}'] = allele_count
            
            if allele_count < expected_allele_counts[i]: # allele count < expectation
                print(f"Warning: HLA-{gene}'s allele count of peptide {peptide} is {allele_count} which is fewer than expectation={expected_allele_counts[i]}")
                if fill_missing: ranks.append(missing_rank) # fill rank for missing allele
            
            if allele_count > expected_allele_counts[i]: # allele count > expectation
                print(f"Warning: HLA-{gene}'s allele count of peptide {peptide} is {allele_count} which is more than expectation={expected_allele_counts[i]}")
        
        phbr = harmonic_mean(ranks)
        d['PHBR'] = phbr
        outputs.append(d)
    
    out_df = pd.DataFrame(outputs)

    return out_df


# calculate harmonic mean for a list of numbers
def harmonic_mean(array):
    if (len(array)==0) or (any(x==0 for x in array)):
        raise ValueError('The list must not contain zero and must have at least one element.')
    return len(array) / sum(1/x for x in array)


def Main(mhc_pred_file, output_file, mhc='i', homozygous_list=[], gene_list=['A','B','C'], expected_allele_counts=[2,2,2], fill_missing=False, missing_rank=50):
    # load predictions
    mhc_pred_df = pd.read_csv(mhc_pred_file, sep='\t')

    # get best rank for each peptide-allele pair (including homozygous loci)
    best_rank_df = GetBestRank(mhc_pred_df, mhc=mhc, homozygous_list=homozygous_list)

    # calculate PHBR for each peptide
    phbr_df = GetPHBR(best_rank_df,
                      gene_list=gene_list,
                      expected_allele_counts=expected_allele_counts,
                      fill_missing=fill_missing, 
                      missing_rank=missing_rank)
    
    phbr_df.to_csv(output_file, sep='\t', index=False)


if __name__=='__main__':
    args = ArgumentParser().parse_args()

    # if args.mhci or args.mhcii are not specified, default is mhci
    if args.mhcii:
        mhc = 'ii'
        gene_list = ['DP','DQ','DR']
        expected_allele_counts=[4,4,2] # DPA/DPB, DQA/DQB have 4 allele combination
    else:
        mhc = 'i'
        gene_list = ['A','B','C']
        expected_allele_counts=[2,2,2]
    
    Main(args.mhc_predictions,
         args.output_file,
         mhc=mhc,
         homozygous_list=args.homozygous_loci.split(','), # available homozygous loci: A,B,C,DRB,DPA,DPB,DQA,DQB
         gene_list=gene_list,
         expected_allele_counts=expected_allele_counts)