#!/usr/bin/env python
from __future__ import print_function
import os
import sys
import re
import json
import shutil
import tempfile
import logging
from subprocess import Popen, PIPE
logging.basicConfig(level=logging.WARNING, format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S',)
import string
import random

from allele_validator import AlleleValidator

NXG_TOOLS_PATH = os.environ.get('NXG_TOOLS_PATH')
if NXG_TOOLS_PATH and os.path.isdir(NXG_TOOLS_PATH):
    logging.debug('load NXG_TOOLS_PATH: %s' % NXG_TOOLS_PATH)
    sys.path.append(NXG_TOOLS_PATH)
elif os.path.isdir(os.path.join(os.path.dirname(__file__), '..', 'method', 'nxg-tools')):
    sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'method', 'nxg-tools'))
# from nxg-tools package
from nxg_common.nxg_common import save_file_from_URI

# if the path to the TC1 executable is in an environment variable, use it
# otherwise, pull it from the path_config.py
TC1_EXECUTABLE_PATH = os.environ.get('TC1_EXECUTABLE_PATH')
if not TC1_EXECUTABLE_PATH:
    from path_config import TC1_EXECUTABLE_PATH

# if we get here and it is still not set, throw an error
if not TC1_EXECUTABLE_PATH:
    print('Path to the TC1 execuatble must be set.  See the README for details on rerunning the configure script')
    sys.exit(1)

TC2_EXECUTABLE_PATH = os.environ.get('TC2_EXECUTABLE_PATH')
if not TC2_EXECUTABLE_PATH:
    from path_config import TC2_EXECUTABLE_PATH

def generate_random_str(length):
    return ''.join(random.sample(string.digits+string.ascii_letters, length))


def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

def save_json(result, output_path):
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    with open(output_path, 'w') as w_file:
        json.dump(result, w_file, indent=2)
    return os.path.abspath(output_path)

def split_peptidea_peptideb_peptides(input_sequence_text):
    # chagne it to upper case as well
    # remove duplicates
    peptide_pair_dictkeys = dict.fromkeys(peptides for peptides in input_sequence_text.strip().upper().split('\n'))
    peptides_list = [peptides.split(',') for peptides in peptide_pair_dictkeys]
    return list(zip(*peptides_list))

def split_predictors(predictors, mode='tc1'):
    icerfire_predictors = []
    mhci_predictors =  []
    mhcii_predictors = []
    other_predictors = []
    for p in predictors:
        if p.get('method','') == 'icerfire':
            icerfire_predictors.append(p)
        elif p.get('tools_group','') == 'mhcii':
            mhcii_predictors.append(p)
        elif p.get('tools_group','') == 'mhci':
            mhci_predictors.append(p)
        else:
            other_predictors.append(p)
    # TODO: check if mhci and mhcii predictors are valid, and if not, raise an error
    if mode == 'tc1':
        # if mode is tc1, we only use mhci predictors
        mhci_predictors = mhci_predictors + other_predictors
        if not mhci_predictors + icerfire_predictors:
            raise ValueError('No mhci predictors found in the input data for tc1 mode')
        if mhcii_predictors:
            print('Warning: mhcii predictors are ignored in tc1 mode, only mhci predictors are used')
            mhcii_predictors = []
    elif mode == 'tc2':
        # if mode is tc2, we only use mhcii predictors
        mhcii_predictors = mhcii_predictors + other_predictors
        if not mhcii_predictors:
            raise ValueError('No mhcii predictors found in the input data for tc2 mode')
        if mhci_predictors or icerfire_predictors:
            print('Warning: mhci and icerfire predictors are ignored in tc2 mode, only mhcii predictors are used')
            mhci_predictors = []
            icerfire_predictors = []

    return icerfire_predictors, mhci_predictors, mhcii_predictors    


