#!/usr/bin/env python

# Functions common to many of the HPC job scripts

import subprocess
import time
import os
import sys
from requests import HTTPError, Timeout

# appending to the PYTHONPATH so we can import nxg_api_client and nxg_common
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

import settings
import nxg_common.nxg_common as nxg_common
import nxg_client.nxg_api_client as nxg_client

def submit_job(sbatch_cmd):
    """Given an sbatch command, submit it to the
    cluster and return the job ID
    """
    
    job_id = None
    max_retries = 5
    attempt = 1
    
    # attempt submission 5 times
    #for attempt in range(1, max_retries+1):
    while(job_id == None and attempt <= max_retries):
        try:
            print(sbatch_cmd)
            result = subprocess.run(sbatch_cmd,
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    shell=True)
            stdout = result.stdout.decode()
            print(f'response: {stdout}')        
            result_parts = stdout.split()     

            # this will fail if a nonzero return code is encountered
            if result.returncode != 0:
                raise Exception(f'Nonzero return code received from sbatch: {result.returncode}')
            
            # this will fail if the stdout is empty
            job_id = result_parts[-1]
            
            # this will fail if the job_id is not an integer
            job_id_int = int(job_id)
            
            print(f'Job submitted successfully on attempt {attempt}')

        except Exception as e:
            print(f'Exception submitting job on attempt {attempt}')
            print(e)
            print(f'response code: {result.returncode}')
            print(f'response (stdout): {result.stdout.decode()}')
            print(f'response (stderr): {result.stderr.decode()}')
            
            if (attempt <= max_retries):
                print(f'sleeping 5 seconds before retry')
                time.sleep(5)
        finally:
            attempt = attempt + 1

    return job_id

def cancel_jobs(job_id_list=[]):
    """Given a list of job_ids, cancel them with scancel"""
    
    for j in job_id_list:
        scancel_cmd = f"scancel {j}"
        print(f"Attempting to cancel job {str(j)} with: {scancel_cmd}")
        result = subprocess.run(scancel_cmd, stdout=subprocess.PIPE, shell=True)
        print("response: " + result.stdout.decode())


