#!/usr/bin/env python

# Functions for working with the nextgen-tools API
import requests
import pprint
import json
import copy
import sys
import os

from requests.adapters import HTTPAdapter
from urllib3.util import Retry

# define the retry strategy, as pulled from here
# https://findwork.dev/blog/advanced-usage-python-requests-timeouts-retries-hooks/#retry-on-failure
retry_strategy = Retry(
    total=5,
    status_forcelist=[429, 500, 502, 503, 504],
    allowed_methods=frozenset({'DELETE', 'GET', 'HEAD', 'OPTIONS', 'POST', 'PUT', 'TRACE'})
)
adapter = HTTPAdapter(max_retries=retry_strategy)
DEFAULT_TIMEOUT=20
http = requests.Session()
# set the raise_for_status flag to raise exceptions on http status codes other than 200
http.hooks = {
   'response': lambda r, *args, **kwargs: r.raise_for_status()
}
http.mount("https://", adapter)
http.mount("http://", adapter)

# appending to the PYTHONPATH so we can import nxg_common
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from nxg_common.column_info import get_column_info
from nxg_common.nxg_common import get_fully_qualified_field_name as fqn
from nxg_common.nxg_common import save_file_from_URI

pp = pprint.PrettyPrinter(indent=4)

class NXGclient:
    def __init__(self,
                 base_uri='https://api-nextgen-tools.iedb.org/api/v1',
                 token=None):
        """Create a new NXGclient object with the default URI
        """
        self.base_uri = base_uri
        
        # a token to be used for POST requests
        self.token = token
        
        # the following parameters are used for caching to reduce
        # the number of calls to the api
        # a dict of stages, keyed by stage id
        self.stages = dict()

        # a dict of pipelines keyed by pipeline id
        self.pipelines = dict()

        # a dict of objects keyed by object id
        self.objects = dict()
        
    def get_pipeline_stages(self, pipeline_id, use_cache=True):
        """Given a pipeline ID, return a list of its stage IDs.
        Unless otherwise specified, a cached version of the pipeline
        stage IDs will be returned"""
        pipeline = self.get_pipeline(pipeline_id, use_cache=use_cache)
        stage_id_list = []
        for stage in pipeline['stages']:
            stage_id_list.append(stage['stage_id'])
        return stage_id_list

    def get_pipeline(self, pipeline_id, use_cache=True):
        """Given a pipeline id, use the pipeline endpoint to retrieve details
        Unless otherwise specified, a cached version of the stage
        will be returned if it exists"""
        # if we've already retrieve this pipeline, no need to do so again
        # unless use_cache=False
        if pipeline_id not in self.pipelines or use_cache == False:            
            endpoint = '/pipeline/' + pipeline_id
            full_url = self.base_uri + endpoint
            self.pipelines[pipeline_id] = make_get_request(full_url)
        
        # return a copy of the pipelines, so that multiple calls
        # and subsequent changes do not affect the data structure
            
        return copy.deepcopy(self.pipelines[pipeline_id])

    def get_pipeline_stage(self, stage_id, use_cache=True):
        """Given a stage ID, retrieve the stage information.
        Unless otherwise specified, a cached version of the stage
        will be returned if it exists"""
        # if we've already retrieve this stage, no need to do so again
        # unless use_cache=False
        if stage_id not in self.stages or use_cache == False:     
            endpoint = '/stage/' + stage_id
            full_url = self.base_uri + endpoint
            self.stages[stage_id] = make_get_request(full_url)

        # return a copy of the stages, so that multiple calls
        # and subsequent changes do not affect the data structure
            
        return copy.deepcopy(self.stages[stage_id])

    # given an object ID, retrieve it from the API and return as JSON
    def get_object(self, object_id, use_cache=True):
        """Given an object ID (sequence, results, etc.), retrieve
        it from the API and return the full object.  Unless
        otherwise specified, a cached version of the object will
        be returned if it exists"""
        # if we've already retrieve this object, no need to do so again
        # unless use_cache=False
        if object_id not in self.objects or use_cache == False:     
            endpoint = '/get_object/' + object_id
            full_url = self.base_uri + endpoint
            self.objects[object_id] = make_get_request(full_url)

        # return a copy of the objects, so that multiple calls
        # and subsequent changes do not affect the data structure
            
        return copy.deepcopy(self.objects[object_id])    

    def get_stage_inputs(self, stage_id):
        """Given a stage ID, return the input_data section, after
        retrieving the input datasets & input parameters objects.
        """
           
        # retrieve the stage information 
        stage = self.get_pipeline_stage(stage_id)
                
        input_data = {}
        # retrieve the input parameters & input datasets objects
        # input_datsets will only be set if this stage is receiving data from an upstream stage
        input_data['input_parameters'] = self.get_object(stage['input_data']['input_parameters_id'])
        
        #TODO: update this to 'piped_data' or 'piped_input_data'
        if 'input_datasets' in stage['input_data']:
            input_data['input_datasets'] = self.get_object(stage['input_data']['input_datasets_id'])

        # set a pointer to the input sequence fasta uri
        if 'input_sequence_text_id' in stage['input_data']:
            input_sequence_id = stage['input_data']['input_sequence_text_id']
            input_data['input_sequence_fasta_uri'] = self.base_uri + '/sequence_list_fasta/' + input_sequence_id

        # set a pointer to the input sequence fasta uri
        if 'input_vcf_text_id' in stage['input_data']:
            input_vcf_id = stage['input_data']['input_vcf_text_id']
            input_data['vcf_download_uri'] = self.base_uri + '/download_vcf/' + input_vcf_id

        # set a pointer to the input sequence fasta uri
        if 'input_neoepitopes_id' in stage['input_data']:
            input_neoepitopes_id = stage['input_data']['input_neoepitopes_id']
            input_data['neoepitopes_download_uri'] = self.base_uri + '/download_neoepitopes/' + input_neoepitopes_id
            print("neoepitopes_download_uri: " + input_data['neoepitopes_download_uri'])
        #TODO if this is a 'filter' stage, add the input_stage_id information
        # currently, we can't pull this information from the stage endpoint
        if stage['stage_type'] == 'filter':
            pass

        return input_data

    def get_stage_result_id(self, stage_id):
        """Given a stage ID, return its result_id
        """
        # retrieve the stage information        
        stage =  self.get_pipeline_stage(stage_id)
        
        return stage['stage_result_id']


    def post_job_ids(self, stage_id: str, job_ids: list):
        """POST job IDs for a stage back to the database"""
        endpoint = '/jobs'
        full_url = self.base_uri + endpoint

        headers = {'Authorization': 'Token ' + self.token,
                   'Content-Type': 'application/json'}
        
        post_data = {'stage_id': stage_id,
                     'cluster_job_ids': job_ids}
        
        post_data = json.dumps(post_data)
        #print(post_data)
        
        update_status = make_post_request(full_url, post_data, headers=headers)
        
        print(update_status)
        
        # TODO: Handle errors      

    def post_results(self, result_id: str, results_object: dict,
                     result_status: str, warnings: list=[], errors: list=[]):
        """POST stage results back to the database"""
        endpoint = '/results/' + result_id
        full_url = self.base_uri + endpoint
        
        headers = {'Authorization': 'Token ' + self.token,
                   'Content-Type': 'application/json'}
                
        post_data = {'result_status': result_status,
                     'result_data': results_object,
                     'warnings': warnings,
                     'errors': errors
                     }
        
        #with open('data.json', 'w') as f:
        #    json.dump(post_data, f)
        post_data = json.dumps(post_data)
        
        update_status = make_post_request(full_url, post_data, headers)
        
        print(update_status)
        
        # TODO: Handle the error properly below - 
        if 'update_error' in update_status:
            print("Error POSTing results:")
            print(update_status['update_error'])