def get_peptide_list_with_sequencenumber(input_sequence_text, has_icerfire=False, mhci_predictors=None, mhcii_predictors=None):
    table_data = []
    peptide_pair_set = set()
    sequence_list = split_peptidea_peptideb_peptides(input_sequence_text)
    input_sequence_rows = list(zip(*sequence_list))
    seq_num = 1
    # 8-14 for mhci and 11-30 for mhcii
    if has_icerfire or mhci_predictors:
        min_length = 8
        max_length = 14
    elif mhcii_predictors:
        min_length = 11
        max_length = 30
    for peptide_pair in zip(*sequence_list[:2]):
        if peptide_pair not in peptide_pair_set:
            peptide_pair_set.add(peptide_pair)
            peptidea, peptideb = peptide_pair
            # remove peptides longer than 14 or shorter than 8 for mhci, or longer than 30 or shorter than 11 for mhcii
            if len(peptidea) < min_length or len(peptideb) < min_length or len(peptidea) > max_length or len(peptideb) > max_length:
                print(f"Warning: exclude peptide pair {peptidea} and {peptideb} from prediction which is not in the range of {min_length}-{max_length}")
            else:
                table_data.append((seq_num,peptidea, peptideb))
                seq_num += 1
    # remove crosponding rows from input_sequence_rows as well if peptides longer than 14 or shorter than 8 for mhci, or longer than 30 or shorter than 11 for mhcii
    input_sequence_rows = [row for row in input_sequence_rows if len(row[0]) >= min_length and len(row[1]) >= min_length and len(row[0]) <= max_length and len(row[1]) <= max_length]
    '''
    e.g. input_sequence_rows:
    [('FLYNPLTRV', 'FLYNLL'), ('MLGERLFPL', 'MLGEQLFPL'), ('FLDEFMEAV', 'FLDEFMEGV'), ('VVLSWAPPV', 'VVMSWAPPV'), ('LLLDDSLVSI', 'LLLDDLLVSI'), ('GSFGDIYLA', 'GLFGDIYLA'), ('ALYGSVPVL', 'ALYGFVPVL'), ('ILTGLNYEA', 'ILTGLNYEV'), ('SLADEAEVHL', 'SLADEAEVYL'), ('CQWGRLWQL', 'CMWGRLWQL')]
    table_data:
    [(2, 'MLGERLFPL', 'MLGEQLFPL'), (3, 'FLDEFMEAV', 'FLDEFMEGV'), (4, 'VVLSWAPPV', 'VVMSWAPPV'), (5, 'LLLDDSLVSI', 'LLLDDLLVSI'), (6, 'GSFGDIYLA', 'GLFGDIYLA'), (7, 'ALYGSVPVL', 'ALYGFVPVL'), (8, 'ILTGLNYEA', 'ILTGLNYEV'), (9, 'SLADEAEVHL', 'SLADEAEVYL'), (10, 'CQWGRLWQL', 'CMWGRLWQL')]
    '''
    table_columns = [
        "core.sequence_number",
        "core.peptide-peptideA",
        "core.peptide-peptideB",
    ]
    av = None
    if len(sequence_list) > 2:
        # check if there's allele within the input_sequence_text
        av = AlleleValidator()
        if any(av.validate_alleles(allele) or av.convert_synonym_to_iedblabel(allele) for allele in sequence_list[2]):
            table_columns.append("core.allele")
            for i in range(len(table_data)):
                row = list(table_data[i])
                allele = sequence_list[2][i]
                if not av.validate_alleles(allele):
                    allele = av.convert_synonym_to_iedblabel(allele)
                if allele:
                    row.append(allele)
                    # update input_sequence_rows for allele name to iedblabel
                    input_sequence_rows[i] = list(input_sequence_rows[i])
                    input_sequence_rows[i][2] = allele
                else:
                    # remove row from list:
                    print(f"Warnning: exclude allele {allele} from prediction which is not recognized")
                    row = None
                    input_sequence_rows[i] = None
                table_data[i] = row
            table_data = [row for row in table_data if row]
            input_sequence_rows = [row for row in input_sequence_rows if row]
    input_sequence_text = '\n'.join([','.join(row) for row in input_sequence_rows])
    results = [dict(result_type="peptide_table", table_columns=table_columns, table_data=table_data),]
    return dict(warnings=[], results=results),input_sequence_text


