import argparse
import sys

parser = argparse.ArgumentParser(description="Process peptide and sequence data with MHC predictions")

parser.add_argument("--peptide-output",
                    required=True,
                    help="Path to the peptide TSV")
parser.add_argument("--sequence-output",
                    required=True,
                    help="Path to the sequence TSV")
parser.add_argument("--phbr-input",
                    required=True,
                    help="Path to the output file to be used as input for PHBR")

# optional arguments
parser.add_argument("--seqnum-colname",
                    required=False,
                    default='seq #',
                    help="Name of the peptide table column to pull the sequence number from (default: %(default)s)'")
parser.add_argument("--peptide-colname",
                    required=False,
                    default='peptide',
                    help="Name of the peptide table column to pull the peptide from (default: %(default)s)'")
parser.add_argument("--start-colname",
                    required=False,
                    default='start',
                    help="Name of the peptide table column to pull the start position from (default: %(default)s)'")
parser.add_argument("--end-colname",
                    required=False,
                    default='end',
                    help="Name of the peptide table column to pull the end position from (default: %(default)s)'")
parser.add_argument("--allele-colname",
                    required=False,
                    default='allele',
                    help="Name of the peptide table column to pull the allele from (default: %(default)s)'")
parser.add_argument("--rank-colname",
                    required=False,
                    default='netmhcpan_el percentile',
                    help="Name of the peptide table column to pull the predicted rank from (default: %(default)s)'")
parser.add_argument("--sequence-seqnum-colname",
                    required=False,
                    default='seq #',
                    help="Name of the sequence table column to pull the sequence number from (default: %(default)s)'")
parser.add_argument("--sequence-sequence-colname",
                    required=False,
                    default='sequence',
                    help="Name of the sequence table column to pull the sequence from (default: %(default)s)'")

# optional arguments related to the mutation position.  If neither are defined, the position will be assumed to be
# the central position
parser.add_argument("--sequence-mutation-position-colname",
                    required=False,
                    help="Name of the sequence table column containing a comma-separated list of the mutation start and end positions. If this is not specified, the --mutation-position can be specified separately.")
parser.add_argument("--mutation-position",
                    required=False,
                    type=str,
                    help="A comma-separated list of mutation start and end to be used across all sequences.  Alternatively, specify --sequence-mutation-position-colname for peptide-specific positions.")
parser.add_argument("--keep-unmutated",
                    required=False,
                    default=False,
                    action='store_true',
                    help="Keep sub-peptides that do not contain the mutation (default: %(default)s)")


# TODO: add validation of input file formats

args = parser.parse_args()

import pandas as pd

# Access the arguments
peptide_output = args.peptide_output
sequence_output = args.sequence_output
phbr_input = args.phbr_input
seqnum_colname = args.seqnum_colname
peptide_colname = args.peptide_colname
start_colname = args.start_colname
end_colname = args.end_colname
allele_colname = args.allele_colname
rank_colname = args.rank_colname
sequence_seqnum_colname = args.sequence_seqnum_colname
sequence_sequence_colname = args.sequence_sequence_colname
sequence_mutation_position_colname = args.sequence_mutation_position_colname
mutation_position = args.mutation_position
keep_unmutated = args.keep_unmutated


# read in the predictions and the input sequences
preds = pd.read_csv(peptide_output, sep="\t")
seqs = pd.read_csv(sequence_output, sep="\t")

# determine how to deal with mutation positions
# if --mutation-position is specified, use this across all peptides
# if --sequence-mutation-position-colname is specified, use this field for each peptide
# if neither are specified, use the central position
# if both are specified, issue a warning

if mutation_position is not None:
    if sequence_mutation_position_colname is not None:
        print("Warning: both --mutation-position and --sequence-mutation-position-colname were specified.  Using --mutation-position")
    seqs['mutation_position'] = mutation_position
elif sequence_mutation_position_colname is not None:
    seqs['mutation_position'] = seqs[sequence_mutation_position_colname].astype(str)
else:
    print("Warning: neither --mutation-position nor --sequence-mutation-position-colname were specified.  Using central position")
    seqs['mutation_position'] = seqs[sequence_sequence_colname].apply(lambda x: (len(x)+1)//2).astype(str)

# Handling comma separated mutation position
split = seqs['mutation_position'].str.split(',', expand=True)

# If only one column exists, add a second with NaNs
# Case: when user doesn't specify mutpos or mutpos colname
if split.shape[1] == 1:
    split[1] = pd.NA

seqs[['mp_start', 'mp_end']] = split.astype('Int64')

# if mutation start & end are given, split them into separate fields and remove NA values
seqs['mp_end'] = seqs['mp_end'].fillna(seqs['mp_start'])


merged_df = pd.merge(preds, seqs,
                     left_on=seqnum_colname,
                     right_on=sequence_seqnum_colname,
                     how='left')

print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
print(merged_df)
print(list(merged_df.columns))
print(rank_colname)
# rank_colname = 'binding.smm.percentile'
print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

# confirm that all rows were matched
if merged_df[sequence_sequence_colname].isnull().any():
    print("Error, peptide and sequence inputs are not compatible")
    sys.exit(1)

# if --keep-unumated is False
# remove sequences that don't overlap with the mutation positions
if keep_unmutated is False:
    merged_df = merged_df[((merged_df[start_colname]<=merged_df['mp_start']) & (merged_df[end_colname]>=merged_df['mp_start'])) |
                          ((merged_df[start_colname]<=merged_df['mp_end']) & (merged_df[end_colname]>=merged_df['mp_end']))]

output_df = merged_df.rename(columns={sequence_sequence_colname: 'peptide',
                                      peptide_colname: 'sub-peptide',
                                      allele_colname: 'allele',
                                      rank_colname: 'rank'})[['peptide',
                                                              'sub-peptide',
                                                              'allele',
                                                              'rank']]
         
output_df.sort_values(['peptide','allele','rank'],
                      inplace=True)
                                      
output_df.to_csv(phbr_input,
                 sep="\t",
                 index=False)
