#!/usr/bin/env python3

# this will run after all of the jobs in the chain complete
# either successfully or after failure
# given the stage id and path to aggregated results:
# check that all jobs completed successfully
# if not, post failure to db
# if so, enrich results and post to db

import argparse
import json
import os
import sys
import shutil
from requests import HTTPError, Timeout

# appending to the PYTHONPATH so we can import nxg_api_client
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

import nxg_client.nxg_api_client as nxg_api

# set some globals
# TODO: use concept of classes here so we can get around specifying global
#       variables throughout
stage_info = dict()
warnings_and_errors = dict()
aggregate_results = dict()

parser = argparse.ArgumentParser()
 
# add arguments to the parser
# we can either pass a file that contains all of the stage information
# (stage_info_file) OR stage_id, base_uri, stage_dir, and cluster_jobs_desc
# file separately
parser.add_argument('--stage_info_file', 
    help='Path to the JSON file containing all needed stage info',
    required=False)
# TODO: get other options working
# parser.add_argument('--stage_id', 
#     help='The stage id to submit for execution on the cluster',
#     required=False)
# parser.add_argument('--base_uri', 
#     help='The base URI of the AR API',
#     required=False,
#     default='https://api-nextgen-tools.iedb.org/api/v1')

args =  parser.parse_args()

def check_expected_outputs():
    """Given the stage_info, check that all expected outputs exist
    """
    # TODO: add more checks, e.g., that outputs are of non-zero size,
    #       if warnings/error file exists for steps
    
    # we need to specify that we want to use the global
    # warnings and errors, since we're assigning a value here
    global warnings_and_errors
    
    for j in stage_info['stage_jobs']:
        
        job_id = j['job_id']
        
        # check that the completion flag file is there
        if not os.path.exists(j['flag_file_complete']):
                    error_str = "Runtime error: Completion flag file missing for job " + str(job_id) + ": " + j['flag_file_complete']
                    warnings_and_errors['errors'].append(error_str)
                    print(error_str)            
        
        # check that the expected outputs are there
        if 'expected_outputs' in j:
            
            for eo in j['expected_outputs']:
            
                if not os.path.exists(eo):
                    error_str = "Expected output missing for job id " + str(job_id) + ": " + eo
                    warnings_and_errors['errors'].append(error_str)
                    print(error_str)
                       
    return
        

def load_aggregate_results():
    """Given a stage_info object, find the aggregate job and return
       the results"""

    # we need to specify we want to use the global aggregate
    # since we're assigning a value here
    global aggregate_results
    global warnings_and_errors

    # a list of files from the aggregated results output
    aggregate_results_files = list()

    # assume there is only 1 job that either is job_type 'aggregate'
    # or is flagged with 'final_output' = True
    # start from the end of the list and quit when we find it
    for j in reversed(stage_info['stage_jobs']):
        if j['job_type'] == 'aggregate' or 'final_output' in j:
            print("Aggregate/final results expected from job: " + str(j['job_id']))
            aggregate_results_files = j['expected_outputs'].copy()
            break

    # if we get here and there is no aggregate result, something went wrong
    if not aggregate_results_files:
        # TODO: determine why this error doesn't appear, but the related index_error below does
        warnings_and_errors['errors'].append("No aggregate job found!")
    
    
    # for now we assume there is only 1 aggregate result file
    # TODO: deal with potentially multiple aggregate result files
    try:
        with open(aggregate_results_files[0], 'r') as f:
            aggregate_results = json.load(f)
    except FileNotFoundError as fnf_error:
         print(f"File not found: {fnf_error}")
         warnings_and_errors['errors'].append(str(fnf_error))
    except KeyError as key_error:
        print(f"Key error reading json file: {key_error} not found")
        warnings_and_errors['errors'].append(str(key_error))
    except IndexError as idx_error:
        print(f"Index error - no aggregation file found: {idx_error}")
        warnings_and_errors['errors'].append(str(idx_error))
        
        
    

def post_results_to_db():
    """Post the aggregate results to the database"""
    
    warnings = warnings_and_errors['warnings']
    errors = warnings_and_errors['errors']
    
    api_results = {}
    results_object = {}
    
    # first convert the results to API format
    try:
        api_results = nxg_api.cmd2api_output_mhci(aggregate_results)
        results_object = { 'results': api_results['results'] }
    except KeyError as key_error:
        msg = f"Key error converting cmd results to API: {key_error} not found"
        print(msg)
        results_object = {}
        errors.append(msg)
    
    # add the warnings and errors from the command line tool
    if 'warnings' in api_results:
        warnings+= api_results['warnings']
    if 'errors' in api_results:
        errors += api_results['errors']
    
    # if errors exist, change status to failed
    result_status = 'done'
    if errors:
        result_status = 'error'
        # let's also prepend the errors with something human friendly
        errors.insert(0, f"The following errors occurred during processing.  If these issues persist, please contact help@iedb.org for assistance and include stage reference number {stage_info['stage_id']}")

    # pull the token from the environment
    token = os.environ.get('NXG_AUTH_TOKEN')
    
    # create an API client
    a = nxg_api.NXGclient(base_uri=stage_info['base_uri'],
                          token=token)

    # post the job IDs back to the database
    # if there is an http error here, or the max number of retries are exceeded
    # an error should be posted to the database
    # if we still cannot post an error to the database, an entry in the main log
    # file should be added and the team should be notified
    try:
        a.post_results(result_id=stage_info['result_id'],
                   results_object=results_object,
                   result_status=result_status,
                   warnings=warnings,
                   errors=errors)
    except (HTTPError, Timeout) as e:
        print(f"Error posting results back to database: {e}")
        print("Removing results, adding error message, and retrying")
        result_status='error'
        errors.append(f"The system encountered an error while transferring the results")
        errors.insert(0, f"The following errors occurred during processing.  If these issues persist, please contact help@iedb.org for assistance and include stage reference number {stage_info['stage_id']}")
        try:
            a.post_results(result_id=stage_info['result_id'],
                   results_object={},
                   result_status=result_status,
                   warnings=warnings,
                   errors=errors)
        except (HTTPError, Timeout) as e2:
            print(f"Error posting error message back to the database: {e2}")
            # TODO notify team of issue / add to main log file
        else:
            print(f"Error message posted successfully after removing results")





def cleanup():
    """If the job succeeded, clean up the working directory"""
    stage_dir = stage_info['stage_dir']
    if not warnings_and_errors['errors']:
        print("Job completed - removing working directory: " + stage_dir)
        try:
            shutil.rmtree(stage_dir)
        except:
            'Error deleting directory'
    else:
        print("Job had 1 or more errors, keeping working directory: " + stage_dir)
        

def main():
    
    # initialize the warnings & errors
    global warnings_and_errors
    global stage_info
    
    warnings_and_errors = {'warnings': list(),
                           'errors': list()}
    
    stage_info_file = args.stage_info_file

    # unfreeze the stage info    
    with open(stage_info_file, 'r') as f:
        stage_info = json.load(f)
        
    check_expected_outputs()
    
    load_aggregate_results()

    post_results_to_db()
    
    # clean up the working directory if the job succeeded and the run is
    # not flagged as DEBUG
    #cleanup()
    
main()
