import sys
import os
import json
import warnings
import argparse
import textwrap
from os.path import isfile, getsize
from pathlib import Path

# retrieve project root path
curr_dir = Path(__file__)
proj_dir = curr_dir.parent.parent
sys.path += [
    proj_dir,
    f'{proj_dir}/lib/pepx-database-interface/pepx_database_interface',
]

# load database_functions first
from database_functions import Database, SQLiteDatabase


sys.path += [
    proj_dir,
    f'{proj_dir}/lib/nxg-tools/nxg_cli',
]

# loading other package
from ArgumentParserBuilder import NGArgumentParserBuilder

script_dir = os.path.dirname(os.path.realpath(__file__))

def get_usage_info():
    # read in the example_commands.txt file and return it as a string
    f = open(os.path.join(script_dir, 'example_commands.txt'), 'r')
    lines = f.readlines()
    return "".join(lines)

class ValidatePath(argparse.Action):
    '''
    This will only get called when user sets the flag.
    It will check to see if the path points to a valid SQLite DB.
    '''
    def __call__(self, parser, namespace, values, option_string=None):
        print("Path to SQLite DB has been validated.")
        
        if self.isSQLite3(values):
            # Set the result to an attribute in the namespace
            # >> setattr(namespace, self.dest, result)
            setattr(namespace, self.dest, values)
            return
        
        warnings.warn('Failed to connect to PepX SQLite Database. Please verify the path to the SQLite DB.', stacklevel=2)
        # sys.exit(0)

        # This is for sake of testing
        setattr(namespace, self.dest, values)

    
    def isSQLite3(self, filename):
        if not isfile(filename):
            return False
        
        # SQLite database file header is 100 bytes
        if getsize(filename)< 100: 
            return False

        with open(filename, 'rb') as fd:
            header = fd.read(100)
        
        return header[:16].decode('utf-8') == 'SQLite format 3\x00'


