import sqlalchemy as db
from sqlalchemy.orm import sessionmaker
import json
import numpy as np
import pandas as pd
import sqlite3


class SQLiteDatabase:
	def __init__(self, path) :
		self._db_path = path
		self._db_connection = ''

		try:
			self._db_connection = sqlite3.connect(self._db_path)
			print("Successfully connected to the PepX sqlite database.")
		except sqlite3.Error as error:
			print("Error while connecting to sqlite", error)

	def get_data_sources(self, quantification=None):
		query = '''
				select distinct source 
				from expression_dataset;
				'''
		
		if quantification:
			query = '''
					select distinct source
					from expression_dataset
					where %s is True;
					''' %(quantification + '_level_data')
				
		cursor = self._db_connection.cursor()
		cursor.execute(query)
		datasources = cursor.fetchall()
		return sorted([datasource[0] for datasource in datasources])
		
		
	def get_datasets(self, data_source=None, quantification=None):

		if data_source and quantification :
			query = '''
					select dataset_id, source, title, n_samples
					from expression_dataset
					where source == '%s' and %s is True;
					''' %(data_source, quantification + '_level_data')
		elif data_source :
			query = '''
					select dataset_id, source, title, n_samples
					from expression_dataset
					where source == '%s';
					''' %(data_source)
		elif quantification :
			query = '''
					select dataset_id, source, title, n_samples
					from expression_dataset
					where %s is True;
					''' %(quantification + '_level_data')
		else :
			query = '''
					select dataset_id, source, title, n_samples
					from expression_dataset;
					'''
	
		cursor = self._db_connection.cursor()
		cursor.execute(query)
		datasets = cursor.fetchall()

		datasets_df = pd.DataFrame(datasets, columns=['dataset_id','source','title','n_samples'])
		# Replace 'NaN' values as 'None'
		datasets_df = datasets_df.replace({np.nan:None})
		ds_list = datasets_df.to_dict('records')
		result = json.dumps(ds_list, sort_keys=True)

		return result


	def get_peptide_results_df(self, table, dataset_id, peptides):
		query = '''
				pragma table_info(%s);
				'''%(table)

		cursor = self._db_connection.cursor()
		pragmas = cursor.execute(query).fetchall()
		columns = [n for _, n, *_ in pragmas]
		# columns = [column.replace('mean', 'median') for column in columns]

		if len(peptides) == 1:
			query = '''
				select *
				from %s
				where dataset_id == '%s' and peptide in ('%s');
				''' %(table, dataset_id, peptides[0])
		else:
			query = '''
					select *
					from %s
					where dataset_id == '%s' and peptide in %s;
					''' %(table, dataset_id, tuple(peptides))

		cursor.execute(query)
		datasets = cursor.fetchall()
		datasets_df = pd.DataFrame(datasets, columns=columns)

		return datasets_df

	def peptide_lookup(self, peptides, dataset_id, quantification):
		collapsed_gene_tpms_df = self.get_peptide_results_df('peptide_' + quantification + '_tpms', dataset_id, peptides)
		expanded_df = self.get_peptide_results_df('peptide_' + quantification + '_tpms_collapsed', dataset_id, peptides)
		return({'collapsed': collapsed_gene_tpms_df, 'expanded': expanded_df})



class Database:
	def __init__( self, password, user='afrentzen', port=5432, database='pepX-prod-20231018'):
		self.engine = db.create_engine('postgresql://' + user + ':' + password + '@psql.liai.org:' + str(port) + '/' + database, 
									  pool_size=10,
                                      max_overflow=2,
                                      pool_recycle=300,
                                      pool_pre_ping=True,
                                      pool_use_lifo=True)
		self.connection = self.engine.connect()
		self.metadata = db.MetaData()
		Session = sessionmaker(bind=self.engine)
		self.session = Session()

	def get_data_sources(self, quantification=None):
		'''
		Query on quantification
		'''
		table = db.Table('expression_dataset', self.metadata, autoload=True, autoload_with=self.engine)

		if not quantification:
			query = self.session.query(table.columns.source.distinct()).all()
			return(sorted(list(map(lambda t: t[0], query))))
		else:
			query = table.select().where(table.columns[quantification + '_level_data'] == True)
			dataset_df = pd.read_sql_query(query, self.connection)
			return sorted(dataset_df['source'].unique())
		
	def get_datasets(self, data_source=None, quantification=None):
		table = db.Table('expression_dataset', self.metadata, autoload=True, autoload_with=self.engine)

		if data_source and quantification:
			quant = '%s_level_data' %(quantification)
			query = table.select().where(db.and_(table.columns['source'] == data_source, table.columns[quant] == True))
		elif data_source:
			query = table.select().where(table.columns['source'] == data_source)
		elif quantification:
			quant = '%s_level_data' %(quantification)
			query = table.select().where(table.columns[quant] == True)
		else:
			query =  table.select()

		dataset_df = pd.read_sql_query(query, self.connection)

		# Replace 'NaN' values as 'None'
		dataset_df = dataset_df.replace({np.nan:None})
		ds_list = dataset_df[['dataset_id','source','title','n_samples']].to_dict('records')
		result = json.dumps(ds_list, sort_keys=True)
		return result

	def get_peptide_results_df(self, table, dataset_id, peptides):
		table = db.Table(table, self.metadata, autoload=True, autoload_with=self.engine)
		query = table.select().where(db.and_(table.columns['dataset_id']  == dataset_id, table.columns['peptide'].in_(peptides)))
		protein_counts_df = pd.read_sql_query(query, self.connection)
		# print(query)
		# print(list(protein_counts_df.columns))

		return protein_counts_df
	
	def check_peptide_existance(self, table, peptides):
		unmmapped_peptides = []
		table = db.Table(table, self.metadata, autoload=True, autoload_with=self.engine)
		
		for peptide in peptides :
			# Query to check if peptide exists in the database
			peptide_exists = bool(self.session.query(table).filter_by(peptide=peptide).first())
			if not peptide_exists :
				unmmapped_peptides.append(peptide)
		
		return unmmapped_peptides

	def peptide_lookup(self, peptides, dataset_id, quantification):
		unmapped_peptides = self.check_peptide_existance('peptide_' + quantification + '_tpms', peptides)
		collapsed_gene_tpms_df = self.get_peptide_results_df('peptide_' + quantification + '_tpms', dataset_id, peptides)
		expanded_df = self.get_peptide_results_df('peptide_' + quantification + '_tpms_collapsed', dataset_id, peptides)
		return({'collapsed': collapsed_gene_tpms_df, 'expanded': expanded_df, 'unmapped': unmapped_peptides})