def api2cmd_input(tool_group, api_inputs = {}, stage_id = None):
    """Given the tool group name, select the right function
    for mapping the api inputs to command line inputs"""
    
    action_map = {
        'mhci': api2cmd_input_mhci,
        'mhcii': api2cmd_input_mhci,
        'peptide_variant_comparison': api2cmd_input_mhci,
        'phbr': api2cmd_input_phbr,
        'mutgen': api2cmd_input_mutgen,
        'cluster': api2cmd_input_cluster,
        'pepmatch': api2cmd_input_pepmatch
    }
    
    print(f"generating command line inputs for {tool_group}")
    
    return action_map[tool_group](api_inputs, stage_id)

def api2cmd_input_phbr(api_inputs = {}, stage_id = None):
    """Given a dict of stage input_data OR a stage_id,
    return the JSON necessary for the command line tool"""

    if stage_id != None:
        api_inputs = NXGclient().get_stage_inputs(stage_id)

    api_data = api_inputs['input_parameters']['data']

    # start with the input parameters
    # NOTE: Here we are assuming only 1 predictor is passed
    cmd_input = copy.deepcopy(api_data)

    # retrieve the input sequence fasta
    input_sequence_file = save_file_from_URI(api_inputs['neoepitopes_download_uri'])
    with open(input_sequence_file, "r") as f:
        input_fasta_string = f.read()

    # now add the sequence text
    cmd_input['input_neoepitopes'] = input_fasta_string

    return cmd_input