def split_sequences(json_filename, parameters_output_dir=None, split_inputs_dir=None, assume_valid=False):
    logging.debug(TC1_EXECUTABLE_PATH)
    with open(json_filename, 'r') as r_file:
        input_data = json.load(r_file)
    has_icerfire = False
    has_others = False
    mode = input_data.get('mode', 'tc1')
    if mode not in ['tc1', 'tc2']:
        raise ValueError('mode must be either tc1 or tc2, got %s' % mode)
    icerfire_predictors, mhci_predictors, mhcii_predictors = split_predictors(input_data['predictors'], mode=mode)
    if icerfire_predictors:
        has_icerfire = True
        icerfire_input_data = input_data.copy()
        icerfire_input_data['predictors'] = icerfire_predictors
    if mhci_predictors or mhcii_predictors:
        has_others = True
        other_input_data = input_data.copy()
        # TODO: now it's either mhci or mhcii, but we can have both in the future
        #other_input_data['predictors'] = mhci_predictors + mhcii_predictors
        if mhci_predictors:
            other_input_data['predictors'] = mhci_predictors
        elif mhcii_predictors:
            other_input_data['predictors'] = mhcii_predictors
        input_data = other_input_data
    # recreate the directory
    if os.path.exists(parameters_output_dir):
        shutil.rmtree(parameters_output_dir)

    # if not given, create a folder for it
    if not parameters_output_dir:
        parameters_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'job_%s/splitted_parameters' % generate_random_str(6)))
    # if not given, use parameters_output_dir for generated sequence files as well
    if not split_inputs_dir:
        split_inputs_dir = parameters_output_dir
    # create dir if not exist:
    os.makedirs(parameters_output_dir, exist_ok=True)
    os.makedirs(split_inputs_dir, exist_ok=True)

    # get splitted parameters, sequence_peptide_index table, and has_consensus flag
    #output_data, sequence_peptide_index, has_consensus = split_parameters(input_data, split_inputs_dir)

    parameters_output_dir = os.path.abspath(parameters_output_dir)
    result_output_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir, 'results'))
    aggregate_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir, 'aggregate'))
    base_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir))
    job_descriptions_path = os.path.abspath(os.path.join(base_dir, 'job_descriptions.json'))
    #sequence_peptide_index_path = os.path.abspath(os.path.join(result_output_dir, 'peptides_index.json'))
    mhci_predict_executable_path = TC1_EXECUTABLE_PATH
    mhcii_predict_executable_path = TC2_EXECUTABLE_PATH
    pvc_executable_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'run_pvc.py'))
    sequence_peptide_index_path = os.path.abspath(os.path.join(result_output_dir, 'sequence_peptide_index.json'))

    # TODO: how to deal with the  temp files, when should we remove them?
    to_delete = []
    input_sequence_text = input_data.pop('input_sequence_text','')
    peptide_list = input_data.pop('peptide_list', '')
    input_sequence_fasta_uri = input_data.pop('input_sequence_fasta_uri', '')
    input_sequence_text_uri = input_sequence_fasta_uri.replace('sequence_list_fasta', 'download_sequences')
    if input_sequence_text_uri:
        input_sequence_text_file_path = save_file_from_URI(input_sequence_text_uri, target_dir=split_inputs_dir)
        to_delete.append(input_sequence_text_file_path)
        with open(input_sequence_text_file_path, 'r') as r_file:
            input_sequence_text = r_file.read()
    elif input_sequence_text :
        with tempfile.NamedTemporaryFile(mode='w', dir=split_inputs_dir, delete=False) as tmp_peptides_file:
            input_sequence_text_file_path = tmp_peptides_file.name
            to_delete.append(input_sequence_text_file_path)
            tmp_peptides_file.write(input_sequence_text)
    elif 'input_sequence_text_file_path' in input_data:
        input_sequence_text_file_path = input_data.pop('input_sequence_text_file_path')
        with open(input_sequence_text_file_path, 'r') as r_file:
            input_sequence_text = r_file.read()
    else:
        raise ValueError('input_sequence_text must be provided')
    # get unique peptides pair with seq # with column names:  "sequence_number", "peptide-peptideA","peptide-peptideB",
    # it also updates input_sequence_text with iedblabel for allele and remove duplicates rows
    # remove peptides longer than 14 or shorter than 8 for mhci, or longer than 30 or shorter than 11 for mhcii
    sequence_peptide_index, input_sequence_text = get_peptide_list_with_sequencenumber(input_sequence_text, has_icerfire, mhci_predictors, mhcii_predictors)
    save_json(sequence_peptide_index, sequence_peptide_index_path)
    # either mhci_predictors or mhcii_predictors for now
    # TODO: change it if we want both mhci and mhcii in one prediction being supported
    if mhci_predictors or (mhcii_predictors and not has_icerfire):
        if mhci_predictors:
            executable_path = mhci_predict_executable_path
        elif mhcii_predictors:
            executable_path = mhcii_predict_executable_path

        # create dir for peptidea and peptideb jobs
        peptidea_path = os.path.abspath(os.path.join(parameters_output_dir, 'peptidea'))
        peptideb_path = os.path.abspath(os.path.join(parameters_output_dir, 'peptideb'))
        os.makedirs(peptidea_path, exist_ok=True)
        os.makedirs(peptideb_path, exist_ok=True)

        output_data = []
        peptide_file_path_length_pairs = []
        if not input_sequence_text:
            raise ValueError('input_sequence_text must contain at least one row with peptides')
        peptidea_peptideb_peptides = split_peptidea_peptideb_peptides(input_sequence_text)
        if len(peptidea_peptideb_peptides) < 2:
            raise ValueError('input_sequence_text must contain at least two columns with peptides')
        peptidea_peptide_list, peptideb_peptide_list = peptidea_peptideb_peptides[:2]
        if len(peptidea_peptideb_peptides) >= 3 and peptidea_peptideb_peptides[2]:
            av = AlleleValidator()
            if any(av.validate_alleles(allele) or av.convert_synonym_to_iedblabel(allele) for allele in peptidea_peptideb_peptides[2]):
                alleles_from_seq_input = peptidea_peptideb_peptides[2]
                input_data['alleles'] = ','.join(set(alleles_from_seq_input))
        peptidea_peptides_path = os.path.abspath(save_peptide_list(peptidea_peptide_list, os.path.join(split_inputs_dir,'peptidea_peptides.txt')))
        peptideb_peptides_path = os.path.abspath(save_peptide_list(peptideb_peptide_list, os.path.join(split_inputs_dir,'peptideb_peptides.txt')))

        input_data['peptide_file_path'] = peptidea_peptides_path
        peptidea_input_path = save_json(input_data, os.path.join(peptidea_path, 'input.json'))
        input_data['peptide_file_path'] = peptideb_peptides_path
        peptideb_input_path = save_json(input_data, os.path.join(peptideb_path, 'input.json'))

        cmd = [executable_path, '-j', peptidea_input_path, '--split', '--split-dir=%s' % os.path.join( peptidea_path,'parameter_units'), '--keep-empty-row']
        logging.debug(' '.join(cmd))
        process = Popen(cmd, stdout=PIPE)
        cmd = [executable_path, '-j', peptideb_input_path, '--split', '--split-dir=%s' % os.path.join( peptideb_path,'parameter_units'), '--keep-empty-row']
        logging.debug(' '.join(cmd))
        process = Popen(cmd, stdout=PIPE)
        stdoutdata, stderrdata_ignored = process.communicate()
        stdoutdata = stdoutdata.decode()
        logging.debug('Raw output:\n{}'.format(stdoutdata))

        peptidea_job_descriptions_json_file = os.path.join(peptidea_path, 'job_descriptions.json')
        peptideb_job_descriptions_json_file = os.path.join(peptideb_path, 'job_descriptions.json')
        final_job_descriptions_json_file = os.path.abspath(os.path.join(parameters_output_dir, '..', 'job_descriptions.json'))
        final_result_path = os.path.join(parameters_output_dir, 'aggregated_result.json')

        peptidea_job_descriptions = read_json_file(peptidea_job_descriptions_json_file)
        peptideb_job_descriptions = read_json_file(peptideb_job_descriptions_json_file)
        final_job_descriptions = merge_jobs_and_add_aggregate(peptidea_job_descriptions, peptideb_job_descriptions, pvc_executable_path, final_result_path)
    else:
        final_job_descriptions_json_file = os.path.abspath(os.path.join(parameters_output_dir, '..', 'job_descriptions.json'))
        final_job_descriptions = []
    # TODO what if there's no peptidea and peptideb job
    if has_icerfire:
        allele_validation_errors, allele_validation_warnings = run_validation_for_alleles(input_sequence_text, input_data['alleles'].split(','), 'icerfire', tool_group='pvc')
        if allele_validation_errors:
            print(f"Error: {allele_validation_errors}")
        else:
            final_result_path = os.path.join(parameters_output_dir, 'aggregated_result.json')
            icerfire_input_data['input_sequence_text'] = input_sequence_text
            final_job_descriptions = add_icerfire_and_add_aggregate(final_job_descriptions, pvc_executable_path, icerfire_input_data, parameters_output_dir, final_job_descriptions_json_file, final_result_path)
    save_json(final_job_descriptions, final_job_descriptions_json_file)

    print(f"job_descriptions_json_file: {final_job_descriptions_json_file}")
    return final_job_descriptions_json_file

