#!/usr/bin/env python
# coding: utf-8

# Here we provide a framework to aggregate the individual results using pandas. This illustrates the basic principle of what we can do easily/quickly with pandas.  There is likely a bit more logic that needs to go around this, but it should be much faster than the current approach.

# Pseudocode/algorithm for aggregation:
# 
# * foreach prediction job in job_descriptions:
#   * read expected output files into dict of results, keyed by method (e.g., smm, netchop, etc.) and reslult_type (e.g., peptide_table, etc.) - data from multipel jobs of the same method will be combined here
#     * above will require 'method' being added to job_descriptions
#   
# * result_type integration
#   * binding & immunogenicity (peptide table) - merge by peptide & allele
#   * basic processing, netctl, netctlpan (peptide table) - merge by peptide, allele, seq #, start
#   * netchop - merge by seq# & position - since this is currently the only method that outputs a residue table, no merge is really necessary, but we've worked it out below
#   
# * join with core peptide info
# 
# * remove duplicate rows; duplicate rows are removed, but we need to understand where duplicates are coming from in the first place
#   
# * merge warnings & errors
# 
# * fill missing values
# 
# 
# Questions for discussion:
# 
# * Are sequence numbers maintained in the split jobs?  E.g., if there are 10 sequences as input, will they all have the same sequence number in every job?  Assuming yes.
#   * If we update the code such that we also split by sequences, can we still maintain unique sequence numbers with the current approach?  If not, we need to rethink how we're doing this.
#   
# * Why are there so many duplicates?
# 
# * Should we add a 'method' to 'netmhcpan_allele_distance' results?  Allele distances should match between netmhcpan_el & ba, so currently we just put them all into one table
# 
# 

import os
import json
import re
import pandas as pd

# read in the job description file, which is where we'll pull the output locations from
# note the '-mod' suffix indicated that I needed to modify the job descriptions file so that the expected
# output paths matched my local system

def save_json(result, output_path):
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    with open(output_path, 'w') as w_file:
        json.dump(result, w_file, indent=2)

def read_json_file(json_file_path):
    with open(json_file_path, 'r') as r_file:
        return json.load(r_file)

# get unique_vals for string type columns, and field_ranges for number type for a table
def get_filter_ranges(table):
    warnings = []
    object_indices = []
    numeric_indices = []
    for i,dtype in enumerate(table.dtypes):
        if dtype == "object":
            object_indices.append(i)
        else:
            numeric_indices.append(i)

    all_cols = list(table)
    unique_vals = dict()
    for i in object_indices:
        unique_vals[all_cols[i]] = list(map(lambda x: re.sub("^nan$", "-", str(x)), table.iloc[:,i].unique()))

    field_ranges = dict()
    for i in numeric_indices:
        min_value = str(table.iloc[:,i].min())
        max_value = str(table.iloc[:,i].max())
        f = all_cols[i]
        if min_value != 'nan' and max_value != 'nan':
            field_ranges.setdefault(f, {})
            field_ranges[all_cols[i]]['min'] = min_value
            field_ranges[all_cols[i]]['max'] = max_value
        else:
            warnings.append("can not calculate min & max for %s" % f)

    return unique_vals, field_ranges, warnings

