from datetime import datetime as dt
import os, sys
import pandas as pd
import argparse
import tempfile

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + '/'
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Get the absolute path of the "src" directory and HARDCODEDDLY ADDING IT // deprecated behaviour but shouldn't affect the run of standalone version
src_path = '/home/local/tools/src/ICERFIRE-1.0/src/'
# Add the "src" directory to the Python module search path
sys.path.append(src_path)
# ?
src_path = os.path.dirname(os.path.abspath(__file__))
# Add the "src" directory to the Python module search path
sys.path.append(src_path)
# print(module_path, sys.path)
from train_eval import evaluate_trained_models
from mutation_tools import pipeline_mutation_scores
from utils import str2bool, get_random_id, get_datetime_string, mkdirs, pkl_load

MEDIAN_TOTAL_TMP = 6.071

def get_rank(pred, hp):
    return (hp[0] > pred).mean() * 100


def args_parser():
    parser = argparse.ArgumentParser(
        'Runs the ICERFIRE model on preprocessed data, assuming data has been processed to return' \
        'NetMHCpan ranks, ICOREs, self-similarity score, and PepX expression scores (Optional).')
    parser.add_argument('-j', '--jobid', dest='jobid', type=str, help='Job ID from the server')
    parser.add_argument('-f', '--infile', dest='infile', type=str, required=True,
                        help='Full path to the file containing the icores/selfsimilarity scores')
    parser.add_argument('-pf', '--pepxpath', dest='pepxpath', type=str, required=False, default=None,
                        help='Full path to the file containing the PepX query of the test file')
    parser.add_argument('-ae', '--add_expression', dest='add_expression', type=str2bool,
                        required=False, default=True,
                        help='Whether to use the model that includes expression as a feature')
    parser.add_argument('-ue', '--user_expression', dest='user_expression', type=str2bool,
                        default=False, help='Whether the user provides their own expression values')
    parser.add_argument('-tmp', '--tmpdir', dest='tmpdir', type=str, required=False, default=tempfile.mkdtemp(),
                        help='Output directory')
    return parser.parse_args()