def run_validation_for_alleles(input_sequence_text, alleles, method, lengths=None, tool_group='mhci'):
    av = AlleleValidator()
    errors = []
    warnings = []
    allele_list = [a.strip() for a in alleles]
    if lengths:
        length_list = list(map(str,lengths))
    else:
        length_list = []
    unavailable_alleles = []
    peptidea_peptideb_peptides = split_peptidea_peptideb_peptides(input_sequence_text)
    if len(peptidea_peptideb_peptides) >= 3 and peptidea_peptideb_peptides[2]:
        if any(av.validate_alleles(allele) or av.convert_synonym_to_iedblabel(allele) for allele in peptidea_peptideb_peptides[2]):
            alleles_from_seq_input = peptidea_peptideb_peptides[2]
            allele_list = list(set(alleles_from_seq_input))

    if method == 'icerfire':
        tools_group="pvc"
    else:
        tools_group=tool_group
    validations = av.validate_alleles(allele_list, method, tools_group=tools_group)
    unavailable_alleles = [allele for allele, validation in zip(allele_list, validations) if validation is False]
    available_alleles = [allele for allele, validation in zip(allele_list, validations) if validation is True]
    logging.debug('unavailable_alleles for icerfire: %s' % unavailable_alleles)
    for allele in unavailable_alleles:
        warnings.append('%s cannot predict for allele %s' % (method, allele))
    if len(unavailable_alleles) == len(allele_list):
        errors.append(f"all selected alleles are not available selected predictor {method}")

    return errors, warnings

