# Add any validation logic here. These files can be used along with argument parser.
# For example, all or most arguments should have a validation function defined here.
# Then, these functions can be called from add_argument() by setting the 'type'.
# ex) self.parser_preprocess.add_argument(
#           "--inputs-dir",
#           dest="preprocess_inputs_dir",
#           type=validators.validate_directory)
# ====================================================================================
# Import all core validators (DO NOT MODIFY)
# import core.set_pythonpath as _set_pythonpath  # side-effect import configures PYTHONPATH
from core.core_validators import (
    get_dependencies_from_paths,
    create_directory_structure_for_dependencies,
    validate_file,
    validate_directory,
    validate_directory_given_filename,
    validate_preprocess_dir
)
from allele_validator import AlleleValidator
from enum import Enum, auto
from pathlib import Path
from io import StringIO
import csv
import pandas as pd


# ---------- Stage 1: MHC Class ----------
class MHCClass(Enum):
    MHCI = auto()
    MHCII = auto()
    UNKNOWN = auto()


# ---------- Stage 2: Top-level Input Category ----------
class InputCategory(Enum):
    NEOEPITOPES_STRING = auto()
    MHC_SEQUENCE_TABLE = auto()
    PEPTIDE_SEQUENCE_TABLE = auto()
    BINDING_RESULT_URI = auto()
    UNKNOWN = auto()


# ---------- Stage 3: Atomic Table Type ----------
class AtomicTable(Enum):
    SEQUENCE_TABLE = auto()        # ordinary sequence table
    SEQUENCE_WITH_MUTPOS = auto()  # sequence table with mutation position
    MUT_VS_REF_TABLE = auto()      # mutation vs reference peptide table
    UNKNOWN = auto()


# ---------- InputManager ----------
class InputManager:
    def __init__(self, input_data: dict):
        self.data = input_data

        # Stage 1: detect MHC class presence
        self.has_mhci, self.has_mhcii = self.detect_mhc_presence()

        # Stage 2: top-level category
        self.category = self.detect_category()

        # Stage 3: atomic table type
        self.atomic_table = self.detect_atomic_table()

    # ---------- Stage 1 ----------
    def detect_mhc_presence(self):
        """Detect whether input contains class I and/or class II data."""
        has_i = "class_i" in self.data
        has_ii = "class_ii" in self.data
        return has_i, has_ii

    @property
    def mhc_classes(self):
        """Return list of MHC classes present."""
        classes = []
        if self.has_mhci:
            classes.append(MHCClass.MHCI)
        if self.has_mhcii:
            classes.append(MHCClass.MHCII)
        if not classes:
            classes.append(MHCClass.UNKNOWN)
        return classes

    # ---------- Stage 2 ----------
    def detect_category(self) -> InputCategory:
        if "mhc_result_uri" in self.data:
            return InputCategory.BINDING_RESULT_URI
        if "input_neoepitopes" in self.data:
            return InputCategory.NEOEPITOPES_STRING
        if "mhc_peptide_tsv" in self.data and "mhc_sequence_tsv" in self.data:
            return InputCategory.PEPTIDE_SEQUENCE_TABLE
        if "mhc_sequence_tsv" in self.data and not "mhc_peptide_tsv" in self.data:
            return InputCategory.MHC_SEQUENCE_TABLE

        return InputCategory.UNKNOWN

    # ---------- Stage 3 ----------
    def detect_atomic_table(self) -> AtomicTable:
        if self.category in [InputCategory.NEOEPITOPES_STRING, InputCategory.MHC_SEQUENCE_TABLE]:
            content = self.data.get("mhc_sequence_tsv") or self.data.get("input_neoepitopes")
            return self._detect_sequence_table_type(content)

        if self.category == InputCategory.BINDING_RESULT_URI:
            return AtomicTable.UNKNOWN  # convertible later

        if self.category == InputCategory.PEPTIDE_SEQUENCE_TABLE:
            # return AtomicTable.MUT_VS_REF_TABLE
            content = self.data.get("mhc_sequence_tsv")
            return self._detect_sequence_table_type(content)

        return AtomicTable.UNKNOWN

    # ---------- Helper ----------
    def _detect_sequence_table_type(self, tsv_path_or_content) -> AtomicTable:
        """Analyze TSV content to classify SEQUENCE_TABLE vs MUT_VS_REF_TABLE"""
        # Case 1: It's TSV text content (contains newlines or tabs)
        if isinstance(tsv_path_or_content, str) and ('\n' in tsv_path_or_content or '\t' in tsv_path_or_content):
            lines = tsv_path_or_content

        # Case 2: It's a file path
        elif isinstance(tsv_path_or_content, (str, Path)) and Path(tsv_path_or_content).is_file():
            with open(tsv_path_or_content, "r") as f:
                lines = f.read()
        else:
            raise ValueError("Invalid TSV input: expected a file path or TSV-formatted string.")

        reader = csv.reader(StringIO(lines), delimiter="\t")
        headers = next(reader, [])
        headers = [h.strip().lower() for h in headers]

        has_mut_peptide_header = any(("mut" in h and "pep" in h) for h in headers)
        has_ref_peptide_header = any(("ref" in h and "pep" in h) for h in headers)

        if has_mut_peptide_header and has_ref_peptide_header:
            return AtomicTable.MUT_VS_REF_TABLE
        elif ("sequence" in headers) or ("peptide" in headers):
            has_mutation_position_like = any(("mut" in h) and ("pep" not in h) for h in headers)
            if has_mutation_position_like:
                return AtomicTable.SEQUENCE_WITH_MUTPOS
            else:
                return AtomicTable.SEQUENCE_TABLE
        return AtomicTable.UNKNOWN

    # ---------- Convenience ----------
    def describe(self):
        return {
            "mhc_classes": [cls.name for cls in self.mhc_classes],
            "category": self.category.name,
            "atomic_table": self.atomic_table.name
        }