def main():
    args = vars(args_parser())

    # Get the output directory with a random ID and a tag to indicate whether we used the model with expression
    run_dt = get_datetime_string()
    run_id = str(args['jobid']) if args['jobid'] is not None else get_random_id(6)
    run_tag = 'AddExpr' if args['add_expression'] else 'NoExpr'
    # basename = os.path.basename(args['infile']).split('.')[0]
    # print('basename', basename)
    # print('infile, pepxfile', args['infile'], args['pepxpath'])
    run_name = f'{run_dt}_{run_tag}_{run_id}'
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + '/'
    # if tmpdir exists, outdir=tmpdir, otherwise create it
    if os.path.exists(args['tmpdir']):
        outdir = args['tmpdir']
    else:
        #outdir = os.path.join(f'{parent_dir}output/', f'{run_name}/')
        outdir = os.path.join(tempfile.mkdtemp(), f'{run_name}/')
        mkdirs(outdir)
    jobid = str(args['jobid'])
    # Get the directory one level above the script

    # Load appropriate model and data
    data = pd.read_csv(args['infile'], sep=' ')
    #print(f'{parent_dir}data/human_proteome/preds_100k.txt')
    preds_100k = pd.read_csv(f'{parent_dir}data/human_proteome/preds_100k.txt', header=None)

    unpickle = pkl_load(f'{parent_dir}saved_models/ICERFIRE_Expr{args["add_expression"]}.pkl')
    models, kwargs, ics = unpickle['model'], unpickle['kwargs'], unpickle['ics']

    if args['add_expression'] and os.path.exists(args['pepxpath']) and args['pepxpath'] != "None":
        if args['user_expression'] and 'total_gene_tpm' not in data.columns:
            print('\nAdd expression and User-provided expression were selected but no TPM values found in the data.\n' \
                  'Continuing with a model using queried expression values.\n')
        # TODO : DEAL WITH case where PepX is not used and maybe expression is still enabled (and provided)
        try:
            pepx = pd.read_csv(args['pepxpath'])
        except:
            pepx = None
        if pepx is not None:
            data = pd.merge(data, pepx.rename(columns={'peptide': 'icore_wt_aligned'}), how='left',
                left_on='icore_wt_aligned', right_on='icore_wt_aligned')
            data["TPMFilledWithMedian"] = data["total_gene_tpm"].isna()
            #median_value = data["total_gene_tpm"].median(skipna=True)
            data["total_gene_tpm"].fillna(MEDIAN_TOTAL_TMP, inplace=True)
        else:
            data["total_gene_tpm"] = MEDIAN_TOTAL_TMP

        data.fillna(data.median(skipna=True, numeric_only=True), inplace=True)

    data = pipeline_mutation_scores(data, 'icore_mut', 'icore_wt_aligned', ics,
                                    threshold=kwargs['threshold'], prefix='icore_')
    data['seq_id'] = [f'seq_{i}' for i in range(1, len(data) + 1)]

    predictions, test_results = evaluate_trained_models(data, models, ics, encoding_kwargs=kwargs, test_mode=True,
                                                        n_jobs=-1)
    # Saving results as CSV table
    predictions.sort_values('Peptide', ascending=True, inplace=True)
    predictions.rename(columns={'mean_pred': 'prediction'}, inplace=True)
    predictions.reset_index(drop=True, inplace=True)
    predictions['%Rank'] = predictions['prediction'].apply(get_rank, hp=preds_100k)

    cols_to_save = ['Peptide', 'wild_type', 'HLA', 'Pep', 'Core', 'icore_start_pos', 'icore_mut', 'icore_wt_aligned', 'EL_rank_mut',
                    'EL_rank_wt_aligned']
    cols_to_save = cols_to_save + kwargs['mut_col'] + ['prediction', '%Rank']
    # (remove some columns) Keep all 15 columns
    #cols_to_save = ['Peptide', 'wild_type', 'HLA', 'prediction', '%Rank']
    # This could be display name
    # column_rename_map = {'current name': 'target name', 'Peptide': 'peptide B', 'wild_type': 'peptide A', 'HLA': 'allele', 'Pep': 'peptide B core', 'Core': 'peptide A core', 'icore_mut': 'peptide B icore', 'icore_wt_aligned': 'peptide A icore', 'EL_rank_mut': 'peptide B EL rank', 'EL_rank_wt_aligned': 'peptide A EL rank'}
    # rename some columns
    column_rename_map = {
        #'current name': 'target_name',
        # wild_type is peptideA as we swapped first two elements wild-type peptide and mutant peptide of input
        'Peptide': 'peptide-peptideB',
        'wild_type': 'peptide-peptideA',
        'HLA': 'allele',
        'Pep': 'peptide_b_core',
        'Core': 'peptide_a_core',
        'icore_mut': 'peptide_b_icore',
        'icore_wt_aligned': 'peptide_a_icore',
        'EL_rank_mut': 'peptide_b_el_rank',
        'EL_rank_wt_aligned': 'peptide_a_el_rank'
    }
    # replace column name if in column_rename_map
    cols_to_save = [column_rename_map.get(k, k) for k in cols_to_save]
    predictions.rename(columns=column_rename_map).to_csv(f'{outdir}ICERFIRE_predictions.csv',
                       columns=cols_to_save, index=False)

    print(f'final csv result saved to: {outdir}ICERFIRE_predictions.csv')
    if test_results is not None:
        pd.DataFrame(test_results).rename(columns={k: v for k, v in zip(range(len(test_results.keys())),
                                                                        [f'fold_{x}' for x in
                                                                         range(1, len(test_results.keys()))])}) \
            .to_csv(f'{outdir}ICERFIRE_metrics_per_fold.csv', index=False)

    # Cleaning input/temporary files and returning the final saved filename
    for f in os.listdir(args['tmpdir']):
        if (f.endswith('.csv') or f.endswith('.txt')) and f != 'ICERFIRE_predictions.csv':
            os.remove(os.path.join(args['tmpdir'], f))
    return predictions, run_name, jobid


if __name__ == '__main__':
    predictions, run_name, jobid = main()