def add_icerfire_and_add_aggregate(final_job_descriptions, pvc_executable_path, icerfire_input_data, parameters_output_dir, final_job_descriptions_json_file, final_result_path):
    aggregate_depends_on_job_ids = []
    max_job_id = get_max_job_id(final_job_descriptions)
    if max_job_id:
        aggregate_depends_on_job_ids.append(max_job_id)
    job_id = max_job_id + 1

    # TODO: consider where's better to put it
    #result_output_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir, 'results'))
    result_output_dir = os.path.abspath(os.path.join(parameters_output_dir, 'results'))
    icerfire_result_prefix = '%s/icerfire' % result_output_dir
    icerfire_result_path = '%s/icerfire.json' % result_output_dir
    os.makedirs(result_output_dir, exist_ok=True)
    data_unit_file_path = os.path.join(parameters_output_dir, 'icerfire_input.json')
    save_json(icerfire_input_data, data_unit_file_path)
    shell_cmd='%s -j %s/icerfire_input.json -o %s -f json' % (pvc_executable_path, parameters_output_dir, icerfire_result_prefix)
    #if assume_valid:
    #    shell_cmd += ' --assume-valid'
    job_description = dict(
        shell_cmd=shell_cmd,
        job_id=job_id,
        job_type="prediction",
        depends_on_job_ids=[],
        expected_outputs=[icerfire_result_path]
    )
    aggregate_depends_on_job_ids.append(job_id)
    final_job_descriptions.append(job_description)

    # add aggreate job
    # job_id == -1 means no job is required to run
    if job_id > -1:
        job_id +=1
        # act differently if final_result_path exist
        shell_cmd='%s --aggregate --job-desc-file=%s --icerfire-result-path=%s --aggregate-output-path=%s' % (pvc_executable_path, final_job_descriptions_json_file, icerfire_result_path, final_result_path)
        aggreate_job_description = dict(
            shell_cmd=shell_cmd,
            job_id=job_id,
            job_type="aggregate",
            depends_on_job_ids=aggregate_depends_on_job_ids,
            expected_outputs=[final_result_path], 
        )
        final_job_descriptions.append(aggreate_job_description)
    return final_job_descriptions