def validate_alleles(alleles: str, class_type: MHCClass, method: str):
    av = AlleleValidator()

    if class_type == MHCClass.MHCI:
        tools_group = 'mhci'
    elif class_type == MHCClass.MHCII:
        tools_group = 'mhcii'
    else:
        raise ValueError(f"Invalid class type: {class_type}")
    
    vals = av.validate_alleles(alleles, method=method, tools_group=tools_group)

    # Split alleles into valid and invalid lists
    valid_alleles = [allele for allele, is_valid in zip(alleles, vals) if is_valid]
    invalid_alleles = [allele for allele, is_valid in zip(alleles, vals) if not is_valid]

    return valid_alleles, invalid_alleles


def validate_sequence_table(sequence_table_df: pd.DataFrame, peptide_length_range: list) -> pd.DataFrame:
    """
    Validate the sequence table and filter out sequences that are too short.
    
    Args:
        sequence_table_df: DataFrame containing the sequence table
        peptide_length_range: List containing [min_length, max_length] for peptides
        
    Returns:
        DataFrame with sequences shorter than min_length removed
    """
    print('sequence_table_df: ', sequence_table_df)
    print('--------------------------------')
    # Check if any sequences are shorter than the minimum length
    if 'sequence' in sequence_table_df.columns:
        min_length = peptide_length_range[0]
        too_short_mask = sequence_table_df['sequence'].str.len() < min_length
        if too_short_mask.any():
            short_sequences = sequence_table_df[too_short_mask]['sequence'].tolist()
            print(f"Warning: The following sequences are shorter than {min_length} amino acids and will be discarded:")
            for seq in short_sequences:
                print(f"  - {seq}")
            return sequence_table_df[~too_short_mask]
    
    return sequence_table_df


def validate_mutation_position(sequence_table_df: pd.DataFrame, mut_pos_col: str, mut_pos: int) -> None:
    """
    Validate that the 'mutation_position' column in the sequence table is present.
    """
    print(f"sequence_table_df: \n{sequence_table_df.head()}")
    print('************************************************')
    print(f"mut_pos_col: {mut_pos_col}")
    print(f"mut_pos: {mut_pos}")
    print('************************************************')
    
    if mut_pos:
        # Check if mutation position is within sequence length for each sequence
        for idx, row in sequence_table_df.iterrows():
            seq_len = len(row['sequence'])
            if mut_pos < 1 or mut_pos > seq_len:
                raise ValueError(f"Mutation position {mut_pos} is out of range for sequence {row['sequence']} (length {seq_len})")
    elif mut_pos_col:
        # Check if mutation position column is present
        if mut_pos_col not in sequence_table_df.columns:
            raise ValueError(f"Required column '{mut_pos_col}' not found in sequence table")
        
        # Check if mutation position is within sequence length for each sequence
        for idx, row in sequence_table_df.iterrows():
            seq_len = len(row['sequence'])
            mut_pos_str = str(row[mut_pos_col])
            
            # Skip validation if mutation position is "all"
            if mut_pos_str.lower() == 'all':
                continue
                
            # Handle comma-separated mutation positions
            mut_pos_str = str(row[mut_pos_col])
            mut_positions = [int(pos.strip()) for pos in mut_pos_str.split(',')]
            for mut_pos in mut_positions:
                if mut_pos < 1 or mut_pos > seq_len:
                    raise ValueError(f"Mutation position {mut_pos} is out of range for sequence {row['sequence']} (length {seq_len})")
    else:
        raise ValueError("No mutation position provided")