import unittest
import os
import sys
import pandas as pd
import json

# Add parent directory to path to allow imports when running from within nxg_common directory
# This is needed for CI/CD which runs tests from within nxg_common
# Get the directory containing this test file (nxg_common/tests)
test_dir = os.path.dirname(os.path.abspath(__file__))
# Get the parent of that (nxg_common)
nxg_common_dir = os.path.dirname(test_dir)
# Get the parent of that (project root)
project_root = os.path.dirname(nxg_common_dir)

# If project root is not in sys.path, add it
# This handles both cases: running from project root or from nxg_common directory
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from nxg_common import set_python_env, api_results_json_to_df
from nxg_common.column_info import get_column_info

class TestPythonEnv(unittest.TestCase):
    '''Basic unit test for nxg common'''
    def test_no_args(self):
        self.assertEqual(set_python_env(), [])
    
    def test_modules_only(self):
        self.assertEqual(set_python_env(modules="numpy,scipy"), ["module load numpy", "module load scipy"])
    
    def test_virtualenv_only(self):
        os.mkdir('test_venv')
        self.assertEqual(set_python_env(virtualenv='test_venv'), ["source test_venv/bin/activate"])
        os.rmdir('test_venv')
    
    def test_modules_and_virtualenv(self):
        os.mkdir('test_venv')
        self.assertEqual(set_python_env(modules="numpy", virtualenv='test_venv'), ["module load numpy", "source test_venv/bin/activate"])
        os.rmdir('test_venv')


class TestApiResultsJsonToDf(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Create a mock JSON file for testing."""
        cls.test_json_file = "test_data.json"
        cls.test_data = {
            "data": {
                "results": [
                    {
                        "type": "table1",
                        "table_columns": [
                            {"name": "col1"},
                            {"name": "col2"}
                        ],
                        "table_data": [
                            [1, "A"],
                            [2, "B"]
                        ]
                    },
                    {
                        "type": "table2",
                        "table_columns": [
                            {"name": "colA"},
                            {"name": "colB"}
                        ],
                        "table_data": [
                            [10, "X"],
                            [20, "Y"]
                        ]
                    }
                ]
            }
        }
        with open(cls.test_json_file, "w") as f:
            json.dump(cls.test_data, f)

    @classmethod
    def tearDownClass(cls):
        """Clean up the mock JSON file after testing."""
        os.remove(cls.test_json_file)

    def test_api_results_json_to_df(self):
        """Test the function with a valid JSON file and table types."""
        # Define the table types to extract
        table_types = ["table1"]

        # Call the function
        result = api_results_json_to_df(self.test_json_file, table_types)

        # Verify the result
        self.assertIsInstance(result, dict)  # Check if the result is a dictionary
        self.assertIn("table1", result)  # Check if the expected table type is in the result
        self.assertIsInstance(result["table1"], pd.DataFrame)  # Check if the value is a DataFrame

        # Verify the DataFrame content
        expected_df = pd.DataFrame(
            [[1, "A"], [2, "B"]],
            columns=["col1", "col2"]
        )
        pd.testing.assert_frame_equal(result["table1"], expected_df)

    def test_api_results_json_to_df_no_matching_tables(self):
        """Test the function with table types that do not exist in the JSON."""
        # Define table types that do not exist in the JSON
        table_types = ["non_existent_table"]

        # Call the function
        result = api_results_json_to_df(self.test_json_file, table_types)

        # Verify the result
        self.assertEqual(result, {})  # Check if the result is an empty dictionary


class TestColumnInfo(unittest.TestCase):
    '''Test cases for column_info display names'''
    
    def test_phbr_display_names(self):
        """Test that PHBR display names are uppercase"""
        # Test PHBR score
        phbr_info = get_column_info("phbr.phbr")
        self.assertEqual(phbr_info["display_name"], "PHBR score")
        
        # Test PHBR-I
        phbr_i_info = get_column_info("phbr.phbr-i")
        self.assertEqual(phbr_i_info["display_name"], "PHBR-I")
        
        # Test PHBR-II
        phbr_ii_info = get_column_info("phbr.phbr-ii")
        self.assertEqual(phbr_ii_info["display_name"], "PHBR-II")
        
        # Test PHBR-I (mhci variant)
        phbr_mhci_info = get_column_info("phbr.phbr-mhci")
        self.assertEqual(phbr_mhci_info["display_name"], "PHBR-I")
        
        # Test PHBR-II (mhcii variant)
        phbr_mhcii_info = get_column_info("phbr.phbr-mhcii")
        self.assertEqual(phbr_mhcii_info["display_name"], "PHBR-II")
    
    def test_rsa_display_name(self):
        """Test that RSA display name is uppercase"""
        rsa_info = get_column_info("discotope.rsa")
        self.assertEqual(rsa_info["display_name"], "RSA")
    
    def test_plddt_display_name(self):
        """Test that pLDDT display name has correct capitalization"""
        plddt_info = get_column_info("discotope.plddts")
        self.assertEqual(plddt_info["display_name"], "pLDDTs")
    
    def test_hla_alleles_display_names(self):
        """Test that HLA-A/B/C alleles display names have lowercase 'number of' but uppercase acronym"""
        hla_a_info = get_column_info("phbr.#a")
        self.assertEqual(hla_a_info["display_name"], "number of HLA-A alleles")
        
        hla_b_info = get_column_info("phbr.#b")
        self.assertEqual(hla_b_info["display_name"], "number of HLA-B alleles")
        
        hla_c_info = get_column_info("phbr.#c")
        self.assertEqual(hla_c_info["display_name"], "number of HLA-C alleles")
    
    def test_other_display_names_are_lowercase(self):
        """Test that other display names (not acronyms) are lowercase"""
        # Test a few examples that were changed to lowercase
        input_seq_info = get_column_info("pepmatch.peptide")
        self.assertEqual(input_seq_info["display_name"], "input sequence")
        
        matched_seq_info = get_column_info("pepmatch.matched_sequence")
        self.assertEqual(matched_seq_info["display_name"], "matched sequence")
        
        protein_id_info = get_column_info("pepmatch.protein_id")
        self.assertEqual(protein_id_info["display_name"], "protein id")
        
        alignment_info = get_column_info("cluster.alignment")
        self.assertEqual(alignment_info["display_name"], "alignment")
        
        peptide_info = get_column_info("cluster.peptide")
        self.assertEqual(peptide_info["display_name"], "peptide")
        
        chain_info = get_column_info("core.pdb_chain")
        self.assertEqual(chain_info["display_name"], "chain")
        
        epitope_info = get_column_info("discotope.epitope")
        self.assertEqual(epitope_info["display_name"], "epitope")
        
        # Test immunogenicity combined score (was "immunogenicity Combined Score")
        combined_score_info = get_column_info("immunogenicity.cd4episcore.combined_score")
        self.assertEqual(combined_score_info["display_name"], "immunogenicity combined score")

if __name__ == "__main__":
    unittest.main()