def get_max_job_id(final_job_descriptions):
    if final_job_descriptions:
        return max([job['job_id'] for job in final_job_descriptions])
    return 0

def merge_jobs_and_add_aggregate(peptidea_job_descriptions, peptideb_job_descriptions ,pvc_executable_path, final_result_path):
    '''
    put peptidea and peptideb job_descriptions together and complete final_job_descriptions
    '''
    final_job_descriptions = []
    final_aggregate_depends_on_job_ids = []

    peptidea_output_path, peptidea_final_job_id = add_job_descriptions(peptidea_job_descriptions, final_job_descriptions)
    final_aggregate_depends_on_job_ids.append(peptidea_final_job_id)

    peptideb_output_path, peptideb_final_job_id = add_job_descriptions(peptideb_job_descriptions, final_job_descriptions)
    final_aggregate_depends_on_job_ids.append(peptideb_final_job_id)

    final_aggregate_cmd = (pvc_executable_path, "--aggregate", f"--peptidea-result-path={peptidea_output_path}", f"--peptideb-result-path={peptideb_output_path}", f"--aggregate-output-path={final_result_path}")
    final_job_id = max([job['job_id'] for job in final_job_descriptions]) + 1

    final_job_descriptions.append({
        "shell_cmd": " ".join(final_aggregate_cmd),
        "job_id": final_job_id,
        "job_type": "aggregate",
        "depends_on_job_ids": final_aggregate_depends_on_job_ids,
        "expected_outputs": [final_result_path,]
    })
    return final_job_descriptions

def add_job_descriptions(sub_job_descriptions, final_job_descriptions):
    initial_job_id = 1
    if final_job_descriptions:
        initial_job_id = max([job['job_id'] for job in final_job_descriptions]) + 1
    new_job_id = initial_job_id
    job_id_map = {}
    for job in sub_job_descriptions:
        job_id_map[job['job_id']] = new_job_id
        job['job_id'] = new_job_id
        job['depends_on_job_ids'] = [job_id_map[old_id] for old_id in job['depends_on_job_ids']]
        final_job_descriptions.append(job)
        if job["job_type"] == "aggregate":
            output_path = job['expected_outputs'][0]
            aggregate_job_id = job['job_id']
        new_job_id += 1
    return output_path, aggregate_job_id

def save_peptide_list(peptide_list, file_path):
    with open(file_path, mode='w') as peptides_file:
        peptides_file.write('\n'.join(peptide_list))
    return file_path

def read_json_file(file_path):
    with open(file_path, 'r') as r_file:
        return json.load(r_file)

if __name__ == '__main__':
    pass