def api2cmd_input_pepmatch(api_inputs = {}, stage_id = None):
    """Given a dict of stage input_data OR a stage_id,
    return the JSON necessary for the command line tool"""

    # convert from the api format:
    #  {
    #     "input_parameters":
    #     {
    #         "id": "41816486-ff48-4911-bbc4-4305f9ce8aeb",
    #         "type": "input_parameters",
    #         "format": null,
    #         "data":
    #         {
    #             "mismatch": 3,
    #             "proteome": "Human",
    #             "best_match": true
    #         }
    #     },
    #     "input_sequence_fasta_uri": "https://api-nextgen-tools-dev.iedb.org/api/v1/sequence_list_fasta/9ea35cad-1b50-42a0-90a1-b7c4fca32410"
    # }
    
    # to the command line format
    # {
    #     "input_sequence_text": "DDEDSKQNIFHFLYR\nADPGPHLMGGGGRAK\nKAVELGVKLLHAFHT\nQLQNLGINPANIGLS\nHEVWFFGLQYVDSKG",
    #     "mismatch": 3,
    #     "proteome": "Human",
    #     "best_match": true
    # }

    if stage_id != None:
        api_inputs = NXGclient().get_stage_inputs(stage_id)

    api_data = api_inputs['input_parameters']['data']

    # start with the input parameters
    # NOTE: Here we are assuming only 1 predictor is passed
    cmd_input = copy.deepcopy(api_data)

    # retrieve the input sequence fasta
    input_sequence_file = save_file_from_URI(api_inputs['input_sequence_fasta_uri'])
    with open(input_sequence_file, "r") as f:
        input_fasta_string = f.read()

    # now add the sequence text
    cmd_input['input_sequence_text'] = input_fasta_string

    return cmd_input
 
# functions to handle conversion between formats
# if we pass a stage id, we'll pull the inputs from the stage
# otherwise, we'll use what's sent to use in the api_inputs dictionary
# TODO: determine if this belongs here or in antoher library
def api2cmd_input_cluster(api_inputs = {}, stage_id = None):
    """Given a dict of stage input_data OR a stage_id.
    return the JSON necessary for the cluster command line tool"""
    # Convert from the format used in the API:
    # {
    #     "input_parameters":
    #     {
    #         "id": "e05f11f6-6654-412e-b22a-b7d1d6fd1a56",
    #         "type": "input_parameters",
    #         "format": null,
    #         "data":
    #         {
    #             "predictors":
    #             [
    #                 {
    #                     "type": "cluster",
    #                     "method": "cluster-break"
    #                 }
    #             ],
    #             "cluster_pct_identity": 0.7,
    #             "peptide_length_range":
    #             [
    #                 0,
    #                 0
    #             ]
    #         }
    #     },
    #     "input_sequence_fasta_uri": "https://api-nextgen-tools-dev.iedb.org/api/v1/sequence_list_fasta/0caf1b6f-2c98-4e1f-a510-0df22bf1c4de"
    # }
    
    # to the format used in the command line tool:
    # {
    #     "input_sequence_text": ">Mus Pep1\nLEQIHVLENSLVL\n>Mus Pep2\nFVEHIHVLENSLAFK\n>Mus Pep3\nGLYGREPDLSSDIKERFA\n>Mus Pep4\nEWFSILLASDKREKI",
    #     "method": "cluster-break",
    #     "cluster_pct_identity": 0.7,
    #     "peptide_length_range": [
    #         0,
    #         0
    #     ]
    # }
     # TODO: this will retrieve from the main server, but there
    # should be a way to pull from any server
        
    if stage_id != None:
        api_inputs = NXGclient().get_stage_inputs(stage_id)

    api_data = api_inputs['input_parameters']['data']

    # start with the input parameters
    # NOTE: Here we are assuming only 1 predictor is passed
    cmd_input = { 'cluster_pct_identity': api_data['cluster_pct_identity'],
                  'peptide_length_range': api_data['peptide_length_range'],
                  'method': api_data['predictors'][0]['method'] }

    # retrieve the input sequence fasta
    input_sequence_file = save_file_from_URI(api_inputs['input_sequence_fasta_uri'])
    with open(input_sequence_file, "r") as f:
        input_fasta_string = f.read()

    # now add the sequence text
    cmd_input['input_sequence_text'] = input_fasta_string

    return cmd_input       
       
