import unittest
import os
import json
import nxg_api_client

SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__))

# we use the same client for all tests
# TODO: add basic test for creating client
# TODO: add test of cache
a = nxg_api_client.NXGclient(base_uri='https://api-nextgen-tools-dev.iedb.org/api/v1')
# a pipeline ID to use for testing
pipeline_id = '8daf7324-d747-477b-9b2d-b6d70ff6cf41'
# a stage ID to use for testing
stage_id = '4837272b-6064-4be4-a949-efc80fce11ca'

class TestGetInputs(unittest.TestCase):
    
    def test_get_pipeline(self):
        '''Test the return value of get pipeline against the expected'''
        pipeline = a.get_pipeline(pipeline_id)
        with open(os.path.join(SCRIPT_DIR,'expected_pipeline.json'), 'r') as json_data:
            expected_pipeline = json.load(json_data)
            
        stage1 = pipeline['stages'][0]
        expected_stage1 = expected_pipeline['stages'][0]
        
        # test the pipeline level fields
        pipeline_fields_to_test = ['email', 'pipeline_date', 'pipeline_id', 'pipeline_spec_id', 'pipeline_title']
        for f in pipeline_fields_to_test:
            self.assertEqual(pipeline[f], expected_pipeline[f])
        
        # test the stage level fields
        stage_fields_to_test = ['input_sequence_text', 'stage_display_name', 'stage_id', 'stage_number', 'stage_result_id',
                                'stage_result_uri', 'stage_type', 'stage_url', 'tool_group']
        
        for f in stage_fields_to_test:
            self.assertEqual(stage1[f], expected_stage1[f])
            
        # test the fields that require special care
        self.assertDictEqual(stage1['input_parameters'], expected_stage1['input_parameters'])
        self.assertDictEqual(stage1['piped_data'], expected_stage1['piped_data'])
        self.assertDictEqual(stage1['table_state'], expected_stage1['table_state'])

        self.assertSetEqual(set(stage1['stage_messages']['warnings']), set(expected_stage1['stage_messages']['warnings']))
        self.assertSetEqual(set(stage1['stage_messages']['errors']), set(expected_stage1['stage_messages']['errors']))


    def test_get_stages(self):
        '''Test the return value of get_pipeline_stages'''
        stage_list = a.get_pipeline_stages(pipeline_id)
        self.assertListEqual(stage_list, [stage_id])
    
    def test_get_stage(self):
        '''Test the return value of get_pipeline_stage'''
        stage = a.get_pipeline_stage(stage_id)
        with open(os.path.join(SCRIPT_DIR,'expected_stage.json'), 'r') as json_data:
            expected_stage = json.load(json_data)
        
        # compare fields that should be equal    
        fields_to_test = ['stage_display_name', 'stage_id', 'stage_result_id', 'stage_result_uri', 'stage_status',
                          'stage_type', 'tool_group']
        
        for f in fields_to_test:
            self.assertEqual(stage[f], expected_stage[f])
        
        # compare dicts that should be equal
        self.assertDictEqual(stage['input_data'], expected_stage['input_data'])
        self.assertDictEqual(stage['piped_data'], expected_stage['piped_data'])
          
        # compare sets that should be equal
        self.assertSetEqual(set(stage['stage_messages']['warnings']), set(expected_stage['stage_messages']['warnings']))
        self.assertSetEqual(set(stage['stage_messages']['errors']), set(expected_stage['stage_messages']['errors']))


    def test_get_stage_inputs(self):
        '''Test that get stage inputs returns the expected result'''
        stage_inputs = a.get_stage_inputs(stage_id)
        with open(os.path.join(SCRIPT_DIR,'expected_inputs.json'), 'r') as json_data:
            expected_inputs = json.load(json_data)
        self.assertDictEqual(stage_inputs, expected_inputs)

    def test_cmd_input_conversion(self):
        '''Test the input is properly converted to command line format'''
        
        # since we already retrieved these above, it should pull from
        # the cache    
        stage_inputs = a.get_stage_inputs(stage_id)
        cmd_inputs = nxg_api_client.api2cmd_input_mhci(stage_inputs)
        with open(os.path.join(SCRIPT_DIR,'expected_cmd_inputs.json'), 'r') as json_data:
            expected_cmd_inputs = json.load(json_data)
        self.assertDictEqual(cmd_inputs, expected_cmd_inputs)        

class TestAPIConversion(unittest.TestCase):
    
    def test_cmd2api(self):
        """Test that the cmd2api function is producing the expected output"""
        with open(os.path.join(SCRIPT_DIR,'cmd_agg_output.json'), 'r') as f:
            cmd_agg_output = json.load(f)
        with open(os.path.join(SCRIPT_DIR,'expected_api_agg_output.json'), 'r') as f:
            expected_api_agg_output = json.load(f)
        api_agg_output = nxg_api_client.cmd2api_output_mhci(cmd_agg_output)
        
        # first check the easy stuff - errors and warnings are the same
        self.assertListEqual(cmd_agg_output['errors'], expected_api_agg_output['errors'])
        self.assertListEqual(cmd_agg_output['warnings'], expected_api_agg_output['warnings'])

        # address this as time permits        
        # # now go through each result, comparing the values in each table column entry
        # expected_results_sorted =  sorted(expected_api_agg_output, key=lambda d: d['type']) 
        # results_sorted =  sorted(api_agg_output, key=lambda d: d['type'])

        # # now iterate through each result, checking
        
        # uncomment below if we need to generate a new expected output to update
        # the test
        with open(os.path.join(SCRIPT_DIR, 'new_output.json'), 'w') as f:
             json.dump(api_agg_output, f)
        self.assertDictEqual(api_agg_output, expected_api_agg_output)
        
if __name__ == '__main__':
    unittest.main()