import unittest
import sys
import json
import itertools
import pandas as pd
from database_functions import Database


class TestPepxDatabase(unittest.TestCase):
    def setUp(self) -> None:
        '''
        This function is called right before calling each test functions.
        '''
        pd.set_option('expand_frame_repr', False)
        pd.set_option('display.max_columns', 999)

        # Allow print statements to show for debugging
        sys.stdout = sys.__stdout__
    
    @classmethod
    def setUpClass(cls):
        '''
        This function will only run once.
        '''
        cls.database = Database(
            password='fizq8N5nmZ5vNeyq', 
            user='djangotools-pepx'
            )
        
    def test_get_datasources(self):
        datasources = self.database.get_data_sources()
        expected_datasources = [
            'Abelin', 
            'CCLE',
            'GTEx', 
            'HPA', 
            'HeLa', 
            'TCGA'
            ]
        
        # Compares lists to make sure they have same elements
        # regardless of their order.
        self.assertCountEqual(datasources, expected_datasources)


    def test_get_gene_datasources(self):
        datasources = self.database.get_data_sources('gene')
        expected_datasources = [
            'Abelin',
            'CCLE',
            'GTEx',
            'HPA',
            'HeLa',
            'TCGA'
        ]

        # Compares lists to make sure they have same elements
        # regardless of their order.
        self.assertCountEqual(datasources, expected_datasources)

    def test_get_transcript_datasources(self):
        datasources = self.database.get_data_sources('transcript')
        expected_datasources = [
            'CCLE',
            'GTEx',
            'HPA',
            'TCGA'
        ]

        # Compares lists to make sure they have same elements
        # regardless of their order.
        self.assertCountEqual(datasources, expected_datasources)
    
    def test_get_datasets(self):
        datasets = self.database.get_datasets()
        datasets = json.loads(datasets)

        # Picked out one sample from  each datasets
        expected_datasets_subsample = [
            {'dataset_id': 1, 'n_samples': None, 'source': 'HeLa', 'title': 'HeLa: cantarella'},
            {'dataset_id': 2, 'n_samples': None, 'source': 'Abelin', 'title': 'Abelin: B721.221'}, 
            {'dataset_id': 3, 'n_samples': None, 'source': 'TCGA', 'title': 'TCGA: PANCAN'},
            {'dataset_id': 1088, 'n_samples': None, 'source': 'CCLE', 'title': 'CCLE: BXPC3_PANCREAS'},
            {'dataset_id': 181, 'n_samples': None, 'source': 'HPA', 'title': 'HPA: middle cingulate cortex'},
            {'dataset_id': 63, 'n_samples': 104.0, 'source': 'GTEx', 'title': 'GTEx: Kidney'},
        ]

        for dataset in expected_datasets_subsample:
            self.assertIn(dataset, datasets)

    def test_get_datasets_given_source(self):
        datasets = self.database.get_datasets('Abelin', None)
        datasets = json.loads(datasets)

        expected_datasets = [{'dataset_id': 2, 'n_samples': None, 'source': 'Abelin', 'title': 'Abelin: B721.221'}]
        self.assertEqual(datasets, expected_datasets)

    def test_get_gene_datasets(self):
        datasets = self.database.get_datasets(None, 'gene')
        datasets = json.loads(datasets)

        # Picked out one sample from  each datasets
        expected_datasets_subsample = [
            {'dataset_id': 1, 'n_samples': None, 'source': 'HeLa', 'title': 'HeLa: cantarella'},
            {'dataset_id': 2, 'n_samples': None, 'source': 'Abelin', 'title': 'Abelin: B721.221'},
            {'dataset_id': 3, 'n_samples': None, 'source': 'TCGA', 'title': 'TCGA: PANCAN'},
            {'dataset_id': 1088, 'n_samples': None, 'source': 'CCLE', 'title': 'CCLE: BXPC3_PANCREAS'},
            {'dataset_id': 181, 'n_samples': None, 'source': 'HPA', 'title': 'HPA: middle cingulate cortex'},
            {'dataset_id': 63, 'n_samples': 104.0, 'source': 'GTEx', 'title': 'GTEx: Kidney'},
        ]

        for dataset in expected_datasets_subsample:
            self.assertIn(dataset, datasets)

    def test_get_gene_datasets_given_source(self):
        datasets = self.database.get_datasets('TCGA', 'gene')
        datasets = json.loads(datasets)

        expected_datasets_subsample = [
            {'dataset_id': 8, 'n_samples': 1211.0, 'source': 'TCGA', 'title': 'TCGA: BRCA'},
            {'dataset_id': 18, 'n_samples': 421.0, 'source': 'TCGA', 'title': 'TCGA: LIHC'},
            {'dataset_id': 31, 'n_samples': 137.0, 'source': 'TCGA', 'title': 'TCGA: TGCT'},
        ]

        for dataset in expected_datasets_subsample:
            self.assertIn(dataset, datasets)

    def test_get_transcript_datasets(self):
        datasets = self.database.get_datasets(None, 'transcript')
        datasets = json.loads(datasets)

        expected_datasets = [
            {'dataset_id': 38, 'n_samples': 3326.0, 'source': 'GTEx', 'title': 'GTEx: Brain'},
            {'dataset_id': 70, 'n_samples': None, 'source': 'HPA', 'title': 'HPA: angular gyrus'},
            {'dataset_id': 400, 'n_samples': None, 'source': 'CCLE', 'title': 'CCLE: LN340_CENTRAL_NERVOUS_SYSTEM'},
            {'dataset_id': 1752, 'n_samples': None, 'source': 'TCGA', 'title': 'TCGA: PRAD - RSEM'}
        ]

        for dataset in expected_datasets:
            self.assertIn(dataset, datasets)

    def test_get_transcript_datasets_given_source(self):
        datasets = self.database.get_datasets('TCGA', 'transcript')
        datasets = json.loads(datasets)

        expected_datasets_subsample = [
            {'dataset_id': 1738, 'n_samples': None, 'source': 'TCGA', 'title': 'TCGA: ESCA - RSEM'},
            {'dataset_id': 1751, 'n_samples': None, 'source': 'TCGA', 'title': 'TCGA: PCPG - RSEM'},
            {'dataset_id': 1779, 'n_samples': None, 'source': 'TCGA', 'title': 'TCGA: LUAD - Kallisto'},
        ]

        for dataset in expected_datasets_subsample:
            self.assertIn(dataset, datasets)

    def test_gene_table_lookup(self):
        peptide_dictionary = self.database.peptide_lookup(['ALLVRYTKKVPQVS','APQVSTPTLVEAAR'], 1353, 'gene')

        collapsed_result = peptide_dictionary['collapsed'].values.tolist()
        collapsed_result = list(
            itertools.chain.from_iterable(collapsed_result)
            )
        
        expanded_result = peptide_dictionary['expanded'].values.tolist()
        expanded_result = list(
            itertools.chain.from_iterable(expanded_result)
        )

        expected_collapsed_result = [1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ALB', 9, 7, 0.778, 1.0, 0.08, 0.08, 0.062]
        expected_expanded_result = [1353, 'ALLVRYTKKVPQVS', 0.08, 0.08, 0.08, 0.062, 0.062, 0.062, 0.08,
                                    0.08, 0.08, 'ALB', 'ENSG00000163631', '0.08', '0.080', '0.062', '9', '7', '0.778', '1.000']

        self.assertListEqual(collapsed_result, expected_collapsed_result)
        self.assertListEqual(expanded_result, expected_expanded_result)


    def test_transcript_table_lookup(self):
        peptide_dictionary = self.database.peptide_lookup(['ALLVRYTKKVPQVS','APQVSTPTLVEAAR'], 1353, 'transcript')
        collapsed_result = peptide_dictionary['collapsed'].values.tolist()
        collapsed_result = list(
            itertools.chain.from_iterable(collapsed_result)
        )

        expanded_result = peptide_dictionary['expanded'].values.tolist()
        expanded_result = list(
            itertools.chain.from_iterable(expanded_result)
        )

        expected_collapsed_result = [1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ENSP00000295897', 'ENST00000295897', 'ALB', 1, 0.0, 0.0, 1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ENSP00000384695', 'ENST00000401494', 'ALB', 1, 0.0, 0.0, 1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ENSP00000401820', 'ENST00000415165', 'ALB', 1, 0.03, 0.03, 1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631',
                                     'ENSP00000421027', 'ENST00000503124', 'ALB', 1, 0.0, 0.0, 1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ENSP00000422784', 'ENST00000509063', 'ALB', 1, 0.0, 0.0, 1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ENSP00000426179', 'ENST00000511370', 'ALB', 1, 0.0, 0.0, 1353, 'ALLVRYTKKVPQVS', 'ENSG00000163631', 'ENSP00000480485', 'ENST00000621628', 'ALB', 1, 0.0, 0.0]
        expected_expanded_result = [1353, 'ALLVRYTKKVPQVS', 0.03, 0.0, 0.03, 0.03, 0.0, 0.03, 1, 7, 'ALB', 'ENSG00000163631', 'ENST00000295897;ENST00000401494;ENST00000415165;ENST00000503124;ENST00000509063;ENST00000511370;ENST00000621628',
                                    'ENSP00000295897;ENSP00000384695;ENSP00000401820;ENSP00000421027;ENSP00000422784;ENSP00000426179;ENSP00000480485', '1;1;1;1;1;1;1', '0;0;0.03;0;0;0;0', '0.000;0.000;0.030;0.000;0.000;0.000;0.000']

        self.assertListEqual(collapsed_result, expected_collapsed_result)
        self.assertListEqual(expanded_result, expected_expanded_result)
        

if __name__ == '__main__':
    unittest.main()