#!/usr/bin/env python3

# given a stage id (and maybe some other parameters):
# 1: download the stage inputs
# 2: convert to command line inputs
# 3: run the command-line tool with the --split option
# 4: iterate through the split jobs and submit to the cluster
# 5: submit an additional job to post the results to the database upon completion
#    this last job should run regardless of the upstream jobs so it can update
#    the result with an error message should the prediction fail

# This script will be run by a shell script under sbatch
# It will create a temporary directory for the workflow, unless
# it is pas

import argparse
import os
import sys
import json
import subprocess
import settings as s
from tempfile import mkdtemp
import signal

# appending to the PYTHONPATH so we can import nxg_api_client and nxg_common
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

import nxg_client.nxg_api_client as nxg_api
import nxg_common.nxg_common as nxg_common
import nxg_hpc_jobs.nxg_hpc_common as nxg_hpc

parser = argparse.ArgumentParser()
 
#TODO: remove all defaults - arguments will be set from the config file, unless
# they are overridded by the command line parameters
 
# add arguments to the parser
parser.add_argument('--stage-id', 
    help='The stage id to submit for execution on the cluster',
    required=True)
parser.add_argument('--base-uri',
    help='The URI for the API',
    required=False)
parser.add_argument('--env-name',
    help='Name of the environment to use and pull setttings from for all default values',
    required=False,
    default='prod')
parser.add_argument('--dev',
    help='A flag to change the URI to point to the dev server. This will override the --base_uri parameter.',
    action='store_true',
    default=False,
    required=False)
parser.add_argument('--scratch-dir',
    help='The top level directory to be used for creating temporary directories.  Default is temp directory defined by the shell.',
    required=False)
parser.add_argument('--stage-dir',
    help='The directory in which the job will be run.  If not specified, a temporary directory will be created',
    required=False)
parser.add_argument('--cmdline-path',
    help="Path to the pepmatch command line tool",
    required=False)
parser.add_argument('--token',
    help="API token for POSTing data",
    required=False)

# The following arguments deal with the python environment and may be incompatible
# if multiple are specified
parser.add_argument('--python-path',
    help="Path to the python binary to used for running the command line script",
    default="python",
    required=False)

parser.add_argument('--virtualenv',
    help="Path to the virtualenv to be activated before running the command line script",
    required=False,
    type=str)
parser.add_argument('--modules',
    help="List of modules that need to be loaded before running the command line script",
    required=False,
    type=str)

# the following arguments are for the python environment needed for the
# nxg-tools code
parser.add_argument('--virtualenv-nxgtools',
    help="Path to the virtualenv to be activated to run scripts in the nxg-tools project",
    required=False,
    type=str)
parser.add_argument('--modules-nxgtools',
    help="List of modules that need to be loadedto run scripts in the nxg-tools project",
    required=False,
    type=str)

args = parser.parse_args()

def signal_handler(sig, frame):
    """A signal handler to do some cleanup when the script is aborted"""
    print('Caught a signal. Attempting to exit cleanly')
    
    # cancel any jobs that were submitted
    if jobs_by_cluster_id:
        print("Attempting to cancel jobs")
        job_id_list = list(jobs_by_cluster_id.keys())
        nxg_hpc.cancel_jobs(job_id_list)
    
    # post an error back to the database
    print("Attempting to post an error message back to the database")
    result_id = a.get_stage_result_id(stage_id)
    a.post_results(result_id=result_id,
                   results_object={},
                   result_status='error',
                   warnings=[],
                   errors=[f"The following errors occurred during processing.  If these issues persist, please contact help@iedb.org for assistance and include stage reference number {stage_id}",
                           'Initial job was terminated by an external signal'])

    print("Exiting with error code 1")
    exit(1)
    
def split_inputs(input_json_file, stage_dir, env_cmds):
   """Use the command line tool to split inputs
   Return the path to the job description file
   """
   
   cmdline_path = tool_settings['cmdline_path']
   split_dir = os.path.join(stage_dir, 'split_inputs')
   
   # create a string from the env activation commands
   env_cmd_string = '; '.join(env_cmds)
   
   split_clause = '--split-dir=' + split_dir
   cmd = f"{env_cmd_string}; python {cmdline_path} -j {input_json_file} --split {split_clause} --assume-valid"
   
   print("executing: " + cmd)
   result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
   print(result.stdout.decode())
   
   return(os.path.join(stage_dir, 'job_descriptions.json'))