def api2cmd_input_mutgen(api_inputs = {}, stage_id = None):
    """Given a dict of stage input_data OR a stage_id,
    return the JSON necessary for the command line tool"""

    # TODO: this will retrieve from the main server, but there
    # should be a way to pull from any server
    if stage_id != None:
        api_inputs = NXGclient().get_stage_inputs(stage_id)

    # start with the input parameters
    cmd_input = api_inputs['input_parameters']['data']

    # now add the sequence text
    cmd_input['vcf_download_uri'] = api_inputs['vcf_download_uri']

    return cmd_input
        
def api2cmd_input_mhci(api_inputs = {}, stage_id = None):
    """Given a dict of stage input_data OR a stage_id,
    return the JSON necessary for the command line tool"""

    # TODO: this will retrieve from the main server, but there
    # should be a way to pull from any server
    if stage_id != None:
        api_inputs = NXGclient().get_stage_inputs(stage_id)
    
    # start with the input parameters
    cmd_input = api_inputs['input_parameters']['data']

    # now add the sequence text
    cmd_input['input_sequence_fasta_uri'] = api_inputs['input_sequence_fasta_uri']

    return cmd_input

def cmd2api_output_mhci(cmd_output):
    """Given the aggregated output from the command line tool,
    update to the format needed by the API.
    """
    # first copy the output into a new variable
    api_output = copy.deepcopy(cmd_output)
       
    # add the column descriptions to the 'table_columns'
    # within each result
    # e.g., we go from:
    # "table_columns": [
    #    "sequence_number",
    #     ...
    #  ]
    # to:
    # "table_columns": [
    # {
    #   "name": "sequence_number",
    #   "type": "int",
    #   "hidden": false,
    #   "source": "core",
    #   "sort_order": 0,
    #   "description": "Index of the input sequence among all input sequences.",
    #   "display_name": "seq #",
    #   "default_order": null,
    #   "row_sort_priority": null
    # },
    # ...
    # ]
    
    for r in api_output['results']:
        
        # change 'result_type' to 'type' if it is used
        if 'type' not in r and 'result_type' in r:
            r['type'] = r.pop('result_type')
        
        # skip this result if no table columns are included
        if 'table_columns' not in r:
            continue
        
        # create a new list to hold the table columns with descriptions
        full_table_columns = list()
        
        column_overrides = {}
        # pull in the column overrides if they are included in the result
        if 'column_overrides' in r:
            column_overrides = copy.deepcopy(r['column_overrides'])
            # remove column overrides from result object
            del r['column_overrides']
        
        # while we go through the table columns, keep track of any
        # row_sort_priority fields so we can renumber staring from
        # 0 afterwards. The row_sort_priorities dict will be keyed
        # by the indices of full_table_columns
        row_sort_priorities = dict()
        for c in r['table_columns']:
            c_info = get_column_info(c)
            # make column info unique with other table's same name column
            c_info = c_info.copy()
            c_info['key'] = c
            if 'row_sort_priority' in c_info:
                rsp = c_info['row_sort_priority']
                if rsp is not None:
                    row_sort_priorities[len(full_table_columns)] = rsp
                    
            # if an override exists for the column, apply it here
            if c in column_overrides:
                for o_field, o_value in column_overrides[c].items():
                    c_info[o_field] = o_value
            
            full_table_columns.append(c_info)
        
        # now we need to sort the row_sort_priorities by value,
        # then assign new priorities
        reunumberd_rsp = list()
        for k, v in sorted(row_sort_priorities.items(), key=lambda item: item[1]):
            full_table_columns[k]['row_sort_priority'] = len(reunumberd_rsp)
            reunumberd_rsp.append(k)
        
        # enrich the table column data with the unique_values & ranges,
        # if they are set
        
        # first add the lowercase versions of the keys to the
        # unique_values & field_ranges hashes
        # TODO: this isn't the cleanest solution; ideally the names that
        # the standalone returns should match
        unique_vals = dict()
        field_ranges = dict()
        for k, v in r['unique_vals'].items():
            unique_vals[k.lower()] = v
        
        for k, v in r['field_ranges'].items():
            field_ranges[k.lower()] = v
        
        # now remove the 'unique_vals' & 'field_ranges' from the result
        del r['unique_vals']
        del r['field_ranges']
            
        for f in full_table_columns:
            # get the fully-qualified field name
            field_name = fqn(f)
            value_limits = dict()
            # if couldn't find value_limits for given field_name, try with the key of the column_info
            key = f.pop('key', None)
            if field_name not in unique_vals and field_name not in field_ranges and key:
                field_name = key
            if field_name in unique_vals:
                value_limits['unique_values'] = unique_vals[field_name]

            if field_name in field_ranges:
                value_limits['min'] = float(field_ranges[field_name]['min'])
                value_limits['max'] = float(field_ranges[field_name]['max'])
            
            if value_limits:
                f['value_limits'] = value_limits
            else:
                print("WARNING: No value limits (unique_values, min, max) found for: " + field_name)
        
        r['table_columns'] = full_table_columns
              
    return api_output
    
    
