#!/usr/bin/env python3

# given a filter stage id (and maybe some other parameters):
# 1: download the table_state
# 2: download the results from the input stage
# 3: apply the filters
# 4: post the results back to the db

import argparse
import os
import sys
from tempfile import mkdtemp, mkstemp

# appending to the PYTHONPATH so we can import nxg_api_client
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

import nxg_client.nxg_api_client as nxg_api

parser = argparse.ArgumentParser()
 
# add arguments to the parser
parser.add_argument('--stage_id', 
    help='The filter stage id',
    required=True)
parser.add_argument('--input_stage_id',
    help='The input stage ID.  This is only here for debugging as the input stage ID should be grabbed from the filter stage data',
    required=False)
parser.add_argument('--base_uri',
    help='The URI for the API',
    default='https://api-nextgen-tools.iedb.org/api/v1',
    required=False)
parser.add_argument('--dev',
    help='A flag to change the URI to point to the dev server. This will override the --base_uri parameter.',
    action='store_true',
    default=False,
    required=False)


args = parser.parse_args()

# given a results object and filter specification, apply the filters
# and return the filtered results
def apply_filters(input_results, filter_spec):
    '''Given a list of input_results and a filter_spec (state), apply the filters
    and return the list of filtered results
    '''
    # TODO: currently this only supports peptide tables; ideally this would loop through
    # each table and the filter_spec could include details on filters for each of them
    tables_dict = {table["type"]:table for table in input_results}
    if type(filter_spec) is dict:
        filter_spec['table'] = 'peptide_table'
        filter_spec_list = [filter_spec]
    elif type(filter_spec) is list:
        filter_spec_list = filter_spec
    else:
        raise TypeError("filter_spec should be a list of filters (table_state)")
    for filter_spec in filter_spec_list:
        table_type = filter_spec['table']
        table = tables_dict[table_type]
        columns = table ['table_columns']
        table_data = table['table_data']
        print(filter_spec)
        for column_name, column_filter in filter_spec['columns'].items():
            index = None
            for i in range(len(columns)):
                if column_name == '.'.join((columns[i]['source'], columns[i]['name'])):
                    index = i
            if index is None:
                raise ValueError('column \'%s\' does not exist' % column_name)

            if 'min' in column_filter['search'] and (column_filter['search']['min'] or column_filter['search']['min']==0 or column_filter['search']['min']=='0'):
                search_condition = lambda x:x[index]!='-' and float(x[index])>=float(column_filter['search']['min'])
                table_data = list(filter(search_condition, table_data))

            if 'max' in column_filter['search'] and (column_filter['search']['max'] or column_filter['search']['max']==0 or column_filter['search']['max']=='0'):
                search_condition = lambda x:x[index]!='-' and float(x[index])<=float(column_filter['search']['max'])
                table_data = list(filter(search_condition, table_data))

            if 'search' in column_filter['search']:
                search_condition = lambda x:re.match(column_filter['search']['search'], str(x[index]))
                table_data = list(filter(search_condition, table_data))

            table_data = list(filter(search_condition, table_data))
        table['table_data'] = table_data
    filtered_results = list(tables_dict.values())
    return filtered_results    

def main():
    
    global stage_id
    global input_stage_id
    
    stage_id = args.stage_id
    base_uri = args.base_uri
    
    if args.dev:
        base_uri = 'https://api-nextgen-tools-dev.iedb.org/api/v1'
    
    if args.input_stage_id:
        input_stage_id = args.input_stage_id
    
    print("stage id: " + stage_id)
    print("base uri: " + base_uri)
    
    # create an api client
    a = nxg_api.NXGclient(base_uri=base_uri)
    
    # download the stage inputs
    print("fetching filter stage description")
    filter_stage_inputs = a.get_stage_inputs(stage_id)
    #TODO: if we weren't explicitly given an input stage id, we pull it
    # from the filter stage info; currently this doesn't work because
    # the input stage id is not returned
    
    # get the results for the input stage
    print("fetching the results from the input stage")
    input_stage_result_id = a.get_stage_result_id(input_stage_id)
    input_stage_results = a.get_object(input_stage_result_id)
    
    # filter the results according to the filter stage
    print("applying filters")
    filtered_results = apply_filters(input_stage_results['data']['results'], filter_stage_inputs['input_parameters']['data']['table_state'])
    
    # post the results
    print("posting filtered results")
    filter_stage_result_id = a.get_stage_result_id(stage_id)
    # TODO: deal with warnings and errors here and add them to the post?
    # pull the token from the environment
    token = os.environ.get('NXG_AUTH_TOKEN')
    a.post_results(result_id=filter_stage_result_id,
                   results_object={ 'results': filtered_results },
                   result_status='done',
                   token=token)
    
main()