class PepXArgumentParser(NGArgumentParserBuilder):
    def __init__(self):
        super().__init__()

        # Add custom usage
        self.usage = textwrap.dedent(get_usage_info())

        # Program description
        self.description = textwrap.dedent(
        '''\
            The Peptide eXpression annotator (pepX) takes a peptide as input,
            identifies from which proteins the peptide can be derived, and
            returns an estimate of the expression level of those source proteins
            from selected public databases.
        '''    
        )
        
        self.usage=textwrap.dedent(get_usage_info())
        
        self.name = "run_pepx.py"
        self.version = '1.0'
        self._pepx_database = ''


        # Add arguments
        self.sequences()

        self.input_file(
            help='A TXT or CSV file containing list of protein sequences.'
        )

        # Split/Aggregate arguments
        self.json_filename()
        self.split()
        self.split_dir()
        self.split_input_dir()
        self.aggregate()
        self.aggregate_input_dir()
        self.aggregate_result_dir()
        self.aggregate_output_format()
        self.job_description_file()
        self.assume_valid()

        # Set PepX Database
        # dest -> db_path by default.
        self.database_path(
            single_dash_option='p',
            double_dash_option='database_path',
            help='Sets PepX Database.',
            action=ValidatePath,
        )

        # Set other arguments
        self.quantitation_level()
        self.datasource()
        self.dataset_id()
        self.list_datasets()
        self.pg_summary()
        self.output_file(
            help='Path to the output file containing the result in CSV (default) format.'
        )
        self.output_format(
            choices = ['json', 'tsv', 'csv'],
            default = ['csv'],
            help = 'Specify the output file format.'
        )

        
    ###########################################################
    # Optional Argument specifically for PepX
    ###########################################################
    def quantitation_level(self) -> None:
        self.parser.add_argument(
            '-q',
            '--quant_level',  
            dest = 'quant_level', 
            required = False,
            nargs = 1,
            type = str.lower,
            choices = ['gene', 'transcript'],
            help = 'Quantitation level.'
        )


    def datasource(self) -> None:
        available_datasources = []
        help_text = ''
        
        # Utilize parser.parse_known_args()
        if 'gene' in sys.argv:
            available_datasources = ['Abelin', 'CCLE', 'GTEX', 'HELA', 'TCGA']
            help_text = '''\
            available datasource when quantitation level is set to 'gene'
            : 'Abelin', 'CCLE', 'GTEX', 'HELA', 'TCGA'
            '''
        elif 'transcript' in sys.argv:
            available_datasources = ['CCLE', 'GTEX', 'HPA','TCGA']
            help_text = '''\
            available datasource when quantitation level is set to 'transcript'
            : 'CCLE', 'GTEX', 'HPA','TCGA'
            '''
        else:
            help_text =  '''\
            available datasource when quantitation level is set to 'gene'
            : 'Abelin', 'CCLE', 'GTEX', 'HELA', 'TCGA'\n
            available datasource when quantitation level is set to 'transcript'
            : 'CCLE', 'GTEX', 'HPA','TCGA'
            '''

        self.parser.add_argument(
            '-s',
            '--datasource', 
            dest = 'datasource', 
            required = False,
            nargs = 1,
            type = str,
            choices=available_datasources,
            help = textwrap.dedent(help_text))
        
        
    def dataset_id(self) -> None:
        self.parser.add_argument(
            '-d', 
            '--dataset_id',
            dest = 'dataset_id', 
            required = False,
            nargs = 1,
            type = str,
            help = textwrap.dedent('''
            use the following command to search for available dataset
            (NOTE: quantitation level and datasource name is required)
            : %(prog)s -q gene -s CCLE --list_datasets/-l
            '''))
        
    
    def list_datasets(self) -> None:
        self.parser.add_argument(
            '-l',
            '--list-datasets', 
            dest = 'list_datasets', 
            required = False,
            nargs = '?',
            type = str,
            default = False,
            const = True,
            help = textwrap.dedent('''
            list out available datasets for given quantiation level and datasource name.
            '''))
        
    def pg_summary(self) -> None:
        self.parser.add_argument(
            '-u', 
            '--pg-summary', 
            dest = 'pg_summary', 
            required = False,
            nargs = '?',
            type = str,
            default = False,
            const = True,
            help = textwrap.dedent('''
            It will show expanded result of pepx.
            ''')
        )
        
    
    ###########################################################
    # Helper Functions
    ###########################################################
    def get_sequences(self, namespace: argparse.Namespace) -> list[str]:
        '''
        If both inline sequences and input file is provided, the inline sequences should take
        precedence over the input file sequences.
        '''
        sequences = namespace.sequences
        file_name = None
        file_ext = None

        if hasattr(namespace, 'input_file'):
            file_name, file_ext = os.path.splitext(namespace.input_file.name)
        
        # No inline sequences, but input file is provided
        if not sequences:
            if not file_name:
                raise self.parser.error(textwrap.dedent('''
                Either provide inline sequences or sequence file using the 'input_file' flag.
                '''))

            # Parse CSV or text file
            with open(file_name + file_ext, 'r') as f:
                sequences = f.read().splitlines()

            # Exclude header for CSV/TSV file
            if file_ext == '.csv' or file_ext == '.tsv':
                sequences = sequences[1:]
        
        return sequences
    
    def get_database(self, namespace: argparse.Namespace) -> object:
        db_path = namespace.db_path
        
        if db_path:
            self._pepx_database = SQLiteDatabase(path=db_path)
        else:
            warnings.warn("Database path was not set. By default, PepX-prod database will be used.", stacklevel=2)

            try :
                user = os.getenv('PEPX_DB_USER', '')
                password = os.getenv('PEPX_DB_PWD', '')
                self._pepx_database = Database(password=password, user=user)
                print("Successfully connected to the database.")
            except :
                print('Failed to connect PepX-prod database. Please check to make sure the username and password is set correctly.')
                sys.exit(0)

        return self._pepx_database

    def get_dataset(self,
                data_source: str,
                quantification: str,
                dataset_id: str) -> dict[str: str]:
        
        dataset_id = dataset_id
        valid_datasets = self._pepx_database.get_datasets(data_source=data_source, quantification=quantification)
        valid_datasets = json.loads(valid_datasets)
        is_valid_dataset = False
        target_dataset = None

        for valid_dataset in valid_datasets :
            if dataset_id == str(valid_dataset['dataset_id']):
                is_valid_dataset = True
                target_dataset = valid_dataset
                break
            
        if not is_valid_dataset :
            raise self.parser.error(textwrap.dedent('''
                                %s is not a valid dataset id. 
                                Please check the available dataset id using the --list_datasets/-l.
                                 ''' %(dataset_id)))
        
        return target_dataset


    def get_available_datasets(self, 
                           data_source: str,
                           quantification: str) -> None:
        
        datasets = self._pepx_database.get_datasets(data_source=data_source, quantification=quantification)
        datasets = json.loads(datasets)
        for dataset in datasets:
            print('ID: %s (%s)' %(dataset['dataset_id'], dataset['title']))
    
    def check_required_arguments(self, namespace: argparse.Namespace) -> None:
        '''
        This function checks flags to make sure all required flags are set.
        '''

        # if user specifies output file, then output format must be set as well.
        if hasattr(namespace, 'output_file'):
            if not namespace.output_format:
                raise self.parser.error(textwrap.dedent('''
                Please specifiy the format of the output file: 'csv', 'tsv', 'json'.
                '''))
            