def main():
    
    global stage_dir
    global stage_id
    global settings
    global tool_settings
    
    # a dict of jobs keyed by cluster job id
    # this is a global so that the signal handler can check
    # if it already exists
    global jobs_by_cluster_id
    jobs_by_cluster_id = {}
        
    # our api client
    # it's a global so the signal handler can make use of it
    global a
    
    # pull values for several arguments
    stage_dir = args.stage_dir
    env_name = args.env_name

    print(f"Loading settings for environemnt: {env_name}")    
    settings = s.load_global_settings(env_name, args)

    # override the URI if --dev is passed
    if args.dev:
        print("--dev flag passed - overriding the base URI with the dev URI")
        settings['base_uri'] = 'https://api-nextgen-tools-dev.iedb.org/api/v1'

    stage_id = args.stage_id
    
    base_uri = settings['base_uri']
    temp_dir = settings['scratch_dir']
    python_path = settings['python_path']
    token = settings['token']
    
    # #pepmatch_proteomes_path = args.pepmatch_proteomes_path
    
    # # # TODO: push an error back to the database if this script fails
    # # if not pepmatch_proteomes_path:
    # #     print("Warning: No pepmatch proteomes path defined. This job will fail")
    # #     sys.exit(1)
    
    # if no stage directory was given, we create a temporary directory
    # if no temp_dir was provided, it should evaluate to None and the
    # default temp dir should be used
    if not stage_dir:
        stage_dir = mkdtemp(dir=temp_dir)
        # change permissions to 777
        os.chmod(stage_dir, 0o777)
    
    print("stage id: " + stage_id)
    print("base uri: " + base_uri)
    print("stage dir: " + stage_dir)
    
    # create an api client
    a = nxg_api.NXGclient(base_uri=base_uri,
                          token=token)
       
    # download the stage inputs
    print("fetching stage inputs")
    api_inputs = a.get_stage_inputs(stage_id)
    
    print("determining stage type from 'tool_group' parameter")
    if 'tool_group' in a.stages[stage_id]:
        stage_type = a.stages[stage_id]['tool_group']
        print(f"stage type: {stage_type}")
    else:
        print("Stage type could not be determined...exiting")
        exit(1)
        
    print("Loading tool group settings")
    tool_settings = s.load_tool_group_settings(env_name, stage_type, args)
    
    # determine the python env for the tools
    print(f"determining python environment for {stage_type}")
    env_cmds = nxg_common.set_python_env(tool_settings['modules'], tool_settings['virtualenv'])
    
    # if there are additional environment variables defined for the tool, let's add those here
    if 'env_vars' in tool_settings:
        for env_name in tool_settings['env_vars'].keys():
            env_value = tool_settings['env_vars'][env_name]
            env_cmds.append(f"export {env_name}={env_value}")
    
    # determine the python env for nxg tools
    print("determining python environment for nxg-tools")
    env_cmds_nxgtools = nxg_common.set_python_env(settings['modules_nxgtools'], settings['virtualenv_nxgtools'])
 
    # convert to command line inputs
    print("converting to command line inputs")
    cmd_input = nxg_api.api2cmd_input(stage_type, api_inputs)
    cmd_input_file = os.path.join(stage_dir, 'cmd_input.json')
    with open(cmd_input_file, 'w', encoding='utf-8') as f:
        json.dump(cmd_input, f, ensure_ascii=False, indent=4)

    # split inputs using command line tool
    print("splitting inputs")
    jobs_desc_file = split_inputs(cmd_input_file, stage_dir, env_cmds)

    # read the jobs description
    print("reading job descriptions")
    jobs_desc = nxg_common.read_json_file(jobs_desc_file)
    
    # update the shell_cmd for each job to point to the correct python path
    # add the name of the flag file (e.g., job_id.complete) that will indicate completion
    for j in jobs_desc:
        j['shell_cmd'] =  python_path + " " + j['shell_cmd']
        j['flag_file_complete'] = os.path.join(stage_dir, str(j['job_id']) + '.complete')
    
    # submit to the cluster
    # create the NXGHPCstage object and link it to the api client
    nxg_stage = nxg_hpc.NXGHPCstage(python_path=python_path,
                                    stage_id=stage_id,
                                    stage_dir=stage_dir,
                                    api_client=a)
    
    # TODO: here we are hardcoding the memory requirement, but ideally
    # this should be pulled from the job description
    nxg_stage.submit_all_jobs(jobs_desc, env_cmds, mem_mb_default=tool_settings['mem_mb'])
    jobs_by_cluster_id = nxg_common.catalog_list_by(jobs_desc, 'cluster_job_id')
    
    print(jobs_by_cluster_id)
    
    # # post the job IDs back to the database
    nxg_stage.post_stage_job_ids()
    
    # # serialize information about the stage, so downstream jobs can
    # # use it, instead of pulling again from the api
    stage_info = { 'stage_id': stage_id,
                   'result_id': a.get_stage_result_id(stage_id),
                   'base_uri': base_uri,
                   'stage_dir': stage_dir,
                   'stage_jobs': jobs_desc
                   }
    stage_info_file = os.path.join(stage_dir, 'stage_info.json')    
    with open(stage_info_file, 'w', encoding='utf-8') as f:
        json.dump(stage_info, f, ensure_ascii=False, indent=4)    

    # submit a cleanup job to the cluster that will push results
    # and clean up the temporary directory (if the job succceeded)
    # if the job failed, post any relevant error messages and update
    # the status
    
    nxg_stage.submit_final_job(stage_info_file, env_cmds_nxgtools)


# define the signals that we sould catch
# TODO: determine if there are other signals that can/should be caught
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)

main()