# Add any validation logic here. These files can be used along with argument parser.
# For example, all or most arguments should have a validation function defined here.
# Then, these functions can be called from add_argument() by setting the 'type'.
# ex) self.parser_preprocess.add_argument(
#           "--inputs-dir",
#           dest="preprocess_inputs_dir",
#           type=validators.validate_directory)
import argparse
import pandas as pd
from pathlib import Path
import sys
import os

# Get the absolute path to the project root directory
PROJECT_ROOT = Path(os.path.abspath(os.path.dirname(__file__))).parent
sys.path.insert(0, str(PROJECT_ROOT))
from paths import APP_ROOT

# Add the allele-validator directory to Python path
sys.path.insert(0, str(APP_ROOT / 'libs' / 'allele-validator'))
from allele_validator import AlleleValidator


def validate_file(path_str):
    path = Path(path_str)
    if not path.is_file():
        raise argparse.ArgumentTypeError(f"'{path_str}' is not a valid file.")
    return path

def validate_directory(path_str):
    path = Path(path_str)
    if not path.is_dir():
        raise argparse.ArgumentTypeError(f"'{path_str}' is not a valid directory.")
    return path

def validate_directory_given_filename(path_str):
    path = Path(path_str)
    parent_dir = path.parent
    
    if not parent_dir.is_dir():
        raise argparse.ArgumentTypeError(f"'{path_str}' is not a valid directory.")
    return path

# ADD ADDITIONAL VALIDATION LOGIC HERE
# ------------------------------------
def validate_alleles(alleles: str, class_type: str, method: str):
    av = AlleleValidator()

    if class_type.value == 'i':
        tools_group = 'mhci'
    elif class_type.value == 'ii':
        tools_group = 'mhcii'
    else:
        raise ValueError(f"Invalid class type: {class_type}")
    
    vals = av.validate_alleles(alleles, method=method, tools_group=tools_group)

    # Split alleles into valid and invalid lists
    valid_alleles = [allele for allele, is_valid in zip(alleles, vals) if is_valid]
    invalid_alleles = [allele for allele, is_valid in zip(alleles, vals) if not is_valid]

    return valid_alleles, invalid_alleles

# def validate_peptide_length_range(peptide_length_range: list, class_type: str):
#     # Check if first number is smaller than or equal to second number
#     if peptide_length_range[0] > peptide_length_range[1]:
#         raise ValueError(f"Minimum peptide length ({peptide_length_range[0]}) must be less than or equal to maximum peptide length ({peptide_length_range[1]})")

#     # Class I lengths should be between 8 and 11
#     if class_type.value == 'i':
#         if peptide_length_range[0] < 8:
#             raise ValueError(f"Minimum peptide length must be at least 8 for MHCI")
    
#     # Class II lengths should be 15
#     if class_type.value == 'ii':
#         if peptide_length_range[0] != 15 or peptide_length_range[1] != 15:
#             raise ValueError(f"For MHCII, peptide length range must be exactly [15, 15]")


def validate_sequence_table(sequence_table_df: pd.DataFrame, peptide_length_range: list) -> pd.DataFrame:
    """
    Validate the sequence table and filter out sequences that are too short.
    
    Args:
        sequence_table_df: DataFrame containing the sequence table
        peptide_length_range: List containing [min_length, max_length] for peptides
        
    Returns:
        DataFrame with sequences shorter than min_length removed
    """
    # Check if any sequences are shorter than the minimum length
    if 'sequence' in sequence_table_df.columns:
        min_length = peptide_length_range[0]
        too_short_mask = sequence_table_df['sequence'].str.len() < min_length
        if too_short_mask.any():
            short_sequences = sequence_table_df[too_short_mask]['sequence'].tolist()
            print(f"Warning: The following sequences are shorter than {min_length} amino acids and will be discarded:")
            for seq in short_sequences:
                print(f"  - {seq}")
            return sequence_table_df[~too_short_mask]
    
    return sequence_table_df

def validate_mhc2phbr_output(path_str: str) -> None:
    """
    Validate that the 'rank' column in the mhc2phbr output contains valid float values.
    
    Args:
        path_str: Path to the mhc2phbr output file
        
    Raises:
        ValueError: If the 'rank' column contains invalid values ('-') or non-float values
    """
    df = pd.read_csv(path_str, sep='\t')
    
    if 'rank' not in df.columns:
        raise ValueError("Required column 'rank' not found in mhc2phbr output")
    
    # Check for any '-' values
    if (df['rank'] == '-').any():
        raise ValueError(f"Invalid value '-' found in 'rank' column\n{df}")
        
    
    # Try to convert all values to float
    try:
        df['rank'].astype(float)
    except ValueError as e:
        raise ValueError(f"Non-float values found in 'rank' column: {str(e)}")