def make_get_request(url, headers={}, timeout=DEFAULT_TIMEOUT):
    """Make an HTTP get request and return the response as JSON, with optional headers.
    Default timeout is 5 seconds"""

    response = http.get(url, headers=headers, timeout=timeout)
    
    print(f"response status code: {response.status_code}")
    
    #print(response.text)
    return response.json()

def make_post_request(url, post_data, headers={}, timeout=DEFAULT_TIMEOUT):
    """Make an HTTP post request and return the response as JSON, with optional headers.
    Default timeout is 5 seconds."""
    
    response = http.post(url, data=post_data, headers=headers, timeout=timeout)
        
    print(f"response status code: {response.status_code}")
    
    return response.json()


# just running some tests here
# TODO - create proper test cases
# TODO - add examples of all return formats
def main():
    pipeline_id = '8daf7324-d747-477b-9b2d-b6d70ff6cf41'
    a = NXGclient(base_uri='https://api-nextgen-tools-dev.iedb.org/api/v1')
    pipeline_json = a.get_pipeline(pipeline_id)
    print("pipeline:")
    print(json.dumps(pipeline_json))
    print(pipeline_json['pipeline_date'])
    stage_list = a.get_pipeline_stages(pipeline_id)
    for sid in stage_list:
        print("stage_id: " + sid)
        stage = a.get_pipeline_stage(sid)
        print("stage:")
        print(json.dumps(stage))
        stage_inputs = a.get_stage_inputs(sid)
        print("stage inputs:")
        print(json.dumps(stage_inputs))

        cmd_inputs = api2cmd_input_mhci(stage_inputs)
        print("cmd inputs:")
        print(json.dumps(cmd_inputs))
    
    # uncomment below to regenerate expected api aggregateed output    
    # with open(os.path.join(SCRIPT_DIR,'tests/cmd_agg_output.json'), 'r') as f:
    #     cmd_agg_output = json.load(f)
    #     api_agg_output = cmd2api_output_mhci(cmd_agg_output)
    #     print("api agg output:")
    #     print(json.dumps(api_agg_output))

if __name__ == '__main__':
    main()