class NXGHPCstage:
    def __init__(self,
        python_path='python',
        stage_dir=None,
        stage_id=None,
        api_client=nxg_client.NXGclient()):
        """Create a new NXGHPCstage object
        """
        self.python_path = python_path
        self.stage_dir = stage_dir
        self.stage_id = stage_id
        self.api_client = api_client
        # we will not know the jobs description at init time, but will add this information
        # later, upon submission
        self.jobs_data = dict()

    def submit_all_jobs(self, jobs_data, env_cmds, mem_mb_default=500, env_name='prod'):
        """Given the job desription data, submit
        each job to the cluster.  A default mem_mb requirement of 500 is used,
        but this can be overridden by the individual jobs. Return the job data
        with the cluster job IDs attached."""
        
        self.jobs_data = jobs_data
        
        # catalog the jobs by job ID
        jobs_by_id = nxg_common.catalog_list_by(jobs_data, 'job_id')
        
        for j in jobs_data:
            
            # assign the default memory requirement, then check if
            # it should be altered for this job
            mem_mb = mem_mb_default
            
            # check if a specific memory requirement is requested for the job
            # otherwise, we use the default
            if 'mem_mb' in j:
                mem_mb = j['mem_mb']
            elif j['job_type'] == 'aggregate':
                # up the memory requirement to 5GB for aggregation
                mem_mb = 5000           
            # let's build up the command staring with sbatch
            sbatch_cmd = 'sbatch'

            sbatch_cmd += ' -o' + os.path.join(self.stage_dir, str(j['job_id']) + '.out')
            
            # create a name for the job based on the first 8 characters of the stage name,
            # the job serial number within the stage, and the job type
            # {STAGE_ID_SHORT}_{SERIAL}_{JOB_TYPE}
            job_name = '_'.join([self.stage_id[:8], str(j['job_id']), j['job_type']]) 
            
            sbatch_cmd += f' --job-name {job_name}'
                        
            # add the resource request
            #TODO: make this dynamic
            sbatch_cmd += f" --ntasks=1 --cpus-per-task=1 --mem {mem_mb}mb"
            
            # get the dependency string
            # The way this is implemented assumes that any jobs that
            # the current job is dependent upon have already been submitted
            dep_cluster_ids = list()
            for dep_job_id in j['depends_on_job_ids']:
                dep_job = jobs_data[jobs_by_id[dep_job_id]]
                dep_cluster_ids.append(dep_job['cluster_job_id'])
            
            # create the dependency string
            dep_string = ':'.join(dep_cluster_ids)

            if (dep_string):
                sbatch_cmd += ' --kill-on-invalid-dep=yes -d afterok:' + dep_string
            
            # now let's work on the commands to wrap
            # start with ensuring the job fails after any failed command
            job_cmds = ['set -Eeuo pipefail']

            job_cmds = job_cmds + env_cmds
            job_cmds.append('echo HOST: $HOSTNAME')
            job_cmds.append('echo SLURM_JOB_ID: \$SLURM_JOB_ID')
            job_cmds.append(j['shell_cmd'])
            
            # add a flag file when the job completes
            job_cmds.append(f"touch {j['flag_file_complete']}")
            
            # encode the stage_id into the comment field
            sbatch_cmd += f" --comment='stage_id={self.stage_id};'"

            # join the commands into a string
            wrapped_cmd = "; ".join(job_cmds)
            sbatch_cmd += ' --wrap "' + wrapped_cmd + '"'
            
            # TODO: add a name for each job
            
            # start the jobs with a hold
            sbatch_cmd += ' -H'
            
            job_id = submit_job(sbatch_cmd)
            
            # if there was an error submitting the job, cancel all jobs
            if (job_id == None):
                print("Job submission error - canceling all jobs for this stage")
                self.cancel_all_jobs()
                
            
            j['cluster_job_id'] = job_id
            
        return jobs_data
    
    #TODO: there's a lot of duplicated code between this function and the 
    # submit_all_jobs function; determine if some of that can be pushed into 
    # the submit_job function
    def submit_final_job(self, stage_info_file, env_cmds):
        """Submit a job that will run after all of the other jobs have completed
        If all jobs succeeded, results should be posted back to the database
        If any job failed, the stage status should be marked as such"""
        
        script_dir = os.path.dirname(os.path.realpath(__file__))
        
        # as a placeholder, let's just release all jobs
        all_job_ids = list()
        for j in self.jobs_data:
            all_job_ids.append(j['cluster_job_id'])
        
        dep_string = ':'.join(all_job_ids)
        
        # start to build the sbatch command
        sbatch_cmd = 'sbatch'
        sbatch_cmd += ' --export=NXG_AUTH_TOKEN'
        sbatch_cmd += ' -o' + os.path.join(self.stage_dir, 'final.out')    
        sbatch_cmd += '  --kill-on-invalid-dep=yes -d afterany:' + dep_string
        sbatch_cmd += f" --comment='stage_id={self.stage_id};'"
        
        # now let's work on the commands to wrap
        # fail after any error
        job_cmds = ['set -Eeuo pipefail']
        job_cmds = job_cmds + env_cmds
        job_cmds.append('echo SLURM_JOB_ID: \$SLURM_JOB_ID')
        finish_cmd = self.python_path + ' ' + os.path.join(script_dir, 'finish_run.py')
        finish_cmd += ' --stage_info_file ' + stage_info_file
        job_cmds.append(finish_cmd)
        
        # add a flag file for completion
        job_completion_file = os.path.join(self.stage_dir, 'job.complete')
        job_cmds.append(f"touch {job_completion_file}")
        
        # join the commands into a string
        wrapped_cmd = "; ".join(job_cmds)
        sbatch_cmd += ' --wrap "' + wrapped_cmd + '"'
            
        # start the jobs with a hold
        sbatch_cmd += ' -H'
        
        final_job_id = submit_job(sbatch_cmd)
        
        # if there was an error submitting the job, cancel all jobs
        if (final_job_id == None):
            print("Job submission error - canceling all jobs for this stage")
            self.cancel_all_jobs()

        all_job_ids.append(final_job_id)

        # adding this info may have unintended side effects    
        # self.jobs_data.append({'shell_cmd': finish_cmd,
        #                        'job_id': -1,
        #                        'job_type': 'finish_run',
        #                        'expected_outputs': [],
        #                        'depends_on_job_ids': [],
        #                        'cluster_job_id': final_job_id})
    
        # release the hold
        job_id_string = ','.join(all_job_ids)
        release_cmd = 'scontrol release ' + job_id_string
        print("releasing job hold: " + release_cmd)
        result = subprocess.run(release_cmd, stdout=subprocess.PIPE, shell=True)
        print("response: " + result.stdout.decode())
    
    
    def post_stage_job_ids(self):
        """Post the stage job IDs back to the database"""
        # post the job IDs back to the database
        # if there is an http error here, or the max number of retries are exceeded
        # an error should be posted to the database and the jobs should be canceled
        
        jobs_by_cluster_id = nxg_common.catalog_list_by(self.jobs_data, 'cluster_job_id')
        
        try:
            self.api_client.post_job_ids(self.stage_id, list(jobs_by_cluster_id.keys()))
        except (HTTPError, Timeout) as e:
            print(f"Error posting JOB IDs: {e}")
            print("Canceling all jobs")
            self.cancel_all_jobs()
    
    
    def cancel_all_jobs(self):
        """Cancel all jobs associated with this stage"""
        
        print("Canceling all jobs associated with this stage")
        jobs_by_cluster_id = nxg_common.catalog_list_by(self.jobs_data, 'cluster_job_id')
        cluster_job_id_list = list(jobs_by_cluster_id.keys())
        cancel_jobs(cluster_job_id_list)
        

        