def aggregate_result_file(job_desc_file, output_dir, aggregate_output_dir, has_consensus=None, keep_empty_row=False):
    job_desc = read_json_file(job_desc_file)

    all_results = dict()
    warnings = set()
    errors = set()

    for j in job_desc:
        # skip over non-prediction jobs
        if j['job_type'] != 'prediction':
            continue

        # loop through the expected outputs of prediction jobs
        for o in j['expected_outputs']:
            output = read_json_file(o)
            for r in output['results']:
                t = r['type']
                if 'method' in r:
                    m = r['method']
                # netmhcpan_allele_distance results currently don't have a method,
                # so we add that here; since netmhcpan_el & netmhcpan_ba will always
                # output the same allele distances, we only ultimately need to keep
                # all unique rows
                elif t == 'netmhcpan_allele_distance':
                    m = 'binding.netmhcpan'
                else:
                    raise ValueError('No method information was found for a result')
                # TODO: deal with non-conforming results here
                if t == 'processing_plots':
                    if 'processing_plots' not in all_results:
                        all_results['processing_plots'] = dict()

                    if m not in all_results['processing_plots']:
                        all_results['processing_plots'][m] = r
                else:
                    df = pd.DataFrame(r['table_data'], columns=r['table_columns'])
                    if t not in all_results:
                        # if the prediction type has not yet been included, we need to initialize it
                        all_results[t] = dict()
                        all_results[t][m] = df
                    elif m not in all_results[t]:
                        # if the type exists, but the method does not we get here
                        all_results[t][m] = df
                    else:
                        # if we get here, there are already results of this type for this method,
                        # so we concatenate the df
                        all_results[t][m] = pd.concat([all_results[t][m], df])

            # keep all unique warnings & errors
            if 'warnings' in output:
                for w in output['warnings']:
                    warnings.add(w)

            if 'errors' in output:
                for e in output['errors']:
                    errors.add(e)

    # Now we merge together all of the peptide tables.  For more exploratory analysis on how these merges are working and prefixes being added, see the more detailed document.
    # we'll place all merged restuls tables into the merge_results variable, that will eventually be converted to a split
    # dict and incorporated into the 'results' slot of the aggregated results
    merged_results = dict()
    # merge all peptide table
    if 'peptide_table' in all_results:
        # read from the job description file first
        core_peptide_file = os.path.join(output_dir, 'sequence_peptide_index.json')
        core_json = read_json_file(core_peptide_file)
        final_peptide_table =  pd.DataFrame(core_json['results'][0]['table_data'], columns=core_json['results'][0]['table_columns'])
        final_peptide_table = final_peptide_table.add_prefix('core.')
        # get a list of all binding methods
        binding_methods = list(filter(lambda m: m.startswith('binding.'), all_results['peptide_table'].keys()))
        # iterate through the binding methods and join into the final peptide table
        for m in binding_methods:
            #TODO: we're dropping duplicates here, but we need to understand why they are occuring in the first place
            peptide_table = all_results['peptide_table'][m].drop_duplicates()
            # covert scores, ic50 and percentile to float if they are not
            for field in ['ic50', 'score', 'percentile',]:
                if peptide_table.get(field, None) is not None:
                    peptide_table[field] = peptide_table[field].astype('float64')
            join_cols = ['peptide', 'allele']
            left_join_cols = ['core.' + x for x in join_cols]
            right_join_cols = [m + '.' + x for x in join_cols]
            final_peptide_table = pd.merge(final_peptide_table, peptide_table.add_prefix(m + '.'),
            left_on=left_join_cols,
            right_on=right_join_cols,
            how='left').drop(labels=right_join_cols, axis=1)

        # add the median binding percentile column before the first binding-related column
        #TODO: discuss if we need median column if there's only 1 binding method
        if binding_methods:
            binding_columns = list(filter(lambda m: m.startswith('binding.'), final_peptide_table.columns))
            percentile_columns = list(filter(lambda m: m.startswith('binding.') and m.endswith('.percentile'), final_peptide_table.columns))
            first_binding_column_index = list(final_peptide_table.columns).index(binding_columns[0])
            final_peptide_table.insert(first_binding_column_index,
                                    'binding.median_percentile',
                                    final_peptide_table.filter(items=percentile_columns).median(axis=1).round(2))

            # add the consenus percentile rank column, if consensus method was selected
            #TODO: we need a flag for whether the consensus method was selected
            # read the flag in job_desciption file if it's not given
            if has_consensus:
                all_consensus_percentile_columns = ['binding.ann.percentile', 'binding.comblib_sidney2008.percentile', 'binding.smm.percentile']
                consensus_percentile_columns = set(all_consensus_percentile_columns) & set(final_peptide_table.columns)
                if consensus_percentile_columns:
                    median_index = list(final_peptide_table.columns).index('binding.median_percentile')
                    final_peptide_table.insert(median_index + 1,
                                            'binding.consensus_percentile',
                                            final_peptide_table.filter(items=consensus_percentile_columns).median(axis=1).round(2))

        # add the immunogenicity data
        if 'immunogenicity' in all_results['peptide_table']:
            print("found immunogenicity")
            m = 'immunogenicity'
            #TODO: we're dropping duplicates here, but we need to understand why they are occuring in the first place
            peptide_table = all_results['peptide_table'][m].drop_duplicates()
            # covert scores to float if it is not
            for field in ['score', ]:
                if peptide_table.get(field, None) is not None:
                    peptide_table[field] = peptide_table[field].astype('float64')
            join_cols = ['peptide', 'allele']
            left_join_cols = ['core.' + x for x in join_cols]
            right_join_cols = [m + '.' + x for x in join_cols]
            final_peptide_table = pd.merge(final_peptide_table, peptide_table.add_prefix(m + '.'),
            left_on=left_join_cols,
            right_on=right_join_cols,
            how='left').drop(labels=right_join_cols, axis=1)

        # add the basic processing & netctlpan data
        peptide_processing_methods = ['processing.basic_processing', 'processing.netctl', 'processing.netctlpan']
        processing_methods = list(set(peptide_processing_methods) & set(all_results['peptide_table']))
        for m in processing_methods:
            #TODO: we're dropping duplicates here, but we need to understand why they are occuring in the first place
            peptide_table = all_results['peptide_table'][m].drop_duplicates()
            # sometimes the start & sequence_number columns need to be explicitly cast
            peptide_table = peptide_table.astype({'start': 'int64',
                                                'sequence_number': 'int64'})
            # covert them to float if they are not
            for field in ['tap_prediction_score', 'mhc_prediction', 'cleavage_prediction_score', 'combined_prediction_score', 'percentile_rank', 'predicted_mhc_binding_affinity', 'rescale_binding_affinity', 'c_terminal_cleavage_affinity', 'tap_transport_efficiency', 'predictions_score']:
                if peptide_table.get(field, None) is not None:
                    peptide_table[field] = peptide_table[field].astype('float64')
            join_cols = ['sequence_number', 'start', 'peptide', 'allele']
            left_join_cols = ['core.' + x for x in join_cols]
            right_join_cols = [m + '.' + x for x in join_cols]
            final_peptide_table = pd.merge(final_peptide_table, peptide_table.add_prefix(m + '.'),
            left_on=left_join_cols,
            right_on=right_join_cols,
            how='left').drop(labels=right_join_cols, axis=1)    

        merged_results['peptide_table'] = final_peptide_table

    # Merge the residue tables. A netchop is the only table that outputs one currently, this isn't necessary, but we have the logic worked out - and UNTESTED.
    # now let's merge the reidue tables
    if 'residue_table' in all_results:
        # check if contact has finished all the merge work for residue table?
        final_residue_table = pd.DataFrame()

        # add the basic processing & netctlpan data
        residue_processing_methods = ['processing.netchop']
        processing_methods = list(set(residue_processing_methods) & set(all_results['residue_table']))
        for m in processing_methods:
            residue_table = all_results['residue_table'][m]
            # convert_format to float for prediction score
            if residue_table.get('prediction_score', None) is not None:
                residue_table['prediction_score'] = residue_table['prediction_score'].astype('float64')
            # add prefix "method_name" to `prediction_score` and "core." to others
            residue_table=residue_table.set_index(['prediction_score']).add_prefix('core.').reset_index()
            join_cols = ['sequence_number', 'position', 'amino_acid']
            residue_table=residue_table.set_index(['core.' + x for x in join_cols]).add_prefix(m + '.').reset_index()
            left_join_cols = ['core.' + x for x in join_cols]
            right_join_cols = ['core.' + x for x in join_cols]
            if final_residue_table.empty:
                final_residue_table = residue_table
            else:
                final_residue_table = pd.merge(final_residue_table, residue_table,
                left_on=left_join_cols,
                right_on=right_join_cols,
                how='left').drop(labels=right_join_cols, axis=1)
        final_residue_table = final_residue_table.drop_duplicates()
        merged_results['residue_table'] = final_residue_table

    # binding.netmhcpan is hardcoded as the method for the allele distance table
    if 'netmhcpan_allele_distance' in all_results:
        merged_results['netmhcpan_allele_distance'] = all_results['netmhcpan_allele_distance']['binding.netmhcpan'].drop_duplicates().add_prefix('allele_distances.')

    # Finally, we put everything back together.  Note that the keys are slightly different than what we had previously - 'columnns' instead of 'table_columns' and 'data' instead of 'table_data'.  We can update the names, but we would have to create new variables which seems unnecessarily expensive.  It's probably best to simply update the downstream code that needs to work with this data.

    # now let's put this all into an aggregated object that we can dump as JSON
    final_results = list()
    for t in merged_results.keys():
        # drop all na and all "-" columns
        if t in ['peptide_table', 'residue_table']:
            df_t = (merged_results[t].fillna('-') == '-').all()
            empty_columns = df_t[df_t]
            if not empty_columns.empty:
                warnings.add("The following empty columns were removed from the output: '%s'. In some cases, this may indicate a runtime error. Please check your results carefully and contact help@iedb.org if you need assistance." % ', '.join(empty_columns.keys()))
            merged_results[t] =  merged_results[t].drop(empty_columns.index, axis=1)
            # drop all na and all "-" rows for those column names not startswith core.
            if not keep_empty_row:
                # Get the columns to check
                columns_to_check = [col for col in merged_results[t].columns if not col.startswith("core.")]
                # Apply the function to each row and create a boolean mask
                mask = merged_results[t][columns_to_check].apply(lambda row: all(pd.isna(val) or val == '-' for val in row), axis=1)
                # Drop rows where the mask is True
                merged_results[t] = merged_results[t][mask == False].reset_index(drop=True)

        # get unique_vals for string type columns, and field_ranges for number type for each table
        unique_vals, field_ranges, filter_warnings = get_filter_ranges(merged_results[t])
        warnings.update(filter_warnings)
        # replace na with "-" after get filter information (e.g. min, max)
        if t in ['peptide_table', 'residue_table']:
            merged_results[t] = merged_results[t].fillna('-')
        dict_result = merged_results[t].to_dict(orient='split')
        del dict_result['index']
        dict_result['result_type'] = t
        dict_result['table_columns'] = dict_result.pop('columns')
        dict_result['table_data'] = sorted(dict_result.pop('data'))
        dict_result['unique_vals'] = unique_vals
        dict_result['field_ranges'] = field_ranges
        final_results.append(dict_result)

    # add the processing plots onto the results
    if 'processing_plots' in all_results:
        for m in all_results['processing_plots']:
            if 'type' in all_results['processing_plots'][m]:
                all_results['processing_plots'][m]['result_type'] = all_results['processing_plots'][m].pop('type')
            final_results.append(all_results['processing_plots'][m])

    aggregated_results = {
        'warnings': list(warnings),
        'errors': list(errors),
        'results': final_results
    }

    # this is the final aggregated result object that can be written to JSON
    # That's it!  The above object can now be dumped to JSON.

    # save result to output path
    # print('aggregated_results: %s' % aggregated_results)
    final_result_file_path = os.path.join(aggregate_output_dir, 'aggregated_result.json')
    save_json(aggregated_results, final_result_file_path)
    print('aggregated_result_path:%s' % final_result_file_path)

    return
